Pass-Through grammar parameter in web server. (#855) Closes #778

This commit is contained in:
Daniel Thuerck 2023-11-01 23:51:12 +01:00 committed by GitHub
parent 25cb710281
commit 5f8f369d1b
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23

View file

@ -518,6 +518,10 @@ mirostat_eta_field = Field(
default=0.1, ge=0.001, le=1.0, description="Mirostat learning rate"
)
grammar = Field(
default=None,
description="A CBNF grammar (as string) to be used for formatting the model's output."
)
class CreateCompletionRequest(BaseModel):
prompt: Union[str, List[str]] = Field(
@ -533,6 +537,7 @@ class CreateCompletionRequest(BaseModel):
mirostat_mode: int = mirostat_mode_field
mirostat_tau: float = mirostat_tau_field
mirostat_eta: float = mirostat_eta_field
grammar: Optional[str] = None
echo: bool = Field(
default=False,
description="Whether to echo the prompt in the generated text. Useful for chatbots.",
@ -634,6 +639,9 @@ async def create_completion(
]
)
if body.grammar is not None:
kwargs["grammar"] = llama_cpp.LlamaGrammar.from_string(body.grammar)
iterator_or_completion: Union[
llama_cpp.Completion, Iterator[llama_cpp.CompletionChunk]
] = await run_in_threadpool(llama, **kwargs)
@ -714,6 +722,7 @@ class CreateChatCompletionRequest(BaseModel):
mirostat_mode: int = mirostat_mode_field
mirostat_tau: float = mirostat_tau_field
mirostat_eta: float = mirostat_eta_field
grammar: Optional[str] = None
stop: Optional[List[str]] = stop_field
stream: bool = stream_field
presence_penalty: Optional[float] = presence_penalty_field
@ -772,6 +781,9 @@ async def create_chat_completion(
]
)
if body.grammar is not None:
kwargs["grammar"] = llama_cpp.LlamaGrammar.from_string(body.grammar)
iterator_or_completion: Union[
llama_cpp.ChatCompletion, Iterator[llama_cpp.ChatCompletionChunk]
] = await run_in_threadpool(llama.create_chat_completion, **kwargs)