feat: Generic Chat Formats, Tool Calling, and Huggingface Pull Support for Multimodal Models (Obsidian, LLaVA1.6, Moondream) (#1147)
* Test dummy image tags in chat templates * Format and improve types for llava_cpp.py * Add from_pretrained support to llava chat format. * Refactor llava chat format to use a jinja2 * Revert chat format test * Add moondream support (wip) * Update moondream chat format * Update moondream chat format * Update moondream prompt * Add function calling support * Cache last image embed * Add Llava1.6 support * Add nanollava support * Add obisidian support * Remove unnecessary import * Re-order multimodal chat formats * Logits all no longer required for multi-modal models * Update README.md * Update docs * Update README * Fix typo * Update README * Fix typo
This commit is contained in:
parent
97fb860eba
commit
fe2da09538
5 changed files with 712 additions and 146 deletions
41
README.md
41
README.md
|
@ -490,14 +490,15 @@ Due to discrepancies between llama.cpp and HuggingFace's tokenizers, it is requi
|
||||||
|
|
||||||
### Multi-modal Models
|
### Multi-modal Models
|
||||||
|
|
||||||
`llama-cpp-python` supports the llava1.5 family of multi-modal models which allow the language model to
|
`llama-cpp-python` supports such as llava1.5 which allow the language model to read information from both text and images.
|
||||||
read information from both text and images.
|
|
||||||
|
|
||||||
You'll first need to download one of the available multi-modal models in GGUF format:
|
You'll first need to download one of the available multi-modal models in GGUF format:
|
||||||
|
|
||||||
- [llava-v1.5-7b](https://huggingface.co/mys/ggml_llava-v1.5-7b)
|
- [llava-v1.5-7b](https://huggingface.co/mys/ggml_llava-v1.5-7b)
|
||||||
- [llava-v1.5-13b](https://huggingface.co/mys/ggml_llava-v1.5-13b)
|
- [llava-v1.5-13b](https://huggingface.co/mys/ggml_llava-v1.5-13b)
|
||||||
- [bakllava-1-7b](https://huggingface.co/mys/ggml_bakllava-1)
|
- [bakllava-1-7b](https://huggingface.co/mys/ggml_bakllava-1)
|
||||||
|
- [llava-v1.6-34b](https://huggingface.co/cjpais/llava-v1.6-34B-gguf)
|
||||||
|
- [moondream2](https://huggingface.co/vikhyatk/moondream2)
|
||||||
|
|
||||||
Then you'll need to use a custom chat handler to load the clip model and process the chat messages and images.
|
Then you'll need to use a custom chat handler to load the clip model and process the chat messages and images.
|
||||||
|
|
||||||
|
@ -509,7 +510,6 @@ Then you'll need to use a custom chat handler to load the clip model and process
|
||||||
model_path="./path/to/llava/llama-model.gguf",
|
model_path="./path/to/llava/llama-model.gguf",
|
||||||
chat_handler=chat_handler,
|
chat_handler=chat_handler,
|
||||||
n_ctx=2048, # n_ctx should be increased to accomodate the image embedding
|
n_ctx=2048, # n_ctx should be increased to accomodate the image embedding
|
||||||
logits_all=True,# needed to make llava work
|
|
||||||
)
|
)
|
||||||
>>> llm.create_chat_completion(
|
>>> llm.create_chat_completion(
|
||||||
messages = [
|
messages = [
|
||||||
|
@ -517,14 +517,45 @@ Then you'll need to use a custom chat handler to load the clip model and process
|
||||||
{
|
{
|
||||||
"role": "user",
|
"role": "user",
|
||||||
"content": [
|
"content": [
|
||||||
{"type": "image_url", "image_url": {"url": "https://.../image.png"}},
|
{"type" : "text", "text": "What's in this image?"},
|
||||||
{"type" : "text", "text": "Describe this image in detail please."}
|
{"type": "image_url", "image_url": {"url": "https://upload.wikimedia.org/wikipedia/commons/thumb/d/dd/Gfp-wisconsin-madison-the-nature-boardwalk.jpg/2560px-Gfp-wisconsin-madison-the-nature-boardwalk.jpg" } }
|
||||||
]
|
]
|
||||||
}
|
}
|
||||||
]
|
]
|
||||||
)
|
)
|
||||||
```
|
```
|
||||||
|
|
||||||
|
You can also pull the model from the Hugging Face Hub using the `from_pretrained` method.
|
||||||
|
|
||||||
|
```python
|
||||||
|
>>> from llama_cpp import Llama
|
||||||
|
>>> from llama_cpp.llama_chat_format import MoondreamChatHandler
|
||||||
|
>>> chat_handler = MoondreamChatHandler.from_pretrained(
|
||||||
|
repo_id="vikhyatk/moondream2",
|
||||||
|
filename="*mmproj*",
|
||||||
|
)
|
||||||
|
>>> llm = Llama.from_pretrained(
|
||||||
|
repo_id="vikhyatk/moondream2"
|
||||||
|
filename="*text-model*",
|
||||||
|
chat_handler=chat_handler,
|
||||||
|
n_ctx=2048, # n_ctx should be increased to accomodate the image embedding
|
||||||
|
)
|
||||||
|
>>> llm.create_chat_completion(
|
||||||
|
messages = [
|
||||||
|
{
|
||||||
|
"role": "user",
|
||||||
|
"content": [
|
||||||
|
{"type" : "text", "text": "What's in this image?"},
|
||||||
|
{"type": "image_url", "image_url": {"url": "https://upload.wikimedia.org/wikipedia/commons/thumb/d/dd/Gfp-wisconsin-madison-the-nature-boardwalk.jpg/2560px-Gfp-wisconsin-madison-the-nature-boardwalk.jpg" } }
|
||||||
|
|
||||||
|
]
|
||||||
|
}
|
||||||
|
]
|
||||||
|
)
|
||||||
|
```
|
||||||
|
|
||||||
|
**Note**: Multi-modal models also support tool calling and JSON mode.
|
||||||
|
|
||||||
<details>
|
<details>
|
||||||
<summary>Loading a Local Image</summary>
|
<summary>Loading a Local Image</summary>
|
||||||
|
|
||||||
|
|
|
@ -98,6 +98,8 @@ You'll first need to download one of the available multi-modal models in GGUF fo
|
||||||
- [llava-v1.5-7b](https://huggingface.co/mys/ggml_llava-v1.5-7b)
|
- [llava-v1.5-7b](https://huggingface.co/mys/ggml_llava-v1.5-7b)
|
||||||
- [llava-v1.5-13b](https://huggingface.co/mys/ggml_llava-v1.5-13b)
|
- [llava-v1.5-13b](https://huggingface.co/mys/ggml_llava-v1.5-13b)
|
||||||
- [bakllava-1-7b](https://huggingface.co/mys/ggml_bakllava-1)
|
- [bakllava-1-7b](https://huggingface.co/mys/ggml_bakllava-1)
|
||||||
|
- [llava-v1.6-34b](https://huggingface.co/cjpais/llava-v1.6-34B-gguf)
|
||||||
|
- [moondream2](https://huggingface.co/vikhyatk/moondream2)
|
||||||
|
|
||||||
Then when you run the server you'll need to also specify the path to the clip model used for image embedding and the `llava-1-5` chat_format
|
Then when you run the server you'll need to also specify the path to the clip model used for image embedding and the `llava-1-5` chat_format
|
||||||
|
|
||||||
|
|
|
@ -6,6 +6,8 @@ import ctypes
|
||||||
import dataclasses
|
import dataclasses
|
||||||
import random
|
import random
|
||||||
import string
|
import string
|
||||||
|
|
||||||
|
from contextlib import ExitStack
|
||||||
from typing import Any, Dict, Iterator, List, Literal, Optional, Tuple, Union, Protocol, cast
|
from typing import Any, Dict, Iterator, List, Literal, Optional, Tuple, Union, Protocol, cast
|
||||||
|
|
||||||
import jinja2
|
import jinja2
|
||||||
|
@ -2163,42 +2165,80 @@ def functionary_v1_v2_chat_handler(
|
||||||
|
|
||||||
|
|
||||||
class Llava15ChatHandler:
|
class Llava15ChatHandler:
|
||||||
_clip_free = None
|
DEFAULT_SYSTEM_MESSAGE = "A chat between a curious human and an artificial intelligence assistant. The assistant gives helpful, detailed, and polite answers to the human's questions."
|
||||||
|
|
||||||
|
CHAT_FORMAT = (
|
||||||
|
"{% for message in messages %}"
|
||||||
|
"{% if message.role == 'system' %}"
|
||||||
|
"{{ message.content }}"
|
||||||
|
"{% endif %}"
|
||||||
|
"{% if message.role == 'user' %}"
|
||||||
|
"{% if message.content is string %}"
|
||||||
|
"\nUSER: {{ message.content }}"
|
||||||
|
"{% elif message.content is iterable %}"
|
||||||
|
"\nUSER: "
|
||||||
|
"{% for content in message.content %}"
|
||||||
|
"{% if content.type == 'text' %}"
|
||||||
|
"{{ content.text }}"
|
||||||
|
"{% endif %}"
|
||||||
|
"{% if content.type == 'image_url' and content.image_url is string %}"
|
||||||
|
"{{ content.image_url }}"
|
||||||
|
"{% endif %}"
|
||||||
|
"{% if content.type == 'image_url' and content.image_url is mapping %}"
|
||||||
|
"{{ content.image_url.url }}"
|
||||||
|
"{% endif %}"
|
||||||
|
"{% endfor %}"
|
||||||
|
"{% endif %}"
|
||||||
|
"{% endif %}"
|
||||||
|
"{% if message.role == 'assistant' and message.content is not none %}"
|
||||||
|
"\nASSISTANT: {{ message.content }}"
|
||||||
|
"{% endif %}"
|
||||||
|
"{% endfor %}"
|
||||||
|
"{% if add_generation_prompt %}"
|
||||||
|
"\nASSISTANT: "
|
||||||
|
"{% endif %}"
|
||||||
|
)
|
||||||
|
|
||||||
def __init__(self, clip_model_path: str, verbose: bool = False):
|
def __init__(self, clip_model_path: str, verbose: bool = False):
|
||||||
import llama_cpp.llava_cpp as llava_cpp
|
import llama_cpp.llava_cpp as llava_cpp
|
||||||
|
|
||||||
self._llava_cpp = llava_cpp
|
|
||||||
self.clip_model_path = clip_model_path
|
self.clip_model_path = clip_model_path
|
||||||
self.verbose = verbose
|
self.verbose = verbose
|
||||||
self._clip_free = self._llava_cpp._libllava.clip_free # type: ignore
|
|
||||||
|
self._llava_cpp = llava_cpp # TODO: Fix
|
||||||
|
self._exit_stack = ExitStack()
|
||||||
|
self._last_image_embed: Optional[llava_cpp.CtypesPointer[llava_cpp.llava_image_embed]] = None
|
||||||
|
self._last_image_hash: Optional[int] = None
|
||||||
|
|
||||||
if not os.path.exists(clip_model_path):
|
if not os.path.exists(clip_model_path):
|
||||||
raise ValueError(f"Clip model path does not exist: {clip_model_path}")
|
raise ValueError(f"Clip model path does not exist: {clip_model_path}")
|
||||||
|
|
||||||
with suppress_stdout_stderr(disable=self.verbose):
|
with suppress_stdout_stderr(disable=self.verbose):
|
||||||
self.clip_ctx = self._llava_cpp.clip_model_load(
|
clip_ctx = self._llava_cpp.clip_model_load(
|
||||||
self.clip_model_path.encode(), 0
|
self.clip_model_path.encode(), 0
|
||||||
)
|
)
|
||||||
|
|
||||||
def __del__(self):
|
if clip_ctx is None:
|
||||||
|
raise ValueError(f"Failed to load clip model: {clip_model_path}")
|
||||||
|
|
||||||
|
self.clip_ctx = clip_ctx
|
||||||
|
|
||||||
|
def clip_free():
|
||||||
with suppress_stdout_stderr(disable=self.verbose):
|
with suppress_stdout_stderr(disable=self.verbose):
|
||||||
if self.clip_ctx is not None and self._clip_free is not None:
|
self._llava_cpp.clip_free(self.clip_ctx)
|
||||||
self._clip_free(self.clip_ctx)
|
|
||||||
self.clip_ctx = None
|
self._exit_stack.callback(clip_free)
|
||||||
|
|
||||||
|
def last_image_embed_free():
|
||||||
|
with suppress_stdout_stderr(disable=self.verbose):
|
||||||
|
if self._last_image_embed is not None:
|
||||||
|
self._llava_cpp.llava_image_embed_free(self._last_image_embed)
|
||||||
|
self._last_image_embed = None
|
||||||
|
|
||||||
|
self._exit_stack.callback(last_image_embed_free)
|
||||||
|
|
||||||
def load_image(self, image_url: str) -> bytes:
|
def load_image(self, image_url: str) -> bytes:
|
||||||
if image_url.startswith("data:"):
|
return self._load_image(image_url)
|
||||||
import base64
|
|
||||||
|
|
||||||
image_bytes = base64.b64decode(image_url.split(",")[1])
|
|
||||||
return image_bytes
|
|
||||||
else:
|
|
||||||
import urllib.request
|
|
||||||
|
|
||||||
with urllib.request.urlopen(image_url) as f:
|
|
||||||
image_bytes = f.read()
|
|
||||||
return image_bytes
|
|
||||||
|
|
||||||
def __call__(
|
def __call__(
|
||||||
self,
|
self,
|
||||||
|
@ -2216,6 +2256,7 @@ class Llava15ChatHandler:
|
||||||
typical_p: float = 1.0,
|
typical_p: float = 1.0,
|
||||||
stream: bool = False,
|
stream: bool = False,
|
||||||
stop: Optional[Union[str, List[str]]] = [],
|
stop: Optional[Union[str, List[str]]] = [],
|
||||||
|
seed: Optional[int] = None,
|
||||||
response_format: Optional[
|
response_format: Optional[
|
||||||
llama_types.ChatCompletionRequestResponseFormat
|
llama_types.ChatCompletionRequestResponseFormat
|
||||||
] = None,
|
] = None,
|
||||||
|
@ -2230,68 +2271,54 @@ class Llava15ChatHandler:
|
||||||
model: Optional[str] = None,
|
model: Optional[str] = None,
|
||||||
logits_processor: Optional[llama.LogitsProcessorList] = None,
|
logits_processor: Optional[llama.LogitsProcessorList] = None,
|
||||||
grammar: Optional[llama.LlamaGrammar] = None,
|
grammar: Optional[llama.LlamaGrammar] = None,
|
||||||
|
logit_bias: Optional[Dict[str, float]] = None,
|
||||||
|
logprobs: Optional[bool] = None,
|
||||||
|
top_logprobs: Optional[int] = None,
|
||||||
**kwargs, # type: ignore
|
**kwargs, # type: ignore
|
||||||
) -> Union[
|
) -> Union[
|
||||||
llama_types.CreateChatCompletionResponse,
|
llama_types.CreateChatCompletionResponse,
|
||||||
Iterator[llama_types.CreateChatCompletionStreamResponse],
|
Iterator[llama_types.CreateChatCompletionStreamResponse],
|
||||||
]:
|
]:
|
||||||
assert (
|
|
||||||
llama.context_params.logits_all is True
|
|
||||||
) # BUG: logits_all=True is required for llava
|
|
||||||
assert self.clip_ctx is not None
|
assert self.clip_ctx is not None
|
||||||
system_prompt = _get_system_message(messages)
|
|
||||||
system_prompt = (
|
|
||||||
system_prompt
|
|
||||||
if system_prompt != ""
|
|
||||||
else "A chat between a curious human and an artificial intelligence assistant. The assistant gives helpful, detailed, and polite answers to the human's questions."
|
|
||||||
)
|
|
||||||
user_role = "\nUSER:"
|
|
||||||
assistant_role = "\nASSISTANT:"
|
|
||||||
llama.reset()
|
|
||||||
llama.eval(llama.tokenize(system_prompt.encode("utf8"), add_bos=True))
|
|
||||||
for message in messages:
|
|
||||||
if message["role"] == "user" and message["content"] is not None:
|
|
||||||
if isinstance(message["content"], str):
|
|
||||||
llama.eval(
|
|
||||||
llama.tokenize(
|
|
||||||
f"{user_role} {message['content']}".encode("utf8"),
|
|
||||||
add_bos=False,
|
|
||||||
)
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
assert isinstance(message["content"], list)
|
|
||||||
llama.eval(
|
|
||||||
llama.tokenize(f"{user_role} ".encode("utf8"), add_bos=False)
|
|
||||||
)
|
|
||||||
for content in message["content"]:
|
|
||||||
if content["type"] == "text":
|
|
||||||
llama.eval(
|
|
||||||
llama.tokenize(
|
|
||||||
f"{content['text']}".encode("utf8"), add_bos=False
|
|
||||||
)
|
|
||||||
)
|
|
||||||
if content["type"] == "image_url":
|
|
||||||
image_bytes = (
|
|
||||||
self.load_image(content["image_url"]["url"])
|
|
||||||
if isinstance(content["image_url"], dict)
|
|
||||||
else self.load_image(content["image_url"])
|
|
||||||
)
|
|
||||||
import array
|
|
||||||
|
|
||||||
data_array = array.array("B", image_bytes)
|
system_prompt = _get_system_message(messages)
|
||||||
c_ubyte_ptr = (
|
if system_prompt == "":
|
||||||
ctypes.c_ubyte * len(data_array)
|
messages = [llama_types.ChatCompletionRequestSystemMessage(role="system", content=self.DEFAULT_SYSTEM_MESSAGE)] + messages
|
||||||
).from_buffer(data_array)
|
|
||||||
|
image_urls = self.get_image_urls(messages)
|
||||||
|
template = jinja2.Template(self.CHAT_FORMAT)
|
||||||
|
text = template.render(messages=messages, add_generation_prompt=True)
|
||||||
|
split_text = self.split_text_on_image_urls(text, image_urls)
|
||||||
|
|
||||||
|
def embed_image_bytes(image_bytes: bytes):
|
||||||
|
if self._last_image_embed is not None and self._last_image_hash is not None and hash(image_bytes) == self._last_image_hash:
|
||||||
|
return self._last_image_embed
|
||||||
with suppress_stdout_stderr(disable=self.verbose):
|
with suppress_stdout_stderr(disable=self.verbose):
|
||||||
embed = (
|
embed = (
|
||||||
self._llava_cpp.llava_image_embed_make_with_bytes(
|
self._llava_cpp.llava_image_embed_make_with_bytes(
|
||||||
self.clip_ctx,
|
self.clip_ctx,
|
||||||
llama.context_params.n_threads,
|
llama.context_params.n_threads_batch,
|
||||||
c_ubyte_ptr,
|
(ctypes.c_uint8 * len(image_bytes)).from_buffer(bytearray(image_bytes)),
|
||||||
len(image_bytes),
|
len(image_bytes),
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
try:
|
self._last_image_embed = embed
|
||||||
|
self._last_image_hash = hash(image_bytes)
|
||||||
|
return embed
|
||||||
|
|
||||||
|
# Evaluate prompt
|
||||||
|
llama.reset()
|
||||||
|
for i, (type_, value) in enumerate(split_text):
|
||||||
|
if type_ == "text":
|
||||||
|
tokens = llama.tokenize(value.encode("utf8"), add_bos=i == 0)
|
||||||
|
if llama.n_tokens + len(tokens) > llama.n_ctx():
|
||||||
|
raise ValueError("Prompt exceeds n_ctx") # TODO: Fix
|
||||||
|
llama.eval(tokens)
|
||||||
|
else:
|
||||||
|
image_bytes = self.load_image(value)
|
||||||
|
embed = embed_image_bytes(image_bytes)
|
||||||
|
if llama.n_tokens + embed.contents.n_image_pos > llama.n_ctx():
|
||||||
|
raise ValueError("Prompt exceeds n_ctx") # TODO: Fix
|
||||||
n_past = ctypes.c_int(llama.n_tokens)
|
n_past = ctypes.c_int(llama.n_tokens)
|
||||||
n_past_p = ctypes.pointer(n_past)
|
n_past_p = ctypes.pointer(n_past)
|
||||||
with suppress_stdout_stderr(disable=self.verbose):
|
with suppress_stdout_stderr(disable=self.verbose):
|
||||||
|
@ -2301,36 +2328,66 @@ class Llava15ChatHandler:
|
||||||
llama.n_batch,
|
llama.n_batch,
|
||||||
n_past_p,
|
n_past_p,
|
||||||
)
|
)
|
||||||
assert llama.n_ctx() >= n_past.value
|
|
||||||
llama.n_tokens = n_past.value
|
llama.n_tokens = n_past.value
|
||||||
finally:
|
|
||||||
with suppress_stdout_stderr(disable=self.verbose):
|
|
||||||
self._llava_cpp.llava_image_embed_free(embed)
|
|
||||||
if message["role"] == "assistant" and message["content"] is not None:
|
|
||||||
llama.eval(
|
|
||||||
llama.tokenize(
|
|
||||||
f"ASSISTANT: {message['content']}".encode("utf8"), add_bos=False
|
|
||||||
)
|
|
||||||
)
|
|
||||||
assert llama.n_ctx() >= llama.n_tokens
|
|
||||||
llama.eval(llama.tokenize(f"{assistant_role}".encode("utf8"), add_bos=False))
|
|
||||||
assert llama.n_ctx() >= llama.n_tokens
|
|
||||||
|
|
||||||
|
# Get prompt tokens to avoid a cache miss
|
||||||
prompt = llama.input_ids[: llama.n_tokens].tolist()
|
prompt = llama.input_ids[: llama.n_tokens].tolist()
|
||||||
|
|
||||||
if response_format is not None and response_format["type"] == "json_object":
|
if response_format is not None and response_format["type"] == "json_object":
|
||||||
grammar = _grammar_for_response_format(response_format)
|
grammar = _grammar_for_response_format(response_format)
|
||||||
|
|
||||||
return _convert_completion_to_chat(
|
# Convert legacy functions to tools
|
||||||
llama.create_completion(
|
if functions is not None:
|
||||||
|
tools = [
|
||||||
|
{
|
||||||
|
"type": "function",
|
||||||
|
"function": function,
|
||||||
|
}
|
||||||
|
for function in functions
|
||||||
|
]
|
||||||
|
|
||||||
|
# Convert legacy function_call to tool_choice
|
||||||
|
if function_call is not None:
|
||||||
|
if isinstance(function_call, str) and (
|
||||||
|
function_call == "none" or function_call == "auto"
|
||||||
|
):
|
||||||
|
tool_choice = function_call
|
||||||
|
if isinstance(function_call, dict) and "name" in function_call:
|
||||||
|
tool_choice = {
|
||||||
|
"type": "function",
|
||||||
|
"function": {
|
||||||
|
"name": function_call["name"],
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
tool = None
|
||||||
|
if tool_choice is not None and isinstance(tool_choice, dict) and tools is not None:
|
||||||
|
name = tool_choice["function"]["name"]
|
||||||
|
tool = next((t for t in tools if t["function"]["name"] == name), None)
|
||||||
|
if tool is None:
|
||||||
|
raise ValueError(f"Tool choice '{name}' not found in tools.")
|
||||||
|
schema = tool["function"]["parameters"]
|
||||||
|
try:
|
||||||
|
# create grammar from json schema
|
||||||
|
grammar = llama_grammar.LlamaGrammar.from_json_schema(
|
||||||
|
json.dumps(schema), verbose=llama.verbose
|
||||||
|
)
|
||||||
|
except Exception as e:
|
||||||
|
grammar = llama_grammar.LlamaGrammar.from_string(
|
||||||
|
llama_grammar.JSON_GBNF, verbose=llama.verbose
|
||||||
|
)
|
||||||
|
|
||||||
|
completion_or_chunks = llama.create_completion(
|
||||||
prompt=prompt,
|
prompt=prompt,
|
||||||
temperature=temperature,
|
temperature=temperature,
|
||||||
top_p=top_p,
|
top_p=top_p,
|
||||||
top_k=top_k,
|
top_k=top_k,
|
||||||
min_p=min_p,
|
min_p=min_p,
|
||||||
typical_p=typical_p,
|
typical_p=typical_p,
|
||||||
|
logprobs=top_logprobs if logprobs else None,
|
||||||
stream=stream,
|
stream=stream,
|
||||||
stop=stop,
|
stop=stop,
|
||||||
|
seed=seed,
|
||||||
max_tokens=max_tokens,
|
max_tokens=max_tokens,
|
||||||
presence_penalty=presence_penalty,
|
presence_penalty=presence_penalty,
|
||||||
frequency_penalty=frequency_penalty,
|
frequency_penalty=frequency_penalty,
|
||||||
|
@ -2342,8 +2399,348 @@ class Llava15ChatHandler:
|
||||||
model=model,
|
model=model,
|
||||||
logits_processor=logits_processor,
|
logits_processor=logits_processor,
|
||||||
grammar=grammar,
|
grammar=grammar,
|
||||||
),
|
logit_bias=logit_bias,
|
||||||
stream=stream,
|
)
|
||||||
|
if tool is not None:
|
||||||
|
tool_name = tool["function"]["name"]
|
||||||
|
return _convert_completion_to_chat_function(
|
||||||
|
tool_name, completion_or_chunks, stream
|
||||||
|
)
|
||||||
|
return _convert_completion_to_chat(completion_or_chunks, stream=stream)
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def _load_image(image_url: str) -> bytes:
|
||||||
|
# TODO: Add Pillow support for other image formats beyond (jpg, png)
|
||||||
|
if image_url.startswith("data:"):
|
||||||
|
import base64
|
||||||
|
|
||||||
|
image_bytes = base64.b64decode(image_url.split(",")[1])
|
||||||
|
return image_bytes
|
||||||
|
else:
|
||||||
|
import urllib.request
|
||||||
|
|
||||||
|
with urllib.request.urlopen(image_url) as f:
|
||||||
|
image_bytes = f.read()
|
||||||
|
return image_bytes
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def get_image_urls(messages: List[llama_types.ChatCompletionRequestMessage]):
|
||||||
|
image_urls: List[str] = []
|
||||||
|
for message in messages:
|
||||||
|
if message["role"] == "user":
|
||||||
|
if message["content"] is None:
|
||||||
|
continue
|
||||||
|
for content in message["content"]:
|
||||||
|
if isinstance(content, dict) and "type" in content:
|
||||||
|
if content["type"] == "image_url":
|
||||||
|
if (
|
||||||
|
isinstance(content["image_url"], dict)
|
||||||
|
and "url" in content["image_url"]
|
||||||
|
):
|
||||||
|
image_urls.append(content["image_url"]["url"])
|
||||||
|
else:
|
||||||
|
image_urls.append(content["image_url"])
|
||||||
|
return image_urls
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def split_text_on_image_urls(text: str, image_urls: List[str]):
|
||||||
|
def find_first(s: str, substrs: List[str]):
|
||||||
|
for i, substr in enumerate(substrs):
|
||||||
|
pos = s.find(substr)
|
||||||
|
if pos != -1:
|
||||||
|
return pos, i
|
||||||
|
return None, None
|
||||||
|
|
||||||
|
split_text: List[Tuple[Literal["text", "image_url"], str]] = []
|
||||||
|
remaining = text
|
||||||
|
while remaining:
|
||||||
|
# Find first image_url
|
||||||
|
pos, i = find_first(remaining, image_urls)
|
||||||
|
if pos is not None and i is not None:
|
||||||
|
if pos > 0:
|
||||||
|
split_text.append(("text", remaining[:pos]))
|
||||||
|
split_text.append(("image_url", image_urls[i]))
|
||||||
|
remaining = remaining[pos + len(image_urls[i]) :]
|
||||||
|
else:
|
||||||
|
split_text.append(("text", remaining))
|
||||||
|
remaining = ""
|
||||||
|
return split_text
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def from_pretrained(
|
||||||
|
cls,
|
||||||
|
repo_id: str,
|
||||||
|
filename: Optional[str],
|
||||||
|
local_dir: Optional[Union[str, os.PathLike[str]]] = None,
|
||||||
|
local_dir_use_symlinks: Union[bool, Literal["auto"]] = "auto",
|
||||||
|
cache_dir: Optional[Union[str, os.PathLike[str]]] = None,
|
||||||
|
**kwargs: Any,
|
||||||
|
) -> "Llava15ChatHandler":
|
||||||
|
import fnmatch
|
||||||
|
from pathlib import Path
|
||||||
|
try:
|
||||||
|
from huggingface_hub import hf_hub_download, HfFileSystem # type: ignore
|
||||||
|
from huggingface_hub.utils import validate_repo_id # type: ignore
|
||||||
|
except ImportError:
|
||||||
|
raise ImportError(
|
||||||
|
"Llama.from_pretrained requires the huggingface-hub package. "
|
||||||
|
"You can install it with `pip install huggingface-hub`."
|
||||||
|
)
|
||||||
|
|
||||||
|
validate_repo_id(repo_id)
|
||||||
|
|
||||||
|
hffs = HfFileSystem()
|
||||||
|
|
||||||
|
files = [
|
||||||
|
file["name"] if isinstance(file, dict) else file
|
||||||
|
for file in hffs.ls(repo_id) # type: ignore
|
||||||
|
]
|
||||||
|
|
||||||
|
# split each file into repo_id, subfolder, filename
|
||||||
|
file_list: List[str] = []
|
||||||
|
for file in files:
|
||||||
|
rel_path = Path(file).relative_to(repo_id)
|
||||||
|
file_list.append(str(rel_path))
|
||||||
|
|
||||||
|
matching_files = [file for file in file_list if fnmatch.fnmatch(file, filename)] # type: ignore
|
||||||
|
|
||||||
|
if len(matching_files) == 0:
|
||||||
|
raise ValueError(
|
||||||
|
f"No file found in {repo_id} that match {filename}\n\n"
|
||||||
|
f"Available Files:\n{json.dumps(file_list)}"
|
||||||
|
)
|
||||||
|
|
||||||
|
if len(matching_files) > 1:
|
||||||
|
raise ValueError(
|
||||||
|
f"Multiple files found in {repo_id} matching {filename}\n\n"
|
||||||
|
f"Available Files:\n{json.dumps(files)}"
|
||||||
|
)
|
||||||
|
|
||||||
|
(matching_file,) = matching_files
|
||||||
|
|
||||||
|
subfolder = str(Path(matching_file).parent)
|
||||||
|
filename = Path(matching_file).name
|
||||||
|
|
||||||
|
# download the file
|
||||||
|
hf_hub_download(
|
||||||
|
repo_id=repo_id,
|
||||||
|
filename=filename,
|
||||||
|
subfolder=subfolder,
|
||||||
|
local_dir=cast(Union[str, Path, None], local_dir),
|
||||||
|
local_dir_use_symlinks=local_dir_use_symlinks,
|
||||||
|
cache_dir=cast(Union[str, Path, None], cache_dir),
|
||||||
|
)
|
||||||
|
|
||||||
|
if local_dir is None:
|
||||||
|
model_path = hf_hub_download(
|
||||||
|
repo_id=repo_id,
|
||||||
|
filename=filename,
|
||||||
|
subfolder=subfolder,
|
||||||
|
local_dir=local_dir,
|
||||||
|
local_dir_use_symlinks=local_dir_use_symlinks,
|
||||||
|
cache_dir=cast(Union[str, Path, None], cache_dir),
|
||||||
|
local_files_only=True,
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
model_path = os.path.join(local_dir, filename)
|
||||||
|
|
||||||
|
return cls(
|
||||||
|
clip_model_path=model_path,
|
||||||
|
**kwargs,
|
||||||
|
)
|
||||||
|
|
||||||
|
class ObsidianChatHandler(Llava15ChatHandler):
|
||||||
|
# Prompt Format
|
||||||
|
# The model followed ChatML format. However, with ### as the seperator
|
||||||
|
|
||||||
|
# <|im_start|>user
|
||||||
|
# What is this sign about?\n<image>
|
||||||
|
# ###
|
||||||
|
# <|im_start|>assistant
|
||||||
|
# The sign is about bullying, and it is placed on a black background with a red background.
|
||||||
|
# ###
|
||||||
|
|
||||||
|
CHAT_FORMAT = (
|
||||||
|
"{% for message in messages %}"
|
||||||
|
# System message
|
||||||
|
"{% if message.role == 'system' %}"
|
||||||
|
"<|im_start|>system\n"
|
||||||
|
"{{ message.content }}\n"
|
||||||
|
"###\n"
|
||||||
|
"{% endif %}"
|
||||||
|
# User message
|
||||||
|
"{% if message.role == 'user' %}"
|
||||||
|
"<|im_start|>user\n"
|
||||||
|
"{% if message.content is string %}"
|
||||||
|
"{{ message.content }}"
|
||||||
|
"{% endif %}"
|
||||||
|
"{% if message.content is iterable %}"
|
||||||
|
"{% for content in message.content %}"
|
||||||
|
"{% if content.type == 'text' %}"
|
||||||
|
"{{ content.text }}"
|
||||||
|
"{% endif %}"
|
||||||
|
"{% if content.type == 'image_url' %}"
|
||||||
|
"{{ content.image_url }}"
|
||||||
|
"{% endif %}"
|
||||||
|
"{% endfor %}"
|
||||||
|
"{% endif %}"
|
||||||
|
"###\n"
|
||||||
|
"{% endif %}"
|
||||||
|
# Assistant message
|
||||||
|
"{% if message.role == 'assistant' %}"
|
||||||
|
"<|im_start|>assistant\n"
|
||||||
|
"{{ message.content }}"
|
||||||
|
"###\n"
|
||||||
|
"{% endif %}"
|
||||||
|
"{% endfor %}"
|
||||||
|
# Generation prompt
|
||||||
|
"{% if add_generation_prompt %}"
|
||||||
|
"<|im_start|>assistant\n"
|
||||||
|
"{% endif %}"
|
||||||
|
)
|
||||||
|
|
||||||
|
class MoondreamChatHandler(Llava15ChatHandler):
|
||||||
|
# Chat Format:
|
||||||
|
# f"<image>\n\n{chat_history}Question: {question}\n\nAnswer:"
|
||||||
|
CHAT_FORMAT = (
|
||||||
|
"{% for message in messages %}"
|
||||||
|
"{% if message.role == 'user' %}"
|
||||||
|
"{% if message.content is iterable %}"
|
||||||
|
"{% for content in message.content %}"
|
||||||
|
|
||||||
|
# <image>
|
||||||
|
"{% if content.type == 'image_url' %}"
|
||||||
|
"{% if content.image_url is string %}"
|
||||||
|
"{{ content.image_url }}\n\n"
|
||||||
|
"{% endif %}"
|
||||||
|
"{% if content.image_url is mapping %}"
|
||||||
|
"{{ content.image_url.url }}\n\n"
|
||||||
|
"{% endif %}"
|
||||||
|
"{% endif %}"
|
||||||
|
|
||||||
|
# Question:
|
||||||
|
"{% if content.type == 'text' %}"
|
||||||
|
"Question: {{ content.text }}\n\n"
|
||||||
|
"{% endif %}"
|
||||||
|
"{% endfor %}"
|
||||||
|
"{% endif %}"
|
||||||
|
|
||||||
|
# Question:
|
||||||
|
"{% if message.content is string %}"
|
||||||
|
"Question: {{ message.content }}\n\n"
|
||||||
|
"{% endif %}"
|
||||||
|
|
||||||
|
"{% endif %}"
|
||||||
|
|
||||||
|
# Answer:
|
||||||
|
"{% if message.role == 'assistant' %}"
|
||||||
|
"Answer:{{ message.content }}\n\n"
|
||||||
|
"{% endif %}"
|
||||||
|
"{% endfor %}"
|
||||||
|
|
||||||
|
# Generation prompt
|
||||||
|
"{% if add_generation_prompt %}"
|
||||||
|
"Answer:"
|
||||||
|
"{% endif %}"
|
||||||
|
)
|
||||||
|
|
||||||
|
class Llava16ChatHandler(Llava15ChatHandler):
|
||||||
|
DEFAULT_SYSTEM_MESSAGE = "A chat between a curious human and an artificial intelligence assistant. The assistant gives helpful, detailed, and polite answers to the human's questions. "
|
||||||
|
|
||||||
|
# Example prompt
|
||||||
|
# "A chat between a curious human and an artificial intelligence assistant. The assistant gives helpful, detailed, and polite answers to the human's questions. USER: <image>\nWhat is shown in this image? ASSISTANT:"
|
||||||
|
|
||||||
|
CHAT_FORMAT = (
|
||||||
|
"{% for message in messages %}"
|
||||||
|
"{% if message.role == 'system' %}"
|
||||||
|
"{{ message.content }}"
|
||||||
|
"{% endif %}"
|
||||||
|
"{% if message.role == 'user' %}"
|
||||||
|
"{% if message.content is iterable %}"
|
||||||
|
"{% for content in message.content %}"
|
||||||
|
|
||||||
|
# <image>
|
||||||
|
"{% if content.type == 'image_url' %}"
|
||||||
|
"{% if content.image_url is string %}"
|
||||||
|
"{{ content.image_url }}\n"
|
||||||
|
"{% endif %}"
|
||||||
|
"{% if content.image_url is mapping %}"
|
||||||
|
"{{ content.image_url.url }}\n"
|
||||||
|
"{% endif %}"
|
||||||
|
"{% endif %}"
|
||||||
|
|
||||||
|
# Question:
|
||||||
|
"{% if content.type == 'text' %}"
|
||||||
|
"{{ content.text }}"
|
||||||
|
"{% endif %}"
|
||||||
|
"{% endfor %}"
|
||||||
|
"{% endif %}"
|
||||||
|
|
||||||
|
# Question:
|
||||||
|
"{% if message.content is string %}"
|
||||||
|
"{{ message.content }}"
|
||||||
|
"{% endif %}"
|
||||||
|
|
||||||
|
"{% endif %}"
|
||||||
|
|
||||||
|
# Answer:
|
||||||
|
"{% if message.role == 'assistant' %}"
|
||||||
|
"{{ message.content }}"
|
||||||
|
"{% endif %}"
|
||||||
|
"{% endfor %}"
|
||||||
|
|
||||||
|
# Generation prompt
|
||||||
|
"{% if add_generation_prompt %}"
|
||||||
|
"Answer:"
|
||||||
|
"{% endif %}"
|
||||||
|
)
|
||||||
|
|
||||||
|
class NanoLlavaChatHandler(Llava15ChatHandler):
|
||||||
|
# Prompt Format
|
||||||
|
# The model follow the ChatML standard, however, without \n at the end of <|im_end|>:
|
||||||
|
|
||||||
|
# <|im_start|>system
|
||||||
|
# Answer the question<|im_end|><|im_start|>user
|
||||||
|
# <image>
|
||||||
|
# What is the picture about?<|im_end|><|im_start|>assistant
|
||||||
|
|
||||||
|
CHAT_FORMAT = (
|
||||||
|
"{% for message in messages %}"
|
||||||
|
# System message
|
||||||
|
"{% if message.role == 'system' %}"
|
||||||
|
"<|im_start|>system\n"
|
||||||
|
"{{ message.content }}"
|
||||||
|
"<|im_end|>"
|
||||||
|
"{% endif %}"
|
||||||
|
# User message
|
||||||
|
"{% if message.role == 'user' %}"
|
||||||
|
"<|im_start|>user\n"
|
||||||
|
"{% if message.content is string %}"
|
||||||
|
"{{ message.content }}"
|
||||||
|
"{% endif %}"
|
||||||
|
"{% if message.content is iterable %}"
|
||||||
|
"{% for content in message.content %}"
|
||||||
|
"{% if content.type == 'text' %}"
|
||||||
|
"{{ content.text }}"
|
||||||
|
"{% endif %}"
|
||||||
|
"{% if content.type == 'image_url' %}"
|
||||||
|
"{{ content.image_url }}"
|
||||||
|
"{% endif %}"
|
||||||
|
"{% endfor %}"
|
||||||
|
"{% endif %}"
|
||||||
|
"<|im_end|>"
|
||||||
|
"{% endif %}"
|
||||||
|
# Assistant message
|
||||||
|
"{% if message.role == 'assistant' %}"
|
||||||
|
"<|im_start|>assistant\n"
|
||||||
|
"{{ message.content }}"
|
||||||
|
"<|im_end|>"
|
||||||
|
"{% endif %}"
|
||||||
|
"{% endfor %}"
|
||||||
|
# Generation prompt
|
||||||
|
"{% if add_generation_prompt %}"
|
||||||
|
"<|im_start|>assistant\n"
|
||||||
|
"{% endif %}"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
|
|
@ -1,3 +1,5 @@
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
import sys
|
import sys
|
||||||
import os
|
import os
|
||||||
import ctypes
|
import ctypes
|
||||||
|
@ -14,10 +16,22 @@ from ctypes import (
|
||||||
Structure,
|
Structure,
|
||||||
)
|
)
|
||||||
import pathlib
|
import pathlib
|
||||||
from typing import List, Union, NewType, Optional, TypeVar, Callable, Any
|
from typing import (
|
||||||
|
List,
|
||||||
|
Union,
|
||||||
|
NewType,
|
||||||
|
Optional,
|
||||||
|
TypeVar,
|
||||||
|
Callable,
|
||||||
|
Any,
|
||||||
|
TYPE_CHECKING,
|
||||||
|
Generic,
|
||||||
|
)
|
||||||
|
from typing_extensions import TypeAlias
|
||||||
|
|
||||||
import llama_cpp.llama_cpp as llama_cpp
|
import llama_cpp.llama_cpp as llama_cpp
|
||||||
|
|
||||||
|
|
||||||
# Load the library
|
# Load the library
|
||||||
def _load_shared_library(lib_base_name: str):
|
def _load_shared_library(lib_base_name: str):
|
||||||
# Construct the paths to the possible shared library names
|
# Construct the paths to the possible shared library names
|
||||||
|
@ -79,8 +93,27 @@ _libllava = _load_shared_library(_libllava_base_name)
|
||||||
|
|
||||||
# ctypes helper
|
# ctypes helper
|
||||||
|
|
||||||
|
if TYPE_CHECKING:
|
||||||
|
CtypesCData = TypeVar("CtypesCData", bound=ctypes._CData) # type: ignore
|
||||||
|
|
||||||
|
CtypesArray: TypeAlias = ctypes.Array[CtypesCData] # type: ignore
|
||||||
|
|
||||||
|
CtypesPointer: TypeAlias = ctypes._Pointer[CtypesCData] # type: ignore
|
||||||
|
|
||||||
|
CtypesVoidPointer: TypeAlias = ctypes.c_void_p
|
||||||
|
|
||||||
|
class CtypesRef(Generic[CtypesCData]):
|
||||||
|
pass
|
||||||
|
|
||||||
|
CtypesPointerOrRef: TypeAlias = Union[
|
||||||
|
CtypesPointer[CtypesCData], CtypesRef[CtypesCData]
|
||||||
|
]
|
||||||
|
|
||||||
|
CtypesFuncPointer: TypeAlias = ctypes._FuncPointer # type: ignore
|
||||||
|
|
||||||
F = TypeVar("F", bound=Callable[..., Any])
|
F = TypeVar("F", bound=Callable[..., Any])
|
||||||
|
|
||||||
|
|
||||||
def ctypes_function_for_shared_library(lib: ctypes.CDLL):
|
def ctypes_function_for_shared_library(lib: ctypes.CDLL):
|
||||||
def ctypes_function(
|
def ctypes_function(
|
||||||
name: str, argtypes: List[Any], restype: Any, enabled: bool = True
|
name: str, argtypes: List[Any], restype: Any, enabled: bool = True
|
||||||
|
@ -111,6 +144,7 @@ ctypes_function = ctypes_function_for_shared_library(_libllava)
|
||||||
clip_ctx_p = NewType("clip_ctx_p", int)
|
clip_ctx_p = NewType("clip_ctx_p", int)
|
||||||
clip_ctx_p_ctypes = c_void_p
|
clip_ctx_p_ctypes = c_void_p
|
||||||
|
|
||||||
|
|
||||||
# struct llava_image_embed {
|
# struct llava_image_embed {
|
||||||
# float * embed;
|
# float * embed;
|
||||||
# int n_image_pos;
|
# int n_image_pos;
|
||||||
|
@ -121,36 +155,72 @@ class llava_image_embed(Structure):
|
||||||
("n_image_pos", c_int),
|
("n_image_pos", c_int),
|
||||||
]
|
]
|
||||||
|
|
||||||
|
|
||||||
# /** sanity check for clip <-> llava embed size match */
|
# /** sanity check for clip <-> llava embed size match */
|
||||||
# LLAVA_API bool llava_validate_embed_size(const llama_context * ctx_llama, const clip_ctx * ctx_clip);
|
# LLAVA_API bool llava_validate_embed_size(const llama_context * ctx_llama, const clip_ctx * ctx_clip);
|
||||||
@ctypes_function("llava_validate_embed_size", [llama_cpp.llama_context_p_ctypes, clip_ctx_p_ctypes], c_bool)
|
@ctypes_function(
|
||||||
def llava_validate_embed_size(ctx_llama: llama_cpp.llama_context_p, ctx_clip: clip_ctx_p, /) -> bool:
|
"llava_validate_embed_size",
|
||||||
...
|
[llama_cpp.llama_context_p_ctypes, clip_ctx_p_ctypes],
|
||||||
|
c_bool,
|
||||||
|
)
|
||||||
|
def llava_validate_embed_size(
|
||||||
|
ctx_llama: llama_cpp.llama_context_p, ctx_clip: clip_ctx_p, /
|
||||||
|
) -> bool: ...
|
||||||
|
|
||||||
|
|
||||||
# /** build an image embed from image file bytes */
|
# /** build an image embed from image file bytes */
|
||||||
# LLAVA_API struct llava_image_embed * llava_image_embed_make_with_bytes(struct clip_ctx * ctx_clip, int n_threads, const unsigned char * image_bytes, int image_bytes_length);
|
# LLAVA_API struct llava_image_embed * llava_image_embed_make_with_bytes(struct clip_ctx * ctx_clip, int n_threads, const unsigned char * image_bytes, int image_bytes_length);
|
||||||
@ctypes_function("llava_image_embed_make_with_bytes", [clip_ctx_p_ctypes, c_int, POINTER(c_uint8), c_int], POINTER(llava_image_embed))
|
@ctypes_function(
|
||||||
def llava_image_embed_make_with_bytes(ctx_clip: clip_ctx_p, n_threads: Union[c_int, int], image_bytes: bytes, image_bytes_length: Union[c_int, int], /) -> "_Pointer[llava_image_embed]":
|
"llava_image_embed_make_with_bytes",
|
||||||
...
|
[clip_ctx_p_ctypes, c_int, POINTER(c_uint8), c_int],
|
||||||
|
POINTER(llava_image_embed),
|
||||||
|
)
|
||||||
|
def llava_image_embed_make_with_bytes(
|
||||||
|
ctx_clip: clip_ctx_p,
|
||||||
|
n_threads: Union[c_int, int],
|
||||||
|
image_bytes: CtypesArray[c_uint8],
|
||||||
|
image_bytes_length: Union[c_int, int],
|
||||||
|
/,
|
||||||
|
) -> "_Pointer[llava_image_embed]": ...
|
||||||
|
|
||||||
|
|
||||||
# /** build an image embed from a path to an image filename */
|
# /** build an image embed from a path to an image filename */
|
||||||
# LLAVA_API struct llava_image_embed * llava_image_embed_make_with_filename(struct clip_ctx * ctx_clip, int n_threads, const char * image_path);
|
# LLAVA_API struct llava_image_embed * llava_image_embed_make_with_filename(struct clip_ctx * ctx_clip, int n_threads, const char * image_path);
|
||||||
@ctypes_function("llava_image_embed_make_with_filename", [clip_ctx_p_ctypes, c_int, c_char_p], POINTER(llava_image_embed))
|
@ctypes_function(
|
||||||
def llava_image_embed_make_with_filename(ctx_clip: clip_ctx_p, n_threads: Union[c_int, int], image_path: bytes, /) -> "_Pointer[llava_image_embed]":
|
"llava_image_embed_make_with_filename",
|
||||||
...
|
[clip_ctx_p_ctypes, c_int, c_char_p],
|
||||||
|
POINTER(llava_image_embed),
|
||||||
|
)
|
||||||
|
def llava_image_embed_make_with_filename(
|
||||||
|
ctx_clip: clip_ctx_p, n_threads: Union[c_int, int], image_path: bytes, /
|
||||||
|
) -> "_Pointer[llava_image_embed]": ...
|
||||||
|
|
||||||
|
|
||||||
# LLAVA_API void llava_image_embed_free(struct llava_image_embed * embed);
|
# LLAVA_API void llava_image_embed_free(struct llava_image_embed * embed);
|
||||||
# /** free an embedding made with llava_image_embed_make_* */
|
# /** free an embedding made with llava_image_embed_make_* */
|
||||||
@ctypes_function("llava_image_embed_free", [POINTER(llava_image_embed)], None)
|
@ctypes_function("llava_image_embed_free", [POINTER(llava_image_embed)], None)
|
||||||
def llava_image_embed_free(embed: "_Pointer[llava_image_embed]", /):
|
def llava_image_embed_free(embed: "_Pointer[llava_image_embed]", /): ...
|
||||||
...
|
|
||||||
|
|
||||||
# /** write the image represented by embed into the llama context with batch size n_batch, starting at context pos n_past. on completion, n_past points to the next position in the context after the image embed. */
|
# /** write the image represented by embed into the llama context with batch size n_batch, starting at context pos n_past. on completion, n_past points to the next position in the context after the image embed. */
|
||||||
# LLAVA_API bool llava_eval_image_embed(struct llama_context * ctx_llama, const struct llava_image_embed * embed, int n_batch, int * n_past);
|
# LLAVA_API bool llava_eval_image_embed(struct llama_context * ctx_llama, const struct llava_image_embed * embed, int n_batch, int * n_past);
|
||||||
@ctypes_function("llava_eval_image_embed", [llama_cpp.llama_context_p_ctypes, POINTER(llava_image_embed), c_int, POINTER(c_int)], c_bool)
|
@ctypes_function(
|
||||||
def llava_eval_image_embed(ctx_llama: llama_cpp.llama_context_p, embed: "_Pointer[llava_image_embed]", n_batch: Union[c_int, int], n_past: "_Pointer[c_int]", /) -> bool:
|
"llava_eval_image_embed",
|
||||||
...
|
[
|
||||||
|
llama_cpp.llama_context_p_ctypes,
|
||||||
|
POINTER(llava_image_embed),
|
||||||
|
c_int,
|
||||||
|
POINTER(c_int),
|
||||||
|
],
|
||||||
|
c_bool,
|
||||||
|
)
|
||||||
|
def llava_eval_image_embed(
|
||||||
|
ctx_llama: llama_cpp.llama_context_p,
|
||||||
|
embed: "_Pointer[llava_image_embed]",
|
||||||
|
n_batch: Union[c_int, int],
|
||||||
|
n_past: "_Pointer[c_int]",
|
||||||
|
/,
|
||||||
|
) -> bool: ...
|
||||||
|
|
||||||
|
|
||||||
################################################
|
################################################
|
||||||
|
@ -161,11 +231,12 @@ def llava_eval_image_embed(ctx_llama: llama_cpp.llama_context_p, embed: "_Pointe
|
||||||
# /** load mmproj model */
|
# /** load mmproj model */
|
||||||
# CLIP_API struct clip_ctx * clip_model_load (const char * fname, int verbosity);
|
# CLIP_API struct clip_ctx * clip_model_load (const char * fname, int verbosity);
|
||||||
@ctypes_function("clip_model_load", [c_char_p, c_int], clip_ctx_p_ctypes)
|
@ctypes_function("clip_model_load", [c_char_p, c_int], clip_ctx_p_ctypes)
|
||||||
def clip_model_load(fname: bytes, verbosity: Union[c_int, int], /) -> Optional[clip_ctx_p]:
|
def clip_model_load(
|
||||||
...
|
fname: bytes, verbosity: Union[c_int, int], /
|
||||||
|
) -> Optional[clip_ctx_p]: ...
|
||||||
|
|
||||||
|
|
||||||
# /** free mmproj model */
|
# /** free mmproj model */
|
||||||
# CLIP_API void clip_free(struct clip_ctx * ctx);
|
# CLIP_API void clip_free(struct clip_ctx * ctx);
|
||||||
@ctypes_function("clip_free", [clip_ctx_p_ctypes], None)
|
@ctypes_function("clip_free", [clip_ctx_p_ctypes], None)
|
||||||
def clip_free(ctx: clip_ctx_p, /):
|
def clip_free(ctx: clip_ctx_p, /): ...
|
||||||
...
|
|
||||||
|
|
|
@ -72,9 +72,74 @@ class LlamaProxy:
|
||||||
chat_handler = None
|
chat_handler = None
|
||||||
if settings.chat_format == "llava-1-5":
|
if settings.chat_format == "llava-1-5":
|
||||||
assert settings.clip_model_path is not None, "clip model not found"
|
assert settings.clip_model_path is not None, "clip model not found"
|
||||||
|
if settings.hf_model_repo_id is not None:
|
||||||
|
chat_handler = (
|
||||||
|
llama_cpp.llama_chat_format.Llava15ChatHandler.from_pretrained(
|
||||||
|
repo_id=settings.hf_model_repo_id,
|
||||||
|
filename=settings.clip_model_path,
|
||||||
|
verbose=settings.verbose,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
else:
|
||||||
chat_handler = llama_cpp.llama_chat_format.Llava15ChatHandler(
|
chat_handler = llama_cpp.llama_chat_format.Llava15ChatHandler(
|
||||||
clip_model_path=settings.clip_model_path, verbose=settings.verbose
|
clip_model_path=settings.clip_model_path, verbose=settings.verbose
|
||||||
)
|
)
|
||||||
|
elif settings.chat_format == "obsidian":
|
||||||
|
assert settings.clip_model_path is not None, "clip model not found"
|
||||||
|
if settings.hf_model_repo_id is not None:
|
||||||
|
chat_handler = (
|
||||||
|
llama_cpp.llama_chat_format.ObsidianChatHandler.from_pretrained(
|
||||||
|
repo_id=settings.hf_model_repo_id,
|
||||||
|
filename=settings.clip_model_path,
|
||||||
|
verbose=settings.verbose,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
chat_handler = llama_cpp.llama_chat_format.ObsidianChatHandler(
|
||||||
|
clip_model_path=settings.clip_model_path, verbose=settings.verbose
|
||||||
|
)
|
||||||
|
elif settings.chat_format == "llava-1-6":
|
||||||
|
assert settings.clip_model_path is not None, "clip model not found"
|
||||||
|
if settings.hf_model_repo_id is not None:
|
||||||
|
chat_handler = (
|
||||||
|
llama_cpp.llama_chat_format.Llava16ChatHandler.from_pretrained(
|
||||||
|
repo_id=settings.hf_model_repo_id,
|
||||||
|
filename=settings.clip_model_path,
|
||||||
|
verbose=settings.verbose,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
chat_handler = llama_cpp.llama_chat_format.Llava16ChatHandler(
|
||||||
|
clip_model_path=settings.clip_model_path, verbose=settings.verbose
|
||||||
|
)
|
||||||
|
elif settings.chat_format == "moondream":
|
||||||
|
assert settings.clip_model_path is not None, "clip model not found"
|
||||||
|
if settings.hf_model_repo_id is not None:
|
||||||
|
chat_handler = (
|
||||||
|
llama_cpp.llama_chat_format.MoondreamChatHandler.from_pretrained(
|
||||||
|
repo_id=settings.hf_model_repo_id,
|
||||||
|
filename=settings.clip_model_path,
|
||||||
|
verbose=settings.verbose,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
chat_handler = llama_cpp.llama_chat_format.MoondreamChatHandler(
|
||||||
|
clip_model_path=settings.clip_model_path, verbose=settings.verbose
|
||||||
|
)
|
||||||
|
elif settings.chat_format == "nanollava":
|
||||||
|
assert settings.clip_model_path is not None, "clip model not found"
|
||||||
|
if settings.hf_model_repo_id is not None:
|
||||||
|
chat_handler = (
|
||||||
|
llama_cpp.llama_chat_format.NanoLlavaChatHandler.from_pretrained(
|
||||||
|
repo_id=settings.hf_model_repo_id,
|
||||||
|
filename=settings.clip_model_path,
|
||||||
|
verbose=settings.verbose,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
chat_handler = llama_cpp.llama_chat_format.NanoLlavaChatHandler(
|
||||||
|
clip_model_path=settings.clip_model_path, verbose=settings.verbose
|
||||||
|
)
|
||||||
elif settings.chat_format == "hf-autotokenizer":
|
elif settings.chat_format == "hf-autotokenizer":
|
||||||
assert (
|
assert (
|
||||||
settings.hf_pretrained_model_name_or_path is not None
|
settings.hf_pretrained_model_name_or_path is not None
|
||||||
|
|
Loading…
Add table
Reference in a new issue