Add model_alias option to override model_path in completions. Closes #39
This commit is contained in:
parent
214589e462
commit
a3352923c7
2 changed files with 34 additions and 9 deletions
|
@ -522,7 +522,7 @@ class Llama:
|
|||
if tokens_or_none is not None:
|
||||
tokens.extend(tokens_or_none)
|
||||
|
||||
def create_embedding(self, input: str) -> Embedding:
|
||||
def create_embedding(self, input: str, model: Optional[str] = None) -> Embedding:
|
||||
"""Embed a string.
|
||||
|
||||
Args:
|
||||
|
@ -532,6 +532,7 @@ class Llama:
|
|||
An embedding object.
|
||||
"""
|
||||
assert self.ctx is not None
|
||||
_model: str = model if model is not None else self.model_path
|
||||
|
||||
if self.params.embedding == False:
|
||||
raise RuntimeError(
|
||||
|
@ -561,7 +562,7 @@ class Llama:
|
|||
"index": 0,
|
||||
}
|
||||
],
|
||||
"model": self.model_path,
|
||||
"model": _model,
|
||||
"usage": {
|
||||
"prompt_tokens": n_tokens,
|
||||
"total_tokens": n_tokens,
|
||||
|
@ -598,6 +599,7 @@ class Llama:
|
|||
mirostat_mode: int = 0,
|
||||
mirostat_tau: float = 5.0,
|
||||
mirostat_eta: float = 0.1,
|
||||
model: Optional[str] = None,
|
||||
) -> Union[Iterator[Completion], Iterator[CompletionChunk]]:
|
||||
assert self.ctx is not None
|
||||
completion_id: str = f"cmpl-{str(uuid.uuid4())}"
|
||||
|
@ -610,6 +612,7 @@ class Llama:
|
|||
text: bytes = b""
|
||||
returned_characters: int = 0
|
||||
stop = stop if stop is not None else []
|
||||
_model: str = model if model is not None else self.model_path
|
||||
|
||||
if self.verbose:
|
||||
llama_cpp.llama_reset_timings(self.ctx)
|
||||
|
@ -708,7 +711,7 @@ class Llama:
|
|||
"id": completion_id,
|
||||
"object": "text_completion",
|
||||
"created": created,
|
||||
"model": self.model_path,
|
||||
"model": _model,
|
||||
"choices": [
|
||||
{
|
||||
"text": text[start:].decode("utf-8", errors="ignore"),
|
||||
|
@ -737,7 +740,7 @@ class Llama:
|
|||
"id": completion_id,
|
||||
"object": "text_completion",
|
||||
"created": created,
|
||||
"model": self.model_path,
|
||||
"model": _model,
|
||||
"choices": [
|
||||
{
|
||||
"text": text[returned_characters:].decode(
|
||||
|
@ -807,7 +810,7 @@ class Llama:
|
|||
"id": completion_id,
|
||||
"object": "text_completion",
|
||||
"created": created,
|
||||
"model": self.model_path,
|
||||
"model": _model,
|
||||
"choices": [
|
||||
{
|
||||
"text": text_str,
|
||||
|
@ -842,6 +845,7 @@ class Llama:
|
|||
mirostat_mode: int = 0,
|
||||
mirostat_tau: float = 5.0,
|
||||
mirostat_eta: float = 0.1,
|
||||
model: Optional[str] = None,
|
||||
) -> Union[Completion, Iterator[CompletionChunk]]:
|
||||
"""Generate text from a prompt.
|
||||
|
||||
|
@ -883,6 +887,7 @@ class Llama:
|
|||
mirostat_mode=mirostat_mode,
|
||||
mirostat_tau=mirostat_tau,
|
||||
mirostat_eta=mirostat_eta,
|
||||
model=model,
|
||||
)
|
||||
if stream:
|
||||
chunks: Iterator[CompletionChunk] = completion_or_chunks
|
||||
|
@ -909,6 +914,7 @@ class Llama:
|
|||
mirostat_mode: int = 0,
|
||||
mirostat_tau: float = 5.0,
|
||||
mirostat_eta: float = 0.1,
|
||||
model: Optional[str] = None,
|
||||
) -> Union[Completion, Iterator[CompletionChunk]]:
|
||||
"""Generate text from a prompt.
|
||||
|
||||
|
@ -950,6 +956,7 @@ class Llama:
|
|||
mirostat_mode=mirostat_mode,
|
||||
mirostat_tau=mirostat_tau,
|
||||
mirostat_eta=mirostat_eta,
|
||||
model=model,
|
||||
)
|
||||
|
||||
def _convert_text_completion_to_chat(
|
||||
|
@ -1026,6 +1033,7 @@ class Llama:
|
|||
mirostat_mode: int = 0,
|
||||
mirostat_tau: float = 5.0,
|
||||
mirostat_eta: float = 0.1,
|
||||
model: Optional[str] = None,
|
||||
) -> Union[ChatCompletion, Iterator[ChatCompletionChunk]]:
|
||||
"""Generate a chat completion from a list of messages.
|
||||
|
||||
|
@ -1064,6 +1072,7 @@ class Llama:
|
|||
mirostat_mode=mirostat_mode,
|
||||
mirostat_tau=mirostat_tau,
|
||||
mirostat_eta=mirostat_eta,
|
||||
model=model,
|
||||
)
|
||||
if stream:
|
||||
chunks: Iterator[CompletionChunk] = completion_or_chunks # type: ignore
|
||||
|
|
|
@ -16,6 +16,10 @@ class Settings(BaseSettings):
|
|||
model: str = Field(
|
||||
description="The path to the model to use for generating completions."
|
||||
)
|
||||
model_alias: Optional[str] = Field(
|
||||
default=None,
|
||||
description="The alias of the model to use for generating completions.",
|
||||
)
|
||||
n_ctx: int = Field(default=2048, ge=1, description="The context size.")
|
||||
n_gpu_layers: int = Field(
|
||||
default=0,
|
||||
|
@ -64,6 +68,7 @@ class Settings(BaseSettings):
|
|||
|
||||
router = APIRouter()
|
||||
|
||||
settings: Optional[Settings] = None
|
||||
llama: Optional[llama_cpp.Llama] = None
|
||||
|
||||
|
||||
|
@ -101,6 +106,12 @@ def create_app(settings: Optional[Settings] = None):
|
|||
if settings.cache:
|
||||
cache = llama_cpp.LlamaCache(capacity_bytes=settings.cache_size)
|
||||
llama.set_cache(cache)
|
||||
|
||||
def set_settings(_settings: Settings):
|
||||
global settings
|
||||
settings = _settings
|
||||
|
||||
set_settings(settings)
|
||||
return app
|
||||
|
||||
|
||||
|
@ -112,6 +123,10 @@ def get_llama():
|
|||
yield llama
|
||||
|
||||
|
||||
def get_settings():
|
||||
yield settings
|
||||
|
||||
|
||||
model_field = Field(description="The model to use for generating completions.")
|
||||
|
||||
max_tokens_field = Field(
|
||||
|
@ -236,7 +251,6 @@ def create_completion(
|
|||
completion_or_chunks = llama(
|
||||
**request.dict(
|
||||
exclude={
|
||||
"model",
|
||||
"n",
|
||||
"best_of",
|
||||
"logit_bias",
|
||||
|
@ -274,7 +288,7 @@ CreateEmbeddingResponse = create_model_from_typeddict(llama_cpp.Embedding)
|
|||
def create_embedding(
|
||||
request: CreateEmbeddingRequest, llama: llama_cpp.Llama = Depends(get_llama)
|
||||
):
|
||||
return llama.create_embedding(**request.dict(exclude={"model", "user"}))
|
||||
return llama.create_embedding(**request.dict(exclude={"user"}))
|
||||
|
||||
|
||||
class ChatCompletionRequestMessage(BaseModel):
|
||||
|
@ -335,7 +349,6 @@ def create_chat_completion(
|
|||
completion_or_chunks = llama.create_chat_completion(
|
||||
**request.dict(
|
||||
exclude={
|
||||
"model",
|
||||
"n",
|
||||
"logit_bias",
|
||||
"user",
|
||||
|
@ -378,13 +391,16 @@ GetModelResponse = create_model_from_typeddict(ModelList)
|
|||
|
||||
@router.get("/v1/models", response_model=GetModelResponse)
|
||||
def get_models(
|
||||
settings: Settings = Depends(get_settings),
|
||||
llama: llama_cpp.Llama = Depends(get_llama),
|
||||
) -> ModelList:
|
||||
return {
|
||||
"object": "list",
|
||||
"data": [
|
||||
{
|
||||
"id": llama.model_path,
|
||||
"id": settings.model_alias
|
||||
if settings.model_alias is not None
|
||||
else llama.model_path,
|
||||
"object": "model",
|
||||
"owned_by": "me",
|
||||
"permissions": [],
|
||||
|
|
Loading…
Reference in a new issue