This commit is contained in:
Andrei Betlen 2023-07-19 03:48:27 -04:00
parent b43917c144
commit 0b121a7456
2 changed files with 23 additions and 13 deletions

View file

@ -87,9 +87,11 @@ class ChatCompletion(TypedDict):
choices: List[ChatCompletionChoice]
usage: CompletionUsage
class ChatCompletionChunkDeltaEmpty(TypedDict):
pass
class ChatCompletionChunkDelta(TypedDict):
role: NotRequired[Literal["assistant"]]
content: NotRequired[str]

View file

@ -38,11 +38,13 @@ class Settings(BaseSettings):
default=None,
description="Split layers across multiple GPUs in proportion.",
)
rope_freq_base: float = Field(default=10000, ge=1, description="RoPE base frequency")
rope_freq_scale: float = Field(default=1.0, description="RoPE frequency scaling factor")
seed: int = Field(
default=1337, description="Random seed. -1 for random."
rope_freq_base: float = Field(
default=10000, ge=1, description="RoPE base frequency"
)
rope_freq_scale: float = Field(
default=1.0, description="RoPE frequency scaling factor"
)
seed: int = Field(default=1337, description="Random seed. -1 for random.")
n_batch: int = Field(
default=512, ge=1, description="The batch size to use per eval."
)
@ -186,7 +188,9 @@ def get_settings():
yield settings
model_field = Field(description="The model to use for generating completions.", default=None)
model_field = Field(
description="The model to use for generating completions.", default=None
)
max_tokens_field = Field(
default=16, ge=1, le=2048, description="The maximum number of tokens to generate."
@ -373,9 +377,11 @@ async def create_completion(
kwargs = body.model_dump(exclude=exclude)
if body.logit_bias is not None:
kwargs['logits_processor'] = llama_cpp.LogitsProcessorList([
make_logit_bias_processor(llama, body.logit_bias, body.logit_bias_type),
])
kwargs["logits_processor"] = llama_cpp.LogitsProcessorList(
[
make_logit_bias_processor(llama, body.logit_bias, body.logit_bias_type),
]
)
if body.stream:
send_chan, recv_chan = anyio.create_memory_object_stream(10)
@ -402,7 +408,7 @@ async def create_completion(
return EventSourceResponse(
recv_chan, data_sender_callable=partial(event_publisher, send_chan)
) # type: ignore
) # type: ignore
else:
completion: llama_cpp.Completion = await run_in_threadpool(llama, **kwargs) # type: ignore
return completion
@ -512,9 +518,11 @@ async def create_chat_completion(
kwargs = body.model_dump(exclude=exclude)
if body.logit_bias is not None:
kwargs['logits_processor'] = llama_cpp.LogitsProcessorList([
make_logit_bias_processor(llama, body.logit_bias, body.logit_bias_type),
])
kwargs["logits_processor"] = llama_cpp.LogitsProcessorList(
[
make_logit_bias_processor(llama, body.logit_bias, body.logit_bias_type),
]
)
if body.stream:
send_chan, recv_chan = anyio.create_memory_object_stream(10)
@ -542,7 +550,7 @@ async def create_chat_completion(
return EventSourceResponse(
recv_chan,
data_sender_callable=partial(event_publisher, send_chan),
) # type: ignore
) # type: ignore
else:
completion: llama_cpp.ChatCompletion = await run_in_threadpool(
llama.create_chat_completion, **kwargs # type: ignore