diff --git a/llama_cpp/llama.py b/llama_cpp/llama.py index d367601..29f468e 100644 --- a/llama_cpp/llama.py +++ b/llama_cpp/llama.py @@ -141,7 +141,9 @@ class LlamaDiskCache(BaseLlamaCache): if _key is None: raise KeyError("Key not found") value: "LlamaState" = self.cache.pop(_key) # type: ignore - self.cache.push(_key, side="front") # type: ignore + # NOTE: This puts an integer as key in cache, which breaks, + # Llama.longest_token_prefix(k, key) above since k is not a tuple of ints/tokens + # self.cache.push(_key, side="front") # type: ignore return value def __contains__(self, key: Sequence[int]) -> bool: @@ -168,7 +170,7 @@ class LlamaState: eval_logits: Deque[List[float]], input_ids: npt.NDArray[np.intc], scores: npt.NDArray[np.single], - llama_state, # type: llama_cpp.Array[llama_cpp.c_uint8] + llama_state: bytes, llama_state_size: int, ): self.eval_tokens = eval_tokens @@ -1509,7 +1511,7 @@ class Llama: eval_logits=self.eval_logits.copy(), scores=self._scores.copy(), input_ids=self._input_ids.copy(), - llama_state=llama_state_compact, + llama_state=bytes(llama_state_compact), llama_state_size=n_bytes, ) @@ -1520,7 +1522,10 @@ class Llama: self._scores = state.scores.copy() self._input_ids = state.input_ids.copy() state_size = state.llama_state_size - if llama_cpp.llama_set_state_data(self.ctx, state.llama_state) != state_size: + LLamaStateArrayType = (llama_cpp.c_uint8 * state_size) + llama_state = LLamaStateArrayType.from_buffer_copy(state.llama_state) + + if llama_cpp.llama_set_state_data(self.ctx, llama_state) != state_size: raise RuntimeError("Failed to set llama state data") def n_ctx(self) -> int: