feat: make embedding support list of string as input

makes the /v1/embedding route similar to OpenAI api.
This commit is contained in:
Simon Chabot 2023-05-20 01:23:32 +02:00
parent 01a010be52
commit e783f1c191
2 changed files with 30 additions and 18 deletions

View file

@ -531,7 +531,9 @@ class Llama:
if tokens_or_none is not None:
tokens.extend(tokens_or_none)
def create_embedding(self, input: str, model: Optional[str] = None) -> Embedding:
def create_embedding(
self, input: Union[str, List[str]], model: Optional[str] = None
) -> Embedding:
"""Embed a string.
Args:
@ -551,30 +553,40 @@ class Llama:
if self.verbose:
llama_cpp.llama_reset_timings(self.ctx)
tokens = self.tokenize(input.encode("utf-8"))
self.reset()
self.eval(tokens)
n_tokens = len(tokens)
embedding = llama_cpp.llama_get_embeddings(self.ctx)[
: llama_cpp.llama_n_embd(self.ctx)
]
if isinstance(input, str):
inputs = [input]
else:
inputs = input
if self.verbose:
llama_cpp.llama_print_timings(self.ctx)
data = []
total_tokens = 0
for input in inputs:
tokens = self.tokenize(input.encode("utf-8"))
self.reset()
self.eval(tokens)
n_tokens = len(tokens)
total_tokens += n_tokens
embedding = llama_cpp.llama_get_embeddings(self.ctx)[
: llama_cpp.llama_n_embd(self.ctx)
]
return {
"object": "list",
"data": [
if self.verbose:
llama_cpp.llama_print_timings(self.ctx)
data.append(
{
"object": "embedding",
"embedding": embedding,
"index": 0,
}
],
"model": model_name,
)
return {
"object": "list",
"data": data,
"model": self.model_path,
"usage": {
"prompt_tokens": n_tokens,
"total_tokens": n_tokens,
"prompt_tokens": total_tokens,
"total_tokens": total_tokens,
},
}

View file

@ -275,7 +275,7 @@ def create_completion(
class CreateEmbeddingRequest(BaseModel):
model: Optional[str] = model_field
input: str = Field(description="The input to embed.")
input: Union[str, List[str]] = Field(description="The input to embed.")
user: Optional[str]
class Config: