parent
25cb710281
commit
5f8f369d1b
1 changed files with 12 additions and 0 deletions
|
@ -518,6 +518,10 @@ mirostat_eta_field = Field(
|
||||||
default=0.1, ge=0.001, le=1.0, description="Mirostat learning rate"
|
default=0.1, ge=0.001, le=1.0, description="Mirostat learning rate"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
grammar = Field(
|
||||||
|
default=None,
|
||||||
|
description="A CBNF grammar (as string) to be used for formatting the model's output."
|
||||||
|
)
|
||||||
|
|
||||||
class CreateCompletionRequest(BaseModel):
|
class CreateCompletionRequest(BaseModel):
|
||||||
prompt: Union[str, List[str]] = Field(
|
prompt: Union[str, List[str]] = Field(
|
||||||
|
@ -533,6 +537,7 @@ class CreateCompletionRequest(BaseModel):
|
||||||
mirostat_mode: int = mirostat_mode_field
|
mirostat_mode: int = mirostat_mode_field
|
||||||
mirostat_tau: float = mirostat_tau_field
|
mirostat_tau: float = mirostat_tau_field
|
||||||
mirostat_eta: float = mirostat_eta_field
|
mirostat_eta: float = mirostat_eta_field
|
||||||
|
grammar: Optional[str] = None
|
||||||
echo: bool = Field(
|
echo: bool = Field(
|
||||||
default=False,
|
default=False,
|
||||||
description="Whether to echo the prompt in the generated text. Useful for chatbots.",
|
description="Whether to echo the prompt in the generated text. Useful for chatbots.",
|
||||||
|
@ -634,6 +639,9 @@ async def create_completion(
|
||||||
]
|
]
|
||||||
)
|
)
|
||||||
|
|
||||||
|
if body.grammar is not None:
|
||||||
|
kwargs["grammar"] = llama_cpp.LlamaGrammar.from_string(body.grammar)
|
||||||
|
|
||||||
iterator_or_completion: Union[
|
iterator_or_completion: Union[
|
||||||
llama_cpp.Completion, Iterator[llama_cpp.CompletionChunk]
|
llama_cpp.Completion, Iterator[llama_cpp.CompletionChunk]
|
||||||
] = await run_in_threadpool(llama, **kwargs)
|
] = await run_in_threadpool(llama, **kwargs)
|
||||||
|
@ -714,6 +722,7 @@ class CreateChatCompletionRequest(BaseModel):
|
||||||
mirostat_mode: int = mirostat_mode_field
|
mirostat_mode: int = mirostat_mode_field
|
||||||
mirostat_tau: float = mirostat_tau_field
|
mirostat_tau: float = mirostat_tau_field
|
||||||
mirostat_eta: float = mirostat_eta_field
|
mirostat_eta: float = mirostat_eta_field
|
||||||
|
grammar: Optional[str] = None
|
||||||
stop: Optional[List[str]] = stop_field
|
stop: Optional[List[str]] = stop_field
|
||||||
stream: bool = stream_field
|
stream: bool = stream_field
|
||||||
presence_penalty: Optional[float] = presence_penalty_field
|
presence_penalty: Optional[float] = presence_penalty_field
|
||||||
|
@ -772,6 +781,9 @@ async def create_chat_completion(
|
||||||
]
|
]
|
||||||
)
|
)
|
||||||
|
|
||||||
|
if body.grammar is not None:
|
||||||
|
kwargs["grammar"] = llama_cpp.LlamaGrammar.from_string(body.grammar)
|
||||||
|
|
||||||
iterator_or_completion: Union[
|
iterator_or_completion: Union[
|
||||||
llama_cpp.ChatCompletion, Iterator[llama_cpp.ChatCompletionChunk]
|
llama_cpp.ChatCompletion, Iterator[llama_cpp.ChatCompletionChunk]
|
||||||
] = await run_in_threadpool(llama.create_chat_completion, **kwargs)
|
] = await run_in_threadpool(llama.create_chat_completion, **kwargs)
|
||||||
|
|
Loading…
Reference in a new issue