Merge pull request #370 from Okabintaro/fix-state-pickle
fix: Make LLamaState pickleable for disk cache
This commit is contained in:
commit
628e3fb3df
1 changed files with 9 additions and 4 deletions
|
@ -141,7 +141,9 @@ class LlamaDiskCache(BaseLlamaCache):
|
|||
if _key is None:
|
||||
raise KeyError("Key not found")
|
||||
value: "LlamaState" = self.cache.pop(_key) # type: ignore
|
||||
self.cache.push(_key, side="front") # type: ignore
|
||||
# NOTE: This puts an integer as key in cache, which breaks,
|
||||
# Llama.longest_token_prefix(k, key) above since k is not a tuple of ints/tokens
|
||||
# self.cache.push(_key, side="front") # type: ignore
|
||||
return value
|
||||
|
||||
def __contains__(self, key: Sequence[int]) -> bool:
|
||||
|
@ -168,7 +170,7 @@ class LlamaState:
|
|||
eval_logits: Deque[List[float]],
|
||||
input_ids: npt.NDArray[np.intc],
|
||||
scores: npt.NDArray[np.single],
|
||||
llama_state, # type: llama_cpp.Array[llama_cpp.c_uint8]
|
||||
llama_state: bytes,
|
||||
llama_state_size: int,
|
||||
):
|
||||
self.eval_tokens = eval_tokens
|
||||
|
@ -1509,7 +1511,7 @@ class Llama:
|
|||
eval_logits=self.eval_logits.copy(),
|
||||
scores=self._scores.copy(),
|
||||
input_ids=self._input_ids.copy(),
|
||||
llama_state=llama_state_compact,
|
||||
llama_state=bytes(llama_state_compact),
|
||||
llama_state_size=n_bytes,
|
||||
)
|
||||
|
||||
|
@ -1520,7 +1522,10 @@ class Llama:
|
|||
self._scores = state.scores.copy()
|
||||
self._input_ids = state.input_ids.copy()
|
||||
state_size = state.llama_state_size
|
||||
if llama_cpp.llama_set_state_data(self.ctx, state.llama_state) != state_size:
|
||||
LLamaStateArrayType = (llama_cpp.c_uint8 * state_size)
|
||||
llama_state = LLamaStateArrayType.from_buffer_copy(state.llama_state)
|
||||
|
||||
if llama_cpp.llama_set_state_data(self.ctx, llama_state) != state_size:
|
||||
raise RuntimeError("Failed to set llama state data")
|
||||
|
||||
def n_ctx(self) -> int:
|
||||
|
|
Loading…
Reference in a new issue