feat: Allow for possibly non-pooled embeddings (#1380)
* allow for possibly non-pooled embeddings * add more to embeddings section in README.md --------- Co-authored-by: Andrei <abetlen@gmail.com>
This commit is contained in:
parent
fcfea66857
commit
f6ed21f9a2
5 changed files with 67 additions and 20 deletions
|
@ -575,7 +575,7 @@ llama = Llama(
|
|||
|
||||
### Embeddings
|
||||
|
||||
To generate text embeddings use [`create_embedding`](https://llama-cpp-python.readthedocs.io/en/latest/api-reference/#llama_cpp.Llama.create_embedding).
|
||||
To generate text embeddings use [`create_embedding`](https://llama-cpp-python.readthedocs.io/en/latest/api-reference/#llama_cpp.Llama.create_embedding) or [`embed`](https://llama-cpp-python.readthedocs.io/en/latest/api-reference/#llama_cpp.Llama.embed). Note that you must pass `embedding=True` to the constructor upon model creation for these to work properly.
|
||||
|
||||
```python
|
||||
import llama_cpp
|
||||
|
@ -589,6 +589,12 @@ embeddings = llm.create_embedding("Hello, world!")
|
|||
embeddings = llm.create_embedding(["Hello, world!", "Goodbye, world!"])
|
||||
```
|
||||
|
||||
There are two primary notions of embeddings in a Transformer-style model: *token level* and *sequence level*. Sequence level embeddings are produced by "pooling" token level embeddings together, usually by averaging them or using the first token.
|
||||
|
||||
Models that are explicitly geared towards embeddings will usually return sequence level embeddings by default, one for each input string. Non-embedding models such as those designed for text generation will typically return only token level embeddings, one for each token in each sequence. Thus the dimensionality of the return type will be one higher for token level embeddings.
|
||||
|
||||
It is possible to control pooling behavior in some cases using the `pooling_type` flag on model creation. You can ensure token level embeddings from any model using `LLAMA_POOLING_TYPE_NONE`. The reverse, getting a generation oriented model to yield sequence level embeddings is currently not possible, but you can always do the pooling manually.
|
||||
|
||||
### Adjusting the Context Window
|
||||
|
||||
The context window of the Llama models determines the maximum number of tokens that can be processed at once. By default, this is set to 512 tokens, but can be adjusted based on your requirements.
|
||||
|
|
|
@ -273,6 +273,10 @@ class _LlamaContext:
|
|||
assert self.ctx is not None
|
||||
return llama_cpp.llama_n_ctx(self.ctx)
|
||||
|
||||
def pooling_type(self) -> int:
|
||||
assert self.ctx is not None
|
||||
return llama_cpp.llama_pooling_type(self.ctx)
|
||||
|
||||
def kv_cache_clear(self):
|
||||
assert self.ctx is not None
|
||||
llama_cpp.llama_kv_cache_clear(self.ctx)
|
||||
|
@ -641,6 +645,16 @@ def _should_add_bos(model: _LlamaModel) -> bool:
|
|||
return llama_cpp.llama_vocab_type(model.model) == llama_cpp.LLAMA_VOCAB_TYPE_SPM
|
||||
|
||||
|
||||
# Embedding functions
|
||||
|
||||
|
||||
def _normalize_embedding(embedding):
|
||||
norm = float(np.linalg.norm(embedding))
|
||||
if norm == 0.0:
|
||||
return embedding
|
||||
return [v / norm for v in embedding]
|
||||
|
||||
|
||||
# Python wrappers over common/sampling structs
|
||||
|
||||
|
||||
|
|
|
@ -50,6 +50,7 @@ from ._internals import (
|
|||
_LlamaTokenDataArray, # type: ignore
|
||||
_LlamaSamplingParams, # type: ignore
|
||||
_LlamaSamplingContext, # type: ignore
|
||||
_normalize_embedding, # type: ignore
|
||||
)
|
||||
from ._logger import set_verbose
|
||||
from ._utils import suppress_stdout_stderr
|
||||
|
@ -760,7 +761,7 @@ class Llama:
|
|||
input = input if isinstance(input, list) else [input]
|
||||
|
||||
# get numeric embeddings
|
||||
embeds: List[List[float]]
|
||||
embeds: Union[List[List[float]], List[List[List[float]]]]
|
||||
total_tokens: int
|
||||
embeds, total_tokens = self.embed(input, return_count=True) # type: ignore
|
||||
|
||||
|
@ -787,7 +788,7 @@ class Llama:
|
|||
def embed(
|
||||
self,
|
||||
input: Union[str, List[str]],
|
||||
normalize: bool = True,
|
||||
normalize: bool = False,
|
||||
truncate: bool = True,
|
||||
return_count: bool = False,
|
||||
):
|
||||
|
@ -803,6 +804,10 @@ class Llama:
|
|||
n_embd = self.n_embd()
|
||||
n_batch = self.n_batch
|
||||
|
||||
# get pooling information
|
||||
pooling_type = self.pooling_type()
|
||||
logits_all = pooling_type == llama_cpp.LLAMA_POOLING_TYPE_NONE
|
||||
|
||||
if self.context_params.embeddings == False:
|
||||
raise RuntimeError(
|
||||
"Llama model must be created with embedding=True to call this method"
|
||||
|
@ -820,29 +825,37 @@ class Llama:
|
|||
self._batch.reset()
|
||||
|
||||
# decode and fetch embeddings
|
||||
data: List[List[float]] = []
|
||||
data: Union[List[List[float]], List[List[List[float]]]] = []
|
||||
|
||||
def decode_batch(n_seq: int):
|
||||
def decode_batch(seq_sizes: List[int]):
|
||||
assert self._ctx.ctx is not None
|
||||
llama_cpp.llama_kv_cache_clear(self._ctx.ctx)
|
||||
self._ctx.decode(self._batch)
|
||||
self._batch.reset()
|
||||
|
||||
# store embeddings
|
||||
for i in range(n_seq):
|
||||
ptr = llama_cpp.llama_get_embeddings_seq(
|
||||
self._ctx.ctx, i
|
||||
)
|
||||
if not ptr:
|
||||
raise RuntimeError("Failed to get embeddings from sequence pooling type is not set")
|
||||
if pooling_type == llama_cpp.LLAMA_POOLING_TYPE_NONE:
|
||||
pos: int = 0
|
||||
for i, size in enumerate(seq_sizes):
|
||||
ptr = llama_cpp.llama_get_embeddings(self._ctx.ctx)
|
||||
embedding: List[List[float]] = [
|
||||
ptr[pos + j * n_embd : pos + (j + 1) * n_embd] for j in range(size)
|
||||
]
|
||||
if normalize:
|
||||
embedding = [_normalize_embedding(e) for e in embedding]
|
||||
data.append(embedding)
|
||||
pos += size
|
||||
else:
|
||||
for i in range(len(seq_sizes)):
|
||||
ptr = llama_cpp.llama_get_embeddings_seq(self._ctx.ctx, i)
|
||||
embedding: List[float] = ptr[:n_embd]
|
||||
if normalize:
|
||||
norm = float(np.linalg.norm(embedding))
|
||||
embedding = [v / norm for v in embedding]
|
||||
embedding = _normalize_embedding(embedding)
|
||||
data.append(embedding)
|
||||
|
||||
# init state
|
||||
total_tokens = 0
|
||||
s_batch = []
|
||||
t_batch = 0
|
||||
p_batch = 0
|
||||
|
||||
|
@ -863,17 +876,21 @@ class Llama:
|
|||
|
||||
# time to eval batch
|
||||
if t_batch + n_tokens > n_batch:
|
||||
decode_batch(p_batch)
|
||||
decode_batch(s_batch)
|
||||
s_batch = []
|
||||
t_batch = 0
|
||||
p_batch = 0
|
||||
|
||||
# add to batch
|
||||
self._batch.add_sequence(tokens, p_batch, False)
|
||||
self._batch.add_sequence(tokens, p_batch, logits_all)
|
||||
|
||||
# update batch stats
|
||||
s_batch.append(n_tokens)
|
||||
t_batch += n_tokens
|
||||
p_batch += 1
|
||||
|
||||
# hanlde last batch
|
||||
decode_batch(p_batch)
|
||||
decode_batch(s_batch)
|
||||
|
||||
if self.verbose:
|
||||
llama_cpp.llama_print_timings(self._ctx.ctx)
|
||||
|
@ -1845,6 +1862,10 @@ class Llama:
|
|||
"""Return the newline token."""
|
||||
return self._model.token_nl()
|
||||
|
||||
def pooling_type(self) -> str:
|
||||
"""Return the pooling type."""
|
||||
return self._ctx.pooling_type()
|
||||
|
||||
@staticmethod
|
||||
def logits_to_logprobs(
|
||||
logits: Union[npt.NDArray[np.single], List], axis: int = -1
|
||||
|
|
|
@ -1189,6 +1189,12 @@ def llama_rope_type(model: llama_model_p, /) -> int:
|
|||
...
|
||||
|
||||
|
||||
# LLAMA_API enum llama_pooling_type llama_pooling_type(const struct llama_model * model);
|
||||
@ctypes_function("llama_pooling_type", [llama_model_p_ctypes], ctypes.c_int)
|
||||
def llama_pooling_type(model: llama_model_p, /) -> int:
|
||||
...
|
||||
|
||||
|
||||
# LLAMA_API int32_t llama_n_vocab (const struct llama_model * model);
|
||||
@ctypes_function("llama_n_vocab", [llama_model_p_ctypes], ctypes.c_int32)
|
||||
def llama_n_vocab(model: llama_model_p, /) -> int:
|
||||
|
|
|
@ -24,7 +24,7 @@ class EmbeddingUsage(TypedDict):
|
|||
class Embedding(TypedDict):
|
||||
index: int
|
||||
object: str
|
||||
embedding: List[float]
|
||||
embedding: Union[List[float], List[List[float]]]
|
||||
|
||||
|
||||
class CreateEmbeddingResponse(TypedDict):
|
||||
|
|
Loading…
Reference in a new issue