From 86aeb9f3a14808575d2bb0076e6acb4a30907e6a Mon Sep 17 00:00:00 2001 From: Andrei Betlen Date: Tue, 7 Nov 2023 23:37:28 -0500 Subject: [PATCH] Add seed parameter support for completion and chat_completion requests. Closes #884 --- llama_cpp/llama.py | 10 ++++++++++ llama_cpp/server/app.py | 2 ++ 2 files changed, 12 insertions(+) diff --git a/llama_cpp/llama.py b/llama_cpp/llama.py index 7a2c34f..b1ba2a0 100644 --- a/llama_cpp/llama.py +++ b/llama_cpp/llama.py @@ -1292,6 +1292,7 @@ class Llama: repeat_penalty: float = 1.1, top_k: int = 40, stream: bool = False, + seed: Optional[int] = None, tfs_z: float = 1.0, mirostat_mode: int = 0, mirostat_tau: float = 5.0, @@ -1367,6 +1368,9 @@ class Llama: except KeyError: if self.verbose: print("Llama._create_completion: cache miss", file=sys.stderr) + + if seed is not None: + self._ctx.set_rng_seed(seed) finish_reason = "length" multibyte_fix = 0 @@ -1750,6 +1754,7 @@ class Llama: repeat_penalty: float = 1.1, top_k: int = 40, stream: bool = False, + seed: Optional[int] = None, tfs_z: float = 1.0, mirostat_mode: int = 0, mirostat_tau: float = 5.0, @@ -1795,6 +1800,7 @@ class Llama: repeat_penalty=repeat_penalty, top_k=top_k, stream=stream, + seed=seed, tfs_z=tfs_z, mirostat_mode=mirostat_mode, mirostat_tau=mirostat_tau, @@ -1825,6 +1831,7 @@ class Llama: repeat_penalty: float = 1.1, top_k: int = 40, stream: bool = False, + seed: Optional[int] = None, tfs_z: float = 1.0, mirostat_mode: int = 0, mirostat_tau: float = 5.0, @@ -1870,6 +1877,7 @@ class Llama: repeat_penalty=repeat_penalty, top_k=top_k, stream=stream, + seed=seed, tfs_z=tfs_z, mirostat_mode=mirostat_mode, mirostat_tau=mirostat_tau, @@ -1892,6 +1900,7 @@ class Llama: top_k: int = 40, stream: bool = False, stop: Optional[Union[str, List[str]]] = [], + seed: Optional[int] = None, max_tokens: int = 256, presence_penalty: float = 0.0, frequency_penalty: float = 0.0, @@ -1936,6 +1945,7 @@ class Llama: top_k=top_k, stream=stream, stop=stop, + seed=seed, max_tokens=max_tokens, presence_penalty=presence_penalty, frequency_penalty=frequency_penalty, diff --git a/llama_cpp/server/app.py b/llama_cpp/server/app.py index 8ebc427..7caee89 100644 --- a/llama_cpp/server/app.py +++ b/llama_cpp/server/app.py @@ -608,6 +608,7 @@ class CreateCompletionRequest(BaseModel): frequency_penalty: Optional[float] = frequency_penalty_field logit_bias: Optional[Dict[str, float]] = Field(None) logprobs: Optional[int] = Field(None) + seed: Optional[int] = Field(None) # ignored or currently unsupported model: Optional[str] = model_field @@ -790,6 +791,7 @@ class CreateChatCompletionRequest(BaseModel): presence_penalty: Optional[float] = presence_penalty_field frequency_penalty: Optional[float] = frequency_penalty_field logit_bias: Optional[Dict[str, float]] = Field(None) + seed: Optional[int] = Field(None) # ignored or currently unsupported model: Optional[str] = model_field