fix: segfault when logits_all=False. Closes #1319

This commit is contained in:
Andrei Betlen 2024-04-03 15:30:31 -04:00
parent f96de6d920
commit 8649d7671b

View file

@ -535,14 +535,16 @@ class Llama:
# Save tokens # Save tokens
self.input_ids[n_past : n_past + n_tokens] = batch self.input_ids[n_past : n_past + n_tokens] = batch
# Save logits # Save logits
rows = n_tokens if self.context_params.logits_all:
cols = self._n_vocab rows = n_tokens
offset = ( cols = self._n_vocab
0 if self.context_params.logits_all else n_tokens - 1 logits = self._ctx.get_logits()[: rows * cols]
) # NOTE: Only save the last token logits if logits_all is False self.scores[n_past : n_past + n_tokens, :].reshape(-1)[: :] = logits
self.scores[n_past + offset : n_past + n_tokens, :].reshape(-1)[ else:
: rows = 1
] = self._ctx.get_logits()[offset * cols : rows * cols] cols = self._n_vocab
logits = self._ctx.get_logits()[: rows * cols]
self.scores[n_past + n_tokens - 1, :].reshape(-1)[: :] = logits
# Update n_tokens # Update n_tokens
self.n_tokens += n_tokens self.n_tokens += n_tokens