From 556c7edf47352036f7d876534a2b3ce4e1586a36 Mon Sep 17 00:00:00 2001 From: Andrei Betlen Date: Fri, 9 Jun 2023 10:57:36 -0400 Subject: [PATCH] Truncate max_tokens if it exceeds context length --- llama_cpp/llama.py | 9 ++++++++- 1 file changed, 8 insertions(+), 1 deletion(-) diff --git a/llama_cpp/llama.py b/llama_cpp/llama.py index 05994b6..4b6ce8c 100644 --- a/llama_cpp/llama.py +++ b/llama_cpp/llama.py @@ -811,9 +811,16 @@ class Llama: if self.verbose: llama_cpp.llama_reset_timings(self.ctx) - if len(prompt_tokens) + max_tokens > self._n_ctx: + if len(prompt_tokens) > self._n_ctx: raise ValueError(f"Requested tokens exceed context window of {self._n_ctx}") + # Truncate max_tokens if requested tokens would exceed the context window + max_tokens = ( + max_tokens + if max_tokens + len(prompt_tokens) < self._n_ctx + else (self._n_ctx - len(prompt_tokens)) + ) + if stop != []: stop_sequences = [s.encode("utf-8") for s in stop] else: