bugfix: truncate completion max_tokens to fit context length by default

This commit is contained in:
Andrei Betlen 2023-07-09 18:13:29 -04:00
parent 6f70cc4b7d
commit a86bfdf0a5

View file

@ -824,18 +824,14 @@ class Llama:
if self.verbose: if self.verbose:
llama_cpp.llama_reset_timings(self.ctx) llama_cpp.llama_reset_timings(self.ctx)
if max_tokens <= 0: if len(prompt_tokens) >= llama_cpp.llama_n_ctx(self.ctx):
# Unlimited, depending on n_ctx.
if len(prompt_tokens) >= int(llama_cpp.llama_n_ctx(self.ctx)):
raise ValueError( raise ValueError(
f"Requested tokens exceed context window of {llama_cpp.llama_n_ctx(self.ctx)}" f"Requested tokens exceed context window of {llama_cpp.llama_n_ctx(self.ctx)}"
) )
else:
max_tokens = int(llama_cpp.llama_n_ctx(self.ctx)) - len(prompt_tokens) if max_tokens <= 0:
elif len(prompt_tokens) + max_tokens > int(llama_cpp.llama_n_ctx(self.ctx)): # Unlimited, depending on n_ctx.
raise ValueError( max_tokens = llama_cpp.llama_n_ctx(self.ctx) - len(prompt_tokens)
f"Requested tokens ({len(prompt_tokens)}) exceed context window of {self._n_ctx}"
)
# Truncate max_tokens if requested tokens would exceed the context window # Truncate max_tokens if requested tokens would exceed the context window
max_tokens = ( max_tokens = (