feat: Allow for possibly non-pooled embeddings (#1380)

* allow for possibly non-pooled embeddings

* add more to embeddings section in README.md

---------

Co-authored-by: Andrei <abetlen@gmail.com>
This commit is contained in:
Douglas Hanley 2024-04-25 20:32:44 -05:00 committed by GitHub
parent fcfea66857
commit f6ed21f9a2
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
5 changed files with 67 additions and 20 deletions

View file

@ -575,7 +575,7 @@ llama = Llama(
### Embeddings ### Embeddings
To generate text embeddings use [`create_embedding`](https://llama-cpp-python.readthedocs.io/en/latest/api-reference/#llama_cpp.Llama.create_embedding). To generate text embeddings use [`create_embedding`](https://llama-cpp-python.readthedocs.io/en/latest/api-reference/#llama_cpp.Llama.create_embedding) or [`embed`](https://llama-cpp-python.readthedocs.io/en/latest/api-reference/#llama_cpp.Llama.embed). Note that you must pass `embedding=True` to the constructor upon model creation for these to work properly.
```python ```python
import llama_cpp import llama_cpp
@ -589,6 +589,12 @@ embeddings = llm.create_embedding("Hello, world!")
embeddings = llm.create_embedding(["Hello, world!", "Goodbye, world!"]) embeddings = llm.create_embedding(["Hello, world!", "Goodbye, world!"])
``` ```
There are two primary notions of embeddings in a Transformer-style model: *token level* and *sequence level*. Sequence level embeddings are produced by "pooling" token level embeddings together, usually by averaging them or using the first token.
Models that are explicitly geared towards embeddings will usually return sequence level embeddings by default, one for each input string. Non-embedding models such as those designed for text generation will typically return only token level embeddings, one for each token in each sequence. Thus the dimensionality of the return type will be one higher for token level embeddings.
It is possible to control pooling behavior in some cases using the `pooling_type` flag on model creation. You can ensure token level embeddings from any model using `LLAMA_POOLING_TYPE_NONE`. The reverse, getting a generation oriented model to yield sequence level embeddings is currently not possible, but you can always do the pooling manually.
### Adjusting the Context Window ### Adjusting the Context Window
The context window of the Llama models determines the maximum number of tokens that can be processed at once. By default, this is set to 512 tokens, but can be adjusted based on your requirements. The context window of the Llama models determines the maximum number of tokens that can be processed at once. By default, this is set to 512 tokens, but can be adjusted based on your requirements.

View file

@ -273,6 +273,10 @@ class _LlamaContext:
assert self.ctx is not None assert self.ctx is not None
return llama_cpp.llama_n_ctx(self.ctx) return llama_cpp.llama_n_ctx(self.ctx)
def pooling_type(self) -> int:
assert self.ctx is not None
return llama_cpp.llama_pooling_type(self.ctx)
def kv_cache_clear(self): def kv_cache_clear(self):
assert self.ctx is not None assert self.ctx is not None
llama_cpp.llama_kv_cache_clear(self.ctx) llama_cpp.llama_kv_cache_clear(self.ctx)
@ -641,6 +645,16 @@ def _should_add_bos(model: _LlamaModel) -> bool:
return llama_cpp.llama_vocab_type(model.model) == llama_cpp.LLAMA_VOCAB_TYPE_SPM return llama_cpp.llama_vocab_type(model.model) == llama_cpp.LLAMA_VOCAB_TYPE_SPM
# Embedding functions
def _normalize_embedding(embedding):
norm = float(np.linalg.norm(embedding))
if norm == 0.0:
return embedding
return [v / norm for v in embedding]
# Python wrappers over common/sampling structs # Python wrappers over common/sampling structs

View file

@ -50,6 +50,7 @@ from ._internals import (
_LlamaTokenDataArray, # type: ignore _LlamaTokenDataArray, # type: ignore
_LlamaSamplingParams, # type: ignore _LlamaSamplingParams, # type: ignore
_LlamaSamplingContext, # type: ignore _LlamaSamplingContext, # type: ignore
_normalize_embedding, # type: ignore
) )
from ._logger import set_verbose from ._logger import set_verbose
from ._utils import suppress_stdout_stderr from ._utils import suppress_stdout_stderr
@ -760,7 +761,7 @@ class Llama:
input = input if isinstance(input, list) else [input] input = input if isinstance(input, list) else [input]
# get numeric embeddings # get numeric embeddings
embeds: List[List[float]] embeds: Union[List[List[float]], List[List[List[float]]]]
total_tokens: int total_tokens: int
embeds, total_tokens = self.embed(input, return_count=True) # type: ignore embeds, total_tokens = self.embed(input, return_count=True) # type: ignore
@ -787,7 +788,7 @@ class Llama:
def embed( def embed(
self, self,
input: Union[str, List[str]], input: Union[str, List[str]],
normalize: bool = True, normalize: bool = False,
truncate: bool = True, truncate: bool = True,
return_count: bool = False, return_count: bool = False,
): ):
@ -803,6 +804,10 @@ class Llama:
n_embd = self.n_embd() n_embd = self.n_embd()
n_batch = self.n_batch n_batch = self.n_batch
# get pooling information
pooling_type = self.pooling_type()
logits_all = pooling_type == llama_cpp.LLAMA_POOLING_TYPE_NONE
if self.context_params.embeddings == False: if self.context_params.embeddings == False:
raise RuntimeError( raise RuntimeError(
"Llama model must be created with embedding=True to call this method" "Llama model must be created with embedding=True to call this method"
@ -820,29 +825,37 @@ class Llama:
self._batch.reset() self._batch.reset()
# decode and fetch embeddings # decode and fetch embeddings
data: List[List[float]] = [] data: Union[List[List[float]], List[List[List[float]]]] = []
def decode_batch(n_seq: int): def decode_batch(seq_sizes: List[int]):
assert self._ctx.ctx is not None assert self._ctx.ctx is not None
llama_cpp.llama_kv_cache_clear(self._ctx.ctx) llama_cpp.llama_kv_cache_clear(self._ctx.ctx)
self._ctx.decode(self._batch) self._ctx.decode(self._batch)
self._batch.reset() self._batch.reset()
# store embeddings # store embeddings
for i in range(n_seq): if pooling_type == llama_cpp.LLAMA_POOLING_TYPE_NONE:
ptr = llama_cpp.llama_get_embeddings_seq( pos: int = 0
self._ctx.ctx, i for i, size in enumerate(seq_sizes):
) ptr = llama_cpp.llama_get_embeddings(self._ctx.ctx)
if not ptr: embedding: List[List[float]] = [
raise RuntimeError("Failed to get embeddings from sequence pooling type is not set") ptr[pos + j * n_embd : pos + (j + 1) * n_embd] for j in range(size)
embedding: List[float] = ptr[:n_embd] ]
if normalize: if normalize:
norm = float(np.linalg.norm(embedding)) embedding = [_normalize_embedding(e) for e in embedding]
embedding = [v / norm for v in embedding] data.append(embedding)
data.append(embedding) pos += size
else:
for i in range(len(seq_sizes)):
ptr = llama_cpp.llama_get_embeddings_seq(self._ctx.ctx, i)
embedding: List[float] = ptr[:n_embd]
if normalize:
embedding = _normalize_embedding(embedding)
data.append(embedding)
# init state # init state
total_tokens = 0 total_tokens = 0
s_batch = []
t_batch = 0 t_batch = 0
p_batch = 0 p_batch = 0
@ -863,17 +876,21 @@ class Llama:
# time to eval batch # time to eval batch
if t_batch + n_tokens > n_batch: if t_batch + n_tokens > n_batch:
decode_batch(p_batch) decode_batch(s_batch)
s_batch = []
t_batch = 0 t_batch = 0
p_batch = 0 p_batch = 0
# add to batch # add to batch
self._batch.add_sequence(tokens, p_batch, False) self._batch.add_sequence(tokens, p_batch, logits_all)
# update batch stats
s_batch.append(n_tokens)
t_batch += n_tokens t_batch += n_tokens
p_batch += 1 p_batch += 1
# hanlde last batch # hanlde last batch
decode_batch(p_batch) decode_batch(s_batch)
if self.verbose: if self.verbose:
llama_cpp.llama_print_timings(self._ctx.ctx) llama_cpp.llama_print_timings(self._ctx.ctx)
@ -1845,6 +1862,10 @@ class Llama:
"""Return the newline token.""" """Return the newline token."""
return self._model.token_nl() return self._model.token_nl()
def pooling_type(self) -> str:
"""Return the pooling type."""
return self._ctx.pooling_type()
@staticmethod @staticmethod
def logits_to_logprobs( def logits_to_logprobs(
logits: Union[npt.NDArray[np.single], List], axis: int = -1 logits: Union[npt.NDArray[np.single], List], axis: int = -1

View file

@ -1189,6 +1189,12 @@ def llama_rope_type(model: llama_model_p, /) -> int:
... ...
# LLAMA_API enum llama_pooling_type llama_pooling_type(const struct llama_model * model);
@ctypes_function("llama_pooling_type", [llama_model_p_ctypes], ctypes.c_int)
def llama_pooling_type(model: llama_model_p, /) -> int:
...
# LLAMA_API int32_t llama_n_vocab (const struct llama_model * model); # LLAMA_API int32_t llama_n_vocab (const struct llama_model * model);
@ctypes_function("llama_n_vocab", [llama_model_p_ctypes], ctypes.c_int32) @ctypes_function("llama_n_vocab", [llama_model_p_ctypes], ctypes.c_int32)
def llama_n_vocab(model: llama_model_p, /) -> int: def llama_n_vocab(model: llama_model_p, /) -> int:

View file

@ -24,7 +24,7 @@ class EmbeddingUsage(TypedDict):
class Embedding(TypedDict): class Embedding(TypedDict):
index: int index: int
object: str object: str
embedding: List[float] embedding: Union[List[float], List[List[float]]]
class CreateEmbeddingResponse(TypedDict): class CreateEmbeddingResponse(TypedDict):