feat: Add llama-3-vision-alpha chat format

This commit is contained in:
Andrei Betlen 2024-05-02 11:32:18 -04:00
parent 4f01c452b6
commit 31b1d95a6c
2 changed files with 76 additions and 2 deletions

View file

@ -2165,7 +2165,7 @@ def functionary_v1_v2_chat_handler(
class Llava15ChatHandler: class Llava15ChatHandler:
DEFAULT_SYSTEM_MESSAGE = "A chat between a curious human and an artificial intelligence assistant. The assistant gives helpful, detailed, and polite answers to the human's questions." DEFAULT_SYSTEM_MESSAGE: Optional[str] = "A chat between a curious human and an artificial intelligence assistant. The assistant gives helpful, detailed, and polite answers to the human's questions."
CHAT_FORMAT = ( CHAT_FORMAT = (
"{% for message in messages %}" "{% for message in messages %}"
@ -2288,7 +2288,7 @@ class Llava15ChatHandler:
assert self.clip_ctx is not None assert self.clip_ctx is not None
system_prompt = _get_system_message(messages) system_prompt = _get_system_message(messages)
if system_prompt == "": if system_prompt == "" and self.DEFAULT_SYSTEM_MESSAGE is not None:
messages = [llama_types.ChatCompletionRequestSystemMessage(role="system", content=self.DEFAULT_SYSTEM_MESSAGE)] + messages messages = [llama_types.ChatCompletionRequestSystemMessage(role="system", content=self.DEFAULT_SYSTEM_MESSAGE)] + messages
image_urls = self.get_image_urls(messages) image_urls = self.get_image_urls(messages)
@ -2771,6 +2771,66 @@ class NanoLlavaChatHandler(Llava15ChatHandler):
"{% endif %}" "{% endif %}"
) )
class Llama3VisionAlpha(Llava15ChatHandler):
# question = "<image>" + q
# prompt = f"<|start_header_id|>user<|end_header_id|>\n\n{question}<|eot_id|><|start_header_id|>assistant<|end_header_id|>\n\n"
DEFAULT_SYSTEM_MESSAGE = None
CHAT_FORMAT = (
"{% for message in messages %}"
"<|start_header_id|>"
"{% if message.role == 'user' %}"
"user<|end_header_id|>\n\n"
"{% if message.content is iterable %}"
# <image>
"{% for content in message.content %}"
"{% if content.type == 'image_url' %}"
"{% if content.image_url is string %}"
"{{ content.image_url }}"
"{% endif %}"
"{% if content.image_url is mapping %}"
"{{ content.image_url.url }}"
"{% endif %}"
"{% endif %}"
"{% endfor %}"
# Question:
"{% for content in message.content %}"
"{% if content.type == 'text' %}"
"{{ content.text }}"
"{% endif %}"
"{% endfor %}"
"{% endif %}"
# Question:
"{% if message.content is string %}"
"{{ message.content }}"
"{% endif %}"
"{% endif %}"
# Answer:
"{% if message.role == 'assistant' %}"
"assistant<|end_header_id|>\n\n"
"{{ message.content }}"
"{% endif %}"
"<|eot_id|>"
"{% endfor %}"
# Generation prompt
"{% if add_generation_prompt %}"
"<|start_header_id|>assistant<|end_header_id|>\n\n"
"{% endif %}"
)
@register_chat_completion_handler("chatml-function-calling") @register_chat_completion_handler("chatml-function-calling")
def chatml_function_calling( def chatml_function_calling(

View file

@ -140,6 +140,20 @@ class LlamaProxy:
chat_handler = llama_cpp.llama_chat_format.NanoLlavaChatHandler( chat_handler = llama_cpp.llama_chat_format.NanoLlavaChatHandler(
clip_model_path=settings.clip_model_path, verbose=settings.verbose clip_model_path=settings.clip_model_path, verbose=settings.verbose
) )
elif settings.chat_format == "llama-3-vision-alpha":
assert settings.clip_model_path is not None, "clip model not found"
if settings.hf_model_repo_id is not None:
chat_handler = (
llama_cpp.llama_chat_format.Llama3VisionAlpha.from_pretrained(
repo_id=settings.hf_model_repo_id,
filename=settings.clip_model_path,
verbose=settings.verbose,
)
)
else:
chat_handler = llama_cpp.llama_chat_format.Llama3VisionAlpha(
clip_model_path=settings.clip_model_path, verbose=settings.verbose
)
elif settings.chat_format == "hf-autotokenizer": elif settings.chat_format == "hf-autotokenizer":
assert ( assert (
settings.hf_pretrained_model_name_or_path is not None settings.hf_pretrained_model_name_or_path is not None