Use n_ctx provided from actual context not params

This commit is contained in:
Andrei Betlen 2023-03-24 14:58:10 -04:00
parent 2cc499512c
commit b9c53b88a1

View file

@ -60,12 +60,12 @@ class Llama:
stop = [s.encode("utf-8") for s in stop]
prompt_tokens = llama_cpp.llama_tokenize(
self.ctx, prompt.encode("utf-8"), self.tokens, self.params.n_ctx, True
self.ctx, prompt.encode("utf-8"), self.tokens, llama_cpp.llama_n_ctx(self.ctx), True
)
if prompt_tokens + max_tokens > llama_cpp.llama_n_ctx(self.ctx):
if prompt_tokens + max_tokens > self.params.n_ctx:
raise ValueError(
f"Requested tokens exceed context window of {self.params.n_ctx}"
f"Requested tokens exceed context window of {llama_cpp.llama_n_ctx(self.ctx)}"
)
# Process prompt in chunks to avoid running out of memory