bugfix: truncate completion max_tokens to fit context length by default

This commit is contained in:
Andrei Betlen 2023-07-09 18:13:29 -04:00
parent 6f70cc4b7d
commit a86bfdf0a5

View file

@ -824,18 +824,14 @@ class Llama:
if self.verbose:
llama_cpp.llama_reset_timings(self.ctx)
if max_tokens <= 0:
# Unlimited, depending on n_ctx.
if len(prompt_tokens) >= int(llama_cpp.llama_n_ctx(self.ctx)):
if len(prompt_tokens) >= llama_cpp.llama_n_ctx(self.ctx):
raise ValueError(
f"Requested tokens exceed context window of {llama_cpp.llama_n_ctx(self.ctx)}"
)
else:
max_tokens = int(llama_cpp.llama_n_ctx(self.ctx)) - len(prompt_tokens)
elif len(prompt_tokens) + max_tokens > int(llama_cpp.llama_n_ctx(self.ctx)):
raise ValueError(
f"Requested tokens ({len(prompt_tokens)}) exceed context window of {self._n_ctx}"
)
if max_tokens <= 0:
# Unlimited, depending on n_ctx.
max_tokens = llama_cpp.llama_n_ctx(self.ctx) - len(prompt_tokens)
# Truncate max_tokens if requested tokens would exceed the context window
max_tokens = (