feat: Support batch embeddings (#1186)
* handle batched embeddings * fix normalization issue * fix type hints, ensure no breaking changes to embed * Clear kv cache / reset internal state after embedding complete --------- Co-authored-by: Andrei <abetlen@gmail.com>
This commit is contained in:
parent
36b843228f
commit
d7a67917ba
2 changed files with 123 additions and 34 deletions
|
@ -510,6 +510,14 @@ class _LlamaBatch:
|
|||
self._llama_batch_free(self.batch)
|
||||
self.batch = None
|
||||
|
||||
def n_tokens(self) -> int:
|
||||
assert self.batch is not None
|
||||
return self.batch.n_tokens
|
||||
|
||||
def reset(self):
|
||||
assert self.batch is not None
|
||||
self.batch.n_tokens = 0
|
||||
|
||||
def set_batch(self, batch: Sequence[int], n_past: int, logits_all: bool):
|
||||
assert self.batch is not None
|
||||
n_tokens = len(batch)
|
||||
|
@ -522,6 +530,20 @@ class _LlamaBatch:
|
|||
self.batch.logits[i] = logits_all
|
||||
self.batch.logits[n_tokens - 1] = True
|
||||
|
||||
def add_sequence(self, batch: Sequence[int], seq_id: int, logits_all: bool):
|
||||
assert self.batch is not None
|
||||
n_tokens = len(batch)
|
||||
n_tokens0 = self.batch.n_tokens
|
||||
self.batch.n_tokens += n_tokens
|
||||
for i in range(n_tokens):
|
||||
j = n_tokens0 + i
|
||||
self.batch.token[j] = batch[i]
|
||||
self.batch.pos[j] = i
|
||||
self.batch.seq_id[j][0] = seq_id
|
||||
self.batch.n_seq_id[j] = 1
|
||||
self.batch.logits[j] = logits_all
|
||||
self.batch.logits[n_tokens - 1] = True
|
||||
|
||||
|
||||
class _LlamaTokenDataArray:
|
||||
def __init__(self, *, n_vocab: int):
|
||||
|
|
|
@ -717,10 +717,53 @@ class Llama:
|
|||
Returns:
|
||||
An embedding object.
|
||||
"""
|
||||
assert self._ctx.ctx is not None
|
||||
assert self._model.model is not None
|
||||
model_name: str = model if model is not None else self.model_path
|
||||
|
||||
# get numeric embeddings
|
||||
embeds: List[List[float]]
|
||||
total_tokens: int
|
||||
embeds, total_tokens = self.embed(input, return_count=True) # type: ignore
|
||||
|
||||
# convert to CreateEmbeddingResponse
|
||||
data: List[Embedding] = [
|
||||
{
|
||||
"object": "embedding",
|
||||
"embedding": emb,
|
||||
"index": idx,
|
||||
}
|
||||
for idx, emb in enumerate(embeds)
|
||||
]
|
||||
|
||||
return {
|
||||
"object": "list",
|
||||
"data": data,
|
||||
"model": model_name,
|
||||
"usage": {
|
||||
"prompt_tokens": total_tokens,
|
||||
"total_tokens": total_tokens,
|
||||
},
|
||||
}
|
||||
|
||||
def embed(
|
||||
self,
|
||||
input: Union[str, List[str]],
|
||||
normalize: bool = True,
|
||||
truncate: bool = True,
|
||||
return_count: bool = False,
|
||||
):
|
||||
"""Embed a string.
|
||||
|
||||
Args:
|
||||
input: The utf-8 encoded string to embed.
|
||||
|
||||
Returns:
|
||||
A list of embeddings
|
||||
"""
|
||||
assert self._ctx.ctx is not None
|
||||
n_embd = self.n_embd()
|
||||
n_ctx = self.n_ctx()
|
||||
|
||||
if self.context_params.embedding == False:
|
||||
raise RuntimeError(
|
||||
"Llama model must be created with embedding=True to call this method"
|
||||
|
@ -734,48 +777,72 @@ class Llama:
|
|||
else:
|
||||
inputs = input
|
||||
|
||||
data: List[Embedding] = []
|
||||
# reset batch
|
||||
self._batch.reset()
|
||||
|
||||
# decode and fetch embeddings
|
||||
data: List[List[float]] = []
|
||||
def decode_batch(sizes: List[int]):
|
||||
assert self._ctx.ctx is not None
|
||||
llama_cpp.llama_kv_cache_clear(self._ctx.ctx)
|
||||
self._ctx.decode(self._batch)
|
||||
self._batch.reset()
|
||||
|
||||
# store embeddings
|
||||
for i, s in enumerate(sizes):
|
||||
embedding = llama_cpp.llama_get_embeddings_ith(self._ctx.ctx, i)[
|
||||
:n_embd
|
||||
]
|
||||
norm = np.linalg.norm(embedding) if normalize else s
|
||||
embedding: List[float] = [v / float(norm) for v in embedding]
|
||||
data.append(embedding)
|
||||
|
||||
# init state
|
||||
total_tokens = 0
|
||||
for index, input in enumerate(inputs):
|
||||
tokens = self.tokenize(input.encode("utf-8"), special=True)
|
||||
self.reset()
|
||||
self.eval(tokens)
|
||||
t_batch = 0
|
||||
s_sizes: List[int] = []
|
||||
|
||||
# accumulate batches and encode
|
||||
for text in inputs:
|
||||
tokens = self.tokenize(text.encode("utf-8"))
|
||||
if truncate:
|
||||
tokens = tokens[:n_ctx]
|
||||
|
||||
n_tokens = len(tokens)
|
||||
total_tokens += n_tokens
|
||||
embedding = llama_cpp.llama_get_embeddings(self._ctx.ctx)[
|
||||
: llama_cpp.llama_n_embd(self._model.model)
|
||||
]
|
||||
|
||||
data.append(
|
||||
{
|
||||
"object": "embedding",
|
||||
"embedding": embedding,
|
||||
"index": index,
|
||||
}
|
||||
)
|
||||
# check for overrun
|
||||
if n_tokens > n_ctx:
|
||||
raise ValueError(
|
||||
f"Requested tokens ({n_tokens}) exceed context window of {n_ctx}"
|
||||
)
|
||||
|
||||
# time to eval batch
|
||||
if t_batch + n_tokens > self._n_ctx:
|
||||
decode_batch(s_sizes)
|
||||
t_batch = 0
|
||||
s_sizes = []
|
||||
|
||||
# add to batch
|
||||
self._batch.add_sequence(tokens, len(s_sizes), False)
|
||||
t_batch += n_tokens
|
||||
s_sizes.append(n_tokens)
|
||||
|
||||
# hanlde last batch
|
||||
decode_batch(s_sizes)
|
||||
|
||||
if self.verbose:
|
||||
llama_cpp.llama_print_timings(self._ctx.ctx)
|
||||
|
||||
return {
|
||||
"object": "list",
|
||||
"data": data,
|
||||
"model": model_name,
|
||||
"usage": {
|
||||
"prompt_tokens": total_tokens,
|
||||
"total_tokens": total_tokens,
|
||||
},
|
||||
}
|
||||
output = data[0] if isinstance(input, str) else data
|
||||
|
||||
def embed(self, input: str) -> List[float]:
|
||||
"""Embed a string.
|
||||
llama_cpp.llama_kv_cache_clear(self._ctx.ctx)
|
||||
self.reset()
|
||||
|
||||
Args:
|
||||
input: The utf-8 encoded string to embed.
|
||||
|
||||
Returns:
|
||||
A list of embeddings
|
||||
"""
|
||||
return list(map(float, self.create_embedding(input)["data"][0]["embedding"]))
|
||||
if return_count:
|
||||
return output, total_tokens
|
||||
else:
|
||||
return output
|
||||
|
||||
def _create_completion(
|
||||
self,
|
||||
|
|
Loading…
Reference in a new issue