misc: Improve llava error messages

This commit is contained in:
Andrei Betlen 2024-06-03 11:19:10 -04:00
parent a6457ba74b
commit 6b018e00b1

View file

@ -2642,13 +2642,13 @@ class Llava15ChatHandler:
if type_ == "text": if type_ == "text":
tokens = llama.tokenize(value.encode("utf8"), add_bos=False, special=True) tokens = llama.tokenize(value.encode("utf8"), add_bos=False, special=True)
if llama.n_tokens + len(tokens) > llama.n_ctx(): if llama.n_tokens + len(tokens) > llama.n_ctx():
raise ValueError("Prompt exceeds n_ctx") # TODO: Fix raise ValueError(f"Prompt exceeds n_ctx: {llama.n_tokens + len(tokens)} > {llama.n_ctx()}")
llama.eval(tokens) llama.eval(tokens)
else: else:
image_bytes = self.load_image(value) image_bytes = self.load_image(value)
embed = embed_image_bytes(image_bytes) embed = embed_image_bytes(image_bytes)
if llama.n_tokens + embed.contents.n_image_pos > llama.n_ctx(): if llama.n_tokens + embed.contents.n_image_pos > llama.n_ctx():
raise ValueError("Prompt exceeds n_ctx") # TODO: Fix raise ValueError(f"Prompt exceeds n_ctx: {llama.n_tokens + embed.contents.n_image_pos} > {llama.n_ctx()}")
n_past = ctypes.c_int(llama.n_tokens) n_past = ctypes.c_int(llama.n_tokens)
n_past_p = ctypes.pointer(n_past) n_past_p = ctypes.pointer(n_past)
with suppress_stdout_stderr(disable=self.verbose): with suppress_stdout_stderr(disable=self.verbose):