Merge pull request #370 from Okabintaro/fix-state-pickle

fix: Make LLamaState pickleable for disk cache
This commit is contained in:
Andrei 2023-06-26 08:46:59 -04:00 committed by GitHub
commit 628e3fb3df
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23

View file

@ -141,7 +141,9 @@ class LlamaDiskCache(BaseLlamaCache):
if _key is None:
raise KeyError("Key not found")
value: "LlamaState" = self.cache.pop(_key) # type: ignore
self.cache.push(_key, side="front") # type: ignore
# NOTE: This puts an integer as key in cache, which breaks,
# Llama.longest_token_prefix(k, key) above since k is not a tuple of ints/tokens
# self.cache.push(_key, side="front") # type: ignore
return value
def __contains__(self, key: Sequence[int]) -> bool:
@ -168,7 +170,7 @@ class LlamaState:
eval_logits: Deque[List[float]],
input_ids: npt.NDArray[np.intc],
scores: npt.NDArray[np.single],
llama_state, # type: llama_cpp.Array[llama_cpp.c_uint8]
llama_state: bytes,
llama_state_size: int,
):
self.eval_tokens = eval_tokens
@ -1509,7 +1511,7 @@ class Llama:
eval_logits=self.eval_logits.copy(),
scores=self._scores.copy(),
input_ids=self._input_ids.copy(),
llama_state=llama_state_compact,
llama_state=bytes(llama_state_compact),
llama_state_size=n_bytes,
)
@ -1520,7 +1522,10 @@ class Llama:
self._scores = state.scores.copy()
self._input_ids = state.input_ids.copy()
state_size = state.llama_state_size
if llama_cpp.llama_set_state_data(self.ctx, state.llama_state) != state_size:
LLamaStateArrayType = (llama_cpp.c_uint8 * state_size)
llama_state = LLamaStateArrayType.from_buffer_copy(state.llama_state)
if llama_cpp.llama_set_state_data(self.ctx, llama_state) != state_size:
raise RuntimeError("Failed to set llama state data")
def n_ctx(self) -> int: