Format
This commit is contained in:
parent
b43917c144
commit
0b121a7456
2 changed files with 23 additions and 13 deletions
|
@ -87,9 +87,11 @@ class ChatCompletion(TypedDict):
|
|||
choices: List[ChatCompletionChoice]
|
||||
usage: CompletionUsage
|
||||
|
||||
|
||||
class ChatCompletionChunkDeltaEmpty(TypedDict):
|
||||
pass
|
||||
|
||||
|
||||
class ChatCompletionChunkDelta(TypedDict):
|
||||
role: NotRequired[Literal["assistant"]]
|
||||
content: NotRequired[str]
|
||||
|
|
|
@ -38,11 +38,13 @@ class Settings(BaseSettings):
|
|||
default=None,
|
||||
description="Split layers across multiple GPUs in proportion.",
|
||||
)
|
||||
rope_freq_base: float = Field(default=10000, ge=1, description="RoPE base frequency")
|
||||
rope_freq_scale: float = Field(default=1.0, description="RoPE frequency scaling factor")
|
||||
seed: int = Field(
|
||||
default=1337, description="Random seed. -1 for random."
|
||||
rope_freq_base: float = Field(
|
||||
default=10000, ge=1, description="RoPE base frequency"
|
||||
)
|
||||
rope_freq_scale: float = Field(
|
||||
default=1.0, description="RoPE frequency scaling factor"
|
||||
)
|
||||
seed: int = Field(default=1337, description="Random seed. -1 for random.")
|
||||
n_batch: int = Field(
|
||||
default=512, ge=1, description="The batch size to use per eval."
|
||||
)
|
||||
|
@ -186,7 +188,9 @@ def get_settings():
|
|||
yield settings
|
||||
|
||||
|
||||
model_field = Field(description="The model to use for generating completions.", default=None)
|
||||
model_field = Field(
|
||||
description="The model to use for generating completions.", default=None
|
||||
)
|
||||
|
||||
max_tokens_field = Field(
|
||||
default=16, ge=1, le=2048, description="The maximum number of tokens to generate."
|
||||
|
@ -373,9 +377,11 @@ async def create_completion(
|
|||
kwargs = body.model_dump(exclude=exclude)
|
||||
|
||||
if body.logit_bias is not None:
|
||||
kwargs['logits_processor'] = llama_cpp.LogitsProcessorList([
|
||||
make_logit_bias_processor(llama, body.logit_bias, body.logit_bias_type),
|
||||
])
|
||||
kwargs["logits_processor"] = llama_cpp.LogitsProcessorList(
|
||||
[
|
||||
make_logit_bias_processor(llama, body.logit_bias, body.logit_bias_type),
|
||||
]
|
||||
)
|
||||
|
||||
if body.stream:
|
||||
send_chan, recv_chan = anyio.create_memory_object_stream(10)
|
||||
|
@ -402,7 +408,7 @@ async def create_completion(
|
|||
|
||||
return EventSourceResponse(
|
||||
recv_chan, data_sender_callable=partial(event_publisher, send_chan)
|
||||
) # type: ignore
|
||||
) # type: ignore
|
||||
else:
|
||||
completion: llama_cpp.Completion = await run_in_threadpool(llama, **kwargs) # type: ignore
|
||||
return completion
|
||||
|
@ -512,9 +518,11 @@ async def create_chat_completion(
|
|||
kwargs = body.model_dump(exclude=exclude)
|
||||
|
||||
if body.logit_bias is not None:
|
||||
kwargs['logits_processor'] = llama_cpp.LogitsProcessorList([
|
||||
make_logit_bias_processor(llama, body.logit_bias, body.logit_bias_type),
|
||||
])
|
||||
kwargs["logits_processor"] = llama_cpp.LogitsProcessorList(
|
||||
[
|
||||
make_logit_bias_processor(llama, body.logit_bias, body.logit_bias_type),
|
||||
]
|
||||
)
|
||||
|
||||
if body.stream:
|
||||
send_chan, recv_chan = anyio.create_memory_object_stream(10)
|
||||
|
@ -542,7 +550,7 @@ async def create_chat_completion(
|
|||
return EventSourceResponse(
|
||||
recv_chan,
|
||||
data_sender_callable=partial(event_publisher, send_chan),
|
||||
) # type: ignore
|
||||
) # type: ignore
|
||||
else:
|
||||
completion: llama_cpp.ChatCompletion = await run_in_threadpool(
|
||||
llama.create_chat_completion, **kwargs # type: ignore
|
||||
|
|
Loading…
Reference in a new issue