diff --git a/llama_cpp/llama.py b/llama_cpp/llama.py index 2865d27..764c91e 100644 --- a/llama_cpp/llama.py +++ b/llama_cpp/llama.py @@ -141,7 +141,7 @@ class LlamaDiskCache(BaseLlamaCache): if _key is None: raise KeyError("Key not found") value: "LlamaState" = self.cache.pop(_key) # type: ignore - # NOTE: This puts an integer as key in cache, which breaks, + # NOTE: This puts an integer as key in cache, which breaks, # Llama.longest_token_prefix(k, key) above since k is not a tuple of ints/tokens # self.cache.push(_key, side="front") # type: ignore return value @@ -166,17 +166,15 @@ class LlamaDiskCache(BaseLlamaCache): class LlamaState: def __init__( self, - eval_tokens: Deque[int], - eval_logits: Deque[List[float]], input_ids: npt.NDArray[np.intc], scores: npt.NDArray[np.single], + n_tokens: int, llama_state: bytes, llama_state_size: int, ): - self.eval_tokens = eval_tokens - self.eval_logits = eval_logits self.input_ids = input_ids self.scores = scores + self.n_tokens = n_tokens self.llama_state = llama_state self.llama_state_size = llama_state_size @@ -267,8 +265,6 @@ class Llama: self.last_n_tokens_size = last_n_tokens_size self.n_batch = min(n_ctx, n_batch) - self.eval_tokens: Deque[int] = deque(maxlen=n_ctx) - self.eval_logits: Deque[List[float]] = deque(maxlen=n_ctx if logits_all else 1) self.cache: Optional[BaseLlamaCache] = None @@ -329,8 +325,30 @@ class Llama: self._token_nl = Llama.token_nl() self._token_eos = Llama.token_eos() - self._input_ids = np.array([], dtype=np.intc) - self._scores: npt.NDArray[np.single] = np.ndarray((0, self._n_vocab), dtype=np.single) + self.n_tokens = 0 + self.input_ids: npt.NDArray[np.intc] = np.ndarray((n_ctx,), dtype=np.intc) + self.scores: npt.NDArray[np.single] = np.ndarray( + (n_ctx, self._n_vocab), dtype=np.single + ) + + @property + def _input_ids(self) -> npt.NDArray[np.intc]: + return self.input_ids[: self.n_tokens] + + @property + def _scores(self) -> npt.NDArray[np.single]: + return self.scores[: self.n_tokens, :] + + @property + def eval_tokens(self) -> Deque[int]: + return deque(self.input_ids[: self.n_tokens].tolist(), maxlen=self._n_ctx) + + @property + def eval_logits(self) -> Deque[List[float]]: + return deque( + self.scores[: self.n_tokens, :].tolist(), + maxlen=self._n_ctx if self.params.logits_all else 1, + ) def tokenize(self, text: bytes, add_bos: bool = True) -> List[int]: """Tokenize a string. @@ -397,10 +415,7 @@ class Llama: def reset(self): """Reset the model state.""" - self.eval_tokens.clear() - self.eval_logits.clear() - self._input_ids = np.array([], dtype=np.intc) - self._scores = np.ndarray((0, self._n_vocab), dtype=np.single) + self.n_tokens = 0 def eval(self, tokens: Sequence[int]): """Evaluate a list of tokens. @@ -410,7 +425,6 @@ class Llama: """ assert self.ctx is not None n_ctx = self._n_ctx - scores: List[npt.NDArray[np.single]] = [] for i in range(0, len(tokens), self.n_batch): batch = tokens[i : min(len(tokens), i + self.n_batch)] n_past = min(n_ctx - len(batch), len(self._input_ids)) @@ -425,19 +439,16 @@ class Llama: if return_code != 0: raise RuntimeError(f"llama_eval returned {return_code}") # Save tokens - self.eval_tokens.extend(batch) - self._input_ids: npt.NDArray[np.intc] = np.concatenate( - (self._input_ids, np.array(batch, dtype=np.intc)), axis=0 - ) + self.input_ids[self.n_tokens : self.n_tokens + n_tokens] = batch # Save logits rows = n_tokens if self.params.logits_all else 1 n_vocab = self._n_vocab cols = n_vocab logits_view = llama_cpp.llama_get_logits(self.ctx) logits = [logits_view[i * cols : (i + 1) * cols] for i in range(rows)] - self.eval_logits.extend(logits) - scores.append(np.array(logits, dtype=np.single)) - self._scores = np.concatenate(scores) + self.scores[self.n_tokens : self.n_tokens + n_tokens, :] = logits + # Update n_tokens + self.n_tokens += n_tokens def _sample( self, @@ -457,8 +468,7 @@ class Llama: logits_processor: Optional[LogitsProcessorList] = None, ): assert self.ctx is not None - assert len(self.eval_logits) > 0 - assert self._scores.shape[0] > 0 + assert self.n_tokens > 0 n_vocab = self._n_vocab n_ctx = self._n_ctx top_k = llama_cpp.c_int(n_vocab) if top_k.value <= 0 else top_k @@ -475,7 +485,6 @@ class Llama: dtype=np.single, ) self._scores[-1, :] = logits - self.eval_logits[-1] = logits.tolist() nl_logit = logits[self._token_nl] candidates = self._candidates @@ -672,14 +681,7 @@ class Llama: print("Llama.generate: prefix-match hit", file=sys.stderr) reset = False tokens = tokens[longest_prefix:] - self._input_ids = self._input_ids[:longest_prefix] - self._scores = self._scores[:longest_prefix, :] - for _ in range(len(self.eval_tokens) - longest_prefix): - self.eval_tokens.pop() - try: - self.eval_logits.pop() - except IndexError: - pass + self.n_tokens = longest_prefix if reset: self.reset() @@ -819,7 +821,9 @@ class Llama: llama_cpp.llama_reset_timings(self.ctx) if len(prompt_tokens) > self._n_ctx: - raise ValueError(f"Requested tokens ({len(prompt_tokens)}) exceed context window of {self._n_ctx}") + raise ValueError( + f"Requested tokens ({len(prompt_tokens)}) exceed context window of {self._n_ctx}" + ) # Truncate max_tokens if requested tokens would exceed the context window max_tokens = ( @@ -1513,22 +1517,20 @@ class Llama: file=sys.stderr, ) return LlamaState( - eval_tokens=self.eval_tokens.copy(), - eval_logits=self.eval_logits.copy(), - scores=self._scores.copy(), - input_ids=self._input_ids.copy(), + scores=self.scores.copy(), + input_ids=self.input_ids.copy(), + n_tokens=self.n_tokens, llama_state=bytes(llama_state_compact), llama_state_size=n_bytes, ) def load_state(self, state: LlamaState) -> None: assert self.ctx is not None - self.eval_tokens = state.eval_tokens.copy() - self.eval_logits = state.eval_logits.copy() - self._scores = state.scores.copy() - self._input_ids = state.input_ids.copy() + self.scores = state.scores.copy() + self.input_ids = state.input_ids.copy() + self.n_tokens = state.n_tokens state_size = state.llama_state_size - LLamaStateArrayType = (llama_cpp.c_uint8 * state_size) + LLamaStateArrayType = llama_cpp.c_uint8 * state_size llama_state = LLamaStateArrayType.from_buffer_copy(state.llama_state) if llama_cpp.llama_set_state_data(self.ctx, llama_state) != state_size: