Add seed parameter support for completion and chat_completion requests. Closes #884
This commit is contained in:
parent
da1b80285a
commit
86aeb9f3a1
2 changed files with 12 additions and 0 deletions
|
@ -1292,6 +1292,7 @@ class Llama:
|
||||||
repeat_penalty: float = 1.1,
|
repeat_penalty: float = 1.1,
|
||||||
top_k: int = 40,
|
top_k: int = 40,
|
||||||
stream: bool = False,
|
stream: bool = False,
|
||||||
|
seed: Optional[int] = None,
|
||||||
tfs_z: float = 1.0,
|
tfs_z: float = 1.0,
|
||||||
mirostat_mode: int = 0,
|
mirostat_mode: int = 0,
|
||||||
mirostat_tau: float = 5.0,
|
mirostat_tau: float = 5.0,
|
||||||
|
@ -1367,6 +1368,9 @@ class Llama:
|
||||||
except KeyError:
|
except KeyError:
|
||||||
if self.verbose:
|
if self.verbose:
|
||||||
print("Llama._create_completion: cache miss", file=sys.stderr)
|
print("Llama._create_completion: cache miss", file=sys.stderr)
|
||||||
|
|
||||||
|
if seed is not None:
|
||||||
|
self._ctx.set_rng_seed(seed)
|
||||||
|
|
||||||
finish_reason = "length"
|
finish_reason = "length"
|
||||||
multibyte_fix = 0
|
multibyte_fix = 0
|
||||||
|
@ -1750,6 +1754,7 @@ class Llama:
|
||||||
repeat_penalty: float = 1.1,
|
repeat_penalty: float = 1.1,
|
||||||
top_k: int = 40,
|
top_k: int = 40,
|
||||||
stream: bool = False,
|
stream: bool = False,
|
||||||
|
seed: Optional[int] = None,
|
||||||
tfs_z: float = 1.0,
|
tfs_z: float = 1.0,
|
||||||
mirostat_mode: int = 0,
|
mirostat_mode: int = 0,
|
||||||
mirostat_tau: float = 5.0,
|
mirostat_tau: float = 5.0,
|
||||||
|
@ -1795,6 +1800,7 @@ class Llama:
|
||||||
repeat_penalty=repeat_penalty,
|
repeat_penalty=repeat_penalty,
|
||||||
top_k=top_k,
|
top_k=top_k,
|
||||||
stream=stream,
|
stream=stream,
|
||||||
|
seed=seed,
|
||||||
tfs_z=tfs_z,
|
tfs_z=tfs_z,
|
||||||
mirostat_mode=mirostat_mode,
|
mirostat_mode=mirostat_mode,
|
||||||
mirostat_tau=mirostat_tau,
|
mirostat_tau=mirostat_tau,
|
||||||
|
@ -1825,6 +1831,7 @@ class Llama:
|
||||||
repeat_penalty: float = 1.1,
|
repeat_penalty: float = 1.1,
|
||||||
top_k: int = 40,
|
top_k: int = 40,
|
||||||
stream: bool = False,
|
stream: bool = False,
|
||||||
|
seed: Optional[int] = None,
|
||||||
tfs_z: float = 1.0,
|
tfs_z: float = 1.0,
|
||||||
mirostat_mode: int = 0,
|
mirostat_mode: int = 0,
|
||||||
mirostat_tau: float = 5.0,
|
mirostat_tau: float = 5.0,
|
||||||
|
@ -1870,6 +1877,7 @@ class Llama:
|
||||||
repeat_penalty=repeat_penalty,
|
repeat_penalty=repeat_penalty,
|
||||||
top_k=top_k,
|
top_k=top_k,
|
||||||
stream=stream,
|
stream=stream,
|
||||||
|
seed=seed,
|
||||||
tfs_z=tfs_z,
|
tfs_z=tfs_z,
|
||||||
mirostat_mode=mirostat_mode,
|
mirostat_mode=mirostat_mode,
|
||||||
mirostat_tau=mirostat_tau,
|
mirostat_tau=mirostat_tau,
|
||||||
|
@ -1892,6 +1900,7 @@ class Llama:
|
||||||
top_k: int = 40,
|
top_k: int = 40,
|
||||||
stream: bool = False,
|
stream: bool = False,
|
||||||
stop: Optional[Union[str, List[str]]] = [],
|
stop: Optional[Union[str, List[str]]] = [],
|
||||||
|
seed: Optional[int] = None,
|
||||||
max_tokens: int = 256,
|
max_tokens: int = 256,
|
||||||
presence_penalty: float = 0.0,
|
presence_penalty: float = 0.0,
|
||||||
frequency_penalty: float = 0.0,
|
frequency_penalty: float = 0.0,
|
||||||
|
@ -1936,6 +1945,7 @@ class Llama:
|
||||||
top_k=top_k,
|
top_k=top_k,
|
||||||
stream=stream,
|
stream=stream,
|
||||||
stop=stop,
|
stop=stop,
|
||||||
|
seed=seed,
|
||||||
max_tokens=max_tokens,
|
max_tokens=max_tokens,
|
||||||
presence_penalty=presence_penalty,
|
presence_penalty=presence_penalty,
|
||||||
frequency_penalty=frequency_penalty,
|
frequency_penalty=frequency_penalty,
|
||||||
|
|
|
@ -608,6 +608,7 @@ class CreateCompletionRequest(BaseModel):
|
||||||
frequency_penalty: Optional[float] = frequency_penalty_field
|
frequency_penalty: Optional[float] = frequency_penalty_field
|
||||||
logit_bias: Optional[Dict[str, float]] = Field(None)
|
logit_bias: Optional[Dict[str, float]] = Field(None)
|
||||||
logprobs: Optional[int] = Field(None)
|
logprobs: Optional[int] = Field(None)
|
||||||
|
seed: Optional[int] = Field(None)
|
||||||
|
|
||||||
# ignored or currently unsupported
|
# ignored or currently unsupported
|
||||||
model: Optional[str] = model_field
|
model: Optional[str] = model_field
|
||||||
|
@ -790,6 +791,7 @@ class CreateChatCompletionRequest(BaseModel):
|
||||||
presence_penalty: Optional[float] = presence_penalty_field
|
presence_penalty: Optional[float] = presence_penalty_field
|
||||||
frequency_penalty: Optional[float] = frequency_penalty_field
|
frequency_penalty: Optional[float] = frequency_penalty_field
|
||||||
logit_bias: Optional[Dict[str, float]] = Field(None)
|
logit_bias: Optional[Dict[str, float]] = Field(None)
|
||||||
|
seed: Optional[int] = Field(None)
|
||||||
|
|
||||||
# ignored or currently unsupported
|
# ignored or currently unsupported
|
||||||
model: Optional[str] = model_field
|
model: Optional[str] = model_field
|
||||||
|
|
Loading…
Reference in a new issue