Add save/load state api for Llama class
This commit is contained in:
parent
c4c332fc51
commit
197cf80601
1 changed files with 39 additions and 4 deletions
|
@ -4,7 +4,7 @@ import uuid
|
|||
import time
|
||||
import math
|
||||
import multiprocessing
|
||||
from typing import List, Optional, Union, Generator, Sequence, Iterator
|
||||
from typing import List, Optional, Union, Generator, Sequence, Iterator, Deque
|
||||
from collections import deque
|
||||
|
||||
from . import llama_cpp
|
||||
|
@ -20,6 +20,18 @@ class LlamaCache:
|
|||
pass
|
||||
|
||||
|
||||
class LlamaState:
|
||||
def __init__(
|
||||
self,
|
||||
eval_tokens: Deque[llama_cpp.llama_token],
|
||||
eval_logits: Deque[List[float]],
|
||||
llama_state,
|
||||
):
|
||||
self.eval_tokens = eval_tokens
|
||||
self.eval_logits = eval_logits
|
||||
self.llama_state = llama_state
|
||||
|
||||
|
||||
class Llama:
|
||||
"""High-level Python wrapper for a llama.cpp model."""
|
||||
|
||||
|
@ -85,8 +97,8 @@ class Llama:
|
|||
|
||||
self.last_n_tokens_size = last_n_tokens_size
|
||||
self.n_batch = min(n_ctx, n_batch)
|
||||
self.eval_tokens: deque[llama_cpp.llama_token] = deque(maxlen=n_ctx)
|
||||
self.eval_logits: deque[List[float]] = deque(maxlen=n_ctx)
|
||||
self.eval_tokens: Deque[llama_cpp.llama_token] = deque(maxlen=n_ctx)
|
||||
self.eval_logits: Deque[List[float]] = deque(maxlen=n_ctx)
|
||||
|
||||
### HACK: This is a hack to work around the fact that the llama.cpp API does not yet support
|
||||
### saving and restoring state, this allows us to continue a completion if the last
|
||||
|
@ -204,7 +216,10 @@ class Llama:
|
|||
cols = int(n_vocab)
|
||||
rows = n_tokens
|
||||
logits_view = llama_cpp.llama_get_logits(self.ctx)
|
||||
logits = [[logits_view[i * cols + j] for j in range(cols)] for i in range(rows)]
|
||||
logits = [
|
||||
[logits_view[i * cols + j] for j in range(cols)]
|
||||
for i in range(rows)
|
||||
]
|
||||
self.eval_logits.extend(logits)
|
||||
|
||||
def sample(
|
||||
|
@ -828,6 +843,26 @@ class Llama:
|
|||
verbose=state["verbose"],
|
||||
)
|
||||
|
||||
def save_state(self) -> LlamaState:
|
||||
assert self.ctx is not None
|
||||
state_size = llama_cpp.llama_get_state_size(self.ctx)
|
||||
llama_state = (llama_cpp.c_uint8 * int(state_size))()
|
||||
if llama_cpp.llama_copy_state_data(self.ctx, llama_state) != state_size:
|
||||
raise RuntimeError("Failed to copy llama state data")
|
||||
return LlamaState(
|
||||
eval_tokens=self.eval_tokens.copy(),
|
||||
eval_logits=self.eval_logits.copy(),
|
||||
llama_state=llama_state,
|
||||
)
|
||||
|
||||
def load_state(self, state: LlamaState) -> None:
|
||||
assert self.ctx is not None
|
||||
self.eval_tokens = state.eval_tokens.copy()
|
||||
self.eval_logits = state.eval_logits.copy()
|
||||
state_size = llama_cpp.llama_get_state_size(self.ctx)
|
||||
if llama_cpp.llama_set_state_data(self.ctx, state.llama_state) != state_size:
|
||||
raise RuntimeError("Failed to set llama state data")
|
||||
|
||||
@staticmethod
|
||||
def token_eos() -> llama_cpp.llama_token:
|
||||
"""Return the end-of-sequence token."""
|
||||
|
|
Loading…
Reference in a new issue