feat: Support batch embeddings (#1186)

* handle batched embeddings

* fix normalization issue

* fix type hints, ensure no breaking changes to embed

* Clear kv cache / reset internal state after embedding complete

---------

Co-authored-by: Andrei <abetlen@gmail.com>
This commit is contained in:
Douglas Hanley 2024-02-14 03:26:09 -06:00 committed by GitHub
parent 36b843228f
commit d7a67917ba
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
2 changed files with 123 additions and 34 deletions

View file

@ -510,6 +510,14 @@ class _LlamaBatch:
self._llama_batch_free(self.batch) self._llama_batch_free(self.batch)
self.batch = None self.batch = None
def n_tokens(self) -> int:
assert self.batch is not None
return self.batch.n_tokens
def reset(self):
assert self.batch is not None
self.batch.n_tokens = 0
def set_batch(self, batch: Sequence[int], n_past: int, logits_all: bool): def set_batch(self, batch: Sequence[int], n_past: int, logits_all: bool):
assert self.batch is not None assert self.batch is not None
n_tokens = len(batch) n_tokens = len(batch)
@ -522,6 +530,20 @@ class _LlamaBatch:
self.batch.logits[i] = logits_all self.batch.logits[i] = logits_all
self.batch.logits[n_tokens - 1] = True self.batch.logits[n_tokens - 1] = True
def add_sequence(self, batch: Sequence[int], seq_id: int, logits_all: bool):
assert self.batch is not None
n_tokens = len(batch)
n_tokens0 = self.batch.n_tokens
self.batch.n_tokens += n_tokens
for i in range(n_tokens):
j = n_tokens0 + i
self.batch.token[j] = batch[i]
self.batch.pos[j] = i
self.batch.seq_id[j][0] = seq_id
self.batch.n_seq_id[j] = 1
self.batch.logits[j] = logits_all
self.batch.logits[n_tokens - 1] = True
class _LlamaTokenDataArray: class _LlamaTokenDataArray:
def __init__(self, *, n_vocab: int): def __init__(self, *, n_vocab: int):

View file

@ -717,10 +717,53 @@ class Llama:
Returns: Returns:
An embedding object. An embedding object.
""" """
assert self._ctx.ctx is not None
assert self._model.model is not None assert self._model.model is not None
model_name: str = model if model is not None else self.model_path model_name: str = model if model is not None else self.model_path
# get numeric embeddings
embeds: List[List[float]]
total_tokens: int
embeds, total_tokens = self.embed(input, return_count=True) # type: ignore
# convert to CreateEmbeddingResponse
data: List[Embedding] = [
{
"object": "embedding",
"embedding": emb,
"index": idx,
}
for idx, emb in enumerate(embeds)
]
return {
"object": "list",
"data": data,
"model": model_name,
"usage": {
"prompt_tokens": total_tokens,
"total_tokens": total_tokens,
},
}
def embed(
self,
input: Union[str, List[str]],
normalize: bool = True,
truncate: bool = True,
return_count: bool = False,
):
"""Embed a string.
Args:
input: The utf-8 encoded string to embed.
Returns:
A list of embeddings
"""
assert self._ctx.ctx is not None
n_embd = self.n_embd()
n_ctx = self.n_ctx()
if self.context_params.embedding == False: if self.context_params.embedding == False:
raise RuntimeError( raise RuntimeError(
"Llama model must be created with embedding=True to call this method" "Llama model must be created with embedding=True to call this method"
@ -734,48 +777,72 @@ class Llama:
else: else:
inputs = input inputs = input
data: List[Embedding] = [] # reset batch
self._batch.reset()
# decode and fetch embeddings
data: List[List[float]] = []
def decode_batch(sizes: List[int]):
assert self._ctx.ctx is not None
llama_cpp.llama_kv_cache_clear(self._ctx.ctx)
self._ctx.decode(self._batch)
self._batch.reset()
# store embeddings
for i, s in enumerate(sizes):
embedding = llama_cpp.llama_get_embeddings_ith(self._ctx.ctx, i)[
:n_embd
]
norm = np.linalg.norm(embedding) if normalize else s
embedding: List[float] = [v / float(norm) for v in embedding]
data.append(embedding)
# init state
total_tokens = 0 total_tokens = 0
for index, input in enumerate(inputs): t_batch = 0
tokens = self.tokenize(input.encode("utf-8"), special=True) s_sizes: List[int] = []
self.reset()
self.eval(tokens) # accumulate batches and encode
for text in inputs:
tokens = self.tokenize(text.encode("utf-8"))
if truncate:
tokens = tokens[:n_ctx]
n_tokens = len(tokens) n_tokens = len(tokens)
total_tokens += n_tokens total_tokens += n_tokens
embedding = llama_cpp.llama_get_embeddings(self._ctx.ctx)[
: llama_cpp.llama_n_embd(self._model.model)
]
data.append( # check for overrun
{ if n_tokens > n_ctx:
"object": "embedding", raise ValueError(
"embedding": embedding, f"Requested tokens ({n_tokens}) exceed context window of {n_ctx}"
"index": index,
}
) )
# time to eval batch
if t_batch + n_tokens > self._n_ctx:
decode_batch(s_sizes)
t_batch = 0
s_sizes = []
# add to batch
self._batch.add_sequence(tokens, len(s_sizes), False)
t_batch += n_tokens
s_sizes.append(n_tokens)
# hanlde last batch
decode_batch(s_sizes)
if self.verbose: if self.verbose:
llama_cpp.llama_print_timings(self._ctx.ctx) llama_cpp.llama_print_timings(self._ctx.ctx)
return { output = data[0] if isinstance(input, str) else data
"object": "list",
"data": data,
"model": model_name,
"usage": {
"prompt_tokens": total_tokens,
"total_tokens": total_tokens,
},
}
def embed(self, input: str) -> List[float]: llama_cpp.llama_kv_cache_clear(self._ctx.ctx)
"""Embed a string. self.reset()
Args: if return_count:
input: The utf-8 encoded string to embed. return output, total_tokens
else:
Returns: return output
A list of embeddings
"""
return list(map(float, self.create_embedding(input)["data"][0]["embedding"]))
def _create_completion( def _create_completion(
self, self,