From e34f4414cf36f24dbe83f2e4baa610f207b5eb4d Mon Sep 17 00:00:00 2001 From: Andrei Betlen Date: Thu, 29 Jun 2023 00:57:27 -0400 Subject: [PATCH] Hotfix: logits_all bug --- llama_cpp/llama.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/llama_cpp/llama.py b/llama_cpp/llama.py index 8b7dee0..688b2a7 100644 --- a/llama_cpp/llama.py +++ b/llama_cpp/llama.py @@ -443,7 +443,8 @@ class Llama: # Save logits rows = n_tokens if self.params.logits_all else 1 cols = self._n_vocab - self.scores[self.n_tokens : self.n_tokens + n_tokens, :].reshape(-1)[:] = llama_cpp.llama_get_logits(self.ctx)[:rows * cols] + offset = 0 if self.params.logits_all else n_tokens - 1 # NOTE: Only save the last token logits if logits_all is False + self.scores[self.n_tokens + offset: self.n_tokens + n_tokens, :].reshape(-1)[:] = llama_cpp.llama_get_logits(self.ctx)[:rows * cols] # Update n_tokens self.n_tokens += n_tokens