From a86bfdf0a50f23a6aebb3f095ada0afcf8791d6e Mon Sep 17 00:00:00 2001 From: Andrei Betlen Date: Sun, 9 Jul 2023 18:13:29 -0400 Subject: [PATCH] bugfix: truncate completion max_tokens to fit context length by default --- llama_cpp/llama.py | 16 ++++++---------- 1 file changed, 6 insertions(+), 10 deletions(-) diff --git a/llama_cpp/llama.py b/llama_cpp/llama.py index 62e0dae..edb68c9 100644 --- a/llama_cpp/llama.py +++ b/llama_cpp/llama.py @@ -824,18 +824,14 @@ class Llama: if self.verbose: llama_cpp.llama_reset_timings(self.ctx) + if len(prompt_tokens) >= llama_cpp.llama_n_ctx(self.ctx): + raise ValueError( + f"Requested tokens exceed context window of {llama_cpp.llama_n_ctx(self.ctx)}" + ) + if max_tokens <= 0: # Unlimited, depending on n_ctx. - if len(prompt_tokens) >= int(llama_cpp.llama_n_ctx(self.ctx)): - raise ValueError( - f"Requested tokens exceed context window of {llama_cpp.llama_n_ctx(self.ctx)}" - ) - else: - max_tokens = int(llama_cpp.llama_n_ctx(self.ctx)) - len(prompt_tokens) - elif len(prompt_tokens) + max_tokens > int(llama_cpp.llama_n_ctx(self.ctx)): - raise ValueError( - f"Requested tokens ({len(prompt_tokens)}) exceed context window of {self._n_ctx}" - ) + max_tokens = llama_cpp.llama_n_ctx(self.ctx) - len(prompt_tokens) # Truncate max_tokens if requested tokens would exceed the context window max_tokens = (