diff --git a/llama_cpp/llama.py b/llama_cpp/llama.py index 688b2a7..35823cf 100644 --- a/llama_cpp/llama.py +++ b/llama_cpp/llama.py @@ -487,9 +487,9 @@ class Llama: nl_logit = logits[self._token_nl] candidates = self._candidates candidates_data = self._candidates_data - candidates_data["id"] = np.arange(n_vocab, dtype=np.intc) # type: ignore - candidates_data["logit"] = logits - candidates_data["p"] = np.zeros(n_vocab, dtype=np.single) + candidates_data["id"][:] = np.arange(n_vocab, dtype=np.intc) # type: ignore + candidates_data["logit"][:] = logits + candidates_data["p"][:] = np.zeros(n_vocab, dtype=np.single) candidates.data = candidates_data.ctypes.data_as(llama_cpp.llama_token_data_p) candidates.sorted = llama_cpp.c_bool(False) candidates.size = llama_cpp.c_size_t(n_vocab)