fix: Make leading bos_token optional for image chat formats, fix nanollava system message
This commit is contained in:
parent
2a39b99575
commit
77122638b4
1 changed files with 11 additions and 3 deletions
|
@ -2603,7 +2603,12 @@ class Llava15ChatHandler:
|
||||||
|
|
||||||
image_urls = self.get_image_urls(messages)
|
image_urls = self.get_image_urls(messages)
|
||||||
template = jinja2.Template(self.CHAT_FORMAT)
|
template = jinja2.Template(self.CHAT_FORMAT)
|
||||||
text = template.render(messages=messages, add_generation_prompt=True)
|
text = template.render(
|
||||||
|
messages=messages,
|
||||||
|
add_generation_prompt=True,
|
||||||
|
eos_token=llama.detokenize([llama.token_eos()]),
|
||||||
|
bos_token=llama.detokenize([llama.token_bos()]),
|
||||||
|
)
|
||||||
split_text = self.split_text_on_image_urls(text, image_urls)
|
split_text = self.split_text_on_image_urls(text, image_urls)
|
||||||
|
|
||||||
def embed_image_bytes(image_bytes: bytes):
|
def embed_image_bytes(image_bytes: bytes):
|
||||||
|
@ -2624,9 +2629,9 @@ class Llava15ChatHandler:
|
||||||
|
|
||||||
# Evaluate prompt
|
# Evaluate prompt
|
||||||
llama.reset()
|
llama.reset()
|
||||||
for i, (type_, value) in enumerate(split_text):
|
for type_, value in split_text:
|
||||||
if type_ == "text":
|
if type_ == "text":
|
||||||
tokens = llama.tokenize(value.encode("utf8"), add_bos=i == 0)
|
tokens = llama.tokenize(value.encode("utf8"), add_bos=False, special=True)
|
||||||
if llama.n_tokens + len(tokens) > llama.n_ctx():
|
if llama.n_tokens + len(tokens) > llama.n_ctx():
|
||||||
raise ValueError("Prompt exceeds n_ctx") # TODO: Fix
|
raise ValueError("Prompt exceeds n_ctx") # TODO: Fix
|
||||||
llama.eval(tokens)
|
llama.eval(tokens)
|
||||||
|
@ -2644,6 +2649,8 @@ class Llava15ChatHandler:
|
||||||
llama.n_batch,
|
llama.n_batch,
|
||||||
n_past_p,
|
n_past_p,
|
||||||
)
|
)
|
||||||
|
# Required to avoid issues with hf tokenizer
|
||||||
|
llama.input_ids[llama.n_tokens : n_past.value] = -1
|
||||||
llama.n_tokens = n_past.value
|
llama.n_tokens = n_past.value
|
||||||
|
|
||||||
# Get prompt tokens to avoid a cache miss
|
# Get prompt tokens to avoid a cache miss
|
||||||
|
@ -3033,6 +3040,7 @@ class NanoLlavaChatHandler(Llava15ChatHandler):
|
||||||
# Answer the question<|im_end|><|im_start|>user
|
# Answer the question<|im_end|><|im_start|>user
|
||||||
# <image>
|
# <image>
|
||||||
# What is the picture about?<|im_end|><|im_start|>assistant
|
# What is the picture about?<|im_end|><|im_start|>assistant
|
||||||
|
DEFAULT_SYSTEM_MESSAGE = "Answer the question"
|
||||||
|
|
||||||
CHAT_FORMAT = (
|
CHAT_FORMAT = (
|
||||||
"{% for message in messages %}"
|
"{% for message in messages %}"
|
||||||
|
|
Loading…
Reference in a new issue