feat: make embedding support list of string as input
makes the /v1/embedding route similar to OpenAI api.
This commit is contained in:
parent
01a010be52
commit
e783f1c191
2 changed files with 30 additions and 18 deletions
|
@ -531,7 +531,9 @@ class Llama:
|
||||||
if tokens_or_none is not None:
|
if tokens_or_none is not None:
|
||||||
tokens.extend(tokens_or_none)
|
tokens.extend(tokens_or_none)
|
||||||
|
|
||||||
def create_embedding(self, input: str, model: Optional[str] = None) -> Embedding:
|
def create_embedding(
|
||||||
|
self, input: Union[str, List[str]], model: Optional[str] = None
|
||||||
|
) -> Embedding:
|
||||||
"""Embed a string.
|
"""Embed a string.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
|
@ -551,30 +553,40 @@ class Llama:
|
||||||
if self.verbose:
|
if self.verbose:
|
||||||
llama_cpp.llama_reset_timings(self.ctx)
|
llama_cpp.llama_reset_timings(self.ctx)
|
||||||
|
|
||||||
tokens = self.tokenize(input.encode("utf-8"))
|
if isinstance(input, str):
|
||||||
self.reset()
|
inputs = [input]
|
||||||
self.eval(tokens)
|
else:
|
||||||
n_tokens = len(tokens)
|
inputs = input
|
||||||
embedding = llama_cpp.llama_get_embeddings(self.ctx)[
|
|
||||||
: llama_cpp.llama_n_embd(self.ctx)
|
|
||||||
]
|
|
||||||
|
|
||||||
if self.verbose:
|
data = []
|
||||||
llama_cpp.llama_print_timings(self.ctx)
|
total_tokens = 0
|
||||||
|
for input in inputs:
|
||||||
|
tokens = self.tokenize(input.encode("utf-8"))
|
||||||
|
self.reset()
|
||||||
|
self.eval(tokens)
|
||||||
|
n_tokens = len(tokens)
|
||||||
|
total_tokens += n_tokens
|
||||||
|
embedding = llama_cpp.llama_get_embeddings(self.ctx)[
|
||||||
|
: llama_cpp.llama_n_embd(self.ctx)
|
||||||
|
]
|
||||||
|
|
||||||
return {
|
if self.verbose:
|
||||||
"object": "list",
|
llama_cpp.llama_print_timings(self.ctx)
|
||||||
"data": [
|
data.append(
|
||||||
{
|
{
|
||||||
"object": "embedding",
|
"object": "embedding",
|
||||||
"embedding": embedding,
|
"embedding": embedding,
|
||||||
"index": 0,
|
"index": 0,
|
||||||
}
|
}
|
||||||
],
|
)
|
||||||
"model": model_name,
|
|
||||||
|
return {
|
||||||
|
"object": "list",
|
||||||
|
"data": data,
|
||||||
|
"model": self.model_path,
|
||||||
"usage": {
|
"usage": {
|
||||||
"prompt_tokens": n_tokens,
|
"prompt_tokens": total_tokens,
|
||||||
"total_tokens": n_tokens,
|
"total_tokens": total_tokens,
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -275,7 +275,7 @@ def create_completion(
|
||||||
|
|
||||||
class CreateEmbeddingRequest(BaseModel):
|
class CreateEmbeddingRequest(BaseModel):
|
||||||
model: Optional[str] = model_field
|
model: Optional[str] = model_field
|
||||||
input: str = Field(description="The input to embed.")
|
input: Union[str, List[str]] = Field(description="The input to embed.")
|
||||||
user: Optional[str]
|
user: Optional[str]
|
||||||
|
|
||||||
class Config:
|
class Config:
|
||||||
|
|
Loading…
Reference in a new issue