Hotfix: logits_all bug

This commit is contained in:
Andrei Betlen 2023-06-29 00:57:27 -04:00
parent 4d1eb88b13
commit e34f4414cf

View file

@ -443,7 +443,8 @@ class Llama:
# Save logits
rows = n_tokens if self.params.logits_all else 1
cols = self._n_vocab
self.scores[self.n_tokens : self.n_tokens + n_tokens, :].reshape(-1)[:] = llama_cpp.llama_get_logits(self.ctx)[:rows * cols]
offset = 0 if self.params.logits_all else n_tokens - 1 # NOTE: Only save the last token logits if logits_all is False
self.scores[self.n_tokens + offset: self.n_tokens + n_tokens, :].reshape(-1)[:] = llama_cpp.llama_get_logits(self.ctx)[:rows * cols]
# Update n_tokens
self.n_tokens += n_tokens