Add set_seed to Llama class
This commit is contained in:
parent
ca4cb88351
commit
fd41ed3a90
1 changed files with 18 additions and 5 deletions
|
@ -998,6 +998,15 @@ class Llama:
|
|||
"""
|
||||
self.cache = cache
|
||||
|
||||
def set_seed(self, seed: int):
|
||||
"""Set the random seed.
|
||||
|
||||
Args:
|
||||
seed: The random seed.
|
||||
"""
|
||||
assert self._ctx.ctx is not None
|
||||
llama_cpp.llama_set_rng_seed(self._ctx.ctx, seed)
|
||||
|
||||
def reset(self):
|
||||
"""Reset the model state."""
|
||||
self.n_tokens = 0
|
||||
|
@ -1318,10 +1327,14 @@ class Llama:
|
|||
completion_tokens: List[int] = []
|
||||
# Add blank space to start of prompt to match OG llama tokenizer
|
||||
prompt_tokens: List[int] = (
|
||||
self.tokenize(prompt.encode("utf-8"), special=True)
|
||||
if prompt != ""
|
||||
else [self.token_bos()]
|
||||
) if isinstance(prompt, str) else prompt
|
||||
(
|
||||
self.tokenize(prompt.encode("utf-8"), special=True)
|
||||
if prompt != ""
|
||||
else [self.token_bos()]
|
||||
)
|
||||
if isinstance(prompt, str)
|
||||
else prompt
|
||||
)
|
||||
text: bytes = b""
|
||||
returned_tokens: int = 0
|
||||
stop = (
|
||||
|
@ -1374,7 +1387,7 @@ class Llama:
|
|||
except KeyError:
|
||||
if self.verbose:
|
||||
print("Llama._create_completion: cache miss", file=sys.stderr)
|
||||
|
||||
|
||||
if seed is not None:
|
||||
self._ctx.set_rng_seed(seed)
|
||||
|
||||
|
|
Loading…
Reference in a new issue