Add save/load state api for Llama class

This commit is contained in:
Andrei Betlen 2023-04-24 17:51:25 -04:00
parent c4c332fc51
commit 197cf80601

View file

@ -4,7 +4,7 @@ import uuid
import time
import math
import multiprocessing
from typing import List, Optional, Union, Generator, Sequence, Iterator
from typing import List, Optional, Union, Generator, Sequence, Iterator, Deque
from collections import deque
from . import llama_cpp
@ -20,6 +20,18 @@ class LlamaCache:
pass
class LlamaState:
def __init__(
self,
eval_tokens: Deque[llama_cpp.llama_token],
eval_logits: Deque[List[float]],
llama_state,
):
self.eval_tokens = eval_tokens
self.eval_logits = eval_logits
self.llama_state = llama_state
class Llama:
"""High-level Python wrapper for a llama.cpp model."""
@ -85,8 +97,8 @@ class Llama:
self.last_n_tokens_size = last_n_tokens_size
self.n_batch = min(n_ctx, n_batch)
self.eval_tokens: deque[llama_cpp.llama_token] = deque(maxlen=n_ctx)
self.eval_logits: deque[List[float]] = deque(maxlen=n_ctx)
self.eval_tokens: Deque[llama_cpp.llama_token] = deque(maxlen=n_ctx)
self.eval_logits: Deque[List[float]] = deque(maxlen=n_ctx)
### HACK: This is a hack to work around the fact that the llama.cpp API does not yet support
### saving and restoring state, this allows us to continue a completion if the last
@ -204,7 +216,10 @@ class Llama:
cols = int(n_vocab)
rows = n_tokens
logits_view = llama_cpp.llama_get_logits(self.ctx)
logits = [[logits_view[i * cols + j] for j in range(cols)] for i in range(rows)]
logits = [
[logits_view[i * cols + j] for j in range(cols)]
for i in range(rows)
]
self.eval_logits.extend(logits)
def sample(
@ -828,6 +843,26 @@ class Llama:
verbose=state["verbose"],
)
def save_state(self) -> LlamaState:
assert self.ctx is not None
state_size = llama_cpp.llama_get_state_size(self.ctx)
llama_state = (llama_cpp.c_uint8 * int(state_size))()
if llama_cpp.llama_copy_state_data(self.ctx, llama_state) != state_size:
raise RuntimeError("Failed to copy llama state data")
return LlamaState(
eval_tokens=self.eval_tokens.copy(),
eval_logits=self.eval_logits.copy(),
llama_state=llama_state,
)
def load_state(self, state: LlamaState) -> None:
assert self.ctx is not None
self.eval_tokens = state.eval_tokens.copy()
self.eval_logits = state.eval_logits.copy()
state_size = llama_cpp.llama_get_state_size(self.ctx)
if llama_cpp.llama_set_state_data(self.ctx, state.llama_state) != state_size:
raise RuntimeError("Failed to set llama state data")
@staticmethod
def token_eos() -> llama_cpp.llama_token:
"""Return the end-of-sequence token."""