This commit is contained in:
Andrei Betlen 2023-07-19 03:48:27 -04:00
parent b43917c144
commit 0b121a7456
2 changed files with 23 additions and 13 deletions

View file

@ -87,9 +87,11 @@ class ChatCompletion(TypedDict):
choices: List[ChatCompletionChoice] choices: List[ChatCompletionChoice]
usage: CompletionUsage usage: CompletionUsage
class ChatCompletionChunkDeltaEmpty(TypedDict): class ChatCompletionChunkDeltaEmpty(TypedDict):
pass pass
class ChatCompletionChunkDelta(TypedDict): class ChatCompletionChunkDelta(TypedDict):
role: NotRequired[Literal["assistant"]] role: NotRequired[Literal["assistant"]]
content: NotRequired[str] content: NotRequired[str]

View file

@ -38,11 +38,13 @@ class Settings(BaseSettings):
default=None, default=None,
description="Split layers across multiple GPUs in proportion.", description="Split layers across multiple GPUs in proportion.",
) )
rope_freq_base: float = Field(default=10000, ge=1, description="RoPE base frequency") rope_freq_base: float = Field(
rope_freq_scale: float = Field(default=1.0, description="RoPE frequency scaling factor") default=10000, ge=1, description="RoPE base frequency"
seed: int = Field(
default=1337, description="Random seed. -1 for random."
) )
rope_freq_scale: float = Field(
default=1.0, description="RoPE frequency scaling factor"
)
seed: int = Field(default=1337, description="Random seed. -1 for random.")
n_batch: int = Field( n_batch: int = Field(
default=512, ge=1, description="The batch size to use per eval." default=512, ge=1, description="The batch size to use per eval."
) )
@ -186,7 +188,9 @@ def get_settings():
yield settings yield settings
model_field = Field(description="The model to use for generating completions.", default=None) model_field = Field(
description="The model to use for generating completions.", default=None
)
max_tokens_field = Field( max_tokens_field = Field(
default=16, ge=1, le=2048, description="The maximum number of tokens to generate." default=16, ge=1, le=2048, description="The maximum number of tokens to generate."
@ -373,9 +377,11 @@ async def create_completion(
kwargs = body.model_dump(exclude=exclude) kwargs = body.model_dump(exclude=exclude)
if body.logit_bias is not None: if body.logit_bias is not None:
kwargs['logits_processor'] = llama_cpp.LogitsProcessorList([ kwargs["logits_processor"] = llama_cpp.LogitsProcessorList(
make_logit_bias_processor(llama, body.logit_bias, body.logit_bias_type), [
]) make_logit_bias_processor(llama, body.logit_bias, body.logit_bias_type),
]
)
if body.stream: if body.stream:
send_chan, recv_chan = anyio.create_memory_object_stream(10) send_chan, recv_chan = anyio.create_memory_object_stream(10)
@ -402,7 +408,7 @@ async def create_completion(
return EventSourceResponse( return EventSourceResponse(
recv_chan, data_sender_callable=partial(event_publisher, send_chan) recv_chan, data_sender_callable=partial(event_publisher, send_chan)
) # type: ignore ) # type: ignore
else: else:
completion: llama_cpp.Completion = await run_in_threadpool(llama, **kwargs) # type: ignore completion: llama_cpp.Completion = await run_in_threadpool(llama, **kwargs) # type: ignore
return completion return completion
@ -512,9 +518,11 @@ async def create_chat_completion(
kwargs = body.model_dump(exclude=exclude) kwargs = body.model_dump(exclude=exclude)
if body.logit_bias is not None: if body.logit_bias is not None:
kwargs['logits_processor'] = llama_cpp.LogitsProcessorList([ kwargs["logits_processor"] = llama_cpp.LogitsProcessorList(
make_logit_bias_processor(llama, body.logit_bias, body.logit_bias_type), [
]) make_logit_bias_processor(llama, body.logit_bias, body.logit_bias_type),
]
)
if body.stream: if body.stream:
send_chan, recv_chan = anyio.create_memory_object_stream(10) send_chan, recv_chan = anyio.create_memory_object_stream(10)
@ -542,7 +550,7 @@ async def create_chat_completion(
return EventSourceResponse( return EventSourceResponse(
recv_chan, recv_chan,
data_sender_callable=partial(event_publisher, send_chan), data_sender_callable=partial(event_publisher, send_chan),
) # type: ignore ) # type: ignore
else: else:
completion: llama_cpp.ChatCompletion = await run_in_threadpool( completion: llama_cpp.ChatCompletion = await run_in_threadpool(
llama.create_chat_completion, **kwargs # type: ignore llama.create_chat_completion, **kwargs # type: ignore