From 6b018e00b15c18f0c72d85bdc64ee92f69b036ed Mon Sep 17 00:00:00 2001 From: Andrei Betlen Date: Mon, 3 Jun 2024 11:19:10 -0400 Subject: [PATCH] misc: Improve llava error messages --- llama_cpp/llama_chat_format.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/llama_cpp/llama_chat_format.py b/llama_cpp/llama_chat_format.py index 8f3b1de..b384749 100644 --- a/llama_cpp/llama_chat_format.py +++ b/llama_cpp/llama_chat_format.py @@ -2642,13 +2642,13 @@ class Llava15ChatHandler: if type_ == "text": tokens = llama.tokenize(value.encode("utf8"), add_bos=False, special=True) if llama.n_tokens + len(tokens) > llama.n_ctx(): - raise ValueError("Prompt exceeds n_ctx") # TODO: Fix + raise ValueError(f"Prompt exceeds n_ctx: {llama.n_tokens + len(tokens)} > {llama.n_ctx()}") llama.eval(tokens) else: image_bytes = self.load_image(value) embed = embed_image_bytes(image_bytes) if llama.n_tokens + embed.contents.n_image_pos > llama.n_ctx(): - raise ValueError("Prompt exceeds n_ctx") # TODO: Fix + raise ValueError(f"Prompt exceeds n_ctx: {llama.n_tokens + embed.contents.n_image_pos} > {llama.n_ctx()}") n_past = ctypes.c_int(llama.n_tokens) n_past_p = ctypes.pointer(n_past) with suppress_stdout_stderr(disable=self.verbose):