fix: Make leading bos_token optional for image chat formats, fix nanollava system message
This commit is contained in:
parent
2a39b99575
commit
77122638b4
1 changed files with 11 additions and 3 deletions
|
@ -2603,7 +2603,12 @@ class Llava15ChatHandler:
|
|||
|
||||
image_urls = self.get_image_urls(messages)
|
||||
template = jinja2.Template(self.CHAT_FORMAT)
|
||||
text = template.render(messages=messages, add_generation_prompt=True)
|
||||
text = template.render(
|
||||
messages=messages,
|
||||
add_generation_prompt=True,
|
||||
eos_token=llama.detokenize([llama.token_eos()]),
|
||||
bos_token=llama.detokenize([llama.token_bos()]),
|
||||
)
|
||||
split_text = self.split_text_on_image_urls(text, image_urls)
|
||||
|
||||
def embed_image_bytes(image_bytes: bytes):
|
||||
|
@ -2624,9 +2629,9 @@ class Llava15ChatHandler:
|
|||
|
||||
# Evaluate prompt
|
||||
llama.reset()
|
||||
for i, (type_, value) in enumerate(split_text):
|
||||
for type_, value in split_text:
|
||||
if type_ == "text":
|
||||
tokens = llama.tokenize(value.encode("utf8"), add_bos=i == 0)
|
||||
tokens = llama.tokenize(value.encode("utf8"), add_bos=False, special=True)
|
||||
if llama.n_tokens + len(tokens) > llama.n_ctx():
|
||||
raise ValueError("Prompt exceeds n_ctx") # TODO: Fix
|
||||
llama.eval(tokens)
|
||||
|
@ -2644,6 +2649,8 @@ class Llava15ChatHandler:
|
|||
llama.n_batch,
|
||||
n_past_p,
|
||||
)
|
||||
# Required to avoid issues with hf tokenizer
|
||||
llama.input_ids[llama.n_tokens : n_past.value] = -1
|
||||
llama.n_tokens = n_past.value
|
||||
|
||||
# Get prompt tokens to avoid a cache miss
|
||||
|
@ -3033,6 +3040,7 @@ class NanoLlavaChatHandler(Llava15ChatHandler):
|
|||
# Answer the question<|im_end|><|im_start|>user
|
||||
# <image>
|
||||
# What is the picture about?<|im_end|><|im_start|>assistant
|
||||
DEFAULT_SYSTEM_MESSAGE = "Answer the question"
|
||||
|
||||
CHAT_FORMAT = (
|
||||
"{% for message in messages %}"
|
||||
|
|
Loading…
Reference in a new issue