Support smaller state sizes
This commit is contained in:
parent
1d47cce222
commit
43f2907e3a
1 changed files with 10 additions and 4 deletions
|
@ -53,12 +53,14 @@ class LlamaState:
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
eval_tokens: Deque[llama_cpp.llama_token],
|
eval_tokens: Deque[llama_cpp.llama_token],
|
||||||
eval_logits: Deque[List[float]],
|
eval_logits: Deque[List[llama_cpp.c_float]],
|
||||||
llama_state,
|
llama_state,
|
||||||
|
llama_state_size: llama_cpp.c_size_t,
|
||||||
):
|
):
|
||||||
self.eval_tokens = eval_tokens
|
self.eval_tokens = eval_tokens
|
||||||
self.eval_logits = eval_logits
|
self.eval_logits = eval_logits
|
||||||
self.llama_state = llama_state
|
self.llama_state = llama_state
|
||||||
|
self.llama_state_size = llama_state_size
|
||||||
|
|
||||||
|
|
||||||
class Llama:
|
class Llama:
|
||||||
|
@ -950,19 +952,23 @@ class Llama:
|
||||||
assert self.ctx is not None
|
assert self.ctx is not None
|
||||||
state_size = llama_cpp.llama_get_state_size(self.ctx)
|
state_size = llama_cpp.llama_get_state_size(self.ctx)
|
||||||
llama_state = (llama_cpp.c_uint8 * int(state_size))()
|
llama_state = (llama_cpp.c_uint8 * int(state_size))()
|
||||||
if llama_cpp.llama_copy_state_data(self.ctx, llama_state) != state_size:
|
n_bytes = llama_cpp.llama_copy_state_data(self.ctx, llama_state)
|
||||||
|
if int(n_bytes) > int(state_size):
|
||||||
raise RuntimeError("Failed to copy llama state data")
|
raise RuntimeError("Failed to copy llama state data")
|
||||||
|
llama_state_compact = (llama_cpp.c_uint8 * int(n_bytes))()
|
||||||
|
llama_cpp.ctypes.memmove(llama_state_compact, llama_state, int(n_bytes))
|
||||||
return LlamaState(
|
return LlamaState(
|
||||||
eval_tokens=self.eval_tokens.copy(),
|
eval_tokens=self.eval_tokens.copy(),
|
||||||
eval_logits=self.eval_logits.copy(),
|
eval_logits=self.eval_logits.copy(),
|
||||||
llama_state=llama_state,
|
llama_state=llama_state_compact,
|
||||||
|
llama_state_size=n_bytes,
|
||||||
)
|
)
|
||||||
|
|
||||||
def load_state(self, state: LlamaState) -> None:
|
def load_state(self, state: LlamaState) -> None:
|
||||||
assert self.ctx is not None
|
assert self.ctx is not None
|
||||||
self.eval_tokens = state.eval_tokens.copy()
|
self.eval_tokens = state.eval_tokens.copy()
|
||||||
self.eval_logits = state.eval_logits.copy()
|
self.eval_logits = state.eval_logits.copy()
|
||||||
state_size = llama_cpp.llama_get_state_size(self.ctx)
|
state_size = state.llama_state_size
|
||||||
if llama_cpp.llama_set_state_data(self.ctx, state.llama_state) != state_size:
|
if llama_cpp.llama_set_state_data(self.ctx, state.llama_state) != state_size:
|
||||||
raise RuntimeError("Failed to set llama state data")
|
raise RuntimeError("Failed to set llama state data")
|
||||||
|
|
||||||
|
|
Loading…
Reference in a new issue