bugfix: pydantic v2 fields

This commit is contained in:
Andrei Betlen 2023-07-13 23:25:12 -04:00
parent 896ab7b88a
commit de4cc5a233

View file

@ -31,9 +31,7 @@ class Settings(BaseSettings):
ge=0,
description="The number of layers to put on the GPU. The rest will be on the CPU.",
)
seed: int = Field(
default=1337, description="Random seed. -1 for random."
)
seed: int = Field(default=1337, description="Random seed. -1 for random.")
n_batch: int = Field(
default=512, ge=1, description="The batch size to use per eval."
)
@ -80,12 +78,8 @@ class Settings(BaseSettings):
verbose: bool = Field(
default=True, description="Whether to print debug information."
)
host: str = Field(
default="localhost", description="Listen address"
)
port: int = Field(
default=8000, description="Listen port"
)
host: str = Field(default="localhost", description="Listen address")
port: int = Field(default=8000, description="Listen port")
interrupt_requests: bool = Field(
default=True,
description="Whether to interrupt requests when a new request is received.",
@ -178,7 +172,7 @@ def get_settings():
yield settings
model_field = Field(description="The model to use for generating completions.")
model_field = Field(description="The model to use for generating completions.", default=None)
max_tokens_field = Field(
default=16, ge=1, le=2048, description="The maximum number of tokens to generate."
@ -242,21 +236,18 @@ mirostat_mode_field = Field(
default=0,
ge=0,
le=2,
description="Enable Mirostat constant-perplexity algorithm of the specified version (1 or 2; 0 = disabled)"
description="Enable Mirostat constant-perplexity algorithm of the specified version (1 or 2; 0 = disabled)",
)
mirostat_tau_field = Field(
default=5.0,
ge=0.0,
le=10.0,
description="Mirostat target entropy, i.e. the target perplexity - lower values produce focused and coherent text, larger values produce more diverse and less coherent text"
description="Mirostat target entropy, i.e. the target perplexity - lower values produce focused and coherent text, larger values produce more diverse and less coherent text",
)
mirostat_eta_field = Field(
default=0.1,
ge=0.001,
le=1.0,
description="Mirostat learning rate"
default=0.1, ge=0.001, le=1.0, description="Mirostat learning rate"
)
@ -294,22 +285,23 @@ class CreateCompletionRequest(BaseModel):
model: Optional[str] = model_field
n: Optional[int] = 1
best_of: Optional[int] = 1
user: Optional[str] = Field(None)
user: Optional[str] = Field(default=None)
# llama.cpp specific parameters
top_k: int = top_k_field
repeat_penalty: float = repeat_penalty_field
logit_bias_type: Optional[Literal["input_ids", "tokens"]] = Field(None)
class Config:
schema_extra = {
"example": {
"prompt": "\n\n### Instructions:\nWhat is the capital of France?\n\n### Response:\n",
"stop": ["\n", "###"],
}
model_config = {
"json_schema_extra": {
"examples": [
{
"prompt": "\n\n### Instructions:\nWhat is the capital of France?\n\n### Response:\n",
"stop": ["\n", "###"],
}
]
}
}
def make_logit_bias_processor(
@ -328,7 +320,7 @@ def make_logit_bias_processor(
elif logit_bias_type == "tokens":
for token, score in logit_bias.items():
token = token.encode('utf-8')
token = token.encode("utf-8")
for input_id in llama.tokenize(token, add_bos=False):
to_bias[input_id] = score
@ -352,7 +344,7 @@ async def create_completion(
request: Request,
body: CreateCompletionRequest,
llama: llama_cpp.Llama = Depends(get_llama),
):
) -> llama_cpp.Completion:
if isinstance(body.prompt, list):
assert len(body.prompt) <= 1
body.prompt = body.prompt[0] if len(body.prompt) > 0 else ""
@ -364,7 +356,7 @@ async def create_completion(
"logit_bias_type",
"user",
}
kwargs = body.dict(exclude=exclude)
kwargs = body.model_dump(exclude=exclude)
if body.logit_bias is not None:
kwargs['logits_processor'] = llama_cpp.LogitsProcessorList([
@ -396,7 +388,7 @@ async def create_completion(
return EventSourceResponse(
recv_chan, data_sender_callable=partial(event_publisher, send_chan)
)
) # type: ignore
else:
completion: llama_cpp.Completion = await run_in_threadpool(llama, **kwargs) # type: ignore
return completion
@ -405,16 +397,17 @@ async def create_completion(
class CreateEmbeddingRequest(BaseModel):
model: Optional[str] = model_field
input: Union[str, List[str]] = Field(description="The input to embed.")
user: Optional[str]
user: Optional[str] = Field(default=None)
class Config:
schema_extra = {
"example": {
"input": "The food was delicious and the waiter...",
}
model_config = {
"json_schema_extra": {
"examples": [
{
"input": "The food was delicious and the waiter...",
}
]
}
}
@router.post(
@ -424,7 +417,7 @@ async def create_embedding(
request: CreateEmbeddingRequest, llama: llama_cpp.Llama = Depends(get_llama)
):
return await run_in_threadpool(
llama.create_embedding, **request.dict(exclude={"user"})
llama.create_embedding, **request.model_dump(exclude={"user"})
)
@ -461,21 +454,22 @@ class CreateChatCompletionRequest(BaseModel):
repeat_penalty: float = repeat_penalty_field
logit_bias_type: Optional[Literal["input_ids", "tokens"]] = Field(None)
class Config:
schema_extra = {
"example": {
"messages": [
ChatCompletionRequestMessage(
role="system", content="You are a helpful assistant."
),
ChatCompletionRequestMessage(
role="user", content="What is the capital of France?"
),
]
}
model_config = {
"json_schema_extra": {
"examples": [
{
"messages": [
ChatCompletionRequestMessage(
role="system", content="You are a helpful assistant."
).model_dump(),
ChatCompletionRequestMessage(
role="user", content="What is the capital of France?"
).model_dump(),
]
}
]
}
}
@router.post(
@ -486,14 +480,14 @@ async def create_chat_completion(
body: CreateChatCompletionRequest,
llama: llama_cpp.Llama = Depends(get_llama),
settings: Settings = Depends(get_settings),
) -> Union[llama_cpp.ChatCompletion]: # type: ignore
) -> llama_cpp.ChatCompletion:
exclude = {
"n",
"logit_bias",
"logit_bias_type",
"user",
}
kwargs = body.dict(exclude=exclude)
kwargs = body.model_dump(exclude=exclude)
if body.logit_bias is not None:
kwargs['logits_processor'] = llama_cpp.LogitsProcessorList([
@ -526,7 +520,7 @@ async def create_chat_completion(
return EventSourceResponse(
recv_chan,
data_sender_callable=partial(event_publisher, send_chan),
)
) # type: ignore
else:
completion: llama_cpp.ChatCompletion = await run_in_threadpool(
llama.create_chat_completion, **kwargs # type: ignore
@ -546,8 +540,6 @@ class ModelList(TypedDict):
data: List[ModelData]
@router.get("/v1/models")
async def get_models(
settings: Settings = Depends(get_settings),