Merge pull request #328 from spirilis/mirostat
Added mirostat support for completions, chat completions API
This commit is contained in:
commit
d508573fb4
1 changed files with 27 additions and 0 deletions
|
@ -191,6 +191,27 @@ frequency_penalty_field = Field(
|
||||||
description="Positive values penalize new tokens based on their existing frequency in the text so far, decreasing the model's likelihood to repeat the same line verbatim.",
|
description="Positive values penalize new tokens based on their existing frequency in the text so far, decreasing the model's likelihood to repeat the same line verbatim.",
|
||||||
)
|
)
|
||||||
|
|
||||||
|
mirostat_mode_field = Field(
|
||||||
|
default=0,
|
||||||
|
ge=0,
|
||||||
|
le=2,
|
||||||
|
description="Enable Mirostat constant-perplexity algorithm of the specified version (1 or 2; 0 = disabled)"
|
||||||
|
)
|
||||||
|
|
||||||
|
mirostat_tau_field = Field(
|
||||||
|
default=5.0,
|
||||||
|
ge=0.0,
|
||||||
|
le=10.0,
|
||||||
|
description="Mirostat target entropy, i.e. the target perplexity - lower values produce focused and coherent text, larger values produce more diverse and less coherent text"
|
||||||
|
)
|
||||||
|
|
||||||
|
mirostat_eta_field = Field(
|
||||||
|
default=0.1,
|
||||||
|
ge=0.001,
|
||||||
|
le=1.0,
|
||||||
|
description="Mirostat learning rate"
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
class CreateCompletionRequest(BaseModel):
|
class CreateCompletionRequest(BaseModel):
|
||||||
prompt: Union[str, List[str]] = Field(
|
prompt: Union[str, List[str]] = Field(
|
||||||
|
@ -203,6 +224,9 @@ class CreateCompletionRequest(BaseModel):
|
||||||
max_tokens: int = max_tokens_field
|
max_tokens: int = max_tokens_field
|
||||||
temperature: float = temperature_field
|
temperature: float = temperature_field
|
||||||
top_p: float = top_p_field
|
top_p: float = top_p_field
|
||||||
|
mirostat_mode: int = mirostat_mode_field
|
||||||
|
mirostat_tau: float = mirostat_tau_field
|
||||||
|
mirostat_eta: float = mirostat_eta_field
|
||||||
echo: bool = Field(
|
echo: bool = Field(
|
||||||
default=False,
|
default=False,
|
||||||
description="Whether to echo the prompt in the generated text. Useful for chatbots.",
|
description="Whether to echo the prompt in the generated text. Useful for chatbots.",
|
||||||
|
@ -332,6 +356,9 @@ class CreateChatCompletionRequest(BaseModel):
|
||||||
max_tokens: int = max_tokens_field
|
max_tokens: int = max_tokens_field
|
||||||
temperature: float = temperature_field
|
temperature: float = temperature_field
|
||||||
top_p: float = top_p_field
|
top_p: float = top_p_field
|
||||||
|
mirostat_mode: int = mirostat_mode_field
|
||||||
|
mirostat_tau: float = mirostat_tau_field
|
||||||
|
mirostat_eta: float = mirostat_eta_field
|
||||||
stop: Optional[List[str]] = stop_field
|
stop: Optional[List[str]] = stop_field
|
||||||
stream: bool = stream_field
|
stream: bool = stream_field
|
||||||
presence_penalty: Optional[float] = presence_penalty_field
|
presence_penalty: Optional[float] = presence_penalty_field
|
||||||
|
|
Loading…
Reference in a new issue