Add save/load state api for Llama class
This commit is contained in:
parent
c4c332fc51
commit
197cf80601
1 changed files with 39 additions and 4 deletions
|
@ -4,7 +4,7 @@ import uuid
|
||||||
import time
|
import time
|
||||||
import math
|
import math
|
||||||
import multiprocessing
|
import multiprocessing
|
||||||
from typing import List, Optional, Union, Generator, Sequence, Iterator
|
from typing import List, Optional, Union, Generator, Sequence, Iterator, Deque
|
||||||
from collections import deque
|
from collections import deque
|
||||||
|
|
||||||
from . import llama_cpp
|
from . import llama_cpp
|
||||||
|
@ -20,6 +20,18 @@ class LlamaCache:
|
||||||
pass
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
class LlamaState:
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
eval_tokens: Deque[llama_cpp.llama_token],
|
||||||
|
eval_logits: Deque[List[float]],
|
||||||
|
llama_state,
|
||||||
|
):
|
||||||
|
self.eval_tokens = eval_tokens
|
||||||
|
self.eval_logits = eval_logits
|
||||||
|
self.llama_state = llama_state
|
||||||
|
|
||||||
|
|
||||||
class Llama:
|
class Llama:
|
||||||
"""High-level Python wrapper for a llama.cpp model."""
|
"""High-level Python wrapper for a llama.cpp model."""
|
||||||
|
|
||||||
|
@ -85,8 +97,8 @@ class Llama:
|
||||||
|
|
||||||
self.last_n_tokens_size = last_n_tokens_size
|
self.last_n_tokens_size = last_n_tokens_size
|
||||||
self.n_batch = min(n_ctx, n_batch)
|
self.n_batch = min(n_ctx, n_batch)
|
||||||
self.eval_tokens: deque[llama_cpp.llama_token] = deque(maxlen=n_ctx)
|
self.eval_tokens: Deque[llama_cpp.llama_token] = deque(maxlen=n_ctx)
|
||||||
self.eval_logits: deque[List[float]] = deque(maxlen=n_ctx)
|
self.eval_logits: Deque[List[float]] = deque(maxlen=n_ctx)
|
||||||
|
|
||||||
### HACK: This is a hack to work around the fact that the llama.cpp API does not yet support
|
### HACK: This is a hack to work around the fact that the llama.cpp API does not yet support
|
||||||
### saving and restoring state, this allows us to continue a completion if the last
|
### saving and restoring state, this allows us to continue a completion if the last
|
||||||
|
@ -204,7 +216,10 @@ class Llama:
|
||||||
cols = int(n_vocab)
|
cols = int(n_vocab)
|
||||||
rows = n_tokens
|
rows = n_tokens
|
||||||
logits_view = llama_cpp.llama_get_logits(self.ctx)
|
logits_view = llama_cpp.llama_get_logits(self.ctx)
|
||||||
logits = [[logits_view[i * cols + j] for j in range(cols)] for i in range(rows)]
|
logits = [
|
||||||
|
[logits_view[i * cols + j] for j in range(cols)]
|
||||||
|
for i in range(rows)
|
||||||
|
]
|
||||||
self.eval_logits.extend(logits)
|
self.eval_logits.extend(logits)
|
||||||
|
|
||||||
def sample(
|
def sample(
|
||||||
|
@ -828,6 +843,26 @@ class Llama:
|
||||||
verbose=state["verbose"],
|
verbose=state["verbose"],
|
||||||
)
|
)
|
||||||
|
|
||||||
|
def save_state(self) -> LlamaState:
|
||||||
|
assert self.ctx is not None
|
||||||
|
state_size = llama_cpp.llama_get_state_size(self.ctx)
|
||||||
|
llama_state = (llama_cpp.c_uint8 * int(state_size))()
|
||||||
|
if llama_cpp.llama_copy_state_data(self.ctx, llama_state) != state_size:
|
||||||
|
raise RuntimeError("Failed to copy llama state data")
|
||||||
|
return LlamaState(
|
||||||
|
eval_tokens=self.eval_tokens.copy(),
|
||||||
|
eval_logits=self.eval_logits.copy(),
|
||||||
|
llama_state=llama_state,
|
||||||
|
)
|
||||||
|
|
||||||
|
def load_state(self, state: LlamaState) -> None:
|
||||||
|
assert self.ctx is not None
|
||||||
|
self.eval_tokens = state.eval_tokens.copy()
|
||||||
|
self.eval_logits = state.eval_logits.copy()
|
||||||
|
state_size = llama_cpp.llama_get_state_size(self.ctx)
|
||||||
|
if llama_cpp.llama_set_state_data(self.ctx, state.llama_state) != state_size:
|
||||||
|
raise RuntimeError("Failed to set llama state data")
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def token_eos() -> llama_cpp.llama_token:
|
def token_eos() -> llama_cpp.llama_token:
|
||||||
"""Return the end-of-sequence token."""
|
"""Return the end-of-sequence token."""
|
||||||
|
|
Loading…
Reference in a new issue