feat: Generic Chat Formats, Tool Calling, and Huggingface Pull Support for Multimodal Models (Obsidian, LLaVA1.6, Moondream) (#1147)

* Test dummy image tags in chat templates

* Format and improve  types for llava_cpp.py

* Add from_pretrained support to llava chat format.

* Refactor llava chat format to use a jinja2

* Revert chat format test

* Add moondream support (wip)

* Update moondream chat format

* Update moondream chat format

* Update moondream prompt

* Add function calling support

* Cache last image embed

* Add Llava1.6 support

* Add nanollava support

* Add obisidian support

* Remove unnecessary import

* Re-order multimodal chat formats

* Logits all no longer required for multi-modal models

* Update README.md

* Update docs

* Update README

* Fix typo

* Update README

* Fix typo
This commit is contained in:
Andrei 2024-04-30 01:35:38 -04:00 committed by GitHub
parent 97fb860eba
commit fe2da09538
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
5 changed files with 712 additions and 146 deletions

View file

@ -490,14 +490,15 @@ Due to discrepancies between llama.cpp and HuggingFace's tokenizers, it is requi
### Multi-modal Models ### Multi-modal Models
`llama-cpp-python` supports the llava1.5 family of multi-modal models which allow the language model to `llama-cpp-python` supports such as llava1.5 which allow the language model to read information from both text and images.
read information from both text and images.
You'll first need to download one of the available multi-modal models in GGUF format: You'll first need to download one of the available multi-modal models in GGUF format:
- [llava-v1.5-7b](https://huggingface.co/mys/ggml_llava-v1.5-7b) - [llava-v1.5-7b](https://huggingface.co/mys/ggml_llava-v1.5-7b)
- [llava-v1.5-13b](https://huggingface.co/mys/ggml_llava-v1.5-13b) - [llava-v1.5-13b](https://huggingface.co/mys/ggml_llava-v1.5-13b)
- [bakllava-1-7b](https://huggingface.co/mys/ggml_bakllava-1) - [bakllava-1-7b](https://huggingface.co/mys/ggml_bakllava-1)
- [llava-v1.6-34b](https://huggingface.co/cjpais/llava-v1.6-34B-gguf)
- [moondream2](https://huggingface.co/vikhyatk/moondream2)
Then you'll need to use a custom chat handler to load the clip model and process the chat messages and images. Then you'll need to use a custom chat handler to load the clip model and process the chat messages and images.
@ -509,7 +510,6 @@ Then you'll need to use a custom chat handler to load the clip model and process
model_path="./path/to/llava/llama-model.gguf", model_path="./path/to/llava/llama-model.gguf",
chat_handler=chat_handler, chat_handler=chat_handler,
n_ctx=2048, # n_ctx should be increased to accomodate the image embedding n_ctx=2048, # n_ctx should be increased to accomodate the image embedding
logits_all=True,# needed to make llava work
) )
>>> llm.create_chat_completion( >>> llm.create_chat_completion(
messages = [ messages = [
@ -517,14 +517,45 @@ Then you'll need to use a custom chat handler to load the clip model and process
{ {
"role": "user", "role": "user",
"content": [ "content": [
{"type": "image_url", "image_url": {"url": "https://.../image.png"}}, {"type" : "text", "text": "What's in this image?"},
{"type" : "text", "text": "Describe this image in detail please."} {"type": "image_url", "image_url": {"url": "https://upload.wikimedia.org/wikipedia/commons/thumb/d/dd/Gfp-wisconsin-madison-the-nature-boardwalk.jpg/2560px-Gfp-wisconsin-madison-the-nature-boardwalk.jpg" } }
] ]
} }
] ]
) )
``` ```
You can also pull the model from the Hugging Face Hub using the `from_pretrained` method.
```python
>>> from llama_cpp import Llama
>>> from llama_cpp.llama_chat_format import MoondreamChatHandler
>>> chat_handler = MoondreamChatHandler.from_pretrained(
repo_id="vikhyatk/moondream2",
filename="*mmproj*",
)
>>> llm = Llama.from_pretrained(
repo_id="vikhyatk/moondream2"
filename="*text-model*",
chat_handler=chat_handler,
n_ctx=2048, # n_ctx should be increased to accomodate the image embedding
)
>>> llm.create_chat_completion(
messages = [
{
"role": "user",
"content": [
{"type" : "text", "text": "What's in this image?"},
{"type": "image_url", "image_url": {"url": "https://upload.wikimedia.org/wikipedia/commons/thumb/d/dd/Gfp-wisconsin-madison-the-nature-boardwalk.jpg/2560px-Gfp-wisconsin-madison-the-nature-boardwalk.jpg" } }
]
}
]
)
```
**Note**: Multi-modal models also support tool calling and JSON mode.
<details> <details>
<summary>Loading a Local Image</summary> <summary>Loading a Local Image</summary>

View file

@ -98,6 +98,8 @@ You'll first need to download one of the available multi-modal models in GGUF fo
- [llava-v1.5-7b](https://huggingface.co/mys/ggml_llava-v1.5-7b) - [llava-v1.5-7b](https://huggingface.co/mys/ggml_llava-v1.5-7b)
- [llava-v1.5-13b](https://huggingface.co/mys/ggml_llava-v1.5-13b) - [llava-v1.5-13b](https://huggingface.co/mys/ggml_llava-v1.5-13b)
- [bakllava-1-7b](https://huggingface.co/mys/ggml_bakllava-1) - [bakllava-1-7b](https://huggingface.co/mys/ggml_bakllava-1)
- [llava-v1.6-34b](https://huggingface.co/cjpais/llava-v1.6-34B-gguf)
- [moondream2](https://huggingface.co/vikhyatk/moondream2)
Then when you run the server you'll need to also specify the path to the clip model used for image embedding and the `llava-1-5` chat_format Then when you run the server you'll need to also specify the path to the clip model used for image embedding and the `llava-1-5` chat_format

View file

@ -6,6 +6,8 @@ import ctypes
import dataclasses import dataclasses
import random import random
import string import string
from contextlib import ExitStack
from typing import Any, Dict, Iterator, List, Literal, Optional, Tuple, Union, Protocol, cast from typing import Any, Dict, Iterator, List, Literal, Optional, Tuple, Union, Protocol, cast
import jinja2 import jinja2
@ -2163,42 +2165,80 @@ def functionary_v1_v2_chat_handler(
class Llava15ChatHandler: class Llava15ChatHandler:
_clip_free = None DEFAULT_SYSTEM_MESSAGE = "A chat between a curious human and an artificial intelligence assistant. The assistant gives helpful, detailed, and polite answers to the human's questions."
CHAT_FORMAT = (
"{% for message in messages %}"
"{% if message.role == 'system' %}"
"{{ message.content }}"
"{% endif %}"
"{% if message.role == 'user' %}"
"{% if message.content is string %}"
"\nUSER: {{ message.content }}"
"{% elif message.content is iterable %}"
"\nUSER: "
"{% for content in message.content %}"
"{% if content.type == 'text' %}"
"{{ content.text }}"
"{% endif %}"
"{% if content.type == 'image_url' and content.image_url is string %}"
"{{ content.image_url }}"
"{% endif %}"
"{% if content.type == 'image_url' and content.image_url is mapping %}"
"{{ content.image_url.url }}"
"{% endif %}"
"{% endfor %}"
"{% endif %}"
"{% endif %}"
"{% if message.role == 'assistant' and message.content is not none %}"
"\nASSISTANT: {{ message.content }}"
"{% endif %}"
"{% endfor %}"
"{% if add_generation_prompt %}"
"\nASSISTANT: "
"{% endif %}"
)
def __init__(self, clip_model_path: str, verbose: bool = False): def __init__(self, clip_model_path: str, verbose: bool = False):
import llama_cpp.llava_cpp as llava_cpp import llama_cpp.llava_cpp as llava_cpp
self._llava_cpp = llava_cpp
self.clip_model_path = clip_model_path self.clip_model_path = clip_model_path
self.verbose = verbose self.verbose = verbose
self._clip_free = self._llava_cpp._libllava.clip_free # type: ignore
self._llava_cpp = llava_cpp # TODO: Fix
self._exit_stack = ExitStack()
self._last_image_embed: Optional[llava_cpp.CtypesPointer[llava_cpp.llava_image_embed]] = None
self._last_image_hash: Optional[int] = None
if not os.path.exists(clip_model_path): if not os.path.exists(clip_model_path):
raise ValueError(f"Clip model path does not exist: {clip_model_path}") raise ValueError(f"Clip model path does not exist: {clip_model_path}")
with suppress_stdout_stderr(disable=self.verbose): with suppress_stdout_stderr(disable=self.verbose):
self.clip_ctx = self._llava_cpp.clip_model_load( clip_ctx = self._llava_cpp.clip_model_load(
self.clip_model_path.encode(), 0 self.clip_model_path.encode(), 0
) )
def __del__(self): if clip_ctx is None:
with suppress_stdout_stderr(disable=self.verbose): raise ValueError(f"Failed to load clip model: {clip_model_path}")
if self.clip_ctx is not None and self._clip_free is not None:
self._clip_free(self.clip_ctx) self.clip_ctx = clip_ctx
self.clip_ctx = None
def clip_free():
with suppress_stdout_stderr(disable=self.verbose):
self._llava_cpp.clip_free(self.clip_ctx)
self._exit_stack.callback(clip_free)
def last_image_embed_free():
with suppress_stdout_stderr(disable=self.verbose):
if self._last_image_embed is not None:
self._llava_cpp.llava_image_embed_free(self._last_image_embed)
self._last_image_embed = None
self._exit_stack.callback(last_image_embed_free)
def load_image(self, image_url: str) -> bytes: def load_image(self, image_url: str) -> bytes:
if image_url.startswith("data:"): return self._load_image(image_url)
import base64
image_bytes = base64.b64decode(image_url.split(",")[1])
return image_bytes
else:
import urllib.request
with urllib.request.urlopen(image_url) as f:
image_bytes = f.read()
return image_bytes
def __call__( def __call__(
self, self,
@ -2216,6 +2256,7 @@ class Llava15ChatHandler:
typical_p: float = 1.0, typical_p: float = 1.0,
stream: bool = False, stream: bool = False,
stop: Optional[Union[str, List[str]]] = [], stop: Optional[Union[str, List[str]]] = [],
seed: Optional[int] = None,
response_format: Optional[ response_format: Optional[
llama_types.ChatCompletionRequestResponseFormat llama_types.ChatCompletionRequestResponseFormat
] = None, ] = None,
@ -2230,121 +2271,477 @@ class Llava15ChatHandler:
model: Optional[str] = None, model: Optional[str] = None,
logits_processor: Optional[llama.LogitsProcessorList] = None, logits_processor: Optional[llama.LogitsProcessorList] = None,
grammar: Optional[llama.LlamaGrammar] = None, grammar: Optional[llama.LlamaGrammar] = None,
logit_bias: Optional[Dict[str, float]] = None,
logprobs: Optional[bool] = None,
top_logprobs: Optional[int] = None,
**kwargs, # type: ignore **kwargs, # type: ignore
) -> Union[ ) -> Union[
llama_types.CreateChatCompletionResponse, llama_types.CreateChatCompletionResponse,
Iterator[llama_types.CreateChatCompletionStreamResponse], Iterator[llama_types.CreateChatCompletionStreamResponse],
]: ]:
assert (
llama.context_params.logits_all is True
) # BUG: logits_all=True is required for llava
assert self.clip_ctx is not None assert self.clip_ctx is not None
system_prompt = _get_system_message(messages)
system_prompt = (
system_prompt
if system_prompt != ""
else "A chat between a curious human and an artificial intelligence assistant. The assistant gives helpful, detailed, and polite answers to the human's questions."
)
user_role = "\nUSER:"
assistant_role = "\nASSISTANT:"
llama.reset()
llama.eval(llama.tokenize(system_prompt.encode("utf8"), add_bos=True))
for message in messages:
if message["role"] == "user" and message["content"] is not None:
if isinstance(message["content"], str):
llama.eval(
llama.tokenize(
f"{user_role} {message['content']}".encode("utf8"),
add_bos=False,
)
)
else:
assert isinstance(message["content"], list)
llama.eval(
llama.tokenize(f"{user_role} ".encode("utf8"), add_bos=False)
)
for content in message["content"]:
if content["type"] == "text":
llama.eval(
llama.tokenize(
f"{content['text']}".encode("utf8"), add_bos=False
)
)
if content["type"] == "image_url":
image_bytes = (
self.load_image(content["image_url"]["url"])
if isinstance(content["image_url"], dict)
else self.load_image(content["image_url"])
)
import array
data_array = array.array("B", image_bytes) system_prompt = _get_system_message(messages)
c_ubyte_ptr = ( if system_prompt == "":
ctypes.c_ubyte * len(data_array) messages = [llama_types.ChatCompletionRequestSystemMessage(role="system", content=self.DEFAULT_SYSTEM_MESSAGE)] + messages
).from_buffer(data_array)
with suppress_stdout_stderr(disable=self.verbose): image_urls = self.get_image_urls(messages)
embed = ( template = jinja2.Template(self.CHAT_FORMAT)
self._llava_cpp.llava_image_embed_make_with_bytes( text = template.render(messages=messages, add_generation_prompt=True)
self.clip_ctx, split_text = self.split_text_on_image_urls(text, image_urls)
llama.context_params.n_threads,
c_ubyte_ptr, def embed_image_bytes(image_bytes: bytes):
len(image_bytes), if self._last_image_embed is not None and self._last_image_hash is not None and hash(image_bytes) == self._last_image_hash:
) return self._last_image_embed
) with suppress_stdout_stderr(disable=self.verbose):
try: embed = (
n_past = ctypes.c_int(llama.n_tokens) self._llava_cpp.llava_image_embed_make_with_bytes(
n_past_p = ctypes.pointer(n_past) self.clip_ctx,
with suppress_stdout_stderr(disable=self.verbose): llama.context_params.n_threads_batch,
self._llava_cpp.llava_eval_image_embed( (ctypes.c_uint8 * len(image_bytes)).from_buffer(bytearray(image_bytes)),
llama.ctx, len(image_bytes),
embed,
llama.n_batch,
n_past_p,
)
assert llama.n_ctx() >= n_past.value
llama.n_tokens = n_past.value
finally:
with suppress_stdout_stderr(disable=self.verbose):
self._llava_cpp.llava_image_embed_free(embed)
if message["role"] == "assistant" and message["content"] is not None:
llama.eval(
llama.tokenize(
f"ASSISTANT: {message['content']}".encode("utf8"), add_bos=False
) )
) )
assert llama.n_ctx() >= llama.n_tokens self._last_image_embed = embed
llama.eval(llama.tokenize(f"{assistant_role}".encode("utf8"), add_bos=False)) self._last_image_hash = hash(image_bytes)
assert llama.n_ctx() >= llama.n_tokens return embed
# Evaluate prompt
llama.reset()
for i, (type_, value) in enumerate(split_text):
if type_ == "text":
tokens = llama.tokenize(value.encode("utf8"), add_bos=i == 0)
if llama.n_tokens + len(tokens) > llama.n_ctx():
raise ValueError("Prompt exceeds n_ctx") # TODO: Fix
llama.eval(tokens)
else:
image_bytes = self.load_image(value)
embed = embed_image_bytes(image_bytes)
if llama.n_tokens + embed.contents.n_image_pos > llama.n_ctx():
raise ValueError("Prompt exceeds n_ctx") # TODO: Fix
n_past = ctypes.c_int(llama.n_tokens)
n_past_p = ctypes.pointer(n_past)
with suppress_stdout_stderr(disable=self.verbose):
self._llava_cpp.llava_eval_image_embed(
llama.ctx,
embed,
llama.n_batch,
n_past_p,
)
llama.n_tokens = n_past.value
# Get prompt tokens to avoid a cache miss
prompt = llama.input_ids[: llama.n_tokens].tolist() prompt = llama.input_ids[: llama.n_tokens].tolist()
if response_format is not None and response_format["type"] == "json_object": if response_format is not None and response_format["type"] == "json_object":
grammar = _grammar_for_response_format(response_format) grammar = _grammar_for_response_format(response_format)
return _convert_completion_to_chat( # Convert legacy functions to tools
llama.create_completion( if functions is not None:
prompt=prompt, tools = [
temperature=temperature, {
top_p=top_p, "type": "function",
top_k=top_k, "function": function,
min_p=min_p, }
typical_p=typical_p, for function in functions
stream=stream, ]
stop=stop,
max_tokens=max_tokens, # Convert legacy function_call to tool_choice
presence_penalty=presence_penalty, if function_call is not None:
frequency_penalty=frequency_penalty, if isinstance(function_call, str) and (
repeat_penalty=repeat_penalty, function_call == "none" or function_call == "auto"
tfs_z=tfs_z, ):
mirostat_mode=mirostat_mode, tool_choice = function_call
mirostat_tau=mirostat_tau, if isinstance(function_call, dict) and "name" in function_call:
mirostat_eta=mirostat_eta, tool_choice = {
model=model, "type": "function",
logits_processor=logits_processor, "function": {
grammar=grammar, "name": function_call["name"],
), },
}
tool = None
if tool_choice is not None and isinstance(tool_choice, dict) and tools is not None:
name = tool_choice["function"]["name"]
tool = next((t for t in tools if t["function"]["name"] == name), None)
if tool is None:
raise ValueError(f"Tool choice '{name}' not found in tools.")
schema = tool["function"]["parameters"]
try:
# create grammar from json schema
grammar = llama_grammar.LlamaGrammar.from_json_schema(
json.dumps(schema), verbose=llama.verbose
)
except Exception as e:
grammar = llama_grammar.LlamaGrammar.from_string(
llama_grammar.JSON_GBNF, verbose=llama.verbose
)
completion_or_chunks = llama.create_completion(
prompt=prompt,
temperature=temperature,
top_p=top_p,
top_k=top_k,
min_p=min_p,
typical_p=typical_p,
logprobs=top_logprobs if logprobs else None,
stream=stream, stream=stream,
stop=stop,
seed=seed,
max_tokens=max_tokens,
presence_penalty=presence_penalty,
frequency_penalty=frequency_penalty,
repeat_penalty=repeat_penalty,
tfs_z=tfs_z,
mirostat_mode=mirostat_mode,
mirostat_tau=mirostat_tau,
mirostat_eta=mirostat_eta,
model=model,
logits_processor=logits_processor,
grammar=grammar,
logit_bias=logit_bias,
) )
if tool is not None:
tool_name = tool["function"]["name"]
return _convert_completion_to_chat_function(
tool_name, completion_or_chunks, stream
)
return _convert_completion_to_chat(completion_or_chunks, stream=stream)
@staticmethod
def _load_image(image_url: str) -> bytes:
# TODO: Add Pillow support for other image formats beyond (jpg, png)
if image_url.startswith("data:"):
import base64
image_bytes = base64.b64decode(image_url.split(",")[1])
return image_bytes
else:
import urllib.request
with urllib.request.urlopen(image_url) as f:
image_bytes = f.read()
return image_bytes
@staticmethod
def get_image_urls(messages: List[llama_types.ChatCompletionRequestMessage]):
image_urls: List[str] = []
for message in messages:
if message["role"] == "user":
if message["content"] is None:
continue
for content in message["content"]:
if isinstance(content, dict) and "type" in content:
if content["type"] == "image_url":
if (
isinstance(content["image_url"], dict)
and "url" in content["image_url"]
):
image_urls.append(content["image_url"]["url"])
else:
image_urls.append(content["image_url"])
return image_urls
@staticmethod
def split_text_on_image_urls(text: str, image_urls: List[str]):
def find_first(s: str, substrs: List[str]):
for i, substr in enumerate(substrs):
pos = s.find(substr)
if pos != -1:
return pos, i
return None, None
split_text: List[Tuple[Literal["text", "image_url"], str]] = []
remaining = text
while remaining:
# Find first image_url
pos, i = find_first(remaining, image_urls)
if pos is not None and i is not None:
if pos > 0:
split_text.append(("text", remaining[:pos]))
split_text.append(("image_url", image_urls[i]))
remaining = remaining[pos + len(image_urls[i]) :]
else:
split_text.append(("text", remaining))
remaining = ""
return split_text
@classmethod
def from_pretrained(
cls,
repo_id: str,
filename: Optional[str],
local_dir: Optional[Union[str, os.PathLike[str]]] = None,
local_dir_use_symlinks: Union[bool, Literal["auto"]] = "auto",
cache_dir: Optional[Union[str, os.PathLike[str]]] = None,
**kwargs: Any,
) -> "Llava15ChatHandler":
import fnmatch
from pathlib import Path
try:
from huggingface_hub import hf_hub_download, HfFileSystem # type: ignore
from huggingface_hub.utils import validate_repo_id # type: ignore
except ImportError:
raise ImportError(
"Llama.from_pretrained requires the huggingface-hub package. "
"You can install it with `pip install huggingface-hub`."
)
validate_repo_id(repo_id)
hffs = HfFileSystem()
files = [
file["name"] if isinstance(file, dict) else file
for file in hffs.ls(repo_id) # type: ignore
]
# split each file into repo_id, subfolder, filename
file_list: List[str] = []
for file in files:
rel_path = Path(file).relative_to(repo_id)
file_list.append(str(rel_path))
matching_files = [file for file in file_list if fnmatch.fnmatch(file, filename)] # type: ignore
if len(matching_files) == 0:
raise ValueError(
f"No file found in {repo_id} that match {filename}\n\n"
f"Available Files:\n{json.dumps(file_list)}"
)
if len(matching_files) > 1:
raise ValueError(
f"Multiple files found in {repo_id} matching {filename}\n\n"
f"Available Files:\n{json.dumps(files)}"
)
(matching_file,) = matching_files
subfolder = str(Path(matching_file).parent)
filename = Path(matching_file).name
# download the file
hf_hub_download(
repo_id=repo_id,
filename=filename,
subfolder=subfolder,
local_dir=cast(Union[str, Path, None], local_dir),
local_dir_use_symlinks=local_dir_use_symlinks,
cache_dir=cast(Union[str, Path, None], cache_dir),
)
if local_dir is None:
model_path = hf_hub_download(
repo_id=repo_id,
filename=filename,
subfolder=subfolder,
local_dir=local_dir,
local_dir_use_symlinks=local_dir_use_symlinks,
cache_dir=cast(Union[str, Path, None], cache_dir),
local_files_only=True,
)
else:
model_path = os.path.join(local_dir, filename)
return cls(
clip_model_path=model_path,
**kwargs,
)
class ObsidianChatHandler(Llava15ChatHandler):
# Prompt Format
# The model followed ChatML format. However, with ### as the seperator
# <|im_start|>user
# What is this sign about?\n<image>
# ###
# <|im_start|>assistant
# The sign is about bullying, and it is placed on a black background with a red background.
# ###
CHAT_FORMAT = (
"{% for message in messages %}"
# System message
"{% if message.role == 'system' %}"
"<|im_start|>system\n"
"{{ message.content }}\n"
"###\n"
"{% endif %}"
# User message
"{% if message.role == 'user' %}"
"<|im_start|>user\n"
"{% if message.content is string %}"
"{{ message.content }}"
"{% endif %}"
"{% if message.content is iterable %}"
"{% for content in message.content %}"
"{% if content.type == 'text' %}"
"{{ content.text }}"
"{% endif %}"
"{% if content.type == 'image_url' %}"
"{{ content.image_url }}"
"{% endif %}"
"{% endfor %}"
"{% endif %}"
"###\n"
"{% endif %}"
# Assistant message
"{% if message.role == 'assistant' %}"
"<|im_start|>assistant\n"
"{{ message.content }}"
"###\n"
"{% endif %}"
"{% endfor %}"
# Generation prompt
"{% if add_generation_prompt %}"
"<|im_start|>assistant\n"
"{% endif %}"
)
class MoondreamChatHandler(Llava15ChatHandler):
# Chat Format:
# f"<image>\n\n{chat_history}Question: {question}\n\nAnswer:"
CHAT_FORMAT = (
"{% for message in messages %}"
"{% if message.role == 'user' %}"
"{% if message.content is iterable %}"
"{% for content in message.content %}"
# <image>
"{% if content.type == 'image_url' %}"
"{% if content.image_url is string %}"
"{{ content.image_url }}\n\n"
"{% endif %}"
"{% if content.image_url is mapping %}"
"{{ content.image_url.url }}\n\n"
"{% endif %}"
"{% endif %}"
# Question:
"{% if content.type == 'text' %}"
"Question: {{ content.text }}\n\n"
"{% endif %}"
"{% endfor %}"
"{% endif %}"
# Question:
"{% if message.content is string %}"
"Question: {{ message.content }}\n\n"
"{% endif %}"
"{% endif %}"
# Answer:
"{% if message.role == 'assistant' %}"
"Answer:{{ message.content }}\n\n"
"{% endif %}"
"{% endfor %}"
# Generation prompt
"{% if add_generation_prompt %}"
"Answer:"
"{% endif %}"
)
class Llava16ChatHandler(Llava15ChatHandler):
DEFAULT_SYSTEM_MESSAGE = "A chat between a curious human and an artificial intelligence assistant. The assistant gives helpful, detailed, and polite answers to the human's questions. "
# Example prompt
# "A chat between a curious human and an artificial intelligence assistant. The assistant gives helpful, detailed, and polite answers to the human's questions. USER: <image>\nWhat is shown in this image? ASSISTANT:"
CHAT_FORMAT = (
"{% for message in messages %}"
"{% if message.role == 'system' %}"
"{{ message.content }}"
"{% endif %}"
"{% if message.role == 'user' %}"
"{% if message.content is iterable %}"
"{% for content in message.content %}"
# <image>
"{% if content.type == 'image_url' %}"
"{% if content.image_url is string %}"
"{{ content.image_url }}\n"
"{% endif %}"
"{% if content.image_url is mapping %}"
"{{ content.image_url.url }}\n"
"{% endif %}"
"{% endif %}"
# Question:
"{% if content.type == 'text' %}"
"{{ content.text }}"
"{% endif %}"
"{% endfor %}"
"{% endif %}"
# Question:
"{% if message.content is string %}"
"{{ message.content }}"
"{% endif %}"
"{% endif %}"
# Answer:
"{% if message.role == 'assistant' %}"
"{{ message.content }}"
"{% endif %}"
"{% endfor %}"
# Generation prompt
"{% if add_generation_prompt %}"
"Answer:"
"{% endif %}"
)
class NanoLlavaChatHandler(Llava15ChatHandler):
# Prompt Format
# The model follow the ChatML standard, however, without \n at the end of <|im_end|>:
# <|im_start|>system
# Answer the question<|im_end|><|im_start|>user
# <image>
# What is the picture about?<|im_end|><|im_start|>assistant
CHAT_FORMAT = (
"{% for message in messages %}"
# System message
"{% if message.role == 'system' %}"
"<|im_start|>system\n"
"{{ message.content }}"
"<|im_end|>"
"{% endif %}"
# User message
"{% if message.role == 'user' %}"
"<|im_start|>user\n"
"{% if message.content is string %}"
"{{ message.content }}"
"{% endif %}"
"{% if message.content is iterable %}"
"{% for content in message.content %}"
"{% if content.type == 'text' %}"
"{{ content.text }}"
"{% endif %}"
"{% if content.type == 'image_url' %}"
"{{ content.image_url }}"
"{% endif %}"
"{% endfor %}"
"{% endif %}"
"<|im_end|>"
"{% endif %}"
# Assistant message
"{% if message.role == 'assistant' %}"
"<|im_start|>assistant\n"
"{{ message.content }}"
"<|im_end|>"
"{% endif %}"
"{% endfor %}"
# Generation prompt
"{% if add_generation_prompt %}"
"<|im_start|>assistant\n"
"{% endif %}"
)
@register_chat_completion_handler("chatml-function-calling") @register_chat_completion_handler("chatml-function-calling")

View file

@ -1,3 +1,5 @@
from __future__ import annotations
import sys import sys
import os import os
import ctypes import ctypes
@ -14,10 +16,22 @@ from ctypes import (
Structure, Structure,
) )
import pathlib import pathlib
from typing import List, Union, NewType, Optional, TypeVar, Callable, Any from typing import (
List,
Union,
NewType,
Optional,
TypeVar,
Callable,
Any,
TYPE_CHECKING,
Generic,
)
from typing_extensions import TypeAlias
import llama_cpp.llama_cpp as llama_cpp import llama_cpp.llama_cpp as llama_cpp
# Load the library # Load the library
def _load_shared_library(lib_base_name: str): def _load_shared_library(lib_base_name: str):
# Construct the paths to the possible shared library names # Construct the paths to the possible shared library names
@ -62,7 +76,7 @@ def _load_shared_library(lib_base_name: str):
for _lib_path in _lib_paths: for _lib_path in _lib_paths:
if _lib_path.exists(): if _lib_path.exists():
try: try:
return ctypes.CDLL(str(_lib_path), **cdll_args) # type: ignore return ctypes.CDLL(str(_lib_path), **cdll_args) # type: ignore
except Exception as e: except Exception as e:
raise RuntimeError(f"Failed to load shared library '{_lib_path}': {e}") raise RuntimeError(f"Failed to load shared library '{_lib_path}': {e}")
@ -79,8 +93,27 @@ _libllava = _load_shared_library(_libllava_base_name)
# ctypes helper # ctypes helper
if TYPE_CHECKING:
CtypesCData = TypeVar("CtypesCData", bound=ctypes._CData) # type: ignore
CtypesArray: TypeAlias = ctypes.Array[CtypesCData] # type: ignore
CtypesPointer: TypeAlias = ctypes._Pointer[CtypesCData] # type: ignore
CtypesVoidPointer: TypeAlias = ctypes.c_void_p
class CtypesRef(Generic[CtypesCData]):
pass
CtypesPointerOrRef: TypeAlias = Union[
CtypesPointer[CtypesCData], CtypesRef[CtypesCData]
]
CtypesFuncPointer: TypeAlias = ctypes._FuncPointer # type: ignore
F = TypeVar("F", bound=Callable[..., Any]) F = TypeVar("F", bound=Callable[..., Any])
def ctypes_function_for_shared_library(lib: ctypes.CDLL): def ctypes_function_for_shared_library(lib: ctypes.CDLL):
def ctypes_function( def ctypes_function(
name: str, argtypes: List[Any], restype: Any, enabled: bool = True name: str, argtypes: List[Any], restype: Any, enabled: bool = True
@ -111,6 +144,7 @@ ctypes_function = ctypes_function_for_shared_library(_libllava)
clip_ctx_p = NewType("clip_ctx_p", int) clip_ctx_p = NewType("clip_ctx_p", int)
clip_ctx_p_ctypes = c_void_p clip_ctx_p_ctypes = c_void_p
# struct llava_image_embed { # struct llava_image_embed {
# float * embed; # float * embed;
# int n_image_pos; # int n_image_pos;
@ -121,36 +155,72 @@ class llava_image_embed(Structure):
("n_image_pos", c_int), ("n_image_pos", c_int),
] ]
# /** sanity check for clip <-> llava embed size match */ # /** sanity check for clip <-> llava embed size match */
# LLAVA_API bool llava_validate_embed_size(const llama_context * ctx_llama, const clip_ctx * ctx_clip); # LLAVA_API bool llava_validate_embed_size(const llama_context * ctx_llama, const clip_ctx * ctx_clip);
@ctypes_function("llava_validate_embed_size", [llama_cpp.llama_context_p_ctypes, clip_ctx_p_ctypes], c_bool) @ctypes_function(
def llava_validate_embed_size(ctx_llama: llama_cpp.llama_context_p, ctx_clip: clip_ctx_p, /) -> bool: "llava_validate_embed_size",
... [llama_cpp.llama_context_p_ctypes, clip_ctx_p_ctypes],
c_bool,
)
def llava_validate_embed_size(
ctx_llama: llama_cpp.llama_context_p, ctx_clip: clip_ctx_p, /
) -> bool: ...
# /** build an image embed from image file bytes */ # /** build an image embed from image file bytes */
# LLAVA_API struct llava_image_embed * llava_image_embed_make_with_bytes(struct clip_ctx * ctx_clip, int n_threads, const unsigned char * image_bytes, int image_bytes_length); # LLAVA_API struct llava_image_embed * llava_image_embed_make_with_bytes(struct clip_ctx * ctx_clip, int n_threads, const unsigned char * image_bytes, int image_bytes_length);
@ctypes_function("llava_image_embed_make_with_bytes", [clip_ctx_p_ctypes, c_int, POINTER(c_uint8), c_int], POINTER(llava_image_embed)) @ctypes_function(
def llava_image_embed_make_with_bytes(ctx_clip: clip_ctx_p, n_threads: Union[c_int, int], image_bytes: bytes, image_bytes_length: Union[c_int, int], /) -> "_Pointer[llava_image_embed]": "llava_image_embed_make_with_bytes",
... [clip_ctx_p_ctypes, c_int, POINTER(c_uint8), c_int],
POINTER(llava_image_embed),
)
def llava_image_embed_make_with_bytes(
ctx_clip: clip_ctx_p,
n_threads: Union[c_int, int],
image_bytes: CtypesArray[c_uint8],
image_bytes_length: Union[c_int, int],
/,
) -> "_Pointer[llava_image_embed]": ...
# /** build an image embed from a path to an image filename */ # /** build an image embed from a path to an image filename */
# LLAVA_API struct llava_image_embed * llava_image_embed_make_with_filename(struct clip_ctx * ctx_clip, int n_threads, const char * image_path); # LLAVA_API struct llava_image_embed * llava_image_embed_make_with_filename(struct clip_ctx * ctx_clip, int n_threads, const char * image_path);
@ctypes_function("llava_image_embed_make_with_filename", [clip_ctx_p_ctypes, c_int, c_char_p], POINTER(llava_image_embed)) @ctypes_function(
def llava_image_embed_make_with_filename(ctx_clip: clip_ctx_p, n_threads: Union[c_int, int], image_path: bytes, /) -> "_Pointer[llava_image_embed]": "llava_image_embed_make_with_filename",
... [clip_ctx_p_ctypes, c_int, c_char_p],
POINTER(llava_image_embed),
)
def llava_image_embed_make_with_filename(
ctx_clip: clip_ctx_p, n_threads: Union[c_int, int], image_path: bytes, /
) -> "_Pointer[llava_image_embed]": ...
# LLAVA_API void llava_image_embed_free(struct llava_image_embed * embed); # LLAVA_API void llava_image_embed_free(struct llava_image_embed * embed);
# /** free an embedding made with llava_image_embed_make_* */ # /** free an embedding made with llava_image_embed_make_* */
@ctypes_function("llava_image_embed_free", [POINTER(llava_image_embed)], None) @ctypes_function("llava_image_embed_free", [POINTER(llava_image_embed)], None)
def llava_image_embed_free(embed: "_Pointer[llava_image_embed]", /): def llava_image_embed_free(embed: "_Pointer[llava_image_embed]", /): ...
...
# /** write the image represented by embed into the llama context with batch size n_batch, starting at context pos n_past. on completion, n_past points to the next position in the context after the image embed. */ # /** write the image represented by embed into the llama context with batch size n_batch, starting at context pos n_past. on completion, n_past points to the next position in the context after the image embed. */
# LLAVA_API bool llava_eval_image_embed(struct llama_context * ctx_llama, const struct llava_image_embed * embed, int n_batch, int * n_past); # LLAVA_API bool llava_eval_image_embed(struct llama_context * ctx_llama, const struct llava_image_embed * embed, int n_batch, int * n_past);
@ctypes_function("llava_eval_image_embed", [llama_cpp.llama_context_p_ctypes, POINTER(llava_image_embed), c_int, POINTER(c_int)], c_bool) @ctypes_function(
def llava_eval_image_embed(ctx_llama: llama_cpp.llama_context_p, embed: "_Pointer[llava_image_embed]", n_batch: Union[c_int, int], n_past: "_Pointer[c_int]", /) -> bool: "llava_eval_image_embed",
... [
llama_cpp.llama_context_p_ctypes,
POINTER(llava_image_embed),
c_int,
POINTER(c_int),
],
c_bool,
)
def llava_eval_image_embed(
ctx_llama: llama_cpp.llama_context_p,
embed: "_Pointer[llava_image_embed]",
n_batch: Union[c_int, int],
n_past: "_Pointer[c_int]",
/,
) -> bool: ...
################################################ ################################################
@ -161,11 +231,12 @@ def llava_eval_image_embed(ctx_llama: llama_cpp.llama_context_p, embed: "_Pointe
# /** load mmproj model */ # /** load mmproj model */
# CLIP_API struct clip_ctx * clip_model_load (const char * fname, int verbosity); # CLIP_API struct clip_ctx * clip_model_load (const char * fname, int verbosity);
@ctypes_function("clip_model_load", [c_char_p, c_int], clip_ctx_p_ctypes) @ctypes_function("clip_model_load", [c_char_p, c_int], clip_ctx_p_ctypes)
def clip_model_load(fname: bytes, verbosity: Union[c_int, int], /) -> Optional[clip_ctx_p]: def clip_model_load(
... fname: bytes, verbosity: Union[c_int, int], /
) -> Optional[clip_ctx_p]: ...
# /** free mmproj model */ # /** free mmproj model */
# CLIP_API void clip_free(struct clip_ctx * ctx); # CLIP_API void clip_free(struct clip_ctx * ctx);
@ctypes_function("clip_free", [clip_ctx_p_ctypes], None) @ctypes_function("clip_free", [clip_ctx_p_ctypes], None)
def clip_free(ctx: clip_ctx_p, /): def clip_free(ctx: clip_ctx_p, /): ...
...

View file

@ -72,9 +72,74 @@ class LlamaProxy:
chat_handler = None chat_handler = None
if settings.chat_format == "llava-1-5": if settings.chat_format == "llava-1-5":
assert settings.clip_model_path is not None, "clip model not found" assert settings.clip_model_path is not None, "clip model not found"
chat_handler = llama_cpp.llama_chat_format.Llava15ChatHandler( if settings.hf_model_repo_id is not None:
clip_model_path=settings.clip_model_path, verbose=settings.verbose chat_handler = (
) llama_cpp.llama_chat_format.Llava15ChatHandler.from_pretrained(
repo_id=settings.hf_model_repo_id,
filename=settings.clip_model_path,
verbose=settings.verbose,
)
)
else:
chat_handler = llama_cpp.llama_chat_format.Llava15ChatHandler(
clip_model_path=settings.clip_model_path, verbose=settings.verbose
)
elif settings.chat_format == "obsidian":
assert settings.clip_model_path is not None, "clip model not found"
if settings.hf_model_repo_id is not None:
chat_handler = (
llama_cpp.llama_chat_format.ObsidianChatHandler.from_pretrained(
repo_id=settings.hf_model_repo_id,
filename=settings.clip_model_path,
verbose=settings.verbose,
)
)
else:
chat_handler = llama_cpp.llama_chat_format.ObsidianChatHandler(
clip_model_path=settings.clip_model_path, verbose=settings.verbose
)
elif settings.chat_format == "llava-1-6":
assert settings.clip_model_path is not None, "clip model not found"
if settings.hf_model_repo_id is not None:
chat_handler = (
llama_cpp.llama_chat_format.Llava16ChatHandler.from_pretrained(
repo_id=settings.hf_model_repo_id,
filename=settings.clip_model_path,
verbose=settings.verbose,
)
)
else:
chat_handler = llama_cpp.llama_chat_format.Llava16ChatHandler(
clip_model_path=settings.clip_model_path, verbose=settings.verbose
)
elif settings.chat_format == "moondream":
assert settings.clip_model_path is not None, "clip model not found"
if settings.hf_model_repo_id is not None:
chat_handler = (
llama_cpp.llama_chat_format.MoondreamChatHandler.from_pretrained(
repo_id=settings.hf_model_repo_id,
filename=settings.clip_model_path,
verbose=settings.verbose,
)
)
else:
chat_handler = llama_cpp.llama_chat_format.MoondreamChatHandler(
clip_model_path=settings.clip_model_path, verbose=settings.verbose
)
elif settings.chat_format == "nanollava":
assert settings.clip_model_path is not None, "clip model not found"
if settings.hf_model_repo_id is not None:
chat_handler = (
llama_cpp.llama_chat_format.NanoLlavaChatHandler.from_pretrained(
repo_id=settings.hf_model_repo_id,
filename=settings.clip_model_path,
verbose=settings.verbose,
)
)
else:
chat_handler = llama_cpp.llama_chat_format.NanoLlavaChatHandler(
clip_model_path=settings.clip_model_path, verbose=settings.verbose
)
elif settings.chat_format == "hf-autotokenizer": elif settings.chat_format == "hf-autotokenizer":
assert ( assert (
settings.hf_pretrained_model_name_or_path is not None settings.hf_pretrained_model_name_or_path is not None