Support smaller state sizes
This commit is contained in:
parent
1d47cce222
commit
43f2907e3a
1 changed files with 10 additions and 4 deletions
|
@ -53,12 +53,14 @@ class LlamaState:
|
|||
def __init__(
|
||||
self,
|
||||
eval_tokens: Deque[llama_cpp.llama_token],
|
||||
eval_logits: Deque[List[float]],
|
||||
eval_logits: Deque[List[llama_cpp.c_float]],
|
||||
llama_state,
|
||||
llama_state_size: llama_cpp.c_size_t,
|
||||
):
|
||||
self.eval_tokens = eval_tokens
|
||||
self.eval_logits = eval_logits
|
||||
self.llama_state = llama_state
|
||||
self.llama_state_size = llama_state_size
|
||||
|
||||
|
||||
class Llama:
|
||||
|
@ -950,19 +952,23 @@ class Llama:
|
|||
assert self.ctx is not None
|
||||
state_size = llama_cpp.llama_get_state_size(self.ctx)
|
||||
llama_state = (llama_cpp.c_uint8 * int(state_size))()
|
||||
if llama_cpp.llama_copy_state_data(self.ctx, llama_state) != state_size:
|
||||
n_bytes = llama_cpp.llama_copy_state_data(self.ctx, llama_state)
|
||||
if int(n_bytes) > int(state_size):
|
||||
raise RuntimeError("Failed to copy llama state data")
|
||||
llama_state_compact = (llama_cpp.c_uint8 * int(n_bytes))()
|
||||
llama_cpp.ctypes.memmove(llama_state_compact, llama_state, int(n_bytes))
|
||||
return LlamaState(
|
||||
eval_tokens=self.eval_tokens.copy(),
|
||||
eval_logits=self.eval_logits.copy(),
|
||||
llama_state=llama_state,
|
||||
llama_state=llama_state_compact,
|
||||
llama_state_size=n_bytes,
|
||||
)
|
||||
|
||||
def load_state(self, state: LlamaState) -> None:
|
||||
assert self.ctx is not None
|
||||
self.eval_tokens = state.eval_tokens.copy()
|
||||
self.eval_logits = state.eval_logits.copy()
|
||||
state_size = llama_cpp.llama_get_state_size(self.ctx)
|
||||
state_size = state.llama_state_size
|
||||
if llama_cpp.llama_set_state_data(self.ctx, state.llama_state) != state_size:
|
||||
raise RuntimeError("Failed to set llama state data")
|
||||
|
||||
|
|
Loading…
Reference in a new issue