Add seed parameter support for completion and chat_completion requests. Closes #884
This commit is contained in:
parent
da1b80285a
commit
86aeb9f3a1
2 changed files with 12 additions and 0 deletions
|
@ -1292,6 +1292,7 @@ class Llama:
|
|||
repeat_penalty: float = 1.1,
|
||||
top_k: int = 40,
|
||||
stream: bool = False,
|
||||
seed: Optional[int] = None,
|
||||
tfs_z: float = 1.0,
|
||||
mirostat_mode: int = 0,
|
||||
mirostat_tau: float = 5.0,
|
||||
|
@ -1367,6 +1368,9 @@ class Llama:
|
|||
except KeyError:
|
||||
if self.verbose:
|
||||
print("Llama._create_completion: cache miss", file=sys.stderr)
|
||||
|
||||
if seed is not None:
|
||||
self._ctx.set_rng_seed(seed)
|
||||
|
||||
finish_reason = "length"
|
||||
multibyte_fix = 0
|
||||
|
@ -1750,6 +1754,7 @@ class Llama:
|
|||
repeat_penalty: float = 1.1,
|
||||
top_k: int = 40,
|
||||
stream: bool = False,
|
||||
seed: Optional[int] = None,
|
||||
tfs_z: float = 1.0,
|
||||
mirostat_mode: int = 0,
|
||||
mirostat_tau: float = 5.0,
|
||||
|
@ -1795,6 +1800,7 @@ class Llama:
|
|||
repeat_penalty=repeat_penalty,
|
||||
top_k=top_k,
|
||||
stream=stream,
|
||||
seed=seed,
|
||||
tfs_z=tfs_z,
|
||||
mirostat_mode=mirostat_mode,
|
||||
mirostat_tau=mirostat_tau,
|
||||
|
@ -1825,6 +1831,7 @@ class Llama:
|
|||
repeat_penalty: float = 1.1,
|
||||
top_k: int = 40,
|
||||
stream: bool = False,
|
||||
seed: Optional[int] = None,
|
||||
tfs_z: float = 1.0,
|
||||
mirostat_mode: int = 0,
|
||||
mirostat_tau: float = 5.0,
|
||||
|
@ -1870,6 +1877,7 @@ class Llama:
|
|||
repeat_penalty=repeat_penalty,
|
||||
top_k=top_k,
|
||||
stream=stream,
|
||||
seed=seed,
|
||||
tfs_z=tfs_z,
|
||||
mirostat_mode=mirostat_mode,
|
||||
mirostat_tau=mirostat_tau,
|
||||
|
@ -1892,6 +1900,7 @@ class Llama:
|
|||
top_k: int = 40,
|
||||
stream: bool = False,
|
||||
stop: Optional[Union[str, List[str]]] = [],
|
||||
seed: Optional[int] = None,
|
||||
max_tokens: int = 256,
|
||||
presence_penalty: float = 0.0,
|
||||
frequency_penalty: float = 0.0,
|
||||
|
@ -1936,6 +1945,7 @@ class Llama:
|
|||
top_k=top_k,
|
||||
stream=stream,
|
||||
stop=stop,
|
||||
seed=seed,
|
||||
max_tokens=max_tokens,
|
||||
presence_penalty=presence_penalty,
|
||||
frequency_penalty=frequency_penalty,
|
||||
|
|
|
@ -608,6 +608,7 @@ class CreateCompletionRequest(BaseModel):
|
|||
frequency_penalty: Optional[float] = frequency_penalty_field
|
||||
logit_bias: Optional[Dict[str, float]] = Field(None)
|
||||
logprobs: Optional[int] = Field(None)
|
||||
seed: Optional[int] = Field(None)
|
||||
|
||||
# ignored or currently unsupported
|
||||
model: Optional[str] = model_field
|
||||
|
@ -790,6 +791,7 @@ class CreateChatCompletionRequest(BaseModel):
|
|||
presence_penalty: Optional[float] = presence_penalty_field
|
||||
frequency_penalty: Optional[float] = frequency_penalty_field
|
||||
logit_bias: Optional[Dict[str, float]] = Field(None)
|
||||
seed: Optional[int] = Field(None)
|
||||
|
||||
# ignored or currently unsupported
|
||||
model: Optional[str] = model_field
|
||||
|
|
Loading…
Reference in a new issue