diff --git a/llama_cpp/llama.py b/llama_cpp/llama.py index 62e0dae..edb68c9 100644 --- a/llama_cpp/llama.py +++ b/llama_cpp/llama.py @@ -824,18 +824,14 @@ class Llama: if self.verbose: llama_cpp.llama_reset_timings(self.ctx) + if len(prompt_tokens) >= llama_cpp.llama_n_ctx(self.ctx): + raise ValueError( + f"Requested tokens exceed context window of {llama_cpp.llama_n_ctx(self.ctx)}" + ) + if max_tokens <= 0: # Unlimited, depending on n_ctx. - if len(prompt_tokens) >= int(llama_cpp.llama_n_ctx(self.ctx)): - raise ValueError( - f"Requested tokens exceed context window of {llama_cpp.llama_n_ctx(self.ctx)}" - ) - else: - max_tokens = int(llama_cpp.llama_n_ctx(self.ctx)) - len(prompt_tokens) - elif len(prompt_tokens) + max_tokens > int(llama_cpp.llama_n_ctx(self.ctx)): - raise ValueError( - f"Requested tokens ({len(prompt_tokens)}) exceed context window of {self._n_ctx}" - ) + max_tokens = llama_cpp.llama_n_ctx(self.ctx) - len(prompt_tokens) # Truncate max_tokens if requested tokens would exceed the context window max_tokens = (