Add set_seed to Llama class
This commit is contained in:
parent
ca4cb88351
commit
fd41ed3a90
1 changed files with 18 additions and 5 deletions
|
@ -998,6 +998,15 @@ class Llama:
|
||||||
"""
|
"""
|
||||||
self.cache = cache
|
self.cache = cache
|
||||||
|
|
||||||
|
def set_seed(self, seed: int):
|
||||||
|
"""Set the random seed.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
seed: The random seed.
|
||||||
|
"""
|
||||||
|
assert self._ctx.ctx is not None
|
||||||
|
llama_cpp.llama_set_rng_seed(self._ctx.ctx, seed)
|
||||||
|
|
||||||
def reset(self):
|
def reset(self):
|
||||||
"""Reset the model state."""
|
"""Reset the model state."""
|
||||||
self.n_tokens = 0
|
self.n_tokens = 0
|
||||||
|
@ -1318,10 +1327,14 @@ class Llama:
|
||||||
completion_tokens: List[int] = []
|
completion_tokens: List[int] = []
|
||||||
# Add blank space to start of prompt to match OG llama tokenizer
|
# Add blank space to start of prompt to match OG llama tokenizer
|
||||||
prompt_tokens: List[int] = (
|
prompt_tokens: List[int] = (
|
||||||
|
(
|
||||||
self.tokenize(prompt.encode("utf-8"), special=True)
|
self.tokenize(prompt.encode("utf-8"), special=True)
|
||||||
if prompt != ""
|
if prompt != ""
|
||||||
else [self.token_bos()]
|
else [self.token_bos()]
|
||||||
) if isinstance(prompt, str) else prompt
|
)
|
||||||
|
if isinstance(prompt, str)
|
||||||
|
else prompt
|
||||||
|
)
|
||||||
text: bytes = b""
|
text: bytes = b""
|
||||||
returned_tokens: int = 0
|
returned_tokens: int = 0
|
||||||
stop = (
|
stop = (
|
||||||
|
|
Loading…
Reference in a new issue