diff --git a/llama_cpp/llama.py b/llama_cpp/llama.py index 6babebd..4f10227 100644 --- a/llama_cpp/llama.py +++ b/llama_cpp/llama.py @@ -299,6 +299,8 @@ class Llama: """Reset the model state.""" self.eval_tokens.clear() self.eval_logits.clear() + self._input_ids = np.array([], dtype=np.intc) + self._scores = np.ndarray((0, self._n_vocab), dtype=np.single) def eval(self, tokens: Sequence[int]): """Evaluate a list of tokens. @@ -310,7 +312,7 @@ class Llama: n_ctx = self._n_ctx for i in range(0, len(tokens), self.n_batch): batch = tokens[i : min(len(tokens), i + self.n_batch)] - n_past = min(n_ctx - len(batch), len(self.eval_tokens)) + n_past = min(n_ctx - len(batch), len(self._input_ids)) n_tokens = len(batch) return_code = llama_cpp.llama_eval( ctx=self.ctx, @@ -356,6 +358,7 @@ class Llama: ): assert self.ctx is not None assert len(self.eval_logits) > 0 + assert self._scores.shape[0] > 0 n_vocab = self._n_vocab n_ctx = self._n_ctx top_k = llama_cpp.c_int(n_vocab) if top_k.value <= 0 else top_k @@ -368,7 +371,7 @@ class Llama: if logits_processor is not None: logits = np.array( - logits_processor(list(self.eval_tokens), logits.tolist()), + logits_processor(self._input_ids.tolist(), logits.tolist()), dtype=np.single, ) self._scores[-1, :] = logits @@ -498,8 +501,8 @@ class Llama: """ assert self.ctx is not None last_n_tokens_data = [llama_cpp.llama_token(0)] * max( - 0, self.last_n_tokens_size - len(self.eval_tokens) - ) + list(self.eval_tokens)[-self.last_n_tokens_size :] + 0, self.last_n_tokens_size - len(self._input_ids) + ) + self._input_ids[-self.last_n_tokens_size :].tolist() return self._sample( last_n_tokens_data=(llama_cpp.llama_token * self.last_n_tokens_size)( *last_n_tokens_data @@ -557,9 +560,9 @@ class Llama: """ assert self.ctx is not None - if reset and len(self.eval_tokens) > 0: + if reset and len(self._input_ids) > 0: longest_prefix = 0 - for a, b in zip(self.eval_tokens, tokens[:-1]): + for a, b in zip(self._input_ids, tokens[:-1]): if a == b: longest_prefix += 1 else: @@ -569,6 +572,8 @@ class Llama: print("Llama.generate: prefix-match hit", file=sys.stderr) reset = False tokens = tokens[longest_prefix:] + self._input_ids = self._input_ids[:longest_prefix] + self._scores = self._scores[:longest_prefix, :] for _ in range(len(self.eval_tokens) - longest_prefix): self.eval_tokens.pop() try: @@ -595,7 +600,7 @@ class Llama: logits_processor=logits_processor, ) if stopping_criteria is not None and stopping_criteria( - list(self.eval_tokens), self.eval_logits[-1] + self._input_ids.tolist(), self._scores[-1, :].tolist() ): return tokens_or_none = yield token @@ -820,7 +825,7 @@ class Llama: self.detokenize(completion_tokens[:returned_tokens]) ) token_offset = len(prompt_tokens) + returned_tokens - logits = self.eval_logits[token_offset - 1] + logits = self._scores[token_offset - 1, :].tolist() current_logprobs = Llama.logits_to_logprobs(logits) sorted_logprobs = list( sorted( @@ -869,7 +874,7 @@ class Llama: break if stopping_criteria is not None and stopping_criteria( - list(self.eval_tokens), self.eval_logits[-1] + self._input_ids.tolist(), self._scores[-1, :].tolist() ): text = self.detokenize(completion_tokens) finish_reason = "stop" @@ -899,7 +904,7 @@ class Llama: self.detokenize(completion_tokens[:returned_tokens]) ) token_offset = len(prompt_tokens) + returned_tokens - 1 - logits = self.eval_logits[token_offset] + logits = self._scores[token_offset, :].tolist() current_logprobs = Llama.logits_to_logprobs(logits) sorted_logprobs = list( sorted( @@ -1001,8 +1006,7 @@ class Llama: for token in all_tokens ] all_logprobs = [ - Llama.logits_to_logprobs(list(map(float, row))) - for row in self.eval_logits + Llama.logits_to_logprobs(row.tolist()) for row in self._scores ][token_offset:] for token, token_str, logprobs_token in zip( all_tokens, all_token_strs, all_logprobs