fix: Make leading bos_token optional for image chat formats, fix nanollava system message

This commit is contained in:
Andrei Betlen 2024-05-08 13:12:31 -04:00
parent 2a39b99575
commit 77122638b4

View file

@ -2603,7 +2603,12 @@ class Llava15ChatHandler:
image_urls = self.get_image_urls(messages)
template = jinja2.Template(self.CHAT_FORMAT)
text = template.render(messages=messages, add_generation_prompt=True)
text = template.render(
messages=messages,
add_generation_prompt=True,
eos_token=llama.detokenize([llama.token_eos()]),
bos_token=llama.detokenize([llama.token_bos()]),
)
split_text = self.split_text_on_image_urls(text, image_urls)
def embed_image_bytes(image_bytes: bytes):
@ -2624,9 +2629,9 @@ class Llava15ChatHandler:
# Evaluate prompt
llama.reset()
for i, (type_, value) in enumerate(split_text):
for type_, value in split_text:
if type_ == "text":
tokens = llama.tokenize(value.encode("utf8"), add_bos=i == 0)
tokens = llama.tokenize(value.encode("utf8"), add_bos=False, special=True)
if llama.n_tokens + len(tokens) > llama.n_ctx():
raise ValueError("Prompt exceeds n_ctx") # TODO: Fix
llama.eval(tokens)
@ -2644,6 +2649,8 @@ class Llava15ChatHandler:
llama.n_batch,
n_past_p,
)
# Required to avoid issues with hf tokenizer
llama.input_ids[llama.n_tokens : n_past.value] = -1
llama.n_tokens = n_past.value
# Get prompt tokens to avoid a cache miss
@ -3033,6 +3040,7 @@ class NanoLlavaChatHandler(Llava15ChatHandler):
# Answer the question<|im_end|><|im_start|>user
# <image>
# What is the picture about?<|im_end|><|im_start|>assistant
DEFAULT_SYSTEM_MESSAGE = "Answer the question"
CHAT_FORMAT = (
"{% for message in messages %}"