Support smaller state sizes

This commit is contained in:
Andrei Betlen 2023-05-03 09:33:50 -04:00
parent 1d47cce222
commit 43f2907e3a

View file

@ -53,12 +53,14 @@ class LlamaState:
def __init__( def __init__(
self, self,
eval_tokens: Deque[llama_cpp.llama_token], eval_tokens: Deque[llama_cpp.llama_token],
eval_logits: Deque[List[float]], eval_logits: Deque[List[llama_cpp.c_float]],
llama_state, llama_state,
llama_state_size: llama_cpp.c_size_t,
): ):
self.eval_tokens = eval_tokens self.eval_tokens = eval_tokens
self.eval_logits = eval_logits self.eval_logits = eval_logits
self.llama_state = llama_state self.llama_state = llama_state
self.llama_state_size = llama_state_size
class Llama: class Llama:
@ -950,19 +952,23 @@ class Llama:
assert self.ctx is not None assert self.ctx is not None
state_size = llama_cpp.llama_get_state_size(self.ctx) state_size = llama_cpp.llama_get_state_size(self.ctx)
llama_state = (llama_cpp.c_uint8 * int(state_size))() llama_state = (llama_cpp.c_uint8 * int(state_size))()
if llama_cpp.llama_copy_state_data(self.ctx, llama_state) != state_size: n_bytes = llama_cpp.llama_copy_state_data(self.ctx, llama_state)
if int(n_bytes) > int(state_size):
raise RuntimeError("Failed to copy llama state data") raise RuntimeError("Failed to copy llama state data")
llama_state_compact = (llama_cpp.c_uint8 * int(n_bytes))()
llama_cpp.ctypes.memmove(llama_state_compact, llama_state, int(n_bytes))
return LlamaState( return LlamaState(
eval_tokens=self.eval_tokens.copy(), eval_tokens=self.eval_tokens.copy(),
eval_logits=self.eval_logits.copy(), eval_logits=self.eval_logits.copy(),
llama_state=llama_state, llama_state=llama_state_compact,
llama_state_size=n_bytes,
) )
def load_state(self, state: LlamaState) -> None: def load_state(self, state: LlamaState) -> None:
assert self.ctx is not None assert self.ctx is not None
self.eval_tokens = state.eval_tokens.copy() self.eval_tokens = state.eval_tokens.copy()
self.eval_logits = state.eval_logits.copy() self.eval_logits = state.eval_logits.copy()
state_size = llama_cpp.llama_get_state_size(self.ctx) state_size = state.llama_state_size
if llama_cpp.llama_set_state_data(self.ctx, state.llama_state) != state_size: if llama_cpp.llama_set_state_data(self.ctx, state.llama_state) != state_size:
raise RuntimeError("Failed to set llama state data") raise RuntimeError("Failed to set llama state data")