From d7a67917ba5b601e146377c6d877893dc49bba83 Mon Sep 17 00:00:00 2001 From: Douglas Hanley Date: Wed, 14 Feb 2024 03:26:09 -0600 Subject: [PATCH] feat: Support batch embeddings (#1186) * handle batched embeddings * fix normalization issue * fix type hints, ensure no breaking changes to embed * Clear kv cache / reset internal state after embedding complete --------- Co-authored-by: Andrei --- llama_cpp/_internals.py | 22 +++++++ llama_cpp/llama.py | 135 ++++++++++++++++++++++++++++++---------- 2 files changed, 123 insertions(+), 34 deletions(-) diff --git a/llama_cpp/_internals.py b/llama_cpp/_internals.py index 9473d35..c60fdff 100644 --- a/llama_cpp/_internals.py +++ b/llama_cpp/_internals.py @@ -510,6 +510,14 @@ class _LlamaBatch: self._llama_batch_free(self.batch) self.batch = None + def n_tokens(self) -> int: + assert self.batch is not None + return self.batch.n_tokens + + def reset(self): + assert self.batch is not None + self.batch.n_tokens = 0 + def set_batch(self, batch: Sequence[int], n_past: int, logits_all: bool): assert self.batch is not None n_tokens = len(batch) @@ -522,6 +530,20 @@ class _LlamaBatch: self.batch.logits[i] = logits_all self.batch.logits[n_tokens - 1] = True + def add_sequence(self, batch: Sequence[int], seq_id: int, logits_all: bool): + assert self.batch is not None + n_tokens = len(batch) + n_tokens0 = self.batch.n_tokens + self.batch.n_tokens += n_tokens + for i in range(n_tokens): + j = n_tokens0 + i + self.batch.token[j] = batch[i] + self.batch.pos[j] = i + self.batch.seq_id[j][0] = seq_id + self.batch.n_seq_id[j] = 1 + self.batch.logits[j] = logits_all + self.batch.logits[n_tokens - 1] = True + class _LlamaTokenDataArray: def __init__(self, *, n_vocab: int): diff --git a/llama_cpp/llama.py b/llama_cpp/llama.py index 8d726d3..3e09a20 100644 --- a/llama_cpp/llama.py +++ b/llama_cpp/llama.py @@ -717,10 +717,53 @@ class Llama: Returns: An embedding object. """ - assert self._ctx.ctx is not None assert self._model.model is not None model_name: str = model if model is not None else self.model_path + # get numeric embeddings + embeds: List[List[float]] + total_tokens: int + embeds, total_tokens = self.embed(input, return_count=True) # type: ignore + + # convert to CreateEmbeddingResponse + data: List[Embedding] = [ + { + "object": "embedding", + "embedding": emb, + "index": idx, + } + for idx, emb in enumerate(embeds) + ] + + return { + "object": "list", + "data": data, + "model": model_name, + "usage": { + "prompt_tokens": total_tokens, + "total_tokens": total_tokens, + }, + } + + def embed( + self, + input: Union[str, List[str]], + normalize: bool = True, + truncate: bool = True, + return_count: bool = False, + ): + """Embed a string. + + Args: + input: The utf-8 encoded string to embed. + + Returns: + A list of embeddings + """ + assert self._ctx.ctx is not None + n_embd = self.n_embd() + n_ctx = self.n_ctx() + if self.context_params.embedding == False: raise RuntimeError( "Llama model must be created with embedding=True to call this method" @@ -734,48 +777,72 @@ class Llama: else: inputs = input - data: List[Embedding] = [] + # reset batch + self._batch.reset() + + # decode and fetch embeddings + data: List[List[float]] = [] + def decode_batch(sizes: List[int]): + assert self._ctx.ctx is not None + llama_cpp.llama_kv_cache_clear(self._ctx.ctx) + self._ctx.decode(self._batch) + self._batch.reset() + + # store embeddings + for i, s in enumerate(sizes): + embedding = llama_cpp.llama_get_embeddings_ith(self._ctx.ctx, i)[ + :n_embd + ] + norm = np.linalg.norm(embedding) if normalize else s + embedding: List[float] = [v / float(norm) for v in embedding] + data.append(embedding) + + # init state total_tokens = 0 - for index, input in enumerate(inputs): - tokens = self.tokenize(input.encode("utf-8"), special=True) - self.reset() - self.eval(tokens) + t_batch = 0 + s_sizes: List[int] = [] + + # accumulate batches and encode + for text in inputs: + tokens = self.tokenize(text.encode("utf-8")) + if truncate: + tokens = tokens[:n_ctx] + n_tokens = len(tokens) total_tokens += n_tokens - embedding = llama_cpp.llama_get_embeddings(self._ctx.ctx)[ - : llama_cpp.llama_n_embd(self._model.model) - ] - data.append( - { - "object": "embedding", - "embedding": embedding, - "index": index, - } - ) + # check for overrun + if n_tokens > n_ctx: + raise ValueError( + f"Requested tokens ({n_tokens}) exceed context window of {n_ctx}" + ) + + # time to eval batch + if t_batch + n_tokens > self._n_ctx: + decode_batch(s_sizes) + t_batch = 0 + s_sizes = [] + + # add to batch + self._batch.add_sequence(tokens, len(s_sizes), False) + t_batch += n_tokens + s_sizes.append(n_tokens) + + # hanlde last batch + decode_batch(s_sizes) + if self.verbose: llama_cpp.llama_print_timings(self._ctx.ctx) - return { - "object": "list", - "data": data, - "model": model_name, - "usage": { - "prompt_tokens": total_tokens, - "total_tokens": total_tokens, - }, - } + output = data[0] if isinstance(input, str) else data - def embed(self, input: str) -> List[float]: - """Embed a string. + llama_cpp.llama_kv_cache_clear(self._ctx.ctx) + self.reset() - Args: - input: The utf-8 encoded string to embed. - - Returns: - A list of embeddings - """ - return list(map(float, self.create_embedding(input)["data"][0]["embedding"])) + if return_count: + return output, total_tokens + else: + return output def _create_completion( self,