feat: make embedding support list of string as input

makes the /v1/embedding route similar to OpenAI api.
This commit is contained in:
Simon Chabot 2023-05-20 01:23:32 +02:00
parent 01a010be52
commit e783f1c191
2 changed files with 30 additions and 18 deletions

View file

@ -531,7 +531,9 @@ class Llama:
if tokens_or_none is not None: if tokens_or_none is not None:
tokens.extend(tokens_or_none) tokens.extend(tokens_or_none)
def create_embedding(self, input: str, model: Optional[str] = None) -> Embedding: def create_embedding(
self, input: Union[str, List[str]], model: Optional[str] = None
) -> Embedding:
"""Embed a string. """Embed a string.
Args: Args:
@ -551,30 +553,40 @@ class Llama:
if self.verbose: if self.verbose:
llama_cpp.llama_reset_timings(self.ctx) llama_cpp.llama_reset_timings(self.ctx)
tokens = self.tokenize(input.encode("utf-8")) if isinstance(input, str):
self.reset() inputs = [input]
self.eval(tokens) else:
n_tokens = len(tokens) inputs = input
embedding = llama_cpp.llama_get_embeddings(self.ctx)[
: llama_cpp.llama_n_embd(self.ctx)
]
if self.verbose: data = []
llama_cpp.llama_print_timings(self.ctx) total_tokens = 0
for input in inputs:
tokens = self.tokenize(input.encode("utf-8"))
self.reset()
self.eval(tokens)
n_tokens = len(tokens)
total_tokens += n_tokens
embedding = llama_cpp.llama_get_embeddings(self.ctx)[
: llama_cpp.llama_n_embd(self.ctx)
]
return { if self.verbose:
"object": "list", llama_cpp.llama_print_timings(self.ctx)
"data": [ data.append(
{ {
"object": "embedding", "object": "embedding",
"embedding": embedding, "embedding": embedding,
"index": 0, "index": 0,
} }
], )
"model": model_name,
return {
"object": "list",
"data": data,
"model": self.model_path,
"usage": { "usage": {
"prompt_tokens": n_tokens, "prompt_tokens": total_tokens,
"total_tokens": n_tokens, "total_tokens": total_tokens,
}, },
} }

View file

@ -275,7 +275,7 @@ def create_completion(
class CreateEmbeddingRequest(BaseModel): class CreateEmbeddingRequest(BaseModel):
model: Optional[str] = model_field model: Optional[str] = model_field
input: str = Field(description="The input to embed.") input: Union[str, List[str]] = Field(description="The input to embed.")
user: Optional[str] user: Optional[str]
class Config: class Config: