feat: Support batch embeddings (#1186)

* handle batched embeddings

* fix normalization issue

* fix type hints, ensure no breaking changes to embed

* Clear kv cache / reset internal state after embedding complete

---------

Co-authored-by: Andrei <abetlen@gmail.com>
This commit is contained in:
Douglas Hanley 2024-02-14 03:26:09 -06:00 committed by GitHub
parent 36b843228f
commit d7a67917ba
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
2 changed files with 123 additions and 34 deletions

View file

@ -510,6 +510,14 @@ class _LlamaBatch:
self._llama_batch_free(self.batch)
self.batch = None
def n_tokens(self) -> int:
assert self.batch is not None
return self.batch.n_tokens
def reset(self):
assert self.batch is not None
self.batch.n_tokens = 0
def set_batch(self, batch: Sequence[int], n_past: int, logits_all: bool):
assert self.batch is not None
n_tokens = len(batch)
@ -522,6 +530,20 @@ class _LlamaBatch:
self.batch.logits[i] = logits_all
self.batch.logits[n_tokens - 1] = True
def add_sequence(self, batch: Sequence[int], seq_id: int, logits_all: bool):
assert self.batch is not None
n_tokens = len(batch)
n_tokens0 = self.batch.n_tokens
self.batch.n_tokens += n_tokens
for i in range(n_tokens):
j = n_tokens0 + i
self.batch.token[j] = batch[i]
self.batch.pos[j] = i
self.batch.seq_id[j][0] = seq_id
self.batch.n_seq_id[j] = 1
self.batch.logits[j] = logits_all
self.batch.logits[n_tokens - 1] = True
class _LlamaTokenDataArray:
def __init__(self, *, n_vocab: int):

View file

@ -717,10 +717,53 @@ class Llama:
Returns:
An embedding object.
"""
assert self._ctx.ctx is not None
assert self._model.model is not None
model_name: str = model if model is not None else self.model_path
# get numeric embeddings
embeds: List[List[float]]
total_tokens: int
embeds, total_tokens = self.embed(input, return_count=True) # type: ignore
# convert to CreateEmbeddingResponse
data: List[Embedding] = [
{
"object": "embedding",
"embedding": emb,
"index": idx,
}
for idx, emb in enumerate(embeds)
]
return {
"object": "list",
"data": data,
"model": model_name,
"usage": {
"prompt_tokens": total_tokens,
"total_tokens": total_tokens,
},
}
def embed(
self,
input: Union[str, List[str]],
normalize: bool = True,
truncate: bool = True,
return_count: bool = False,
):
"""Embed a string.
Args:
input: The utf-8 encoded string to embed.
Returns:
A list of embeddings
"""
assert self._ctx.ctx is not None
n_embd = self.n_embd()
n_ctx = self.n_ctx()
if self.context_params.embedding == False:
raise RuntimeError(
"Llama model must be created with embedding=True to call this method"
@ -734,48 +777,72 @@ class Llama:
else:
inputs = input
data: List[Embedding] = []
# reset batch
self._batch.reset()
# decode and fetch embeddings
data: List[List[float]] = []
def decode_batch(sizes: List[int]):
assert self._ctx.ctx is not None
llama_cpp.llama_kv_cache_clear(self._ctx.ctx)
self._ctx.decode(self._batch)
self._batch.reset()
# store embeddings
for i, s in enumerate(sizes):
embedding = llama_cpp.llama_get_embeddings_ith(self._ctx.ctx, i)[
:n_embd
]
norm = np.linalg.norm(embedding) if normalize else s
embedding: List[float] = [v / float(norm) for v in embedding]
data.append(embedding)
# init state
total_tokens = 0
for index, input in enumerate(inputs):
tokens = self.tokenize(input.encode("utf-8"), special=True)
self.reset()
self.eval(tokens)
t_batch = 0
s_sizes: List[int] = []
# accumulate batches and encode
for text in inputs:
tokens = self.tokenize(text.encode("utf-8"))
if truncate:
tokens = tokens[:n_ctx]
n_tokens = len(tokens)
total_tokens += n_tokens
embedding = llama_cpp.llama_get_embeddings(self._ctx.ctx)[
: llama_cpp.llama_n_embd(self._model.model)
]
data.append(
{
"object": "embedding",
"embedding": embedding,
"index": index,
}
)
# check for overrun
if n_tokens > n_ctx:
raise ValueError(
f"Requested tokens ({n_tokens}) exceed context window of {n_ctx}"
)
# time to eval batch
if t_batch + n_tokens > self._n_ctx:
decode_batch(s_sizes)
t_batch = 0
s_sizes = []
# add to batch
self._batch.add_sequence(tokens, len(s_sizes), False)
t_batch += n_tokens
s_sizes.append(n_tokens)
# hanlde last batch
decode_batch(s_sizes)
if self.verbose:
llama_cpp.llama_print_timings(self._ctx.ctx)
return {
"object": "list",
"data": data,
"model": model_name,
"usage": {
"prompt_tokens": total_tokens,
"total_tokens": total_tokens,
},
}
output = data[0] if isinstance(input, str) else data
def embed(self, input: str) -> List[float]:
"""Embed a string.
llama_cpp.llama_kv_cache_clear(self._ctx.ctx)
self.reset()
Args:
input: The utf-8 encoded string to embed.
Returns:
A list of embeddings
"""
return list(map(float, self.create_embedding(input)["data"][0]["embedding"]))
if return_count:
return output, total_tokens
else:
return output
def _create_completion(
self,