feat: Support batch embeddings (#1186)
* handle batched embeddings * fix normalization issue * fix type hints, ensure no breaking changes to embed * Clear kv cache / reset internal state after embedding complete --------- Co-authored-by: Andrei <abetlen@gmail.com>
This commit is contained in:
parent
36b843228f
commit
d7a67917ba
2 changed files with 123 additions and 34 deletions
|
@ -510,6 +510,14 @@ class _LlamaBatch:
|
||||||
self._llama_batch_free(self.batch)
|
self._llama_batch_free(self.batch)
|
||||||
self.batch = None
|
self.batch = None
|
||||||
|
|
||||||
|
def n_tokens(self) -> int:
|
||||||
|
assert self.batch is not None
|
||||||
|
return self.batch.n_tokens
|
||||||
|
|
||||||
|
def reset(self):
|
||||||
|
assert self.batch is not None
|
||||||
|
self.batch.n_tokens = 0
|
||||||
|
|
||||||
def set_batch(self, batch: Sequence[int], n_past: int, logits_all: bool):
|
def set_batch(self, batch: Sequence[int], n_past: int, logits_all: bool):
|
||||||
assert self.batch is not None
|
assert self.batch is not None
|
||||||
n_tokens = len(batch)
|
n_tokens = len(batch)
|
||||||
|
@ -522,6 +530,20 @@ class _LlamaBatch:
|
||||||
self.batch.logits[i] = logits_all
|
self.batch.logits[i] = logits_all
|
||||||
self.batch.logits[n_tokens - 1] = True
|
self.batch.logits[n_tokens - 1] = True
|
||||||
|
|
||||||
|
def add_sequence(self, batch: Sequence[int], seq_id: int, logits_all: bool):
|
||||||
|
assert self.batch is not None
|
||||||
|
n_tokens = len(batch)
|
||||||
|
n_tokens0 = self.batch.n_tokens
|
||||||
|
self.batch.n_tokens += n_tokens
|
||||||
|
for i in range(n_tokens):
|
||||||
|
j = n_tokens0 + i
|
||||||
|
self.batch.token[j] = batch[i]
|
||||||
|
self.batch.pos[j] = i
|
||||||
|
self.batch.seq_id[j][0] = seq_id
|
||||||
|
self.batch.n_seq_id[j] = 1
|
||||||
|
self.batch.logits[j] = logits_all
|
||||||
|
self.batch.logits[n_tokens - 1] = True
|
||||||
|
|
||||||
|
|
||||||
class _LlamaTokenDataArray:
|
class _LlamaTokenDataArray:
|
||||||
def __init__(self, *, n_vocab: int):
|
def __init__(self, *, n_vocab: int):
|
||||||
|
|
|
@ -717,10 +717,53 @@ class Llama:
|
||||||
Returns:
|
Returns:
|
||||||
An embedding object.
|
An embedding object.
|
||||||
"""
|
"""
|
||||||
assert self._ctx.ctx is not None
|
|
||||||
assert self._model.model is not None
|
assert self._model.model is not None
|
||||||
model_name: str = model if model is not None else self.model_path
|
model_name: str = model if model is not None else self.model_path
|
||||||
|
|
||||||
|
# get numeric embeddings
|
||||||
|
embeds: List[List[float]]
|
||||||
|
total_tokens: int
|
||||||
|
embeds, total_tokens = self.embed(input, return_count=True) # type: ignore
|
||||||
|
|
||||||
|
# convert to CreateEmbeddingResponse
|
||||||
|
data: List[Embedding] = [
|
||||||
|
{
|
||||||
|
"object": "embedding",
|
||||||
|
"embedding": emb,
|
||||||
|
"index": idx,
|
||||||
|
}
|
||||||
|
for idx, emb in enumerate(embeds)
|
||||||
|
]
|
||||||
|
|
||||||
|
return {
|
||||||
|
"object": "list",
|
||||||
|
"data": data,
|
||||||
|
"model": model_name,
|
||||||
|
"usage": {
|
||||||
|
"prompt_tokens": total_tokens,
|
||||||
|
"total_tokens": total_tokens,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
def embed(
|
||||||
|
self,
|
||||||
|
input: Union[str, List[str]],
|
||||||
|
normalize: bool = True,
|
||||||
|
truncate: bool = True,
|
||||||
|
return_count: bool = False,
|
||||||
|
):
|
||||||
|
"""Embed a string.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
input: The utf-8 encoded string to embed.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
A list of embeddings
|
||||||
|
"""
|
||||||
|
assert self._ctx.ctx is not None
|
||||||
|
n_embd = self.n_embd()
|
||||||
|
n_ctx = self.n_ctx()
|
||||||
|
|
||||||
if self.context_params.embedding == False:
|
if self.context_params.embedding == False:
|
||||||
raise RuntimeError(
|
raise RuntimeError(
|
||||||
"Llama model must be created with embedding=True to call this method"
|
"Llama model must be created with embedding=True to call this method"
|
||||||
|
@ -734,48 +777,72 @@ class Llama:
|
||||||
else:
|
else:
|
||||||
inputs = input
|
inputs = input
|
||||||
|
|
||||||
data: List[Embedding] = []
|
# reset batch
|
||||||
|
self._batch.reset()
|
||||||
|
|
||||||
|
# decode and fetch embeddings
|
||||||
|
data: List[List[float]] = []
|
||||||
|
def decode_batch(sizes: List[int]):
|
||||||
|
assert self._ctx.ctx is not None
|
||||||
|
llama_cpp.llama_kv_cache_clear(self._ctx.ctx)
|
||||||
|
self._ctx.decode(self._batch)
|
||||||
|
self._batch.reset()
|
||||||
|
|
||||||
|
# store embeddings
|
||||||
|
for i, s in enumerate(sizes):
|
||||||
|
embedding = llama_cpp.llama_get_embeddings_ith(self._ctx.ctx, i)[
|
||||||
|
:n_embd
|
||||||
|
]
|
||||||
|
norm = np.linalg.norm(embedding) if normalize else s
|
||||||
|
embedding: List[float] = [v / float(norm) for v in embedding]
|
||||||
|
data.append(embedding)
|
||||||
|
|
||||||
|
# init state
|
||||||
total_tokens = 0
|
total_tokens = 0
|
||||||
for index, input in enumerate(inputs):
|
t_batch = 0
|
||||||
tokens = self.tokenize(input.encode("utf-8"), special=True)
|
s_sizes: List[int] = []
|
||||||
self.reset()
|
|
||||||
self.eval(tokens)
|
# accumulate batches and encode
|
||||||
|
for text in inputs:
|
||||||
|
tokens = self.tokenize(text.encode("utf-8"))
|
||||||
|
if truncate:
|
||||||
|
tokens = tokens[:n_ctx]
|
||||||
|
|
||||||
n_tokens = len(tokens)
|
n_tokens = len(tokens)
|
||||||
total_tokens += n_tokens
|
total_tokens += n_tokens
|
||||||
embedding = llama_cpp.llama_get_embeddings(self._ctx.ctx)[
|
|
||||||
: llama_cpp.llama_n_embd(self._model.model)
|
|
||||||
]
|
|
||||||
|
|
||||||
data.append(
|
# check for overrun
|
||||||
{
|
if n_tokens > n_ctx:
|
||||||
"object": "embedding",
|
raise ValueError(
|
||||||
"embedding": embedding,
|
f"Requested tokens ({n_tokens}) exceed context window of {n_ctx}"
|
||||||
"index": index,
|
)
|
||||||
}
|
|
||||||
)
|
# time to eval batch
|
||||||
|
if t_batch + n_tokens > self._n_ctx:
|
||||||
|
decode_batch(s_sizes)
|
||||||
|
t_batch = 0
|
||||||
|
s_sizes = []
|
||||||
|
|
||||||
|
# add to batch
|
||||||
|
self._batch.add_sequence(tokens, len(s_sizes), False)
|
||||||
|
t_batch += n_tokens
|
||||||
|
s_sizes.append(n_tokens)
|
||||||
|
|
||||||
|
# hanlde last batch
|
||||||
|
decode_batch(s_sizes)
|
||||||
|
|
||||||
if self.verbose:
|
if self.verbose:
|
||||||
llama_cpp.llama_print_timings(self._ctx.ctx)
|
llama_cpp.llama_print_timings(self._ctx.ctx)
|
||||||
|
|
||||||
return {
|
output = data[0] if isinstance(input, str) else data
|
||||||
"object": "list",
|
|
||||||
"data": data,
|
|
||||||
"model": model_name,
|
|
||||||
"usage": {
|
|
||||||
"prompt_tokens": total_tokens,
|
|
||||||
"total_tokens": total_tokens,
|
|
||||||
},
|
|
||||||
}
|
|
||||||
|
|
||||||
def embed(self, input: str) -> List[float]:
|
llama_cpp.llama_kv_cache_clear(self._ctx.ctx)
|
||||||
"""Embed a string.
|
self.reset()
|
||||||
|
|
||||||
Args:
|
if return_count:
|
||||||
input: The utf-8 encoded string to embed.
|
return output, total_tokens
|
||||||
|
else:
|
||||||
Returns:
|
return output
|
||||||
A list of embeddings
|
|
||||||
"""
|
|
||||||
return list(map(float, self.create_embedding(input)["data"][0]["embedding"]))
|
|
||||||
|
|
||||||
def _create_completion(
|
def _create_completion(
|
||||||
self,
|
self,
|
||||||
|
|
Loading…
Reference in a new issue