diff --git a/llama_cpp/llama.py b/llama_cpp/llama.py index 51bb8b2..1d5a5f4 100644 --- a/llama_cpp/llama.py +++ b/llama_cpp/llama.py @@ -169,6 +169,11 @@ class Llama: The sampled token. """ assert self.ctx is not None + # Temporary workaround for https://github.com/ggerganov/llama.cpp/issues/684 + if temp == 0.0: + temp = 1.0 + top_p = 0.0 + top_k = 1 return llama_cpp.llama_sample_top_p_top_k( ctx=self.ctx, last_n_tokens_data=(llama_cpp.llama_token * self.last_n_tokens_size)( @@ -209,11 +214,6 @@ class Llama: Yields: The generated tokens. """ - # Temporary workaround for https://github.com/ggerganov/llama.cpp/issues/684 - if temp == 0.0: - temp = 1.0 - top_p = 0.0 - top_k = 1 assert self.ctx is not None self.reset() while True: