diff --git a/llama_cpp/llama.py b/llama_cpp/llama.py index 764c91e..8b7dee0 100644 --- a/llama_cpp/llama.py +++ b/llama_cpp/llama.py @@ -442,11 +442,8 @@ class Llama: self.input_ids[self.n_tokens : self.n_tokens + n_tokens] = batch # Save logits rows = n_tokens if self.params.logits_all else 1 - n_vocab = self._n_vocab - cols = n_vocab - logits_view = llama_cpp.llama_get_logits(self.ctx) - logits = [logits_view[i * cols : (i + 1) * cols] for i in range(rows)] - self.scores[self.n_tokens : self.n_tokens + n_tokens, :] = logits + cols = self._n_vocab + self.scores[self.n_tokens : self.n_tokens + n_tokens, :].reshape(-1)[:] = llama_cpp.llama_get_logits(self.ctx)[:rows * cols] # Update n_tokens self.n_tokens += n_tokens