Merge branch 'main' into fix-state-pickle

This commit is contained in:
Andrei 2023-06-26 08:45:02 -04:00 committed by GitHub
commit 5eb4ebb041
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23

View file

@ -407,6 +407,7 @@ class Llama:
""" """
assert self.ctx is not None assert self.ctx is not None
n_ctx = self._n_ctx n_ctx = self._n_ctx
scores = []
for i in range(0, len(tokens), self.n_batch): for i in range(0, len(tokens), self.n_batch):
batch = tokens[i : min(len(tokens), i + self.n_batch)] batch = tokens[i : min(len(tokens), i + self.n_batch)]
n_past = min(n_ctx - len(batch), len(self._input_ids)) n_past = min(n_ctx - len(batch), len(self._input_ids))
@ -432,9 +433,8 @@ class Llama:
logits_view = llama_cpp.llama_get_logits(self.ctx) logits_view = llama_cpp.llama_get_logits(self.ctx)
logits = [logits_view[i * cols : (i + 1) * cols] for i in range(rows)] logits = [logits_view[i * cols : (i + 1) * cols] for i in range(rows)]
self.eval_logits.extend(logits) self.eval_logits.extend(logits)
self._scores: npt.NDArray[np.single] = np.concatenate( scores.append(np.array(logits, dtype=np.single))
(self._scores, np.array(logits, dtype=np.single)), axis=0 self._scores = np.concatenate(scores)
)
def _sample( def _sample(
self, self,