diff --git a/llama_cpp/llama.py b/llama_cpp/llama.py index d201013..1b9f9e9 100644 --- a/llama_cpp/llama.py +++ b/llama_cpp/llama.py @@ -53,12 +53,14 @@ class LlamaState: def __init__( self, eval_tokens: Deque[llama_cpp.llama_token], - eval_logits: Deque[List[float]], + eval_logits: Deque[List[llama_cpp.c_float]], llama_state, + llama_state_size: llama_cpp.c_size_t, ): self.eval_tokens = eval_tokens self.eval_logits = eval_logits self.llama_state = llama_state + self.llama_state_size = llama_state_size class Llama: @@ -950,19 +952,23 @@ class Llama: assert self.ctx is not None state_size = llama_cpp.llama_get_state_size(self.ctx) llama_state = (llama_cpp.c_uint8 * int(state_size))() - if llama_cpp.llama_copy_state_data(self.ctx, llama_state) != state_size: + n_bytes = llama_cpp.llama_copy_state_data(self.ctx, llama_state) + if int(n_bytes) > int(state_size): raise RuntimeError("Failed to copy llama state data") + llama_state_compact = (llama_cpp.c_uint8 * int(n_bytes))() + llama_cpp.ctypes.memmove(llama_state_compact, llama_state, int(n_bytes)) return LlamaState( eval_tokens=self.eval_tokens.copy(), eval_logits=self.eval_logits.copy(), - llama_state=llama_state, + llama_state=llama_state_compact, + llama_state_size=n_bytes, ) def load_state(self, state: LlamaState) -> None: assert self.ctx is not None self.eval_tokens = state.eval_tokens.copy() self.eval_logits = state.eval_logits.copy() - state_size = llama_cpp.llama_get_state_size(self.ctx) + state_size = state.llama_state_size if llama_cpp.llama_set_state_data(self.ctx, state.llama_state) != state_size: raise RuntimeError("Failed to set llama state data")