feat: make embedding support list of string as input

makes the /v1/embedding route similar to OpenAI api.
This commit is contained in:
Simon Chabot 2023-05-20 01:23:32 +02:00
parent 01a010be52
commit e783f1c191
2 changed files with 30 additions and 18 deletions

View file

@ -531,7 +531,9 @@ class Llama:
if tokens_or_none is not None: if tokens_or_none is not None:
tokens.extend(tokens_or_none) tokens.extend(tokens_or_none)
def create_embedding(self, input: str, model: Optional[str] = None) -> Embedding: def create_embedding(
self, input: Union[str, List[str]], model: Optional[str] = None
) -> Embedding:
"""Embed a string. """Embed a string.
Args: Args:
@ -551,30 +553,40 @@ class Llama:
if self.verbose: if self.verbose:
llama_cpp.llama_reset_timings(self.ctx) llama_cpp.llama_reset_timings(self.ctx)
if isinstance(input, str):
inputs = [input]
else:
inputs = input
data = []
total_tokens = 0
for input in inputs:
tokens = self.tokenize(input.encode("utf-8")) tokens = self.tokenize(input.encode("utf-8"))
self.reset() self.reset()
self.eval(tokens) self.eval(tokens)
n_tokens = len(tokens) n_tokens = len(tokens)
total_tokens += n_tokens
embedding = llama_cpp.llama_get_embeddings(self.ctx)[ embedding = llama_cpp.llama_get_embeddings(self.ctx)[
: llama_cpp.llama_n_embd(self.ctx) : llama_cpp.llama_n_embd(self.ctx)
] ]
if self.verbose: if self.verbose:
llama_cpp.llama_print_timings(self.ctx) llama_cpp.llama_print_timings(self.ctx)
data.append(
return {
"object": "list",
"data": [
{ {
"object": "embedding", "object": "embedding",
"embedding": embedding, "embedding": embedding,
"index": 0, "index": 0,
} }
], )
"model": model_name,
return {
"object": "list",
"data": data,
"model": self.model_path,
"usage": { "usage": {
"prompt_tokens": n_tokens, "prompt_tokens": total_tokens,
"total_tokens": n_tokens, "total_tokens": total_tokens,
}, },
} }

View file

@ -275,7 +275,7 @@ def create_completion(
class CreateEmbeddingRequest(BaseModel): class CreateEmbeddingRequest(BaseModel):
model: Optional[str] = model_field model: Optional[str] = model_field
input: str = Field(description="The input to embed.") input: Union[str, List[str]] = Field(description="The input to embed.")
user: Optional[str] user: Optional[str]
class Config: class Config: