diff --git a/llama_cpp/llama.py b/llama_cpp/llama.py index 816cf11..29f468e 100644 --- a/llama_cpp/llama.py +++ b/llama_cpp/llama.py @@ -407,6 +407,7 @@ class Llama: """ assert self.ctx is not None n_ctx = self._n_ctx + scores = [] for i in range(0, len(tokens), self.n_batch): batch = tokens[i : min(len(tokens), i + self.n_batch)] n_past = min(n_ctx - len(batch), len(self._input_ids)) @@ -432,9 +433,8 @@ class Llama: logits_view = llama_cpp.llama_get_logits(self.ctx) logits = [logits_view[i * cols : (i + 1) * cols] for i in range(rows)] self.eval_logits.extend(logits) - self._scores: npt.NDArray[np.single] = np.concatenate( - (self._scores, np.array(logits, dtype=np.single)), axis=0 - ) + scores.append(np.array(logits, dtype=np.single)) + self._scores = np.concatenate(scores) def _sample( self,