diff --git a/llama_cpp/server/app.py b/llama_cpp/server/app.py index 3dd0a38..83cde40 100644 --- a/llama_cpp/server/app.py +++ b/llama_cpp/server/app.py @@ -518,6 +518,10 @@ mirostat_eta_field = Field( default=0.1, ge=0.001, le=1.0, description="Mirostat learning rate" ) +grammar = Field( + default=None, + description="A CBNF grammar (as string) to be used for formatting the model's output." +) class CreateCompletionRequest(BaseModel): prompt: Union[str, List[str]] = Field( @@ -533,6 +537,7 @@ class CreateCompletionRequest(BaseModel): mirostat_mode: int = mirostat_mode_field mirostat_tau: float = mirostat_tau_field mirostat_eta: float = mirostat_eta_field + grammar: Optional[str] = None echo: bool = Field( default=False, description="Whether to echo the prompt in the generated text. Useful for chatbots.", @@ -634,6 +639,9 @@ async def create_completion( ] ) + if body.grammar is not None: + kwargs["grammar"] = llama_cpp.LlamaGrammar.from_string(body.grammar) + iterator_or_completion: Union[ llama_cpp.Completion, Iterator[llama_cpp.CompletionChunk] ] = await run_in_threadpool(llama, **kwargs) @@ -714,6 +722,7 @@ class CreateChatCompletionRequest(BaseModel): mirostat_mode: int = mirostat_mode_field mirostat_tau: float = mirostat_tau_field mirostat_eta: float = mirostat_eta_field + grammar: Optional[str] = None stop: Optional[List[str]] = stop_field stream: bool = stream_field presence_penalty: Optional[float] = presence_penalty_field @@ -772,6 +781,9 @@ async def create_chat_completion( ] ) + if body.grammar is not None: + kwargs["grammar"] = llama_cpp.LlamaGrammar.from_string(body.grammar) + iterator_or_completion: Union[ llama_cpp.ChatCompletion, Iterator[llama_cpp.ChatCompletionChunk] ] = await run_in_threadpool(llama.create_chat_completion, **kwargs)