Load logits directly into scores buffer
This commit is contained in:
parent
b95b0ffbeb
commit
a2ede37bd5
1 changed files with 2 additions and 5 deletions
|
@ -442,11 +442,8 @@ class Llama:
|
||||||
self.input_ids[self.n_tokens : self.n_tokens + n_tokens] = batch
|
self.input_ids[self.n_tokens : self.n_tokens + n_tokens] = batch
|
||||||
# Save logits
|
# Save logits
|
||||||
rows = n_tokens if self.params.logits_all else 1
|
rows = n_tokens if self.params.logits_all else 1
|
||||||
n_vocab = self._n_vocab
|
cols = self._n_vocab
|
||||||
cols = n_vocab
|
self.scores[self.n_tokens : self.n_tokens + n_tokens, :].reshape(-1)[:] = llama_cpp.llama_get_logits(self.ctx)[:rows * cols]
|
||||||
logits_view = llama_cpp.llama_get_logits(self.ctx)
|
|
||||||
logits = [logits_view[i * cols : (i + 1) * cols] for i in range(rows)]
|
|
||||||
self.scores[self.n_tokens : self.n_tokens + n_tokens, :] = logits
|
|
||||||
# Update n_tokens
|
# Update n_tokens
|
||||||
self.n_tokens += n_tokens
|
self.n_tokens += n_tokens
|
||||||
|
|
||||||
|
|
Loading…
Add table
Reference in a new issue