feat: Generic Chat Formats, Tool Calling, and Huggingface Pull Support for Multimodal Models (Obsidian, LLaVA1.6, Moondream) (#1147)

* Test dummy image tags in chat templates

* Format and improve  types for llava_cpp.py

* Add from_pretrained support to llava chat format.

* Refactor llava chat format to use a jinja2

* Revert chat format test

* Add moondream support (wip)

* Update moondream chat format

* Update moondream chat format

* Update moondream prompt

* Add function calling support

* Cache last image embed

* Add Llava1.6 support

* Add nanollava support

* Add obisidian support

* Remove unnecessary import

* Re-order multimodal chat formats

* Logits all no longer required for multi-modal models

* Update README.md

* Update docs

* Update README

* Fix typo

* Update README

* Fix typo
This commit is contained in:
Andrei 2024-04-30 01:35:38 -04:00 committed by GitHub
parent 97fb860eba
commit fe2da09538
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
5 changed files with 712 additions and 146 deletions

View file

@ -490,14 +490,15 @@ Due to discrepancies between llama.cpp and HuggingFace's tokenizers, it is requi
### Multi-modal Models
`llama-cpp-python` supports the llava1.5 family of multi-modal models which allow the language model to
read information from both text and images.
`llama-cpp-python` supports such as llava1.5 which allow the language model to read information from both text and images.
You'll first need to download one of the available multi-modal models in GGUF format:
- [llava-v1.5-7b](https://huggingface.co/mys/ggml_llava-v1.5-7b)
- [llava-v1.5-13b](https://huggingface.co/mys/ggml_llava-v1.5-13b)
- [bakllava-1-7b](https://huggingface.co/mys/ggml_bakllava-1)
- [llava-v1.6-34b](https://huggingface.co/cjpais/llava-v1.6-34B-gguf)
- [moondream2](https://huggingface.co/vikhyatk/moondream2)
Then you'll need to use a custom chat handler to load the clip model and process the chat messages and images.
@ -509,7 +510,6 @@ Then you'll need to use a custom chat handler to load the clip model and process
model_path="./path/to/llava/llama-model.gguf",
chat_handler=chat_handler,
n_ctx=2048, # n_ctx should be increased to accomodate the image embedding
logits_all=True,# needed to make llava work
)
>>> llm.create_chat_completion(
messages = [
@ -517,14 +517,45 @@ Then you'll need to use a custom chat handler to load the clip model and process
{
"role": "user",
"content": [
{"type": "image_url", "image_url": {"url": "https://.../image.png"}},
{"type" : "text", "text": "Describe this image in detail please."}
{"type" : "text", "text": "What's in this image?"},
{"type": "image_url", "image_url": {"url": "https://upload.wikimedia.org/wikipedia/commons/thumb/d/dd/Gfp-wisconsin-madison-the-nature-boardwalk.jpg/2560px-Gfp-wisconsin-madison-the-nature-boardwalk.jpg" } }
]
}
]
)
```
You can also pull the model from the Hugging Face Hub using the `from_pretrained` method.
```python
>>> from llama_cpp import Llama
>>> from llama_cpp.llama_chat_format import MoondreamChatHandler
>>> chat_handler = MoondreamChatHandler.from_pretrained(
repo_id="vikhyatk/moondream2",
filename="*mmproj*",
)
>>> llm = Llama.from_pretrained(
repo_id="vikhyatk/moondream2"
filename="*text-model*",
chat_handler=chat_handler,
n_ctx=2048, # n_ctx should be increased to accomodate the image embedding
)
>>> llm.create_chat_completion(
messages = [
{
"role": "user",
"content": [
{"type" : "text", "text": "What's in this image?"},
{"type": "image_url", "image_url": {"url": "https://upload.wikimedia.org/wikipedia/commons/thumb/d/dd/Gfp-wisconsin-madison-the-nature-boardwalk.jpg/2560px-Gfp-wisconsin-madison-the-nature-boardwalk.jpg" } }
]
}
]
)
```
**Note**: Multi-modal models also support tool calling and JSON mode.
<details>
<summary>Loading a Local Image</summary>

View file

@ -98,6 +98,8 @@ You'll first need to download one of the available multi-modal models in GGUF fo
- [llava-v1.5-7b](https://huggingface.co/mys/ggml_llava-v1.5-7b)
- [llava-v1.5-13b](https://huggingface.co/mys/ggml_llava-v1.5-13b)
- [bakllava-1-7b](https://huggingface.co/mys/ggml_bakllava-1)
- [llava-v1.6-34b](https://huggingface.co/cjpais/llava-v1.6-34B-gguf)
- [moondream2](https://huggingface.co/vikhyatk/moondream2)
Then when you run the server you'll need to also specify the path to the clip model used for image embedding and the `llava-1-5` chat_format

View file

@ -6,6 +6,8 @@ import ctypes
import dataclasses
import random
import string
from contextlib import ExitStack
from typing import Any, Dict, Iterator, List, Literal, Optional, Tuple, Union, Protocol, cast
import jinja2
@ -2163,42 +2165,80 @@ def functionary_v1_v2_chat_handler(
class Llava15ChatHandler:
_clip_free = None
DEFAULT_SYSTEM_MESSAGE = "A chat between a curious human and an artificial intelligence assistant. The assistant gives helpful, detailed, and polite answers to the human's questions."
CHAT_FORMAT = (
"{% for message in messages %}"
"{% if message.role == 'system' %}"
"{{ message.content }}"
"{% endif %}"
"{% if message.role == 'user' %}"
"{% if message.content is string %}"
"\nUSER: {{ message.content }}"
"{% elif message.content is iterable %}"
"\nUSER: "
"{% for content in message.content %}"
"{% if content.type == 'text' %}"
"{{ content.text }}"
"{% endif %}"
"{% if content.type == 'image_url' and content.image_url is string %}"
"{{ content.image_url }}"
"{% endif %}"
"{% if content.type == 'image_url' and content.image_url is mapping %}"
"{{ content.image_url.url }}"
"{% endif %}"
"{% endfor %}"
"{% endif %}"
"{% endif %}"
"{% if message.role == 'assistant' and message.content is not none %}"
"\nASSISTANT: {{ message.content }}"
"{% endif %}"
"{% endfor %}"
"{% if add_generation_prompt %}"
"\nASSISTANT: "
"{% endif %}"
)
def __init__(self, clip_model_path: str, verbose: bool = False):
import llama_cpp.llava_cpp as llava_cpp
self._llava_cpp = llava_cpp
self.clip_model_path = clip_model_path
self.verbose = verbose
self._clip_free = self._llava_cpp._libllava.clip_free # type: ignore
self._llava_cpp = llava_cpp # TODO: Fix
self._exit_stack = ExitStack()
self._last_image_embed: Optional[llava_cpp.CtypesPointer[llava_cpp.llava_image_embed]] = None
self._last_image_hash: Optional[int] = None
if not os.path.exists(clip_model_path):
raise ValueError(f"Clip model path does not exist: {clip_model_path}")
with suppress_stdout_stderr(disable=self.verbose):
self.clip_ctx = self._llava_cpp.clip_model_load(
clip_ctx = self._llava_cpp.clip_model_load(
self.clip_model_path.encode(), 0
)
def __del__(self):
with suppress_stdout_stderr(disable=self.verbose):
if self.clip_ctx is not None and self._clip_free is not None:
self._clip_free(self.clip_ctx)
self.clip_ctx = None
if clip_ctx is None:
raise ValueError(f"Failed to load clip model: {clip_model_path}")
self.clip_ctx = clip_ctx
def clip_free():
with suppress_stdout_stderr(disable=self.verbose):
self._llava_cpp.clip_free(self.clip_ctx)
self._exit_stack.callback(clip_free)
def last_image_embed_free():
with suppress_stdout_stderr(disable=self.verbose):
if self._last_image_embed is not None:
self._llava_cpp.llava_image_embed_free(self._last_image_embed)
self._last_image_embed = None
self._exit_stack.callback(last_image_embed_free)
def load_image(self, image_url: str) -> bytes:
if image_url.startswith("data:"):
import base64
image_bytes = base64.b64decode(image_url.split(",")[1])
return image_bytes
else:
import urllib.request
with urllib.request.urlopen(image_url) as f:
image_bytes = f.read()
return image_bytes
return self._load_image(image_url)
def __call__(
self,
@ -2216,6 +2256,7 @@ class Llava15ChatHandler:
typical_p: float = 1.0,
stream: bool = False,
stop: Optional[Union[str, List[str]]] = [],
seed: Optional[int] = None,
response_format: Optional[
llama_types.ChatCompletionRequestResponseFormat
] = None,
@ -2230,121 +2271,477 @@ class Llava15ChatHandler:
model: Optional[str] = None,
logits_processor: Optional[llama.LogitsProcessorList] = None,
grammar: Optional[llama.LlamaGrammar] = None,
logit_bias: Optional[Dict[str, float]] = None,
logprobs: Optional[bool] = None,
top_logprobs: Optional[int] = None,
**kwargs, # type: ignore
) -> Union[
llama_types.CreateChatCompletionResponse,
Iterator[llama_types.CreateChatCompletionStreamResponse],
]:
assert (
llama.context_params.logits_all is True
) # BUG: logits_all=True is required for llava
assert self.clip_ctx is not None
system_prompt = _get_system_message(messages)
system_prompt = (
system_prompt
if system_prompt != ""
else "A chat between a curious human and an artificial intelligence assistant. The assistant gives helpful, detailed, and polite answers to the human's questions."
)
user_role = "\nUSER:"
assistant_role = "\nASSISTANT:"
llama.reset()
llama.eval(llama.tokenize(system_prompt.encode("utf8"), add_bos=True))
for message in messages:
if message["role"] == "user" and message["content"] is not None:
if isinstance(message["content"], str):
llama.eval(
llama.tokenize(
f"{user_role} {message['content']}".encode("utf8"),
add_bos=False,
)
)
else:
assert isinstance(message["content"], list)
llama.eval(
llama.tokenize(f"{user_role} ".encode("utf8"), add_bos=False)
)
for content in message["content"]:
if content["type"] == "text":
llama.eval(
llama.tokenize(
f"{content['text']}".encode("utf8"), add_bos=False
)
)
if content["type"] == "image_url":
image_bytes = (
self.load_image(content["image_url"]["url"])
if isinstance(content["image_url"], dict)
else self.load_image(content["image_url"])
)
import array
data_array = array.array("B", image_bytes)
c_ubyte_ptr = (
ctypes.c_ubyte * len(data_array)
).from_buffer(data_array)
with suppress_stdout_stderr(disable=self.verbose):
embed = (
self._llava_cpp.llava_image_embed_make_with_bytes(
self.clip_ctx,
llama.context_params.n_threads,
c_ubyte_ptr,
len(image_bytes),
)
)
try:
n_past = ctypes.c_int(llama.n_tokens)
n_past_p = ctypes.pointer(n_past)
with suppress_stdout_stderr(disable=self.verbose):
self._llava_cpp.llava_eval_image_embed(
llama.ctx,
embed,
llama.n_batch,
n_past_p,
)
assert llama.n_ctx() >= n_past.value
llama.n_tokens = n_past.value
finally:
with suppress_stdout_stderr(disable=self.verbose):
self._llava_cpp.llava_image_embed_free(embed)
if message["role"] == "assistant" and message["content"] is not None:
llama.eval(
llama.tokenize(
f"ASSISTANT: {message['content']}".encode("utf8"), add_bos=False
system_prompt = _get_system_message(messages)
if system_prompt == "":
messages = [llama_types.ChatCompletionRequestSystemMessage(role="system", content=self.DEFAULT_SYSTEM_MESSAGE)] + messages
image_urls = self.get_image_urls(messages)
template = jinja2.Template(self.CHAT_FORMAT)
text = template.render(messages=messages, add_generation_prompt=True)
split_text = self.split_text_on_image_urls(text, image_urls)
def embed_image_bytes(image_bytes: bytes):
if self._last_image_embed is not None and self._last_image_hash is not None and hash(image_bytes) == self._last_image_hash:
return self._last_image_embed
with suppress_stdout_stderr(disable=self.verbose):
embed = (
self._llava_cpp.llava_image_embed_make_with_bytes(
self.clip_ctx,
llama.context_params.n_threads_batch,
(ctypes.c_uint8 * len(image_bytes)).from_buffer(bytearray(image_bytes)),
len(image_bytes),
)
)
assert llama.n_ctx() >= llama.n_tokens
llama.eval(llama.tokenize(f"{assistant_role}".encode("utf8"), add_bos=False))
assert llama.n_ctx() >= llama.n_tokens
self._last_image_embed = embed
self._last_image_hash = hash(image_bytes)
return embed
# Evaluate prompt
llama.reset()
for i, (type_, value) in enumerate(split_text):
if type_ == "text":
tokens = llama.tokenize(value.encode("utf8"), add_bos=i == 0)
if llama.n_tokens + len(tokens) > llama.n_ctx():
raise ValueError("Prompt exceeds n_ctx") # TODO: Fix
llama.eval(tokens)
else:
image_bytes = self.load_image(value)
embed = embed_image_bytes(image_bytes)
if llama.n_tokens + embed.contents.n_image_pos > llama.n_ctx():
raise ValueError("Prompt exceeds n_ctx") # TODO: Fix
n_past = ctypes.c_int(llama.n_tokens)
n_past_p = ctypes.pointer(n_past)
with suppress_stdout_stderr(disable=self.verbose):
self._llava_cpp.llava_eval_image_embed(
llama.ctx,
embed,
llama.n_batch,
n_past_p,
)
llama.n_tokens = n_past.value
# Get prompt tokens to avoid a cache miss
prompt = llama.input_ids[: llama.n_tokens].tolist()
if response_format is not None and response_format["type"] == "json_object":
grammar = _grammar_for_response_format(response_format)
return _convert_completion_to_chat(
llama.create_completion(
prompt=prompt,
temperature=temperature,
top_p=top_p,
top_k=top_k,
min_p=min_p,
typical_p=typical_p,
stream=stream,
stop=stop,
max_tokens=max_tokens,
presence_penalty=presence_penalty,
frequency_penalty=frequency_penalty,
repeat_penalty=repeat_penalty,
tfs_z=tfs_z,
mirostat_mode=mirostat_mode,
mirostat_tau=mirostat_tau,
mirostat_eta=mirostat_eta,
model=model,
logits_processor=logits_processor,
grammar=grammar,
),
# Convert legacy functions to tools
if functions is not None:
tools = [
{
"type": "function",
"function": function,
}
for function in functions
]
# Convert legacy function_call to tool_choice
if function_call is not None:
if isinstance(function_call, str) and (
function_call == "none" or function_call == "auto"
):
tool_choice = function_call
if isinstance(function_call, dict) and "name" in function_call:
tool_choice = {
"type": "function",
"function": {
"name": function_call["name"],
},
}
tool = None
if tool_choice is not None and isinstance(tool_choice, dict) and tools is not None:
name = tool_choice["function"]["name"]
tool = next((t for t in tools if t["function"]["name"] == name), None)
if tool is None:
raise ValueError(f"Tool choice '{name}' not found in tools.")
schema = tool["function"]["parameters"]
try:
# create grammar from json schema
grammar = llama_grammar.LlamaGrammar.from_json_schema(
json.dumps(schema), verbose=llama.verbose
)
except Exception as e:
grammar = llama_grammar.LlamaGrammar.from_string(
llama_grammar.JSON_GBNF, verbose=llama.verbose
)
completion_or_chunks = llama.create_completion(
prompt=prompt,
temperature=temperature,
top_p=top_p,
top_k=top_k,
min_p=min_p,
typical_p=typical_p,
logprobs=top_logprobs if logprobs else None,
stream=stream,
stop=stop,
seed=seed,
max_tokens=max_tokens,
presence_penalty=presence_penalty,
frequency_penalty=frequency_penalty,
repeat_penalty=repeat_penalty,
tfs_z=tfs_z,
mirostat_mode=mirostat_mode,
mirostat_tau=mirostat_tau,
mirostat_eta=mirostat_eta,
model=model,
logits_processor=logits_processor,
grammar=grammar,
logit_bias=logit_bias,
)
if tool is not None:
tool_name = tool["function"]["name"]
return _convert_completion_to_chat_function(
tool_name, completion_or_chunks, stream
)
return _convert_completion_to_chat(completion_or_chunks, stream=stream)
@staticmethod
def _load_image(image_url: str) -> bytes:
# TODO: Add Pillow support for other image formats beyond (jpg, png)
if image_url.startswith("data:"):
import base64
image_bytes = base64.b64decode(image_url.split(",")[1])
return image_bytes
else:
import urllib.request
with urllib.request.urlopen(image_url) as f:
image_bytes = f.read()
return image_bytes
@staticmethod
def get_image_urls(messages: List[llama_types.ChatCompletionRequestMessage]):
image_urls: List[str] = []
for message in messages:
if message["role"] == "user":
if message["content"] is None:
continue
for content in message["content"]:
if isinstance(content, dict) and "type" in content:
if content["type"] == "image_url":
if (
isinstance(content["image_url"], dict)
and "url" in content["image_url"]
):
image_urls.append(content["image_url"]["url"])
else:
image_urls.append(content["image_url"])
return image_urls
@staticmethod
def split_text_on_image_urls(text: str, image_urls: List[str]):
def find_first(s: str, substrs: List[str]):
for i, substr in enumerate(substrs):
pos = s.find(substr)
if pos != -1:
return pos, i
return None, None
split_text: List[Tuple[Literal["text", "image_url"], str]] = []
remaining = text
while remaining:
# Find first image_url
pos, i = find_first(remaining, image_urls)
if pos is not None and i is not None:
if pos > 0:
split_text.append(("text", remaining[:pos]))
split_text.append(("image_url", image_urls[i]))
remaining = remaining[pos + len(image_urls[i]) :]
else:
split_text.append(("text", remaining))
remaining = ""
return split_text
@classmethod
def from_pretrained(
cls,
repo_id: str,
filename: Optional[str],
local_dir: Optional[Union[str, os.PathLike[str]]] = None,
local_dir_use_symlinks: Union[bool, Literal["auto"]] = "auto",
cache_dir: Optional[Union[str, os.PathLike[str]]] = None,
**kwargs: Any,
) -> "Llava15ChatHandler":
import fnmatch
from pathlib import Path
try:
from huggingface_hub import hf_hub_download, HfFileSystem # type: ignore
from huggingface_hub.utils import validate_repo_id # type: ignore
except ImportError:
raise ImportError(
"Llama.from_pretrained requires the huggingface-hub package. "
"You can install it with `pip install huggingface-hub`."
)
validate_repo_id(repo_id)
hffs = HfFileSystem()
files = [
file["name"] if isinstance(file, dict) else file
for file in hffs.ls(repo_id) # type: ignore
]
# split each file into repo_id, subfolder, filename
file_list: List[str] = []
for file in files:
rel_path = Path(file).relative_to(repo_id)
file_list.append(str(rel_path))
matching_files = [file for file in file_list if fnmatch.fnmatch(file, filename)] # type: ignore
if len(matching_files) == 0:
raise ValueError(
f"No file found in {repo_id} that match {filename}\n\n"
f"Available Files:\n{json.dumps(file_list)}"
)
if len(matching_files) > 1:
raise ValueError(
f"Multiple files found in {repo_id} matching {filename}\n\n"
f"Available Files:\n{json.dumps(files)}"
)
(matching_file,) = matching_files
subfolder = str(Path(matching_file).parent)
filename = Path(matching_file).name
# download the file
hf_hub_download(
repo_id=repo_id,
filename=filename,
subfolder=subfolder,
local_dir=cast(Union[str, Path, None], local_dir),
local_dir_use_symlinks=local_dir_use_symlinks,
cache_dir=cast(Union[str, Path, None], cache_dir),
)
if local_dir is None:
model_path = hf_hub_download(
repo_id=repo_id,
filename=filename,
subfolder=subfolder,
local_dir=local_dir,
local_dir_use_symlinks=local_dir_use_symlinks,
cache_dir=cast(Union[str, Path, None], cache_dir),
local_files_only=True,
)
else:
model_path = os.path.join(local_dir, filename)
return cls(
clip_model_path=model_path,
**kwargs,
)
class ObsidianChatHandler(Llava15ChatHandler):
# Prompt Format
# The model followed ChatML format. However, with ### as the seperator
# <|im_start|>user
# What is this sign about?\n<image>
# ###
# <|im_start|>assistant
# The sign is about bullying, and it is placed on a black background with a red background.
# ###
CHAT_FORMAT = (
"{% for message in messages %}"
# System message
"{% if message.role == 'system' %}"
"<|im_start|>system\n"
"{{ message.content }}\n"
"###\n"
"{% endif %}"
# User message
"{% if message.role == 'user' %}"
"<|im_start|>user\n"
"{% if message.content is string %}"
"{{ message.content }}"
"{% endif %}"
"{% if message.content is iterable %}"
"{% for content in message.content %}"
"{% if content.type == 'text' %}"
"{{ content.text }}"
"{% endif %}"
"{% if content.type == 'image_url' %}"
"{{ content.image_url }}"
"{% endif %}"
"{% endfor %}"
"{% endif %}"
"###\n"
"{% endif %}"
# Assistant message
"{% if message.role == 'assistant' %}"
"<|im_start|>assistant\n"
"{{ message.content }}"
"###\n"
"{% endif %}"
"{% endfor %}"
# Generation prompt
"{% if add_generation_prompt %}"
"<|im_start|>assistant\n"
"{% endif %}"
)
class MoondreamChatHandler(Llava15ChatHandler):
# Chat Format:
# f"<image>\n\n{chat_history}Question: {question}\n\nAnswer:"
CHAT_FORMAT = (
"{% for message in messages %}"
"{% if message.role == 'user' %}"
"{% if message.content is iterable %}"
"{% for content in message.content %}"
# <image>
"{% if content.type == 'image_url' %}"
"{% if content.image_url is string %}"
"{{ content.image_url }}\n\n"
"{% endif %}"
"{% if content.image_url is mapping %}"
"{{ content.image_url.url }}\n\n"
"{% endif %}"
"{% endif %}"
# Question:
"{% if content.type == 'text' %}"
"Question: {{ content.text }}\n\n"
"{% endif %}"
"{% endfor %}"
"{% endif %}"
# Question:
"{% if message.content is string %}"
"Question: {{ message.content }}\n\n"
"{% endif %}"
"{% endif %}"
# Answer:
"{% if message.role == 'assistant' %}"
"Answer:{{ message.content }}\n\n"
"{% endif %}"
"{% endfor %}"
# Generation prompt
"{% if add_generation_prompt %}"
"Answer:"
"{% endif %}"
)
class Llava16ChatHandler(Llava15ChatHandler):
DEFAULT_SYSTEM_MESSAGE = "A chat between a curious human and an artificial intelligence assistant. The assistant gives helpful, detailed, and polite answers to the human's questions. "
# Example prompt
# "A chat between a curious human and an artificial intelligence assistant. The assistant gives helpful, detailed, and polite answers to the human's questions. USER: <image>\nWhat is shown in this image? ASSISTANT:"
CHAT_FORMAT = (
"{% for message in messages %}"
"{% if message.role == 'system' %}"
"{{ message.content }}"
"{% endif %}"
"{% if message.role == 'user' %}"
"{% if message.content is iterable %}"
"{% for content in message.content %}"
# <image>
"{% if content.type == 'image_url' %}"
"{% if content.image_url is string %}"
"{{ content.image_url }}\n"
"{% endif %}"
"{% if content.image_url is mapping %}"
"{{ content.image_url.url }}\n"
"{% endif %}"
"{% endif %}"
# Question:
"{% if content.type == 'text' %}"
"{{ content.text }}"
"{% endif %}"
"{% endfor %}"
"{% endif %}"
# Question:
"{% if message.content is string %}"
"{{ message.content }}"
"{% endif %}"
"{% endif %}"
# Answer:
"{% if message.role == 'assistant' %}"
"{{ message.content }}"
"{% endif %}"
"{% endfor %}"
# Generation prompt
"{% if add_generation_prompt %}"
"Answer:"
"{% endif %}"
)
class NanoLlavaChatHandler(Llava15ChatHandler):
# Prompt Format
# The model follow the ChatML standard, however, without \n at the end of <|im_end|>:
# <|im_start|>system
# Answer the question<|im_end|><|im_start|>user
# <image>
# What is the picture about?<|im_end|><|im_start|>assistant
CHAT_FORMAT = (
"{% for message in messages %}"
# System message
"{% if message.role == 'system' %}"
"<|im_start|>system\n"
"{{ message.content }}"
"<|im_end|>"
"{% endif %}"
# User message
"{% if message.role == 'user' %}"
"<|im_start|>user\n"
"{% if message.content is string %}"
"{{ message.content }}"
"{% endif %}"
"{% if message.content is iterable %}"
"{% for content in message.content %}"
"{% if content.type == 'text' %}"
"{{ content.text }}"
"{% endif %}"
"{% if content.type == 'image_url' %}"
"{{ content.image_url }}"
"{% endif %}"
"{% endfor %}"
"{% endif %}"
"<|im_end|>"
"{% endif %}"
# Assistant message
"{% if message.role == 'assistant' %}"
"<|im_start|>assistant\n"
"{{ message.content }}"
"<|im_end|>"
"{% endif %}"
"{% endfor %}"
# Generation prompt
"{% if add_generation_prompt %}"
"<|im_start|>assistant\n"
"{% endif %}"
)
@register_chat_completion_handler("chatml-function-calling")

View file

@ -1,3 +1,5 @@
from __future__ import annotations
import sys
import os
import ctypes
@ -14,10 +16,22 @@ from ctypes import (
Structure,
)
import pathlib
from typing import List, Union, NewType, Optional, TypeVar, Callable, Any
from typing import (
List,
Union,
NewType,
Optional,
TypeVar,
Callable,
Any,
TYPE_CHECKING,
Generic,
)
from typing_extensions import TypeAlias
import llama_cpp.llama_cpp as llama_cpp
# Load the library
def _load_shared_library(lib_base_name: str):
# Construct the paths to the possible shared library names
@ -62,7 +76,7 @@ def _load_shared_library(lib_base_name: str):
for _lib_path in _lib_paths:
if _lib_path.exists():
try:
return ctypes.CDLL(str(_lib_path), **cdll_args) # type: ignore
return ctypes.CDLL(str(_lib_path), **cdll_args) # type: ignore
except Exception as e:
raise RuntimeError(f"Failed to load shared library '{_lib_path}': {e}")
@ -79,8 +93,27 @@ _libllava = _load_shared_library(_libllava_base_name)
# ctypes helper
if TYPE_CHECKING:
CtypesCData = TypeVar("CtypesCData", bound=ctypes._CData) # type: ignore
CtypesArray: TypeAlias = ctypes.Array[CtypesCData] # type: ignore
CtypesPointer: TypeAlias = ctypes._Pointer[CtypesCData] # type: ignore
CtypesVoidPointer: TypeAlias = ctypes.c_void_p
class CtypesRef(Generic[CtypesCData]):
pass
CtypesPointerOrRef: TypeAlias = Union[
CtypesPointer[CtypesCData], CtypesRef[CtypesCData]
]
CtypesFuncPointer: TypeAlias = ctypes._FuncPointer # type: ignore
F = TypeVar("F", bound=Callable[..., Any])
def ctypes_function_for_shared_library(lib: ctypes.CDLL):
def ctypes_function(
name: str, argtypes: List[Any], restype: Any, enabled: bool = True
@ -111,6 +144,7 @@ ctypes_function = ctypes_function_for_shared_library(_libllava)
clip_ctx_p = NewType("clip_ctx_p", int)
clip_ctx_p_ctypes = c_void_p
# struct llava_image_embed {
# float * embed;
# int n_image_pos;
@ -121,36 +155,72 @@ class llava_image_embed(Structure):
("n_image_pos", c_int),
]
# /** sanity check for clip <-> llava embed size match */
# LLAVA_API bool llava_validate_embed_size(const llama_context * ctx_llama, const clip_ctx * ctx_clip);
@ctypes_function("llava_validate_embed_size", [llama_cpp.llama_context_p_ctypes, clip_ctx_p_ctypes], c_bool)
def llava_validate_embed_size(ctx_llama: llama_cpp.llama_context_p, ctx_clip: clip_ctx_p, /) -> bool:
...
@ctypes_function(
"llava_validate_embed_size",
[llama_cpp.llama_context_p_ctypes, clip_ctx_p_ctypes],
c_bool,
)
def llava_validate_embed_size(
ctx_llama: llama_cpp.llama_context_p, ctx_clip: clip_ctx_p, /
) -> bool: ...
# /** build an image embed from image file bytes */
# LLAVA_API struct llava_image_embed * llava_image_embed_make_with_bytes(struct clip_ctx * ctx_clip, int n_threads, const unsigned char * image_bytes, int image_bytes_length);
@ctypes_function("llava_image_embed_make_with_bytes", [clip_ctx_p_ctypes, c_int, POINTER(c_uint8), c_int], POINTER(llava_image_embed))
def llava_image_embed_make_with_bytes(ctx_clip: clip_ctx_p, n_threads: Union[c_int, int], image_bytes: bytes, image_bytes_length: Union[c_int, int], /) -> "_Pointer[llava_image_embed]":
...
@ctypes_function(
"llava_image_embed_make_with_bytes",
[clip_ctx_p_ctypes, c_int, POINTER(c_uint8), c_int],
POINTER(llava_image_embed),
)
def llava_image_embed_make_with_bytes(
ctx_clip: clip_ctx_p,
n_threads: Union[c_int, int],
image_bytes: CtypesArray[c_uint8],
image_bytes_length: Union[c_int, int],
/,
) -> "_Pointer[llava_image_embed]": ...
# /** build an image embed from a path to an image filename */
# LLAVA_API struct llava_image_embed * llava_image_embed_make_with_filename(struct clip_ctx * ctx_clip, int n_threads, const char * image_path);
@ctypes_function("llava_image_embed_make_with_filename", [clip_ctx_p_ctypes, c_int, c_char_p], POINTER(llava_image_embed))
def llava_image_embed_make_with_filename(ctx_clip: clip_ctx_p, n_threads: Union[c_int, int], image_path: bytes, /) -> "_Pointer[llava_image_embed]":
...
@ctypes_function(
"llava_image_embed_make_with_filename",
[clip_ctx_p_ctypes, c_int, c_char_p],
POINTER(llava_image_embed),
)
def llava_image_embed_make_with_filename(
ctx_clip: clip_ctx_p, n_threads: Union[c_int, int], image_path: bytes, /
) -> "_Pointer[llava_image_embed]": ...
# LLAVA_API void llava_image_embed_free(struct llava_image_embed * embed);
# /** free an embedding made with llava_image_embed_make_* */
@ctypes_function("llava_image_embed_free", [POINTER(llava_image_embed)], None)
def llava_image_embed_free(embed: "_Pointer[llava_image_embed]", /):
...
def llava_image_embed_free(embed: "_Pointer[llava_image_embed]", /): ...
# /** write the image represented by embed into the llama context with batch size n_batch, starting at context pos n_past. on completion, n_past points to the next position in the context after the image embed. */
# LLAVA_API bool llava_eval_image_embed(struct llama_context * ctx_llama, const struct llava_image_embed * embed, int n_batch, int * n_past);
@ctypes_function("llava_eval_image_embed", [llama_cpp.llama_context_p_ctypes, POINTER(llava_image_embed), c_int, POINTER(c_int)], c_bool)
def llava_eval_image_embed(ctx_llama: llama_cpp.llama_context_p, embed: "_Pointer[llava_image_embed]", n_batch: Union[c_int, int], n_past: "_Pointer[c_int]", /) -> bool:
...
@ctypes_function(
"llava_eval_image_embed",
[
llama_cpp.llama_context_p_ctypes,
POINTER(llava_image_embed),
c_int,
POINTER(c_int),
],
c_bool,
)
def llava_eval_image_embed(
ctx_llama: llama_cpp.llama_context_p,
embed: "_Pointer[llava_image_embed]",
n_batch: Union[c_int, int],
n_past: "_Pointer[c_int]",
/,
) -> bool: ...
################################################
@ -161,11 +231,12 @@ def llava_eval_image_embed(ctx_llama: llama_cpp.llama_context_p, embed: "_Pointe
# /** load mmproj model */
# CLIP_API struct clip_ctx * clip_model_load (const char * fname, int verbosity);
@ctypes_function("clip_model_load", [c_char_p, c_int], clip_ctx_p_ctypes)
def clip_model_load(fname: bytes, verbosity: Union[c_int, int], /) -> Optional[clip_ctx_p]:
...
def clip_model_load(
fname: bytes, verbosity: Union[c_int, int], /
) -> Optional[clip_ctx_p]: ...
# /** free mmproj model */
# CLIP_API void clip_free(struct clip_ctx * ctx);
@ctypes_function("clip_free", [clip_ctx_p_ctypes], None)
def clip_free(ctx: clip_ctx_p, /):
...
def clip_free(ctx: clip_ctx_p, /): ...

View file

@ -72,9 +72,74 @@ class LlamaProxy:
chat_handler = None
if settings.chat_format == "llava-1-5":
assert settings.clip_model_path is not None, "clip model not found"
chat_handler = llama_cpp.llama_chat_format.Llava15ChatHandler(
clip_model_path=settings.clip_model_path, verbose=settings.verbose
)
if settings.hf_model_repo_id is not None:
chat_handler = (
llama_cpp.llama_chat_format.Llava15ChatHandler.from_pretrained(
repo_id=settings.hf_model_repo_id,
filename=settings.clip_model_path,
verbose=settings.verbose,
)
)
else:
chat_handler = llama_cpp.llama_chat_format.Llava15ChatHandler(
clip_model_path=settings.clip_model_path, verbose=settings.verbose
)
elif settings.chat_format == "obsidian":
assert settings.clip_model_path is not None, "clip model not found"
if settings.hf_model_repo_id is not None:
chat_handler = (
llama_cpp.llama_chat_format.ObsidianChatHandler.from_pretrained(
repo_id=settings.hf_model_repo_id,
filename=settings.clip_model_path,
verbose=settings.verbose,
)
)
else:
chat_handler = llama_cpp.llama_chat_format.ObsidianChatHandler(
clip_model_path=settings.clip_model_path, verbose=settings.verbose
)
elif settings.chat_format == "llava-1-6":
assert settings.clip_model_path is not None, "clip model not found"
if settings.hf_model_repo_id is not None:
chat_handler = (
llama_cpp.llama_chat_format.Llava16ChatHandler.from_pretrained(
repo_id=settings.hf_model_repo_id,
filename=settings.clip_model_path,
verbose=settings.verbose,
)
)
else:
chat_handler = llama_cpp.llama_chat_format.Llava16ChatHandler(
clip_model_path=settings.clip_model_path, verbose=settings.verbose
)
elif settings.chat_format == "moondream":
assert settings.clip_model_path is not None, "clip model not found"
if settings.hf_model_repo_id is not None:
chat_handler = (
llama_cpp.llama_chat_format.MoondreamChatHandler.from_pretrained(
repo_id=settings.hf_model_repo_id,
filename=settings.clip_model_path,
verbose=settings.verbose,
)
)
else:
chat_handler = llama_cpp.llama_chat_format.MoondreamChatHandler(
clip_model_path=settings.clip_model_path, verbose=settings.verbose
)
elif settings.chat_format == "nanollava":
assert settings.clip_model_path is not None, "clip model not found"
if settings.hf_model_repo_id is not None:
chat_handler = (
llama_cpp.llama_chat_format.NanoLlavaChatHandler.from_pretrained(
repo_id=settings.hf_model_repo_id,
filename=settings.clip_model_path,
verbose=settings.verbose,
)
)
else:
chat_handler = llama_cpp.llama_chat_format.NanoLlavaChatHandler(
clip_model_path=settings.clip_model_path, verbose=settings.verbose
)
elif settings.chat_format == "hf-autotokenizer":
assert (
settings.hf_pretrained_model_name_or_path is not None