diff --git a/llama_cpp/llama.py b/llama_cpp/llama.py index 054bdc2..4c8ba39 100644 --- a/llama_cpp/llama.py +++ b/llama_cpp/llama.py @@ -295,6 +295,7 @@ class Llama: assert self.ctx is not None assert len(self.eval_logits) > 0 n_vocab = int(llama_cpp.llama_n_vocab(self.ctx)) + top_k = llama_cpp.c_int(n_vocab) if top_k.value <= 0 else top_k logits = self.eval_logits[-1] data = (llama_cpp.llama_token_data * n_vocab)( *[