Add seed parameter support for completion and chat_completion requests. Closes #884

This commit is contained in:
Andrei Betlen 2023-11-07 23:37:28 -05:00
parent da1b80285a
commit 86aeb9f3a1
2 changed files with 12 additions and 0 deletions

View file

@ -1292,6 +1292,7 @@ class Llama:
repeat_penalty: float = 1.1,
top_k: int = 40,
stream: bool = False,
seed: Optional[int] = None,
tfs_z: float = 1.0,
mirostat_mode: int = 0,
mirostat_tau: float = 5.0,
@ -1367,6 +1368,9 @@ class Llama:
except KeyError:
if self.verbose:
print("Llama._create_completion: cache miss", file=sys.stderr)
if seed is not None:
self._ctx.set_rng_seed(seed)
finish_reason = "length"
multibyte_fix = 0
@ -1750,6 +1754,7 @@ class Llama:
repeat_penalty: float = 1.1,
top_k: int = 40,
stream: bool = False,
seed: Optional[int] = None,
tfs_z: float = 1.0,
mirostat_mode: int = 0,
mirostat_tau: float = 5.0,
@ -1795,6 +1800,7 @@ class Llama:
repeat_penalty=repeat_penalty,
top_k=top_k,
stream=stream,
seed=seed,
tfs_z=tfs_z,
mirostat_mode=mirostat_mode,
mirostat_tau=mirostat_tau,
@ -1825,6 +1831,7 @@ class Llama:
repeat_penalty: float = 1.1,
top_k: int = 40,
stream: bool = False,
seed: Optional[int] = None,
tfs_z: float = 1.0,
mirostat_mode: int = 0,
mirostat_tau: float = 5.0,
@ -1870,6 +1877,7 @@ class Llama:
repeat_penalty=repeat_penalty,
top_k=top_k,
stream=stream,
seed=seed,
tfs_z=tfs_z,
mirostat_mode=mirostat_mode,
mirostat_tau=mirostat_tau,
@ -1892,6 +1900,7 @@ class Llama:
top_k: int = 40,
stream: bool = False,
stop: Optional[Union[str, List[str]]] = [],
seed: Optional[int] = None,
max_tokens: int = 256,
presence_penalty: float = 0.0,
frequency_penalty: float = 0.0,
@ -1936,6 +1945,7 @@ class Llama:
top_k=top_k,
stream=stream,
stop=stop,
seed=seed,
max_tokens=max_tokens,
presence_penalty=presence_penalty,
frequency_penalty=frequency_penalty,

View file

@ -608,6 +608,7 @@ class CreateCompletionRequest(BaseModel):
frequency_penalty: Optional[float] = frequency_penalty_field
logit_bias: Optional[Dict[str, float]] = Field(None)
logprobs: Optional[int] = Field(None)
seed: Optional[int] = Field(None)
# ignored or currently unsupported
model: Optional[str] = model_field
@ -790,6 +791,7 @@ class CreateChatCompletionRequest(BaseModel):
presence_penalty: Optional[float] = presence_penalty_field
frequency_penalty: Optional[float] = frequency_penalty_field
logit_bias: Optional[Dict[str, float]] = Field(None)
seed: Optional[int] = Field(None)
# ignored or currently unsupported
model: Optional[str] = model_field