Add model_alias option to override model_path in completions. Closes #39

This commit is contained in:
Andrei Betlen 2023-05-16 17:22:00 -04:00
parent 214589e462
commit a3352923c7
2 changed files with 34 additions and 9 deletions

View file

@ -522,7 +522,7 @@ class Llama:
if tokens_or_none is not None:
tokens.extend(tokens_or_none)
def create_embedding(self, input: str) -> Embedding:
def create_embedding(self, input: str, model: Optional[str] = None) -> Embedding:
"""Embed a string.
Args:
@ -532,6 +532,7 @@ class Llama:
An embedding object.
"""
assert self.ctx is not None
_model: str = model if model is not None else self.model_path
if self.params.embedding == False:
raise RuntimeError(
@ -561,7 +562,7 @@ class Llama:
"index": 0,
}
],
"model": self.model_path,
"model": _model,
"usage": {
"prompt_tokens": n_tokens,
"total_tokens": n_tokens,
@ -598,6 +599,7 @@ class Llama:
mirostat_mode: int = 0,
mirostat_tau: float = 5.0,
mirostat_eta: float = 0.1,
model: Optional[str] = None,
) -> Union[Iterator[Completion], Iterator[CompletionChunk]]:
assert self.ctx is not None
completion_id: str = f"cmpl-{str(uuid.uuid4())}"
@ -610,6 +612,7 @@ class Llama:
text: bytes = b""
returned_characters: int = 0
stop = stop if stop is not None else []
_model: str = model if model is not None else self.model_path
if self.verbose:
llama_cpp.llama_reset_timings(self.ctx)
@ -708,7 +711,7 @@ class Llama:
"id": completion_id,
"object": "text_completion",
"created": created,
"model": self.model_path,
"model": _model,
"choices": [
{
"text": text[start:].decode("utf-8", errors="ignore"),
@ -737,7 +740,7 @@ class Llama:
"id": completion_id,
"object": "text_completion",
"created": created,
"model": self.model_path,
"model": _model,
"choices": [
{
"text": text[returned_characters:].decode(
@ -807,7 +810,7 @@ class Llama:
"id": completion_id,
"object": "text_completion",
"created": created,
"model": self.model_path,
"model": _model,
"choices": [
{
"text": text_str,
@ -842,6 +845,7 @@ class Llama:
mirostat_mode: int = 0,
mirostat_tau: float = 5.0,
mirostat_eta: float = 0.1,
model: Optional[str] = None,
) -> Union[Completion, Iterator[CompletionChunk]]:
"""Generate text from a prompt.
@ -883,6 +887,7 @@ class Llama:
mirostat_mode=mirostat_mode,
mirostat_tau=mirostat_tau,
mirostat_eta=mirostat_eta,
model=model,
)
if stream:
chunks: Iterator[CompletionChunk] = completion_or_chunks
@ -909,6 +914,7 @@ class Llama:
mirostat_mode: int = 0,
mirostat_tau: float = 5.0,
mirostat_eta: float = 0.1,
model: Optional[str] = None,
) -> Union[Completion, Iterator[CompletionChunk]]:
"""Generate text from a prompt.
@ -950,6 +956,7 @@ class Llama:
mirostat_mode=mirostat_mode,
mirostat_tau=mirostat_tau,
mirostat_eta=mirostat_eta,
model=model,
)
def _convert_text_completion_to_chat(
@ -1026,6 +1033,7 @@ class Llama:
mirostat_mode: int = 0,
mirostat_tau: float = 5.0,
mirostat_eta: float = 0.1,
model: Optional[str] = None,
) -> Union[ChatCompletion, Iterator[ChatCompletionChunk]]:
"""Generate a chat completion from a list of messages.
@ -1064,6 +1072,7 @@ class Llama:
mirostat_mode=mirostat_mode,
mirostat_tau=mirostat_tau,
mirostat_eta=mirostat_eta,
model=model,
)
if stream:
chunks: Iterator[CompletionChunk] = completion_or_chunks # type: ignore

View file

@ -16,6 +16,10 @@ class Settings(BaseSettings):
model: str = Field(
description="The path to the model to use for generating completions."
)
model_alias: Optional[str] = Field(
default=None,
description="The alias of the model to use for generating completions.",
)
n_ctx: int = Field(default=2048, ge=1, description="The context size.")
n_gpu_layers: int = Field(
default=0,
@ -64,6 +68,7 @@ class Settings(BaseSettings):
router = APIRouter()
settings: Optional[Settings] = None
llama: Optional[llama_cpp.Llama] = None
@ -101,6 +106,12 @@ def create_app(settings: Optional[Settings] = None):
if settings.cache:
cache = llama_cpp.LlamaCache(capacity_bytes=settings.cache_size)
llama.set_cache(cache)
def set_settings(_settings: Settings):
global settings
settings = _settings
set_settings(settings)
return app
@ -112,6 +123,10 @@ def get_llama():
yield llama
def get_settings():
yield settings
model_field = Field(description="The model to use for generating completions.")
max_tokens_field = Field(
@ -236,7 +251,6 @@ def create_completion(
completion_or_chunks = llama(
**request.dict(
exclude={
"model",
"n",
"best_of",
"logit_bias",
@ -274,7 +288,7 @@ CreateEmbeddingResponse = create_model_from_typeddict(llama_cpp.Embedding)
def create_embedding(
request: CreateEmbeddingRequest, llama: llama_cpp.Llama = Depends(get_llama)
):
return llama.create_embedding(**request.dict(exclude={"model", "user"}))
return llama.create_embedding(**request.dict(exclude={"user"}))
class ChatCompletionRequestMessage(BaseModel):
@ -335,7 +349,6 @@ def create_chat_completion(
completion_or_chunks = llama.create_chat_completion(
**request.dict(
exclude={
"model",
"n",
"logit_bias",
"user",
@ -378,13 +391,16 @@ GetModelResponse = create_model_from_typeddict(ModelList)
@router.get("/v1/models", response_model=GetModelResponse)
def get_models(
settings: Settings = Depends(get_settings),
llama: llama_cpp.Llama = Depends(get_llama),
) -> ModelList:
return {
"object": "list",
"data": [
{
"id": llama.model_path,
"id": settings.model_alias
if settings.model_alias is not None
else llama.model_path,
"object": "model",
"owned_by": "me",
"permissions": [],