feat: Improve Llama.eval performance by avoiding list conversion (#1476)
Co-authored-by: Andrei <abetlen@gmail.com>
This commit is contained in:
parent
087cc0b036
commit
5cae1040e3
1 changed files with 2 additions and 2 deletions
|
@ -562,12 +562,12 @@ class Llama:
|
||||||
if self.context_params.logits_all:
|
if self.context_params.logits_all:
|
||||||
rows = n_tokens
|
rows = n_tokens
|
||||||
cols = self._n_vocab
|
cols = self._n_vocab
|
||||||
logits = self._ctx.get_logits()[: rows * cols]
|
logits = np.ctypeslib.as_array(self._ctx.get_logits(), shape=(rows * cols, ))
|
||||||
self.scores[n_past : n_past + n_tokens, :].reshape(-1)[: :] = logits
|
self.scores[n_past : n_past + n_tokens, :].reshape(-1)[: :] = logits
|
||||||
else:
|
else:
|
||||||
rows = 1
|
rows = 1
|
||||||
cols = self._n_vocab
|
cols = self._n_vocab
|
||||||
logits = self._ctx.get_logits()[: rows * cols]
|
logits = np.ctypeslib.as_array(self._ctx.get_logits(), shape=(rows * cols, ))
|
||||||
self.scores[n_past + n_tokens - 1, :].reshape(-1)[: :] = logits
|
self.scores[n_past + n_tokens - 1, :].reshape(-1)[: :] = logits
|
||||||
# Update n_tokens
|
# Update n_tokens
|
||||||
self.n_tokens += n_tokens
|
self.n_tokens += n_tokens
|
||||||
|
|
Loading…
Reference in a new issue