Use pre-allocated buffers to store input_ids and scores
This commit is contained in:
parent
a5e059c053
commit
b95b0ffbeb
1 changed files with 44 additions and 42 deletions
|
@ -166,17 +166,15 @@ class LlamaDiskCache(BaseLlamaCache):
|
|||
class LlamaState:
|
||||
def __init__(
|
||||
self,
|
||||
eval_tokens: Deque[int],
|
||||
eval_logits: Deque[List[float]],
|
||||
input_ids: npt.NDArray[np.intc],
|
||||
scores: npt.NDArray[np.single],
|
||||
n_tokens: int,
|
||||
llama_state: bytes,
|
||||
llama_state_size: int,
|
||||
):
|
||||
self.eval_tokens = eval_tokens
|
||||
self.eval_logits = eval_logits
|
||||
self.input_ids = input_ids
|
||||
self.scores = scores
|
||||
self.n_tokens = n_tokens
|
||||
self.llama_state = llama_state
|
||||
self.llama_state_size = llama_state_size
|
||||
|
||||
|
@ -267,8 +265,6 @@ class Llama:
|
|||
|
||||
self.last_n_tokens_size = last_n_tokens_size
|
||||
self.n_batch = min(n_ctx, n_batch)
|
||||
self.eval_tokens: Deque[int] = deque(maxlen=n_ctx)
|
||||
self.eval_logits: Deque[List[float]] = deque(maxlen=n_ctx if logits_all else 1)
|
||||
|
||||
self.cache: Optional[BaseLlamaCache] = None
|
||||
|
||||
|
@ -329,8 +325,30 @@ class Llama:
|
|||
self._token_nl = Llama.token_nl()
|
||||
self._token_eos = Llama.token_eos()
|
||||
|
||||
self._input_ids = np.array([], dtype=np.intc)
|
||||
self._scores: npt.NDArray[np.single] = np.ndarray((0, self._n_vocab), dtype=np.single)
|
||||
self.n_tokens = 0
|
||||
self.input_ids: npt.NDArray[np.intc] = np.ndarray((n_ctx,), dtype=np.intc)
|
||||
self.scores: npt.NDArray[np.single] = np.ndarray(
|
||||
(n_ctx, self._n_vocab), dtype=np.single
|
||||
)
|
||||
|
||||
@property
|
||||
def _input_ids(self) -> npt.NDArray[np.intc]:
|
||||
return self.input_ids[: self.n_tokens]
|
||||
|
||||
@property
|
||||
def _scores(self) -> npt.NDArray[np.single]:
|
||||
return self.scores[: self.n_tokens, :]
|
||||
|
||||
@property
|
||||
def eval_tokens(self) -> Deque[int]:
|
||||
return deque(self.input_ids[: self.n_tokens].tolist(), maxlen=self._n_ctx)
|
||||
|
||||
@property
|
||||
def eval_logits(self) -> Deque[List[float]]:
|
||||
return deque(
|
||||
self.scores[: self.n_tokens, :].tolist(),
|
||||
maxlen=self._n_ctx if self.params.logits_all else 1,
|
||||
)
|
||||
|
||||
def tokenize(self, text: bytes, add_bos: bool = True) -> List[int]:
|
||||
"""Tokenize a string.
|
||||
|
@ -397,10 +415,7 @@ class Llama:
|
|||
|
||||
def reset(self):
|
||||
"""Reset the model state."""
|
||||
self.eval_tokens.clear()
|
||||
self.eval_logits.clear()
|
||||
self._input_ids = np.array([], dtype=np.intc)
|
||||
self._scores = np.ndarray((0, self._n_vocab), dtype=np.single)
|
||||
self.n_tokens = 0
|
||||
|
||||
def eval(self, tokens: Sequence[int]):
|
||||
"""Evaluate a list of tokens.
|
||||
|
@ -410,7 +425,6 @@ class Llama:
|
|||
"""
|
||||
assert self.ctx is not None
|
||||
n_ctx = self._n_ctx
|
||||
scores: List[npt.NDArray[np.single]] = []
|
||||
for i in range(0, len(tokens), self.n_batch):
|
||||
batch = tokens[i : min(len(tokens), i + self.n_batch)]
|
||||
n_past = min(n_ctx - len(batch), len(self._input_ids))
|
||||
|
@ -425,19 +439,16 @@ class Llama:
|
|||
if return_code != 0:
|
||||
raise RuntimeError(f"llama_eval returned {return_code}")
|
||||
# Save tokens
|
||||
self.eval_tokens.extend(batch)
|
||||
self._input_ids: npt.NDArray[np.intc] = np.concatenate(
|
||||
(self._input_ids, np.array(batch, dtype=np.intc)), axis=0
|
||||
)
|
||||
self.input_ids[self.n_tokens : self.n_tokens + n_tokens] = batch
|
||||
# Save logits
|
||||
rows = n_tokens if self.params.logits_all else 1
|
||||
n_vocab = self._n_vocab
|
||||
cols = n_vocab
|
||||
logits_view = llama_cpp.llama_get_logits(self.ctx)
|
||||
logits = [logits_view[i * cols : (i + 1) * cols] for i in range(rows)]
|
||||
self.eval_logits.extend(logits)
|
||||
scores.append(np.array(logits, dtype=np.single))
|
||||
self._scores = np.concatenate(scores)
|
||||
self.scores[self.n_tokens : self.n_tokens + n_tokens, :] = logits
|
||||
# Update n_tokens
|
||||
self.n_tokens += n_tokens
|
||||
|
||||
def _sample(
|
||||
self,
|
||||
|
@ -457,8 +468,7 @@ class Llama:
|
|||
logits_processor: Optional[LogitsProcessorList] = None,
|
||||
):
|
||||
assert self.ctx is not None
|
||||
assert len(self.eval_logits) > 0
|
||||
assert self._scores.shape[0] > 0
|
||||
assert self.n_tokens > 0
|
||||
n_vocab = self._n_vocab
|
||||
n_ctx = self._n_ctx
|
||||
top_k = llama_cpp.c_int(n_vocab) if top_k.value <= 0 else top_k
|
||||
|
@ -475,7 +485,6 @@ class Llama:
|
|||
dtype=np.single,
|
||||
)
|
||||
self._scores[-1, :] = logits
|
||||
self.eval_logits[-1] = logits.tolist()
|
||||
|
||||
nl_logit = logits[self._token_nl]
|
||||
candidates = self._candidates
|
||||
|
@ -672,14 +681,7 @@ class Llama:
|
|||
print("Llama.generate: prefix-match hit", file=sys.stderr)
|
||||
reset = False
|
||||
tokens = tokens[longest_prefix:]
|
||||
self._input_ids = self._input_ids[:longest_prefix]
|
||||
self._scores = self._scores[:longest_prefix, :]
|
||||
for _ in range(len(self.eval_tokens) - longest_prefix):
|
||||
self.eval_tokens.pop()
|
||||
try:
|
||||
self.eval_logits.pop()
|
||||
except IndexError:
|
||||
pass
|
||||
self.n_tokens = longest_prefix
|
||||
|
||||
if reset:
|
||||
self.reset()
|
||||
|
@ -819,7 +821,9 @@ class Llama:
|
|||
llama_cpp.llama_reset_timings(self.ctx)
|
||||
|
||||
if len(prompt_tokens) > self._n_ctx:
|
||||
raise ValueError(f"Requested tokens ({len(prompt_tokens)}) exceed context window of {self._n_ctx}")
|
||||
raise ValueError(
|
||||
f"Requested tokens ({len(prompt_tokens)}) exceed context window of {self._n_ctx}"
|
||||
)
|
||||
|
||||
# Truncate max_tokens if requested tokens would exceed the context window
|
||||
max_tokens = (
|
||||
|
@ -1513,22 +1517,20 @@ class Llama:
|
|||
file=sys.stderr,
|
||||
)
|
||||
return LlamaState(
|
||||
eval_tokens=self.eval_tokens.copy(),
|
||||
eval_logits=self.eval_logits.copy(),
|
||||
scores=self._scores.copy(),
|
||||
input_ids=self._input_ids.copy(),
|
||||
scores=self.scores.copy(),
|
||||
input_ids=self.input_ids.copy(),
|
||||
n_tokens=self.n_tokens,
|
||||
llama_state=bytes(llama_state_compact),
|
||||
llama_state_size=n_bytes,
|
||||
)
|
||||
|
||||
def load_state(self, state: LlamaState) -> None:
|
||||
assert self.ctx is not None
|
||||
self.eval_tokens = state.eval_tokens.copy()
|
||||
self.eval_logits = state.eval_logits.copy()
|
||||
self._scores = state.scores.copy()
|
||||
self._input_ids = state.input_ids.copy()
|
||||
self.scores = state.scores.copy()
|
||||
self.input_ids = state.input_ids.copy()
|
||||
self.n_tokens = state.n_tokens
|
||||
state_size = state.llama_state_size
|
||||
LLamaStateArrayType = (llama_cpp.c_uint8 * state_size)
|
||||
LLamaStateArrayType = llama_cpp.c_uint8 * state_size
|
||||
llama_state = LLamaStateArrayType.from_buffer_copy(state.llama_state)
|
||||
|
||||
if llama_cpp.llama_set_state_data(self.ctx, llama_state) != state_size:
|
||||
|
|
Loading…
Reference in a new issue