feat: Improve Llama.eval performance by avoiding list conversion (#1476)

Co-authored-by: Andrei <abetlen@gmail.com>
This commit is contained in:
Linghan Zhong 2024-05-24 00:49:44 -05:00 committed by GitHub
parent 087cc0b036
commit 5cae1040e3
No known key found for this signature in database
GPG key ID: B5690EEEBB952194

View file

@ -562,12 +562,12 @@ class Llama:
if self.context_params.logits_all:
rows = n_tokens
cols = self._n_vocab
logits = self._ctx.get_logits()[: rows * cols]
logits = np.ctypeslib.as_array(self._ctx.get_logits(), shape=(rows * cols, ))
self.scores[n_past : n_past + n_tokens, :].reshape(-1)[: :] = logits
else:
rows = 1
cols = self._n_vocab
logits = self._ctx.get_logits()[: rows * cols]
logits = np.ctypeslib.as_array(self._ctx.get_logits(), shape=(rows * cols, ))
self.scores[n_past + n_tokens - 1, :].reshape(-1)[: :] = logits
# Update n_tokens
self.n_tokens += n_tokens