bugfix: pydantic v2 fields

This commit is contained in:
Andrei Betlen 2023-07-13 23:25:12 -04:00
parent 896ab7b88a
commit de4cc5a233

View file

@ -31,9 +31,7 @@ class Settings(BaseSettings):
ge=0, ge=0,
description="The number of layers to put on the GPU. The rest will be on the CPU.", description="The number of layers to put on the GPU. The rest will be on the CPU.",
) )
seed: int = Field( seed: int = Field(default=1337, description="Random seed. -1 for random.")
default=1337, description="Random seed. -1 for random."
)
n_batch: int = Field( n_batch: int = Field(
default=512, ge=1, description="The batch size to use per eval." default=512, ge=1, description="The batch size to use per eval."
) )
@ -80,12 +78,8 @@ class Settings(BaseSettings):
verbose: bool = Field( verbose: bool = Field(
default=True, description="Whether to print debug information." default=True, description="Whether to print debug information."
) )
host: str = Field( host: str = Field(default="localhost", description="Listen address")
default="localhost", description="Listen address" port: int = Field(default=8000, description="Listen port")
)
port: int = Field(
default=8000, description="Listen port"
)
interrupt_requests: bool = Field( interrupt_requests: bool = Field(
default=True, default=True,
description="Whether to interrupt requests when a new request is received.", description="Whether to interrupt requests when a new request is received.",
@ -178,7 +172,7 @@ def get_settings():
yield settings yield settings
model_field = Field(description="The model to use for generating completions.") model_field = Field(description="The model to use for generating completions.", default=None)
max_tokens_field = Field( max_tokens_field = Field(
default=16, ge=1, le=2048, description="The maximum number of tokens to generate." default=16, ge=1, le=2048, description="The maximum number of tokens to generate."
@ -242,21 +236,18 @@ mirostat_mode_field = Field(
default=0, default=0,
ge=0, ge=0,
le=2, le=2,
description="Enable Mirostat constant-perplexity algorithm of the specified version (1 or 2; 0 = disabled)" description="Enable Mirostat constant-perplexity algorithm of the specified version (1 or 2; 0 = disabled)",
) )
mirostat_tau_field = Field( mirostat_tau_field = Field(
default=5.0, default=5.0,
ge=0.0, ge=0.0,
le=10.0, le=10.0,
description="Mirostat target entropy, i.e. the target perplexity - lower values produce focused and coherent text, larger values produce more diverse and less coherent text" description="Mirostat target entropy, i.e. the target perplexity - lower values produce focused and coherent text, larger values produce more diverse and less coherent text",
) )
mirostat_eta_field = Field( mirostat_eta_field = Field(
default=0.1, default=0.1, ge=0.001, le=1.0, description="Mirostat learning rate"
ge=0.001,
le=1.0,
description="Mirostat learning rate"
) )
@ -294,22 +285,23 @@ class CreateCompletionRequest(BaseModel):
model: Optional[str] = model_field model: Optional[str] = model_field
n: Optional[int] = 1 n: Optional[int] = 1
best_of: Optional[int] = 1 best_of: Optional[int] = 1
user: Optional[str] = Field(None) user: Optional[str] = Field(default=None)
# llama.cpp specific parameters # llama.cpp specific parameters
top_k: int = top_k_field top_k: int = top_k_field
repeat_penalty: float = repeat_penalty_field repeat_penalty: float = repeat_penalty_field
logit_bias_type: Optional[Literal["input_ids", "tokens"]] = Field(None) logit_bias_type: Optional[Literal["input_ids", "tokens"]] = Field(None)
class Config: model_config = {
schema_extra = { "json_schema_extra": {
"example": { "examples": [
{
"prompt": "\n\n### Instructions:\nWhat is the capital of France?\n\n### Response:\n", "prompt": "\n\n### Instructions:\nWhat is the capital of France?\n\n### Response:\n",
"stop": ["\n", "###"], "stop": ["\n", "###"],
} }
]
}
} }
def make_logit_bias_processor( def make_logit_bias_processor(
@ -328,7 +320,7 @@ def make_logit_bias_processor(
elif logit_bias_type == "tokens": elif logit_bias_type == "tokens":
for token, score in logit_bias.items(): for token, score in logit_bias.items():
token = token.encode('utf-8') token = token.encode("utf-8")
for input_id in llama.tokenize(token, add_bos=False): for input_id in llama.tokenize(token, add_bos=False):
to_bias[input_id] = score to_bias[input_id] = score
@ -352,7 +344,7 @@ async def create_completion(
request: Request, request: Request,
body: CreateCompletionRequest, body: CreateCompletionRequest,
llama: llama_cpp.Llama = Depends(get_llama), llama: llama_cpp.Llama = Depends(get_llama),
): ) -> llama_cpp.Completion:
if isinstance(body.prompt, list): if isinstance(body.prompt, list):
assert len(body.prompt) <= 1 assert len(body.prompt) <= 1
body.prompt = body.prompt[0] if len(body.prompt) > 0 else "" body.prompt = body.prompt[0] if len(body.prompt) > 0 else ""
@ -364,7 +356,7 @@ async def create_completion(
"logit_bias_type", "logit_bias_type",
"user", "user",
} }
kwargs = body.dict(exclude=exclude) kwargs = body.model_dump(exclude=exclude)
if body.logit_bias is not None: if body.logit_bias is not None:
kwargs['logits_processor'] = llama_cpp.LogitsProcessorList([ kwargs['logits_processor'] = llama_cpp.LogitsProcessorList([
@ -396,7 +388,7 @@ async def create_completion(
return EventSourceResponse( return EventSourceResponse(
recv_chan, data_sender_callable=partial(event_publisher, send_chan) recv_chan, data_sender_callable=partial(event_publisher, send_chan)
) ) # type: ignore
else: else:
completion: llama_cpp.Completion = await run_in_threadpool(llama, **kwargs) # type: ignore completion: llama_cpp.Completion = await run_in_threadpool(llama, **kwargs) # type: ignore
return completion return completion
@ -405,16 +397,17 @@ async def create_completion(
class CreateEmbeddingRequest(BaseModel): class CreateEmbeddingRequest(BaseModel):
model: Optional[str] = model_field model: Optional[str] = model_field
input: Union[str, List[str]] = Field(description="The input to embed.") input: Union[str, List[str]] = Field(description="The input to embed.")
user: Optional[str] user: Optional[str] = Field(default=None)
class Config: model_config = {
schema_extra = { "json_schema_extra": {
"example": { "examples": [
{
"input": "The food was delicious and the waiter...", "input": "The food was delicious and the waiter...",
} }
]
}
} }
@router.post( @router.post(
@ -424,7 +417,7 @@ async def create_embedding(
request: CreateEmbeddingRequest, llama: llama_cpp.Llama = Depends(get_llama) request: CreateEmbeddingRequest, llama: llama_cpp.Llama = Depends(get_llama)
): ):
return await run_in_threadpool( return await run_in_threadpool(
llama.create_embedding, **request.dict(exclude={"user"}) llama.create_embedding, **request.model_dump(exclude={"user"})
) )
@ -461,23 +454,24 @@ class CreateChatCompletionRequest(BaseModel):
repeat_penalty: float = repeat_penalty_field repeat_penalty: float = repeat_penalty_field
logit_bias_type: Optional[Literal["input_ids", "tokens"]] = Field(None) logit_bias_type: Optional[Literal["input_ids", "tokens"]] = Field(None)
class Config: model_config = {
schema_extra = { "json_schema_extra": {
"example": { "examples": [
{
"messages": [ "messages": [
ChatCompletionRequestMessage( ChatCompletionRequestMessage(
role="system", content="You are a helpful assistant." role="system", content="You are a helpful assistant."
), ).model_dump(),
ChatCompletionRequestMessage( ChatCompletionRequestMessage(
role="user", content="What is the capital of France?" role="user", content="What is the capital of France?"
), ).model_dump(),
]
}
] ]
} }
} }
@router.post( @router.post(
"/v1/chat/completions", "/v1/chat/completions",
) )
@ -486,14 +480,14 @@ async def create_chat_completion(
body: CreateChatCompletionRequest, body: CreateChatCompletionRequest,
llama: llama_cpp.Llama = Depends(get_llama), llama: llama_cpp.Llama = Depends(get_llama),
settings: Settings = Depends(get_settings), settings: Settings = Depends(get_settings),
) -> Union[llama_cpp.ChatCompletion]: # type: ignore ) -> llama_cpp.ChatCompletion:
exclude = { exclude = {
"n", "n",
"logit_bias", "logit_bias",
"logit_bias_type", "logit_bias_type",
"user", "user",
} }
kwargs = body.dict(exclude=exclude) kwargs = body.model_dump(exclude=exclude)
if body.logit_bias is not None: if body.logit_bias is not None:
kwargs['logits_processor'] = llama_cpp.LogitsProcessorList([ kwargs['logits_processor'] = llama_cpp.LogitsProcessorList([
@ -526,7 +520,7 @@ async def create_chat_completion(
return EventSourceResponse( return EventSourceResponse(
recv_chan, recv_chan,
data_sender_callable=partial(event_publisher, send_chan), data_sender_callable=partial(event_publisher, send_chan),
) ) # type: ignore
else: else:
completion: llama_cpp.ChatCompletion = await run_in_threadpool( completion: llama_cpp.ChatCompletion = await run_in_threadpool(
llama.create_chat_completion, **kwargs # type: ignore llama.create_chat_completion, **kwargs # type: ignore
@ -546,8 +540,6 @@ class ModelList(TypedDict):
data: List[ModelData] data: List[ModelData]
@router.get("/v1/models") @router.get("/v1/models")
async def get_models( async def get_models(
settings: Settings = Depends(get_settings), settings: Settings = Depends(get_settings),