Add set_seed to Llama class

This commit is contained in:
Andrei Betlen 2023-11-08 11:09:41 -05:00
parent ca4cb88351
commit fd41ed3a90

View file

@ -998,6 +998,15 @@ class Llama:
"""
self.cache = cache
def set_seed(self, seed: int):
"""Set the random seed.
Args:
seed: The random seed.
"""
assert self._ctx.ctx is not None
llama_cpp.llama_set_rng_seed(self._ctx.ctx, seed)
def reset(self):
"""Reset the model state."""
self.n_tokens = 0
@ -1318,10 +1327,14 @@ class Llama:
completion_tokens: List[int] = []
# Add blank space to start of prompt to match OG llama tokenizer
prompt_tokens: List[int] = (
self.tokenize(prompt.encode("utf-8"), special=True)
if prompt != ""
else [self.token_bos()]
) if isinstance(prompt, str) else prompt
(
self.tokenize(prompt.encode("utf-8"), special=True)
if prompt != ""
else [self.token_bos()]
)
if isinstance(prompt, str)
else prompt
)
text: bytes = b""
returned_tokens: int = 0
stop = (
@ -1374,7 +1387,7 @@ class Llama:
except KeyError:
if self.verbose:
print("Llama._create_completion: cache miss", file=sys.stderr)
if seed is not None:
self._ctx.set_rng_seed(seed)