Format
This commit is contained in:
parent
b43917c144
commit
0b121a7456
2 changed files with 23 additions and 13 deletions
|
@ -87,9 +87,11 @@ class ChatCompletion(TypedDict):
|
||||||
choices: List[ChatCompletionChoice]
|
choices: List[ChatCompletionChoice]
|
||||||
usage: CompletionUsage
|
usage: CompletionUsage
|
||||||
|
|
||||||
|
|
||||||
class ChatCompletionChunkDeltaEmpty(TypedDict):
|
class ChatCompletionChunkDeltaEmpty(TypedDict):
|
||||||
pass
|
pass
|
||||||
|
|
||||||
|
|
||||||
class ChatCompletionChunkDelta(TypedDict):
|
class ChatCompletionChunkDelta(TypedDict):
|
||||||
role: NotRequired[Literal["assistant"]]
|
role: NotRequired[Literal["assistant"]]
|
||||||
content: NotRequired[str]
|
content: NotRequired[str]
|
||||||
|
|
|
@ -38,11 +38,13 @@ class Settings(BaseSettings):
|
||||||
default=None,
|
default=None,
|
||||||
description="Split layers across multiple GPUs in proportion.",
|
description="Split layers across multiple GPUs in proportion.",
|
||||||
)
|
)
|
||||||
rope_freq_base: float = Field(default=10000, ge=1, description="RoPE base frequency")
|
rope_freq_base: float = Field(
|
||||||
rope_freq_scale: float = Field(default=1.0, description="RoPE frequency scaling factor")
|
default=10000, ge=1, description="RoPE base frequency"
|
||||||
seed: int = Field(
|
|
||||||
default=1337, description="Random seed. -1 for random."
|
|
||||||
)
|
)
|
||||||
|
rope_freq_scale: float = Field(
|
||||||
|
default=1.0, description="RoPE frequency scaling factor"
|
||||||
|
)
|
||||||
|
seed: int = Field(default=1337, description="Random seed. -1 for random.")
|
||||||
n_batch: int = Field(
|
n_batch: int = Field(
|
||||||
default=512, ge=1, description="The batch size to use per eval."
|
default=512, ge=1, description="The batch size to use per eval."
|
||||||
)
|
)
|
||||||
|
@ -186,7 +188,9 @@ def get_settings():
|
||||||
yield settings
|
yield settings
|
||||||
|
|
||||||
|
|
||||||
model_field = Field(description="The model to use for generating completions.", default=None)
|
model_field = Field(
|
||||||
|
description="The model to use for generating completions.", default=None
|
||||||
|
)
|
||||||
|
|
||||||
max_tokens_field = Field(
|
max_tokens_field = Field(
|
||||||
default=16, ge=1, le=2048, description="The maximum number of tokens to generate."
|
default=16, ge=1, le=2048, description="The maximum number of tokens to generate."
|
||||||
|
@ -373,9 +377,11 @@ async def create_completion(
|
||||||
kwargs = body.model_dump(exclude=exclude)
|
kwargs = body.model_dump(exclude=exclude)
|
||||||
|
|
||||||
if body.logit_bias is not None:
|
if body.logit_bias is not None:
|
||||||
kwargs['logits_processor'] = llama_cpp.LogitsProcessorList([
|
kwargs["logits_processor"] = llama_cpp.LogitsProcessorList(
|
||||||
make_logit_bias_processor(llama, body.logit_bias, body.logit_bias_type),
|
[
|
||||||
])
|
make_logit_bias_processor(llama, body.logit_bias, body.logit_bias_type),
|
||||||
|
]
|
||||||
|
)
|
||||||
|
|
||||||
if body.stream:
|
if body.stream:
|
||||||
send_chan, recv_chan = anyio.create_memory_object_stream(10)
|
send_chan, recv_chan = anyio.create_memory_object_stream(10)
|
||||||
|
@ -402,7 +408,7 @@ async def create_completion(
|
||||||
|
|
||||||
return EventSourceResponse(
|
return EventSourceResponse(
|
||||||
recv_chan, data_sender_callable=partial(event_publisher, send_chan)
|
recv_chan, data_sender_callable=partial(event_publisher, send_chan)
|
||||||
) # type: ignore
|
) # type: ignore
|
||||||
else:
|
else:
|
||||||
completion: llama_cpp.Completion = await run_in_threadpool(llama, **kwargs) # type: ignore
|
completion: llama_cpp.Completion = await run_in_threadpool(llama, **kwargs) # type: ignore
|
||||||
return completion
|
return completion
|
||||||
|
@ -512,9 +518,11 @@ async def create_chat_completion(
|
||||||
kwargs = body.model_dump(exclude=exclude)
|
kwargs = body.model_dump(exclude=exclude)
|
||||||
|
|
||||||
if body.logit_bias is not None:
|
if body.logit_bias is not None:
|
||||||
kwargs['logits_processor'] = llama_cpp.LogitsProcessorList([
|
kwargs["logits_processor"] = llama_cpp.LogitsProcessorList(
|
||||||
make_logit_bias_processor(llama, body.logit_bias, body.logit_bias_type),
|
[
|
||||||
])
|
make_logit_bias_processor(llama, body.logit_bias, body.logit_bias_type),
|
||||||
|
]
|
||||||
|
)
|
||||||
|
|
||||||
if body.stream:
|
if body.stream:
|
||||||
send_chan, recv_chan = anyio.create_memory_object_stream(10)
|
send_chan, recv_chan = anyio.create_memory_object_stream(10)
|
||||||
|
@ -542,7 +550,7 @@ async def create_chat_completion(
|
||||||
return EventSourceResponse(
|
return EventSourceResponse(
|
||||||
recv_chan,
|
recv_chan,
|
||||||
data_sender_callable=partial(event_publisher, send_chan),
|
data_sender_callable=partial(event_publisher, send_chan),
|
||||||
) # type: ignore
|
) # type: ignore
|
||||||
else:
|
else:
|
||||||
completion: llama_cpp.ChatCompletion = await run_in_threadpool(
|
completion: llama_cpp.ChatCompletion = await run_in_threadpool(
|
||||||
llama.create_chat_completion, **kwargs # type: ignore
|
llama.create_chat_completion, **kwargs # type: ignore
|
||||||
|
|
Loading…
Reference in a new issue