Load logits directly into scores buffer

This commit is contained in:
Andrei Betlen 2023-06-29 00:45:46 -04:00
parent b95b0ffbeb
commit a2ede37bd5

View file

@ -442,11 +442,8 @@ class Llama:
self.input_ids[self.n_tokens : self.n_tokens + n_tokens] = batch self.input_ids[self.n_tokens : self.n_tokens + n_tokens] = batch
# Save logits # Save logits
rows = n_tokens if self.params.logits_all else 1 rows = n_tokens if self.params.logits_all else 1
n_vocab = self._n_vocab cols = self._n_vocab
cols = n_vocab self.scores[self.n_tokens : self.n_tokens + n_tokens, :].reshape(-1)[:] = llama_cpp.llama_get_logits(self.ctx)[:rows * cols]
logits_view = llama_cpp.llama_get_logits(self.ctx)
logits = [logits_view[i * cols : (i + 1) * cols] for i in range(rows)]
self.scores[self.n_tokens : self.n_tokens + n_tokens, :] = logits
# Update n_tokens # Update n_tokens
self.n_tokens += n_tokens self.n_tokens += n_tokens