diff --git a/llama_cpp/llama_chat_format.py b/llama_cpp/llama_chat_format.py index 9da6b98..beff6ca 100644 --- a/llama_cpp/llama_chat_format.py +++ b/llama_cpp/llama_chat_format.py @@ -2603,7 +2603,12 @@ class Llava15ChatHandler: image_urls = self.get_image_urls(messages) template = jinja2.Template(self.CHAT_FORMAT) - text = template.render(messages=messages, add_generation_prompt=True) + text = template.render( + messages=messages, + add_generation_prompt=True, + eos_token=llama.detokenize([llama.token_eos()]), + bos_token=llama.detokenize([llama.token_bos()]), + ) split_text = self.split_text_on_image_urls(text, image_urls) def embed_image_bytes(image_bytes: bytes): @@ -2624,9 +2629,9 @@ class Llava15ChatHandler: # Evaluate prompt llama.reset() - for i, (type_, value) in enumerate(split_text): + for type_, value in split_text: if type_ == "text": - tokens = llama.tokenize(value.encode("utf8"), add_bos=i == 0) + tokens = llama.tokenize(value.encode("utf8"), add_bos=False, special=True) if llama.n_tokens + len(tokens) > llama.n_ctx(): raise ValueError("Prompt exceeds n_ctx") # TODO: Fix llama.eval(tokens) @@ -2644,6 +2649,8 @@ class Llava15ChatHandler: llama.n_batch, n_past_p, ) + # Required to avoid issues with hf tokenizer + llama.input_ids[llama.n_tokens : n_past.value] = -1 llama.n_tokens = n_past.value # Get prompt tokens to avoid a cache miss @@ -3033,6 +3040,7 @@ class NanoLlavaChatHandler(Llava15ChatHandler): # Answer the question<|im_end|><|im_start|>user # # What is the picture about?<|im_end|><|im_start|>assistant + DEFAULT_SYSTEM_MESSAGE = "Answer the question" CHAT_FORMAT = ( "{% for message in messages %}"