From 31b1d95a6c19f5b615a3286069f181a415f872e8 Mon Sep 17 00:00:00 2001 From: Andrei Betlen Date: Thu, 2 May 2024 11:32:18 -0400 Subject: [PATCH] feat: Add llama-3-vision-alpha chat format --- llama_cpp/llama_chat_format.py | 64 ++++++++++++++++++++++++++++++++-- llama_cpp/server/model.py | 14 ++++++++ 2 files changed, 76 insertions(+), 2 deletions(-) diff --git a/llama_cpp/llama_chat_format.py b/llama_cpp/llama_chat_format.py index 0af410a..a17b86b 100644 --- a/llama_cpp/llama_chat_format.py +++ b/llama_cpp/llama_chat_format.py @@ -2165,7 +2165,7 @@ def functionary_v1_v2_chat_handler( class Llava15ChatHandler: - DEFAULT_SYSTEM_MESSAGE = "A chat between a curious human and an artificial intelligence assistant. The assistant gives helpful, detailed, and polite answers to the human's questions." + DEFAULT_SYSTEM_MESSAGE: Optional[str] = "A chat between a curious human and an artificial intelligence assistant. The assistant gives helpful, detailed, and polite answers to the human's questions." CHAT_FORMAT = ( "{% for message in messages %}" @@ -2288,7 +2288,7 @@ class Llava15ChatHandler: assert self.clip_ctx is not None system_prompt = _get_system_message(messages) - if system_prompt == "": + if system_prompt == "" and self.DEFAULT_SYSTEM_MESSAGE is not None: messages = [llama_types.ChatCompletionRequestSystemMessage(role="system", content=self.DEFAULT_SYSTEM_MESSAGE)] + messages image_urls = self.get_image_urls(messages) @@ -2771,6 +2771,66 @@ class NanoLlavaChatHandler(Llava15ChatHandler): "{% endif %}" ) +class Llama3VisionAlpha(Llava15ChatHandler): + # question = "" + q + + # prompt = f"<|start_header_id|>user<|end_header_id|>\n\n{question}<|eot_id|><|start_header_id|>assistant<|end_header_id|>\n\n" + DEFAULT_SYSTEM_MESSAGE = None + + CHAT_FORMAT = ( + "{% for message in messages %}" + + "<|start_header_id|>" + + "{% if message.role == 'user' %}" + + "user<|end_header_id|>\n\n" + + "{% if message.content is iterable %}" + + # + "{% for content in message.content %}" + "{% if content.type == 'image_url' %}" + "{% if content.image_url is string %}" + "{{ content.image_url }}" + "{% endif %}" + "{% if content.image_url is mapping %}" + "{{ content.image_url.url }}" + "{% endif %}" + "{% endif %}" + "{% endfor %}" + + # Question: + "{% for content in message.content %}" + "{% if content.type == 'text' %}" + "{{ content.text }}" + "{% endif %}" + "{% endfor %}" + + "{% endif %}" + + # Question: + "{% if message.content is string %}" + "{{ message.content }}" + "{% endif %}" + + "{% endif %}" + + # Answer: + "{% if message.role == 'assistant' %}" + "assistant<|end_header_id|>\n\n" + "{{ message.content }}" + "{% endif %}" + + "<|eot_id|>" + + "{% endfor %}" + + # Generation prompt + "{% if add_generation_prompt %}" + "<|start_header_id|>assistant<|end_header_id|>\n\n" + "{% endif %}" + ) @register_chat_completion_handler("chatml-function-calling") def chatml_function_calling( diff --git a/llama_cpp/server/model.py b/llama_cpp/server/model.py index 1ad592d..e102fad 100644 --- a/llama_cpp/server/model.py +++ b/llama_cpp/server/model.py @@ -140,6 +140,20 @@ class LlamaProxy: chat_handler = llama_cpp.llama_chat_format.NanoLlavaChatHandler( clip_model_path=settings.clip_model_path, verbose=settings.verbose ) + elif settings.chat_format == "llama-3-vision-alpha": + assert settings.clip_model_path is not None, "clip model not found" + if settings.hf_model_repo_id is not None: + chat_handler = ( + llama_cpp.llama_chat_format.Llama3VisionAlpha.from_pretrained( + repo_id=settings.hf_model_repo_id, + filename=settings.clip_model_path, + verbose=settings.verbose, + ) + ) + else: + chat_handler = llama_cpp.llama_chat_format.Llama3VisionAlpha( + clip_model_path=settings.clip_model_path, verbose=settings.verbose + ) elif settings.chat_format == "hf-autotokenizer": assert ( settings.hf_pretrained_model_name_or_path is not None