diff --git a/llama_cpp/llama.py b/llama_cpp/llama.py index dfac9bb..5a0111b 100644 --- a/llama_cpp/llama.py +++ b/llama_cpp/llama.py @@ -18,6 +18,7 @@ from typing import ( Iterator, Deque, Callable, + Dict, ) from collections import deque from pathlib import Path @@ -1791,7 +1792,7 @@ class Llama: file=sys.stderr, ) return LlamaState( - scores=self.scores.copy(), + scores=self._scores.copy(), input_ids=self.input_ids.copy(), n_tokens=self.n_tokens, llama_state=bytes(llama_state_compact), @@ -1800,7 +1801,9 @@ class Llama: def load_state(self, state: LlamaState) -> None: assert self._ctx.ctx is not None - self.scores = state.scores.copy() + # Only filling in up to `n_tokens` and then zero-ing out the rest + self.scores[: state.n_tokens, :] = state.scores.copy() + self.scores[state.n_tokens :, :] = 0.0 self.input_ids = state.input_ids.copy() self.n_tokens = state.n_tokens state_size = state.llama_state_size @@ -1951,7 +1954,6 @@ class Llama: local_dir_use_symlinks=local_dir_use_symlinks, cache_dir=cache_dir, local_files_only=True, - ) else: model_path = os.path.join(local_dir, filename)