diff --git a/llama_cpp/llama.py b/llama_cpp/llama.py index f7a6e9e..c857bbe 100644 --- a/llama_cpp/llama.py +++ b/llama_cpp/llama.py @@ -4,7 +4,7 @@ import uuid import time import math import multiprocessing -from typing import List, Optional, Union, Generator, Sequence, Iterator +from typing import List, Optional, Union, Generator, Sequence, Iterator, Deque from collections import deque from . import llama_cpp @@ -20,6 +20,18 @@ class LlamaCache: pass +class LlamaState: + def __init__( + self, + eval_tokens: Deque[llama_cpp.llama_token], + eval_logits: Deque[List[float]], + llama_state, + ): + self.eval_tokens = eval_tokens + self.eval_logits = eval_logits + self.llama_state = llama_state + + class Llama: """High-level Python wrapper for a llama.cpp model.""" @@ -85,8 +97,8 @@ class Llama: self.last_n_tokens_size = last_n_tokens_size self.n_batch = min(n_ctx, n_batch) - self.eval_tokens: deque[llama_cpp.llama_token] = deque(maxlen=n_ctx) - self.eval_logits: deque[List[float]] = deque(maxlen=n_ctx) + self.eval_tokens: Deque[llama_cpp.llama_token] = deque(maxlen=n_ctx) + self.eval_logits: Deque[List[float]] = deque(maxlen=n_ctx) ### HACK: This is a hack to work around the fact that the llama.cpp API does not yet support ### saving and restoring state, this allows us to continue a completion if the last @@ -204,7 +216,10 @@ class Llama: cols = int(n_vocab) rows = n_tokens logits_view = llama_cpp.llama_get_logits(self.ctx) - logits = [[logits_view[i * cols + j] for j in range(cols)] for i in range(rows)] + logits = [ + [logits_view[i * cols + j] for j in range(cols)] + for i in range(rows) + ] self.eval_logits.extend(logits) def sample( @@ -828,6 +843,26 @@ class Llama: verbose=state["verbose"], ) + def save_state(self) -> LlamaState: + assert self.ctx is not None + state_size = llama_cpp.llama_get_state_size(self.ctx) + llama_state = (llama_cpp.c_uint8 * int(state_size))() + if llama_cpp.llama_copy_state_data(self.ctx, llama_state) != state_size: + raise RuntimeError("Failed to copy llama state data") + return LlamaState( + eval_tokens=self.eval_tokens.copy(), + eval_logits=self.eval_logits.copy(), + llama_state=llama_state, + ) + + def load_state(self, state: LlamaState) -> None: + assert self.ctx is not None + self.eval_tokens = state.eval_tokens.copy() + self.eval_logits = state.eval_logits.copy() + state_size = llama_cpp.llama_get_state_size(self.ctx) + if llama_cpp.llama_set_state_data(self.ctx, state.llama_state) != state_size: + raise RuntimeError("Failed to set llama state data") + @staticmethod def token_eos() -> llama_cpp.llama_token: """Return the end-of-sequence token."""