Fix top_k value. Closes #220

This commit is contained in:
Andrei Betlen 2023-05-17 01:41:42 -04:00
parent e37a808bc0
commit 7e55244540

View file

@ -295,6 +295,7 @@ class Llama:
assert self.ctx is not None assert self.ctx is not None
assert len(self.eval_logits) > 0 assert len(self.eval_logits) > 0
n_vocab = int(llama_cpp.llama_n_vocab(self.ctx)) n_vocab = int(llama_cpp.llama_n_vocab(self.ctx))
top_k = llama_cpp.c_int(n_vocab) if top_k.value <= 0 else top_k
logits = self.eval_logits[-1] logits = self.eval_logits[-1]
data = (llama_cpp.llama_token_data * n_vocab)( data = (llama_cpp.llama_token_data * n_vocab)(
*[ *[