fix: segfault when logits_all=False. Closes #1319
This commit is contained in:
parent
f96de6d920
commit
8649d7671b
1 changed files with 10 additions and 8 deletions
|
@ -535,14 +535,16 @@ class Llama:
|
|||
# Save tokens
|
||||
self.input_ids[n_past : n_past + n_tokens] = batch
|
||||
# Save logits
|
||||
rows = n_tokens
|
||||
cols = self._n_vocab
|
||||
offset = (
|
||||
0 if self.context_params.logits_all else n_tokens - 1
|
||||
) # NOTE: Only save the last token logits if logits_all is False
|
||||
self.scores[n_past + offset : n_past + n_tokens, :].reshape(-1)[
|
||||
:
|
||||
] = self._ctx.get_logits()[offset * cols : rows * cols]
|
||||
if self.context_params.logits_all:
|
||||
rows = n_tokens
|
||||
cols = self._n_vocab
|
||||
logits = self._ctx.get_logits()[: rows * cols]
|
||||
self.scores[n_past : n_past + n_tokens, :].reshape(-1)[: :] = logits
|
||||
else:
|
||||
rows = 1
|
||||
cols = self._n_vocab
|
||||
logits = self._ctx.get_logits()[: rows * cols]
|
||||
self.scores[n_past + n_tokens - 1, :].reshape(-1)[: :] = logits
|
||||
# Update n_tokens
|
||||
self.n_tokens += n_tokens
|
||||
|
||||
|
|
Loading…
Reference in a new issue