From e783f1c191ff66b45fba4ed7bd2821703952ca62 Mon Sep 17 00:00:00 2001 From: Simon Chabot Date: Sat, 20 May 2023 01:23:32 +0200 Subject: [PATCH] feat: make embedding support list of string as input makes the /v1/embedding route similar to OpenAI api. --- llama_cpp/llama.py | 46 ++++++++++++++++++++++++++--------------- llama_cpp/server/app.py | 2 +- 2 files changed, 30 insertions(+), 18 deletions(-) diff --git a/llama_cpp/llama.py b/llama_cpp/llama.py index 564c6c3..e854674 100644 --- a/llama_cpp/llama.py +++ b/llama_cpp/llama.py @@ -531,7 +531,9 @@ class Llama: if tokens_or_none is not None: tokens.extend(tokens_or_none) - def create_embedding(self, input: str, model: Optional[str] = None) -> Embedding: + def create_embedding( + self, input: Union[str, List[str]], model: Optional[str] = None + ) -> Embedding: """Embed a string. Args: @@ -551,30 +553,40 @@ class Llama: if self.verbose: llama_cpp.llama_reset_timings(self.ctx) - tokens = self.tokenize(input.encode("utf-8")) - self.reset() - self.eval(tokens) - n_tokens = len(tokens) - embedding = llama_cpp.llama_get_embeddings(self.ctx)[ - : llama_cpp.llama_n_embd(self.ctx) - ] + if isinstance(input, str): + inputs = [input] + else: + inputs = input - if self.verbose: - llama_cpp.llama_print_timings(self.ctx) + data = [] + total_tokens = 0 + for input in inputs: + tokens = self.tokenize(input.encode("utf-8")) + self.reset() + self.eval(tokens) + n_tokens = len(tokens) + total_tokens += n_tokens + embedding = llama_cpp.llama_get_embeddings(self.ctx)[ + : llama_cpp.llama_n_embd(self.ctx) + ] - return { - "object": "list", - "data": [ + if self.verbose: + llama_cpp.llama_print_timings(self.ctx) + data.append( { "object": "embedding", "embedding": embedding, "index": 0, } - ], - "model": model_name, + ) + + return { + "object": "list", + "data": data, + "model": self.model_path, "usage": { - "prompt_tokens": n_tokens, - "total_tokens": n_tokens, + "prompt_tokens": total_tokens, + "total_tokens": total_tokens, }, } diff --git a/llama_cpp/server/app.py b/llama_cpp/server/app.py index 1ff0d1e..fea3612 100644 --- a/llama_cpp/server/app.py +++ b/llama_cpp/server/app.py @@ -275,7 +275,7 @@ def create_completion( class CreateEmbeddingRequest(BaseModel): model: Optional[str] = model_field - input: str = Field(description="The input to embed.") + input: Union[str, List[str]] = Field(description="The input to embed.") user: Optional[str] class Config: