diff --git a/llama_cpp/llama.py b/llama_cpp/llama.py index 35823cf..0895182 100644 --- a/llama_cpp/llama.py +++ b/llama_cpp/llama.py @@ -324,6 +324,8 @@ class Llama: self._candidates = candidates self._token_nl = Llama.token_nl() self._token_eos = Llama.token_eos() + self._candidates_data_id = np.arange(self._n_vocab, dtype=np.intc) # type: ignore + self._candidates_data_p = np.zeros(self._n_vocab, dtype=np.single) self.n_tokens = 0 self.input_ids: npt.NDArray[np.intc] = np.ndarray((n_ctx,), dtype=np.intc) @@ -487,9 +489,9 @@ class Llama: nl_logit = logits[self._token_nl] candidates = self._candidates candidates_data = self._candidates_data - candidates_data["id"][:] = np.arange(n_vocab, dtype=np.intc) # type: ignore + candidates_data["id"][:] = self._candidates_data_id # type: ignore candidates_data["logit"][:] = logits - candidates_data["p"][:] = np.zeros(n_vocab, dtype=np.single) + candidates_data["p"][:] = self._candidates_data_p # type: ignore candidates.data = candidates_data.ctypes.data_as(llama_cpp.llama_token_data_p) candidates.sorted = llama_cpp.c_bool(False) candidates.size = llama_cpp.c_size_t(n_vocab)