diff --git a/llama_cpp/llama.py b/llama_cpp/llama.py index 05994b6..4b6ce8c 100644 --- a/llama_cpp/llama.py +++ b/llama_cpp/llama.py @@ -811,9 +811,16 @@ class Llama: if self.verbose: llama_cpp.llama_reset_timings(self.ctx) - if len(prompt_tokens) + max_tokens > self._n_ctx: + if len(prompt_tokens) > self._n_ctx: raise ValueError(f"Requested tokens exceed context window of {self._n_ctx}") + # Truncate max_tokens if requested tokens would exceed the context window + max_tokens = ( + max_tokens + if max_tokens + len(prompt_tokens) < self._n_ctx + else (self._n_ctx - len(prompt_tokens)) + ) + if stop != []: stop_sequences = [s.encode("utf-8") for s in stop] else: