Fix sampling bug when logits_all=False

This commit is contained in:
Andrei Betlen 2023-11-10 05:15:41 -05:00
parent d9b38e3e3a
commit 6f0b0b1b84

View file

@ -1029,16 +1029,16 @@ class Llama:
)
self._ctx.decode(self._batch)
# Save tokens
self.input_ids[self.n_tokens : self.n_tokens + n_tokens] = batch
self.input_ids[n_past : n_past + n_tokens] = batch
# Save logits
rows = n_tokens if self.context_params.logits_all else 1
rows = n_tokens
cols = self._n_vocab
offset = (
0 if self.context_params.logits_all else n_tokens - 1
) # NOTE: Only save the last token logits if logits_all is False
self.scores[self.n_tokens + offset : self.n_tokens + n_tokens, :].reshape(
self.scores[n_past + offset : n_past + n_tokens, :].reshape(
-1
)[:] = self._ctx.get_logits()[: rows * cols]
)[:] = self._ctx.get_logits()[offset * cols: rows * cols]
# Update n_tokens
self.n_tokens += n_tokens