From 9b1c9e902c7846a2cf5d88ce65d35a5f9d9c5f3a Mon Sep 17 00:00:00 2001 From: Eric B Date: Mon, 5 Jun 2023 22:37:11 -0400 Subject: [PATCH] Added mirostat support for completions, chat completions API --- llama_cpp/server/app.py | 27 +++++++++++++++++++++++++++ 1 file changed, 27 insertions(+) diff --git a/llama_cpp/server/app.py b/llama_cpp/server/app.py index ea9dec4..23382e1 100644 --- a/llama_cpp/server/app.py +++ b/llama_cpp/server/app.py @@ -191,6 +191,27 @@ frequency_penalty_field = Field( description="Positive values penalize new tokens based on their existing frequency in the text so far, decreasing the model's likelihood to repeat the same line verbatim.", ) +mirostat_mode_field = Field( + default=0, + ge=0, + le=2, + description="Enable Mirostat constant-perplexity algorithm of the specified version (1 or 2; 0 = disabled)" +) + +mirostat_tau_field = Field( + default=5.0, + ge=0.0, + le=10.0, + description="Mirostat target entropy, i.e. the target perplexity - lower values produce focused and coherent text, larger values produce more diverse and less coherent text" +) + +mirostat_eta_field = Field( + default=0.1, + ge=0.001, + le=1.0, + description="Mirostat learning rate" +) + class CreateCompletionRequest(BaseModel): prompt: Union[str, List[str]] = Field( @@ -203,6 +224,9 @@ class CreateCompletionRequest(BaseModel): max_tokens: int = max_tokens_field temperature: float = temperature_field top_p: float = top_p_field + mirostat_mode: int = mirostat_mode_field + mirostat_tau: float = mirostat_tau_field + mirostat_eta: float = mirostat_eta_field echo: bool = Field( default=False, description="Whether to echo the prompt in the generated text. Useful for chatbots.", @@ -332,6 +356,9 @@ class CreateChatCompletionRequest(BaseModel): max_tokens: int = max_tokens_field temperature: float = temperature_field top_p: float = top_p_field + mirostat_mode: int = mirostat_mode_field + mirostat_tau: float = mirostat_tau_field + mirostat_eta: float = mirostat_eta_field stop: Optional[List[str]] = stop_field stream: bool = stream_field presence_penalty: Optional[float] = presence_penalty_field