Merge pull request #370 from Okabintaro/fix-state-pickle
fix: Make LLamaState pickleable for disk cache
This commit is contained in:
commit
628e3fb3df
1 changed files with 9 additions and 4 deletions
|
@ -141,7 +141,9 @@ class LlamaDiskCache(BaseLlamaCache):
|
||||||
if _key is None:
|
if _key is None:
|
||||||
raise KeyError("Key not found")
|
raise KeyError("Key not found")
|
||||||
value: "LlamaState" = self.cache.pop(_key) # type: ignore
|
value: "LlamaState" = self.cache.pop(_key) # type: ignore
|
||||||
self.cache.push(_key, side="front") # type: ignore
|
# NOTE: This puts an integer as key in cache, which breaks,
|
||||||
|
# Llama.longest_token_prefix(k, key) above since k is not a tuple of ints/tokens
|
||||||
|
# self.cache.push(_key, side="front") # type: ignore
|
||||||
return value
|
return value
|
||||||
|
|
||||||
def __contains__(self, key: Sequence[int]) -> bool:
|
def __contains__(self, key: Sequence[int]) -> bool:
|
||||||
|
@ -168,7 +170,7 @@ class LlamaState:
|
||||||
eval_logits: Deque[List[float]],
|
eval_logits: Deque[List[float]],
|
||||||
input_ids: npt.NDArray[np.intc],
|
input_ids: npt.NDArray[np.intc],
|
||||||
scores: npt.NDArray[np.single],
|
scores: npt.NDArray[np.single],
|
||||||
llama_state, # type: llama_cpp.Array[llama_cpp.c_uint8]
|
llama_state: bytes,
|
||||||
llama_state_size: int,
|
llama_state_size: int,
|
||||||
):
|
):
|
||||||
self.eval_tokens = eval_tokens
|
self.eval_tokens = eval_tokens
|
||||||
|
@ -1509,7 +1511,7 @@ class Llama:
|
||||||
eval_logits=self.eval_logits.copy(),
|
eval_logits=self.eval_logits.copy(),
|
||||||
scores=self._scores.copy(),
|
scores=self._scores.copy(),
|
||||||
input_ids=self._input_ids.copy(),
|
input_ids=self._input_ids.copy(),
|
||||||
llama_state=llama_state_compact,
|
llama_state=bytes(llama_state_compact),
|
||||||
llama_state_size=n_bytes,
|
llama_state_size=n_bytes,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@ -1520,7 +1522,10 @@ class Llama:
|
||||||
self._scores = state.scores.copy()
|
self._scores = state.scores.copy()
|
||||||
self._input_ids = state.input_ids.copy()
|
self._input_ids = state.input_ids.copy()
|
||||||
state_size = state.llama_state_size
|
state_size = state.llama_state_size
|
||||||
if llama_cpp.llama_set_state_data(self.ctx, state.llama_state) != state_size:
|
LLamaStateArrayType = (llama_cpp.c_uint8 * state_size)
|
||||||
|
llama_state = LLamaStateArrayType.from_buffer_copy(state.llama_state)
|
||||||
|
|
||||||
|
if llama_cpp.llama_set_state_data(self.ctx, llama_state) != state_size:
|
||||||
raise RuntimeError("Failed to set llama state data")
|
raise RuntimeError("Failed to set llama state data")
|
||||||
|
|
||||||
def n_ctx(self) -> int:
|
def n_ctx(self) -> int:
|
||||||
|
|
Loading…
Reference in a new issue