Add save/load state api for Llama class

This commit is contained in:
Andrei Betlen 2023-04-24 17:51:25 -04:00
parent c4c332fc51
commit 197cf80601

View file

@ -4,7 +4,7 @@ import uuid
import time import time
import math import math
import multiprocessing import multiprocessing
from typing import List, Optional, Union, Generator, Sequence, Iterator from typing import List, Optional, Union, Generator, Sequence, Iterator, Deque
from collections import deque from collections import deque
from . import llama_cpp from . import llama_cpp
@ -20,6 +20,18 @@ class LlamaCache:
pass pass
class LlamaState:
def __init__(
self,
eval_tokens: Deque[llama_cpp.llama_token],
eval_logits: Deque[List[float]],
llama_state,
):
self.eval_tokens = eval_tokens
self.eval_logits = eval_logits
self.llama_state = llama_state
class Llama: class Llama:
"""High-level Python wrapper for a llama.cpp model.""" """High-level Python wrapper for a llama.cpp model."""
@ -85,8 +97,8 @@ class Llama:
self.last_n_tokens_size = last_n_tokens_size self.last_n_tokens_size = last_n_tokens_size
self.n_batch = min(n_ctx, n_batch) self.n_batch = min(n_ctx, n_batch)
self.eval_tokens: deque[llama_cpp.llama_token] = deque(maxlen=n_ctx) self.eval_tokens: Deque[llama_cpp.llama_token] = deque(maxlen=n_ctx)
self.eval_logits: deque[List[float]] = deque(maxlen=n_ctx) self.eval_logits: Deque[List[float]] = deque(maxlen=n_ctx)
### HACK: This is a hack to work around the fact that the llama.cpp API does not yet support ### HACK: This is a hack to work around the fact that the llama.cpp API does not yet support
### saving and restoring state, this allows us to continue a completion if the last ### saving and restoring state, this allows us to continue a completion if the last
@ -204,7 +216,10 @@ class Llama:
cols = int(n_vocab) cols = int(n_vocab)
rows = n_tokens rows = n_tokens
logits_view = llama_cpp.llama_get_logits(self.ctx) logits_view = llama_cpp.llama_get_logits(self.ctx)
logits = [[logits_view[i * cols + j] for j in range(cols)] for i in range(rows)] logits = [
[logits_view[i * cols + j] for j in range(cols)]
for i in range(rows)
]
self.eval_logits.extend(logits) self.eval_logits.extend(logits)
def sample( def sample(
@ -828,6 +843,26 @@ class Llama:
verbose=state["verbose"], verbose=state["verbose"],
) )
def save_state(self) -> LlamaState:
assert self.ctx is not None
state_size = llama_cpp.llama_get_state_size(self.ctx)
llama_state = (llama_cpp.c_uint8 * int(state_size))()
if llama_cpp.llama_copy_state_data(self.ctx, llama_state) != state_size:
raise RuntimeError("Failed to copy llama state data")
return LlamaState(
eval_tokens=self.eval_tokens.copy(),
eval_logits=self.eval_logits.copy(),
llama_state=llama_state,
)
def load_state(self, state: LlamaState) -> None:
assert self.ctx is not None
self.eval_tokens = state.eval_tokens.copy()
self.eval_logits = state.eval_logits.copy()
state_size = llama_cpp.llama_get_state_size(self.ctx)
if llama_cpp.llama_set_state_data(self.ctx, state.llama_state) != state_size:
raise RuntimeError("Failed to set llama state data")
@staticmethod @staticmethod
def token_eos() -> llama_cpp.llama_token: def token_eos() -> llama_cpp.llama_token:
"""Return the end-of-sequence token.""" """Return the end-of-sequence token."""