From 10b0cb727b249050e16edc8c14c1526cbea0e500 Mon Sep 17 00:00:00 2001 From: Okabintaro <103938900+Okabintaro@users.noreply.github.com> Date: Tue, 13 Jun 2023 12:03:31 +0200 Subject: [PATCH] fix: Make LLamaState pickable for disk cache I fixed the issue by making the saved state a bytes object instead of the ctypes one which can't be pickled. --- llama_cpp/llama.py | 13 +++++++++---- 1 file changed, 9 insertions(+), 4 deletions(-) diff --git a/llama_cpp/llama.py b/llama_cpp/llama.py index 4b6ce8c..0c3d72b 100644 --- a/llama_cpp/llama.py +++ b/llama_cpp/llama.py @@ -141,7 +141,9 @@ class LlamaDiskCache(BaseLlamaCache): if _key is None: raise KeyError("Key not found") value: "LlamaState" = self.cache.pop(_key) # type: ignore - self.cache.push(_key, side="front") # type: ignore + # NOTE: This puts an integer as key in cache, which breaks, + # Llama.longest_token_prefix(k, key) above since k is not a tuple of ints/tokens + # self.cache.push(_key, side="front") # type: ignore return value def __contains__(self, key: Sequence[int]) -> bool: @@ -168,7 +170,7 @@ class LlamaState: eval_logits: Deque[List[float]], input_ids: npt.NDArray[np.intc], scores: npt.NDArray[np.single], - llama_state, # type: llama_cpp.Array[llama_cpp.c_uint8] + llama_state: bytes, llama_state_size: int, ): self.eval_tokens = eval_tokens @@ -1503,7 +1505,7 @@ class Llama: eval_logits=self.eval_logits.copy(), scores=self._scores.copy(), input_ids=self._input_ids.copy(), - llama_state=llama_state_compact, + llama_state=bytes(llama_state_compact), llama_state_size=n_bytes, ) @@ -1514,7 +1516,10 @@ class Llama: self._scores = state.scores.copy() self._input_ids = state.input_ids.copy() state_size = state.llama_state_size - if llama_cpp.llama_set_state_data(self.ctx, state.llama_state) != state_size: + LLamaStateArrayType = (llama_cpp.c_uint8 * state_size) + llama_state = LLamaStateArrayType.from_buffer_copy(state.llama_state) + + if llama_cpp.llama_set_state_data(self.ctx, llama_state) != state_size: raise RuntimeError("Failed to set llama state data") def n_ctx(self) -> int: