Multimodal Support (Llava 1.5) (#821)
* llava v1.5 integration * Point llama.cpp to fork * Add llava shared library target * Fix type * Update llama.cpp * Add llava api * Revert changes to llama and llama_cpp * Update llava example * Add types for new gpt-4-vision-preview api * Fix typo * Update llama.cpp * Update llama_types to match OpenAI v1 API * Update ChatCompletionFunction type * Reorder request parameters * More API type fixes * Even More Type Updates * Add parameter for custom chat_handler to Llama class * Fix circular import * Convert to absolute imports * Fix * Fix pydantic Jsontype bug * Accept list of prompt tokens in create_completion * Add llava1.5 chat handler * Add Multimodal notebook * Clean up examples * Add server docs --------- Co-authored-by: Andrei Betlen <abetlen@gmail.com>
This commit is contained in:
parent
56171cf7bf
commit
aab74f0b2b
10 changed files with 796 additions and 102 deletions
|
@ -41,4 +41,23 @@ if (LLAMA_BUILD)
|
|||
FILES $<TARGET_RUNTIME_DLLS:llama>
|
||||
DESTINATION ${CMAKE_CURRENT_SOURCE_DIR}/llama_cpp
|
||||
)
|
||||
add_subdirectory(vendor/llama.cpp/examples/llava)
|
||||
set_target_properties(llava_shared PROPERTIES OUTPUT_NAME "llava")
|
||||
install(
|
||||
TARGETS llava_shared
|
||||
LIBRARY DESTINATION ${SKBUILD_PLATLIB_DIR}/llama_cpp
|
||||
RUNTIME DESTINATION ${SKBUILD_PLATLIB_DIR}/llama_cpp
|
||||
ARCHIVE DESTINATION ${SKBUILD_PLATLIB_DIR}/llama_cpp
|
||||
FRAMEWORK DESTINATION ${SKBUILD_PLATLIB_DIR}/llama_cpp
|
||||
RESOURCE DESTINATION ${SKBUILD_PLATLIB_DIR}/llama_cpp
|
||||
)
|
||||
# Temporary fix for https://github.com/scikit-build/scikit-build-core/issues/374
|
||||
install(
|
||||
TARGETS llava_shared
|
||||
LIBRARY DESTINATION ${CMAKE_CURRENT_SOURCE_DIR}/llama_cpp
|
||||
RUNTIME DESTINATION ${CMAKE_CURRENT_SOURCE_DIR}/llama_cpp
|
||||
ARCHIVE DESTINATION ${CMAKE_CURRENT_SOURCE_DIR}/llama_cpp
|
||||
FRAMEWORK DESTINATION ${CMAKE_CURRENT_SOURCE_DIR}/llama_cpp
|
||||
RESOURCE DESTINATION ${CMAKE_CURRENT_SOURCE_DIR}/llama_cpp
|
||||
)
|
||||
endif()
|
||||
|
|
77
docs/server.md
Normal file
77
docs/server.md
Normal file
|
@ -0,0 +1,77 @@
|
|||
# OpenAI Compatible Server
|
||||
|
||||
`llama-cpp-python` offers an OpenAI API compatible web server.
|
||||
|
||||
This web server can be used to serve local models and easily connect them to existing clients.
|
||||
|
||||
## Setup
|
||||
|
||||
### Installation
|
||||
|
||||
The server can be installed by running the following command:
|
||||
|
||||
```bash
|
||||
pip install llama-cpp-python[server]
|
||||
```
|
||||
|
||||
### Running the server
|
||||
|
||||
The server can then be started by running the following command:
|
||||
|
||||
```bash
|
||||
python3 -m llama_cpp.server --model <model_path>
|
||||
```
|
||||
|
||||
### Server options
|
||||
|
||||
For a full list of options, run:
|
||||
|
||||
```bash
|
||||
python3 -m llama_cpp.server --help
|
||||
```
|
||||
|
||||
NOTE: All server options are also available as environment variables. For example, `--model` can be set by setting the `MODEL` environment variable.
|
||||
|
||||
## Guides
|
||||
|
||||
### Multi-modal Models
|
||||
|
||||
`llama-cpp-python` supports the llava1.5 family of multi-modal models which allow the language model to
|
||||
read information from both text and images.
|
||||
|
||||
You'll first need to download one of the available multi-modal models in GGUF format:
|
||||
|
||||
- [llava1.5 7b](https://huggingface.co/mys/ggml_llava-v1.5-7b)
|
||||
- [llava1.5 13b](https://huggingface.co/mys/ggml_llava-v1.5-13b)
|
||||
|
||||
Then when you run the server you'll need to also specify the path to the clip model used for image embedding
|
||||
|
||||
```bash
|
||||
python3 -m llama_cpp.server --model <model_path> --clip-model-path <clip_model_path>
|
||||
```
|
||||
|
||||
Then you can just use the OpenAI API as normal
|
||||
|
||||
```python3
|
||||
from openai import OpenAI
|
||||
|
||||
client = OpenAI(base_url="http://<host>:<port>/v1", api_key="sk-xxx")
|
||||
response = client.chat.completions.create(
|
||||
model="gpt-4-vision-preview",
|
||||
messages=[
|
||||
{
|
||||
"role": "user",
|
||||
"content": [
|
||||
{
|
||||
"type": "image_url",
|
||||
"image_url": {
|
||||
"url": "<image_url>"
|
||||
},
|
||||
},
|
||||
{"type": "text", "text": "What does the image say"},
|
||||
],
|
||||
}
|
||||
],
|
||||
)
|
||||
print(response)
|
||||
```
|
84
examples/notebooks/Multimodal.ipynb
Normal file
84
examples/notebooks/Multimodal.ipynb
Normal file
|
@ -0,0 +1,84 @@
|
|||
{
|
||||
"cells": [
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 1,
|
||||
"metadata": {},
|
||||
"outputs": [
|
||||
{
|
||||
"name": "stdout",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"ChatCompletion(id='chatcmpl-65a710ba-41d1-4d0a-a124-a44b2b4a0189', choices=[Choice(finish_reason='stop', index=0, message=ChatCompletionMessage(content=' The image reads \"LlamaC++.\"', role='assistant', function_call=None, tool_calls=None))], created=1699413274, model='gpt-4-vision-preview', object='chat.completion', system_fingerprint=None, usage=CompletionUsage(completion_tokens=10, prompt_tokens=624, total_tokens=634))\n"
|
||||
]
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"from openai import OpenAI\n",
|
||||
"\n",
|
||||
"import urllib.request\n",
|
||||
"import base64\n",
|
||||
"\n",
|
||||
"def get_data_url(url):\n",
|
||||
" return \"data:image/png;base64,\" + base64.b64encode(urllib.request.urlopen(url).read()).decode(\"utf-8\")\n",
|
||||
"\n",
|
||||
"client = OpenAI(base_url=\"http://100.64.159.73:8000/v1\", api_key=\"sk-1234\")\n",
|
||||
"response = client.chat.completions.create(\n",
|
||||
" model=\"gpt-4-vision-preview\",\n",
|
||||
" messages=[\n",
|
||||
" {\n",
|
||||
" \"role\": \"user\",\n",
|
||||
" \"content\": [\n",
|
||||
" {\n",
|
||||
" \"type\": \"image_url\",\n",
|
||||
" \"image_url\": {\n",
|
||||
" \"url\": get_data_url(\"https://user-images.githubusercontent.com/1991296/230134379-7181e485-c521-4d23-a0d6-f7b3b61ba524.png\"),\n",
|
||||
" # \"url\": \"https://user-images.githubusercontent.com/1991296/230134379-7181e485-c521-4d23-a0d6-f7b3b61ba524.png\",\n",
|
||||
" },\n",
|
||||
" },\n",
|
||||
" {\"type\": \"text\", \"text\": \"What does the image say\"},\n",
|
||||
" ],\n",
|
||||
" }\n",
|
||||
" ],\n",
|
||||
")\n",
|
||||
"print(response)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"![](https://user-images.githubusercontent.com/1991296/230134379-7181e485-c521-4d23-a0d6-f7b3b61ba524.png)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": []
|
||||
}
|
||||
],
|
||||
"metadata": {
|
||||
"kernelspec": {
|
||||
"display_name": ".venv",
|
||||
"language": "python",
|
||||
"name": "python3"
|
||||
},
|
||||
"language_info": {
|
||||
"codemirror_mode": {
|
||||
"name": "ipython",
|
||||
"version": 3
|
||||
},
|
||||
"file_extension": ".py",
|
||||
"mimetype": "text/x-python",
|
||||
"name": "python",
|
||||
"nbconvert_exporter": "python",
|
||||
"pygments_lexer": "ipython3",
|
||||
"version": "3.11.5+"
|
||||
},
|
||||
"orig_nbformat": 4
|
||||
},
|
||||
"nbformat": 4,
|
||||
"nbformat_minor": 2
|
||||
}
|
|
@ -21,9 +21,9 @@ from collections import deque, OrderedDict
|
|||
import diskcache
|
||||
import ctypes
|
||||
|
||||
from . import llama_cpp
|
||||
from .llama_types import *
|
||||
from .llama_grammar import LlamaGrammar
|
||||
import llama_cpp.llama_cpp as llama_cpp
|
||||
import llama_cpp.llama_chat_format as llama_chat_format
|
||||
|
||||
import numpy as np
|
||||
|
@ -752,6 +752,7 @@ class Llama:
|
|||
numa: bool = False,
|
||||
# Chat Format Params
|
||||
chat_format: str = "llama-2",
|
||||
chat_handler: Optional[llama_chat_format.LlamaChatCompletionHandler] = None,
|
||||
# Misc
|
||||
verbose: bool = True,
|
||||
# Extra Params
|
||||
|
@ -784,6 +785,7 @@ class Llama:
|
|||
lora_path: Path to a LoRA file to apply to the model.
|
||||
numa: Enable NUMA support. (NOTE: The initial value of this parameter is used for the remainder of the program as this value is set in llama_backend_init)
|
||||
chat_format: String specifying the chat format to use when calling create_chat_completion.
|
||||
chat_handler: Optional chat handler to use when calling create_chat_completion.
|
||||
verbose: Print verbose output to stderr.
|
||||
|
||||
Raises:
|
||||
|
@ -910,6 +912,7 @@ class Llama:
|
|||
print(llama_cpp.llama_print_system_info().decode("utf-8"), file=sys.stderr)
|
||||
|
||||
self.chat_format = chat_format
|
||||
self.chat_handler = chat_handler
|
||||
|
||||
self._n_vocab = self.n_vocab()
|
||||
self._n_ctx = self.n_ctx()
|
||||
|
@ -1231,7 +1234,7 @@ class Llama:
|
|||
else:
|
||||
inputs = input
|
||||
|
||||
data: List[EmbeddingData] = []
|
||||
data: List[Embedding] = []
|
||||
total_tokens = 0
|
||||
for index, input in enumerate(inputs):
|
||||
tokens = self.tokenize(input.encode("utf-8"), special=True)
|
||||
|
@ -1276,7 +1279,7 @@ class Llama:
|
|||
|
||||
def _create_completion(
|
||||
self,
|
||||
prompt: str,
|
||||
prompt: Union[str, List[int]],
|
||||
suffix: Optional[str] = None,
|
||||
max_tokens: int = 16,
|
||||
temperature: float = 0.8,
|
||||
|
@ -1297,7 +1300,9 @@ class Llama:
|
|||
stopping_criteria: Optional[StoppingCriteriaList] = None,
|
||||
logits_processor: Optional[LogitsProcessorList] = None,
|
||||
grammar: Optional[LlamaGrammar] = None,
|
||||
) -> Union[Iterator[Completion], Iterator[CompletionChunk]]:
|
||||
) -> Union[
|
||||
Iterator[CreateCompletionResponse], Iterator[CreateCompletionStreamResponse]
|
||||
]:
|
||||
assert self._ctx is not None
|
||||
assert suffix is None or suffix.__class__ is str
|
||||
|
||||
|
@ -1309,7 +1314,7 @@ class Llama:
|
|||
self.tokenize(prompt.encode("utf-8"), special=True)
|
||||
if prompt != ""
|
||||
else [self.token_bos()]
|
||||
)
|
||||
) if isinstance(prompt, str) else prompt
|
||||
text: bytes = b""
|
||||
returned_tokens: int = 0
|
||||
stop = (
|
||||
|
@ -1322,7 +1327,7 @@ class Llama:
|
|||
|
||||
if len(prompt_tokens) >= self._n_ctx:
|
||||
raise ValueError(
|
||||
f"Requested tokens ({len(prompt_tokens)}) exceed context window of {llama_cpp.llama_n_ctx(self._ctx)}"
|
||||
f"Requested tokens ({len(prompt_tokens)}) exceed context window of {llama_cpp.llama_n_ctx(self.ctx)}"
|
||||
)
|
||||
|
||||
if max_tokens <= 0:
|
||||
|
@ -1732,7 +1737,7 @@ class Llama:
|
|||
|
||||
def create_completion(
|
||||
self,
|
||||
prompt: str,
|
||||
prompt: Union[str, List[int]],
|
||||
suffix: Optional[str] = None,
|
||||
max_tokens: int = 128,
|
||||
temperature: float = 0.8,
|
||||
|
@ -1753,7 +1758,7 @@ class Llama:
|
|||
stopping_criteria: Optional[StoppingCriteriaList] = None,
|
||||
logits_processor: Optional[LogitsProcessorList] = None,
|
||||
grammar: Optional[LlamaGrammar] = None,
|
||||
) -> Union[Completion, Iterator[CompletionChunk]]:
|
||||
) -> Union[CreateCompletionResponse, Iterator[CreateCompletionStreamResponse]]:
|
||||
"""Generate text from a prompt.
|
||||
|
||||
Args:
|
||||
|
@ -1800,7 +1805,7 @@ class Llama:
|
|||
grammar=grammar,
|
||||
)
|
||||
if stream:
|
||||
chunks: Iterator[CompletionChunk] = completion_or_chunks
|
||||
chunks: Iterator[CreateCompletionStreamResponse] = completion_or_chunks
|
||||
return chunks
|
||||
completion: Completion = next(completion_or_chunks) # type: ignore
|
||||
return completion
|
||||
|
@ -1828,7 +1833,7 @@ class Llama:
|
|||
stopping_criteria: Optional[StoppingCriteriaList] = None,
|
||||
logits_processor: Optional[LogitsProcessorList] = None,
|
||||
grammar: Optional[LlamaGrammar] = None,
|
||||
) -> Union[Completion, Iterator[CompletionChunk]]:
|
||||
) -> Union[CreateCompletionResponse, Iterator[CreateCompletionStreamResponse]]:
|
||||
"""Generate text from a prompt.
|
||||
|
||||
Args:
|
||||
|
@ -1879,7 +1884,9 @@ class Llama:
|
|||
self,
|
||||
messages: List[ChatCompletionRequestMessage],
|
||||
functions: Optional[List[ChatCompletionFunction]] = None,
|
||||
function_call: Optional[Union[str, ChatCompletionFunctionCall]] = None,
|
||||
function_call: Optional[ChatCompletionRequestFunctionCall] = None,
|
||||
tools: Optional[List[ChatCompletionTool]] = None,
|
||||
tool_choice: Optional[ChatCompletionToolChoiceOption] = None,
|
||||
temperature: float = 0.2,
|
||||
top_p: float = 0.95,
|
||||
top_k: int = 40,
|
||||
|
@ -1896,7 +1903,9 @@ class Llama:
|
|||
model: Optional[str] = None,
|
||||
logits_processor: Optional[LogitsProcessorList] = None,
|
||||
grammar: Optional[LlamaGrammar] = None,
|
||||
) -> Union[ChatCompletion, Iterator[ChatCompletionChunk]]:
|
||||
) -> Union[
|
||||
CreateChatCompletionResponse, Iterator[CreateChatCompletionStreamResponse]
|
||||
]:
|
||||
"""Generate a chat completion from a list of messages.
|
||||
|
||||
Args:
|
||||
|
@ -1912,12 +1921,16 @@ class Llama:
|
|||
Returns:
|
||||
Generated chat completion or a stream of chat completion chunks.
|
||||
"""
|
||||
handler = llama_chat_format.get_chat_completion_handler(self.chat_format)
|
||||
handler = self.chat_handler or llama_chat_format.get_chat_completion_handler(
|
||||
self.chat_format
|
||||
)
|
||||
return handler(
|
||||
self,
|
||||
llama=self,
|
||||
messages=messages,
|
||||
functions=functions,
|
||||
function_call=function_call,
|
||||
tools=tools,
|
||||
tool_choice=tool_choice,
|
||||
temperature=temperature,
|
||||
top_p=top_p,
|
||||
top_k=top_k,
|
||||
|
@ -1974,6 +1987,7 @@ class Llama:
|
|||
numa=self.numa,
|
||||
# Chat Format Params
|
||||
chat_format=self.chat_format,
|
||||
chat_handler=self.chat_handler,
|
||||
# Misc
|
||||
verbose=self.verbose,
|
||||
)
|
||||
|
@ -2015,6 +2029,7 @@ class Llama:
|
|||
numa=state["numa"],
|
||||
# Chat Format Params
|
||||
chat_format=state["chat_format"],
|
||||
chat_handler=state["chat_handler"],
|
||||
# Misc
|
||||
verbose=state["verbose"],
|
||||
)
|
||||
|
|
|
@ -1,22 +1,24 @@
|
|||
from __future__ import annotations
|
||||
|
||||
import os
|
||||
import ctypes
|
||||
import dataclasses
|
||||
from typing import Any, Dict, Iterator, List, Optional, Tuple, Union, Protocol
|
||||
|
||||
from . import llama_types
|
||||
from . import llama
|
||||
import llama_cpp.llama_types as llama_types
|
||||
import llama_cpp.llama as llama
|
||||
|
||||
|
||||
class LlamaChatCompletionHandler(Protocol):
|
||||
def __call__(
|
||||
self,
|
||||
*,
|
||||
llama: llama.Llama,
|
||||
messages: List[llama_types.ChatCompletionRequestMessage],
|
||||
functions: Optional[List[llama_types.ChatCompletionFunction]] = None,
|
||||
function_call: Optional[
|
||||
Union[str, llama_types.ChatCompletionFunctionCall]
|
||||
] = None,
|
||||
function_call: Optional[llama_types.ChatCompletionRequestFunctionCall] = None,
|
||||
tools: Optional[List[llama_types.ChatCompletionTool]] = None,
|
||||
tool_choice: Optional[llama_types.ChatCompletionToolChoiceOption] = None,
|
||||
temperature: float = 0.2,
|
||||
top_p: float = 0.95,
|
||||
top_k: int = 40,
|
||||
|
@ -33,7 +35,8 @@ class LlamaChatCompletionHandler(Protocol):
|
|||
model: Optional[str] = None,
|
||||
logits_processor: Optional[llama.LogitsProcessorList] = None,
|
||||
grammar: Optional[llama.LlamaGrammar] = None,
|
||||
) -> Union[llama_types.ChatCompletion, Iterator[llama_types.ChatCompletionChunk]]:
|
||||
**kwargs, # type: ignore
|
||||
) -> Union[llama_types.CreateChatCompletionResponse, Iterator[llama_types.CreateChatCompletionStreamResponse]]:
|
||||
...
|
||||
|
||||
|
||||
|
@ -199,7 +202,7 @@ def _convert_text_completion_to_chat(
|
|||
|
||||
|
||||
def _convert_text_completion_chunks_to_chat(
|
||||
chunks: Iterator[llama_types.CompletionChunk],
|
||||
chunks: Iterator[llama_types.CreateCompletionStreamResponse],
|
||||
) -> Iterator[llama_types.ChatCompletionChunk]:
|
||||
for i, chunk in enumerate(chunks):
|
||||
if i == 0:
|
||||
|
@ -239,12 +242,15 @@ def _convert_text_completion_chunks_to_chat(
|
|||
|
||||
def _convert_completion_to_chat(
|
||||
completion_or_chunks: Union[
|
||||
llama_types.Completion, Iterator[llama_types.CompletionChunk]
|
||||
llama_types.CreateCompletionResponse,
|
||||
Iterator[llama_types.CreateCompletionStreamResponse],
|
||||
],
|
||||
stream: bool = False,
|
||||
) -> Union[llama_types.ChatCompletion, Iterator[llama_types.ChatCompletionChunk]]:
|
||||
) -> Union[
|
||||
llama_types.CreateChatCompletionResponse, Iterator[llama_types.ChatCompletionChunk]
|
||||
]:
|
||||
if stream:
|
||||
chunks: Iterator[llama_types.CompletionChunk] = completion_or_chunks # type: ignore
|
||||
chunks: Iterator[llama_types.CreateCompletionStreamResponse] = completion_or_chunks # type: ignore
|
||||
return _convert_text_completion_chunks_to_chat(chunks)
|
||||
else:
|
||||
completion: llama_types.Completion = completion_or_chunks # type: ignore
|
||||
|
@ -329,7 +335,9 @@ def get_chat_format(name: str):
|
|||
)
|
||||
|
||||
|
||||
def hf_autotokenizer_to_chat_formatter(pretrained_model_name_or_path: Union[str, os.PathLike[str]]) -> ChatFormatter:
|
||||
def hf_autotokenizer_to_chat_formatter(
|
||||
pretrained_model_name_or_path: Union[str, os.PathLike[str]]
|
||||
) -> ChatFormatter:
|
||||
# https://huggingface.co/docs/transformers/main/chat_templating
|
||||
# https://huggingface.co/mistralai/Mistral-7B-Instruct-v0.1#instruction-format
|
||||
# https://huggingface.co/mistralai/Mistral-7B-Instruct-v0.1/blob/main/tokenizer_config.json
|
||||
|
@ -538,7 +546,7 @@ def functionary_chat_handler(
|
|||
llama: llama.Llama,
|
||||
messages: List[llama_types.ChatCompletionRequestMessage],
|
||||
functions: Optional[List[llama_types.ChatCompletionFunction]] = None,
|
||||
function_call: Optional[Union[str, llama_types.ChatCompletionFunctionCall]] = None,
|
||||
function_call: Optional[llama_types.ChatCompletionRequestFunctionCall] = None,
|
||||
temperature: float = 0.2,
|
||||
top_p: float = 0.95,
|
||||
top_k: int = 40,
|
||||
|
@ -555,6 +563,7 @@ def functionary_chat_handler(
|
|||
model: Optional[str] = None,
|
||||
logits_processor: Optional[llama.LogitsProcessorList] = None,
|
||||
grammar: Optional[llama.LlamaGrammar] = None,
|
||||
**kwargs, # type: ignore
|
||||
) -> Union[llama_types.ChatCompletion, Iterator[llama_types.ChatCompletionChunk]]:
|
||||
SYSTEM_MESSAGE = """A chat between a curious user and an artificial intelligence assistant. The assistant gives helpful, detailed, and polite answers to the user's questions. The assistant calls functions with appropriate input when necessary"""
|
||||
|
||||
|
@ -613,13 +622,13 @@ def functionary_chat_handler(
|
|||
all_messages: List[llama_types.ChatCompletionRequestMessage] = []
|
||||
if functions is not None:
|
||||
all_messages.append(
|
||||
llama_types.ChatCompletionRequestMessage(
|
||||
llama_types.ChatCompletionRequestSystemMessage(
|
||||
role="system", content=generate_schema_from_functions(functions)
|
||||
)
|
||||
)
|
||||
|
||||
all_messages.append(
|
||||
llama_types.ChatCompletionRequestMessage(
|
||||
llama_types.ChatCompletionRequestSystemMessage(
|
||||
role="system", content=SYSTEM_MESSAGE
|
||||
)
|
||||
)
|
||||
|
@ -636,7 +645,9 @@ def functionary_chat_handler(
|
|||
all_messages.append(message)
|
||||
|
||||
all_messages.append(
|
||||
llama_types.ChatCompletionRequestMessage(role="assistant", content=None)
|
||||
llama_types.ChatCompletionRequestAssistantMessage(
|
||||
role="assistant", content=None
|
||||
)
|
||||
)
|
||||
|
||||
def message_to_str(msg: llama_types.ChatCompletionRequestMessage):
|
||||
|
@ -713,6 +724,10 @@ def functionary_chat_handler(
|
|||
prompt=new_prompt, stop=["user:", "</s>"], stream=False
|
||||
) # type: ignore
|
||||
|
||||
assert "usage" in completion
|
||||
assert isinstance(function_call, str)
|
||||
assert stream is False # TODO: support stream mode
|
||||
|
||||
return llama_types.CreateChatCompletionResponse(
|
||||
id="chat" + completion["id"],
|
||||
object="chat.completion",
|
||||
|
@ -734,3 +749,119 @@ def functionary_chat_handler(
|
|||
],
|
||||
usage=completion["usage"],
|
||||
)
|
||||
|
||||
|
||||
class Llava15ChatHandler:
|
||||
def __init__(self, clip_model_path: str):
|
||||
import llama_cpp.llava_cpp as llava_cpp
|
||||
|
||||
self._llava_cpp = llava_cpp
|
||||
self.clip_model_path = clip_model_path
|
||||
|
||||
self.clip_ctx = self._llava_cpp.clip_model_load(self.clip_model_path.encode(), 0)
|
||||
|
||||
def __del__(self):
|
||||
if self.clip_ctx is not None:
|
||||
self._llava_cpp.clip_free(self.clip_ctx)
|
||||
self.clip_ctx = None
|
||||
|
||||
def load_image(self, image_url: str) -> bytes:
|
||||
if image_url.startswith("data:"):
|
||||
import base64
|
||||
|
||||
image_bytes = base64.b64decode(image_url.split(",")[1])
|
||||
return image_bytes
|
||||
else:
|
||||
import urllib.request
|
||||
|
||||
with urllib.request.urlopen(image_url) as f:
|
||||
image_bytes = f.read()
|
||||
return image_bytes
|
||||
|
||||
def __call__(
|
||||
self,
|
||||
*,
|
||||
llama: llama.Llama,
|
||||
messages: List[llama_types.ChatCompletionRequestMessage],
|
||||
functions: Optional[List[llama_types.ChatCompletionFunction]] = None,
|
||||
function_call: Optional[llama_types.ChatCompletionRequestFunctionCall] = None,
|
||||
tools: Optional[List[llama_types.ChatCompletionTool]] = None,
|
||||
tool_choice: Optional[llama_types.ChatCompletionToolChoiceOption] = None,
|
||||
temperature: float = 0.2,
|
||||
top_p: float = 0.95,
|
||||
top_k: int = 40,
|
||||
stream: bool = False,
|
||||
stop: Optional[Union[str, List[str]]] = [],
|
||||
max_tokens: int = 256,
|
||||
presence_penalty: float = 0.0,
|
||||
frequency_penalty: float = 0.0,
|
||||
repeat_penalty: float = 1.1,
|
||||
tfs_z: float = 1.0,
|
||||
mirostat_mode: int = 0,
|
||||
mirostat_tau: float = 5.0,
|
||||
mirostat_eta: float = 0.1,
|
||||
model: Optional[str] = None,
|
||||
logits_processor: Optional[llama.LogitsProcessorList] = None,
|
||||
grammar: Optional[llama.LlamaGrammar] = None,
|
||||
**kwargs, # type: ignore
|
||||
) -> Union[llama_types.CreateChatCompletionResponse, Iterator[llama_types.CreateChatCompletionStreamResponse]]:
|
||||
assert llama.context_params.logits_all is True # BUG: logits_all=True is required for llava
|
||||
assert self.clip_ctx is not None
|
||||
system_prompt = _get_system_message(messages)
|
||||
system_prompt = system_prompt if system_prompt != "" else "A chat between a curious human and an artificial intelligence assistant. The assistant gives helpful, detailed, and polite answers to the human's questions."
|
||||
system_prompt = "A chat between a curious human and an artificial intelligence assistant. The assistant gives helpful, detailed, and polite answers to the human's questions."
|
||||
user_role = "\nUSER:"
|
||||
assistant_role = "\nASSISTANT:"
|
||||
llama.reset()
|
||||
llama.eval(llama.tokenize(system_prompt.encode("utf8"), add_bos=True))
|
||||
for message in messages:
|
||||
if message["role"] == "user" and message["content"] is not None:
|
||||
if isinstance(message["content"], str):
|
||||
llama.eval(llama.tokenize(f"{user_role} {message['content']}".encode("utf8"), add_bos=False))
|
||||
else:
|
||||
assert isinstance(message["content"], list)
|
||||
llama.eval(llama.tokenize(f"{user_role} ".encode("utf8"), add_bos=False))
|
||||
for content in message["content"]:
|
||||
if content["type"] == "text":
|
||||
llama.eval(llama.tokenize(f"{content['text']}".encode("utf8"), add_bos=False))
|
||||
if content["type"] == "image_url":
|
||||
image_bytes = self.load_image(content["image_url"]["url"]) if isinstance(content["image_url"], dict) else self.load_image(content["image_url"])
|
||||
import array
|
||||
data_array = array.array('B', image_bytes)
|
||||
c_ubyte_ptr = (ctypes.c_ubyte * len(data_array)).from_buffer(data_array)
|
||||
embed = self._llava_cpp.llava_image_embed_make_with_bytes(ctx_clip=self.clip_ctx, n_threads=llama.context_params.n_threads, image_bytes=c_ubyte_ptr, image_bytes_length=len(image_bytes))
|
||||
# image_bytes_p = (ctypes.c_uint8 * len(image_bytes)).from_buffer_copy(image_bytes)
|
||||
# embed = self._llava_cpp.llava_image_embed_make_with_bytes(ctx_clip=self.clip_ctx, n_threads=1, image_bytes=image_bytes_p, image_bytes_length=len(image_bytes))
|
||||
try:
|
||||
n_past = ctypes.c_int(llama.n_tokens)
|
||||
n_past_p = ctypes.pointer(n_past)
|
||||
self._llava_cpp.llava_eval_image_embed(ctx_llama=llama.ctx, embed=embed, n_batch=llama.n_batch, n_past=n_past_p)
|
||||
assert llama.n_ctx() >= n_past.value
|
||||
llama.n_tokens = n_past.value
|
||||
finally:
|
||||
self._llava_cpp.llava_image_embed_free(embed)
|
||||
if message["role"] == "assistant" and message["content"] is not None:
|
||||
llama.eval(llama.tokenize(f"ASSISTANT: {message['content']}".encode("utf8"), add_bos=False))
|
||||
llama.eval(llama.tokenize(f"{assistant_role}".encode("utf8"), add_bos=False))
|
||||
|
||||
prompt = llama._input_ids.tolist()
|
||||
|
||||
return _convert_completion_to_chat(llama.create_completion(
|
||||
prompt=prompt,
|
||||
temperature=temperature,
|
||||
top_p=top_p,
|
||||
top_k=top_k,
|
||||
stream=stream,
|
||||
stop=stop,
|
||||
max_tokens=max_tokens,
|
||||
presence_penalty=presence_penalty,
|
||||
frequency_penalty=frequency_penalty,
|
||||
repeat_penalty=repeat_penalty,
|
||||
tfs_z=tfs_z,
|
||||
mirostat_mode=mirostat_mode,
|
||||
mirostat_tau=mirostat_tau,
|
||||
mirostat_eta=mirostat_eta,
|
||||
model=model,
|
||||
logits_processor=logits_processor,
|
||||
grammar=grammar,
|
||||
), stream=stream)
|
|
@ -19,7 +19,7 @@ from typing import (
|
|||
overload,
|
||||
)
|
||||
|
||||
from . import llama_cpp
|
||||
import llama_cpp.llama_cpp as llama_cpp
|
||||
|
||||
# Type aliases
|
||||
llama_grammar_element = llama_cpp.llama_grammar_element
|
||||
|
|
|
@ -1,4 +1,6 @@
|
|||
"""Types and request signatrues for OpenAI compatibility
|
||||
"""Types and request signatures for OpenAI compatibility
|
||||
|
||||
NOTE: These types may change to match the OpenAI OpenAPI specification.
|
||||
|
||||
Based on the OpenAI OpenAPI specification:
|
||||
https://github.com/openai/openai-openapi/blob/master/openapi.yaml
|
||||
|
@ -8,6 +10,12 @@ from typing import Any, List, Optional, Dict, Union
|
|||
from typing_extensions import TypedDict, NotRequired, Literal
|
||||
|
||||
|
||||
# NOTE: Defining this correctly using annotations seems to break pydantic validation.
|
||||
# This is a workaround until we can figure out how to do this correctly
|
||||
# JsonType = Union[None, int, str, bool, List["JsonType"], Dict[str, "JsonType"]]
|
||||
JsonType = Union[None, int, str, bool, List[Any], Dict[str, Any]]
|
||||
|
||||
|
||||
class EmbeddingUsage(TypedDict):
|
||||
prompt_tokens: int
|
||||
total_tokens: int
|
||||
|
@ -19,9 +27,6 @@ class Embedding(TypedDict):
|
|||
embedding: List[float]
|
||||
|
||||
|
||||
EmbeddingData = Embedding
|
||||
|
||||
|
||||
class CreateEmbeddingResponse(TypedDict):
|
||||
object: Literal["list"]
|
||||
model: str
|
||||
|
@ -49,110 +54,92 @@ class CompletionUsage(TypedDict):
|
|||
total_tokens: int
|
||||
|
||||
|
||||
class CreateCompletionStreamResponse(TypedDict):
|
||||
id: str
|
||||
object: Literal["text_completion"]
|
||||
created: int
|
||||
model: str
|
||||
choices: List[CompletionChoice]
|
||||
|
||||
|
||||
CompletionChunk = CreateCompletionStreamResponse
|
||||
|
||||
|
||||
class CreateCompletionResponse(TypedDict):
|
||||
id: str
|
||||
object: Literal["text_completion"]
|
||||
created: int
|
||||
model: str
|
||||
choices: List[CompletionChoice]
|
||||
usage: CompletionUsage
|
||||
usage: NotRequired[CompletionUsage]
|
||||
|
||||
|
||||
Completion = CreateCompletionResponse
|
||||
|
||||
|
||||
class ChatCompletionFunctionCall(TypedDict):
|
||||
class ChatCompletionResponseFunctionCall(TypedDict):
|
||||
name: str
|
||||
arguments: str
|
||||
|
||||
|
||||
class ChatCompletionResponseMessage(TypedDict):
|
||||
role: Literal["assistant", "user", "system", "function"]
|
||||
content: Optional[str]
|
||||
user: NotRequired[str]
|
||||
function_call: NotRequired[ChatCompletionFunctionCall]
|
||||
tool_calls: NotRequired["ChatCompletionMessageToolCalls"]
|
||||
role: Literal["assistant", "function"] # NOTE: "function" may be incorrect here
|
||||
function_call: NotRequired[ChatCompletionResponseFunctionCall] # DEPRECATED
|
||||
|
||||
|
||||
ChatCompletionMessage = ChatCompletionResponseMessage
|
||||
|
||||
|
||||
class ChatCompletionResponseFunction(TypedDict):
|
||||
class ChatCompletionFunction(TypedDict):
|
||||
name: str
|
||||
description: NotRequired[str]
|
||||
parameters: Dict[str, Any] # TODO: make this more specific
|
||||
|
||||
|
||||
ChatCompletionFunction = ChatCompletionResponseFunction
|
||||
parameters: Dict[str, JsonType] # TODO: make this more specific
|
||||
|
||||
|
||||
class ChatCompletionResponseChoice(TypedDict):
|
||||
index: int
|
||||
message: ChatCompletionMessage
|
||||
message: "ChatCompletionResponseMessage"
|
||||
finish_reason: Optional[str]
|
||||
|
||||
|
||||
ChatCompletionChoice = ChatCompletionResponseChoice
|
||||
|
||||
|
||||
class CreateChatCompletionResponse(TypedDict):
|
||||
id: str
|
||||
object: Literal["chat.completion"]
|
||||
created: int
|
||||
model: str
|
||||
choices: List[ChatCompletionChoice]
|
||||
choices: List["ChatCompletionResponseChoice"]
|
||||
usage: CompletionUsage
|
||||
|
||||
|
||||
ChatCompletion = CreateChatCompletionResponse
|
||||
class ChatCompletionMessageToolCallChunkFunction(TypedDict):
|
||||
name: str
|
||||
arguments: str
|
||||
|
||||
|
||||
class ChatCompletionMessageToolCallChunk(TypedDict):
|
||||
index: int
|
||||
id: NotRequired[str]
|
||||
type: Literal["function"]
|
||||
function: ChatCompletionMessageToolCallChunkFunction
|
||||
|
||||
|
||||
class ChatCompletionStreamResponseDeltaEmpty(TypedDict):
|
||||
pass
|
||||
|
||||
|
||||
ChatCompletionChunkDeltaEmpty = ChatCompletionStreamResponseDeltaEmpty
|
||||
class ChatCompletionStreamResponseDeltaFunctionCall(TypedDict):
|
||||
name: str
|
||||
arguments: str
|
||||
|
||||
|
||||
class ChatCompletionStreamResponseDelta(TypedDict):
|
||||
role: NotRequired[Literal["assistant"]]
|
||||
content: NotRequired[str]
|
||||
function_call: NotRequired[ChatCompletionFunctionCall]
|
||||
|
||||
|
||||
ChatCompletionChunkDelta = ChatCompletionStreamResponseDelta
|
||||
function_call: NotRequired[
|
||||
ChatCompletionStreamResponseDeltaFunctionCall
|
||||
] # DEPRECATED
|
||||
tool_calls: NotRequired[List[ChatCompletionMessageToolCallChunk]]
|
||||
role: NotRequired[Literal["system", "user", "assistant", "tool"]]
|
||||
|
||||
|
||||
class ChatCompletionStreamResponseChoice(TypedDict):
|
||||
index: int
|
||||
delta: Union[ChatCompletionChunkDelta, ChatCompletionChunkDeltaEmpty]
|
||||
delta: Union[
|
||||
ChatCompletionStreamResponseDelta, ChatCompletionStreamResponseDeltaEmpty
|
||||
]
|
||||
finish_reason: Optional[Literal["stop", "length", "function_call"]]
|
||||
|
||||
|
||||
ChatCompletionChunkChoice = ChatCompletionStreamResponseChoice
|
||||
|
||||
|
||||
class ChatCompletionStreamResponse(TypedDict):
|
||||
class CreateChatCompletionStreamResponse(TypedDict):
|
||||
id: str
|
||||
model: str
|
||||
object: Literal["chat.completion.chunk"]
|
||||
created: int
|
||||
choices: List[ChatCompletionChunkChoice]
|
||||
|
||||
|
||||
ChatCompletionChunk = ChatCompletionStreamResponse
|
||||
|
||||
JsonType = Union[None, int, str, bool, List["JsonType"], Dict[str, "JsonType"]]
|
||||
choices: List[ChatCompletionStreamResponseChoice]
|
||||
|
||||
|
||||
class ChatCompletionFunctions(TypedDict):
|
||||
|
@ -165,8 +152,137 @@ class ChatCompletionFunctionCallOption(TypedDict):
|
|||
name: str
|
||||
|
||||
|
||||
class ChatCompletionRequestMessage(TypedDict):
|
||||
role: Literal["assistant", "user", "system", "function"]
|
||||
class ChatCompletionRequestMessageContentPartText(TypedDict):
|
||||
type: Literal["text"]
|
||||
text: str
|
||||
|
||||
|
||||
class ChatCompletionRequestMessageContentPartImageImageUrl(TypedDict):
|
||||
url: str
|
||||
detail: NotRequired[Literal["auto", "low", "high"]]
|
||||
|
||||
|
||||
class ChatCompletionRequestMessageContentPartImage(TypedDict):
|
||||
type: Literal["image_url"]
|
||||
image_url: Union[str, ChatCompletionRequestMessageContentPartImageImageUrl]
|
||||
|
||||
|
||||
ChatCompletionRequestMessageContentPart = Union[
|
||||
ChatCompletionRequestMessageContentPartText,
|
||||
ChatCompletionRequestMessageContentPartImage,
|
||||
]
|
||||
|
||||
|
||||
class ChatCompletionRequestSystemMessage(TypedDict):
|
||||
role: Literal["system"]
|
||||
content: Optional[str]
|
||||
name: NotRequired[str]
|
||||
function_call: NotRequired[ChatCompletionFunctionCall]
|
||||
|
||||
|
||||
class ChatCompletionRequestUserMessage(TypedDict):
|
||||
role: Literal["user"]
|
||||
content: Optional[Union[str, List[ChatCompletionRequestMessageContentPart]]]
|
||||
|
||||
|
||||
class ChatCompletionMessageToolCallFunction(TypedDict):
|
||||
name: str
|
||||
arguments: str
|
||||
|
||||
|
||||
class ChatCompletionMessageToolCall(TypedDict):
|
||||
id: str
|
||||
type: Literal["function"]
|
||||
function: ChatCompletionMessageToolCallFunction
|
||||
|
||||
|
||||
ChatCompletionMessageToolCalls = List[ChatCompletionMessageToolCall]
|
||||
|
||||
|
||||
class ChatCompletionRequestAssistantMessageFunctionCall(TypedDict):
|
||||
name: str
|
||||
arguments: str
|
||||
|
||||
|
||||
class ChatCompletionRequestAssistantMessage(TypedDict):
|
||||
role: Literal["assistant"]
|
||||
content: Optional[str]
|
||||
tool_calls: NotRequired[ChatCompletionMessageToolCalls]
|
||||
function_call: NotRequired[
|
||||
ChatCompletionRequestAssistantMessageFunctionCall
|
||||
] # DEPRECATED
|
||||
|
||||
|
||||
class ChatCompletionRequestToolMessage(TypedDict):
|
||||
role: Literal["tool"]
|
||||
content: Optional[str]
|
||||
tool_call_id: str
|
||||
|
||||
|
||||
class ChatCompletionRequestFunctionMessage(TypedDict):
|
||||
role: Literal["function"]
|
||||
content: Optional[str]
|
||||
name: str
|
||||
|
||||
|
||||
ChatCompletionRequestMessage = Union[
|
||||
ChatCompletionRequestSystemMessage,
|
||||
ChatCompletionRequestUserMessage,
|
||||
ChatCompletionRequestAssistantMessage,
|
||||
ChatCompletionRequestUserMessage,
|
||||
ChatCompletionRequestToolMessage,
|
||||
ChatCompletionRequestFunctionMessage,
|
||||
]
|
||||
|
||||
|
||||
class ChatCompletionRequestFunctionCallOption(TypedDict):
|
||||
name: str
|
||||
|
||||
|
||||
ChatCompletionRequestFunctionCall = Union[
|
||||
Literal["none", "auto"], ChatCompletionRequestFunctionCallOption
|
||||
]
|
||||
|
||||
ChatCompletionFunctionParameters = Dict[str, JsonType] # TODO: make this more specific
|
||||
|
||||
|
||||
class ChatCompletionToolFunction(TypedDict):
|
||||
name: str
|
||||
description: NotRequired[str]
|
||||
parameters: ChatCompletionFunctionParameters
|
||||
|
||||
|
||||
class ChatCompletionTool(TypedDict):
|
||||
type: Literal["function"]
|
||||
function: ChatCompletionToolFunction
|
||||
|
||||
|
||||
class ChatCompletionNamedToolChoiceFunction(TypedDict):
|
||||
name: str
|
||||
|
||||
|
||||
class ChatCompletionNamedToolChoice(TypedDict):
|
||||
type: Literal["function"]
|
||||
function: ChatCompletionNamedToolChoiceFunction
|
||||
|
||||
|
||||
ChatCompletionToolChoiceOption = Union[
|
||||
Literal["none", "auto"], ChatCompletionNamedToolChoice
|
||||
]
|
||||
|
||||
|
||||
# NOTE: The following type names are not part of the OpenAI OpenAPI specification
|
||||
# and will be removed in a future major release.
|
||||
|
||||
EmbeddingData = Embedding
|
||||
CompletionChunk = CreateCompletionResponse
|
||||
Completion = CreateCompletionResponse
|
||||
CreateCompletionStreamResponse = CreateCompletionResponse
|
||||
ChatCompletionMessage = ChatCompletionResponseMessage
|
||||
ChatCompletionChoice = ChatCompletionResponseChoice
|
||||
ChatCompletion = CreateChatCompletionResponse
|
||||
ChatCompletionChunkDeltaEmpty = ChatCompletionStreamResponseDeltaEmpty
|
||||
ChatCompletionChunkChoice = ChatCompletionStreamResponseChoice
|
||||
ChatCompletionChunkDelta = ChatCompletionStreamResponseDelta
|
||||
ChatCompletionChunk = CreateChatCompletionStreamResponse
|
||||
ChatCompletionStreamResponse = CreateChatCompletionStreamResponse
|
||||
ChatCompletionResponseFunction = ChatCompletionFunction
|
||||
ChatCompletionFunctionCall = ChatCompletionResponseFunctionCall
|
||||
|
|
232
llama_cpp/llava_cpp.py
Normal file
232
llama_cpp/llava_cpp.py
Normal file
|
@ -0,0 +1,232 @@
|
|||
import sys
|
||||
import os
|
||||
import ctypes
|
||||
from ctypes import (
|
||||
c_bool,
|
||||
c_char_p,
|
||||
c_int,
|
||||
c_int8,
|
||||
c_int32,
|
||||
c_uint8,
|
||||
c_uint32,
|
||||
c_size_t,
|
||||
c_float,
|
||||
c_double,
|
||||
c_void_p,
|
||||
POINTER,
|
||||
_Pointer, # type: ignore
|
||||
Structure,
|
||||
Array,
|
||||
)
|
||||
import pathlib
|
||||
from typing import List, Union
|
||||
|
||||
import llama_cpp.llama_cpp as llama_cpp
|
||||
|
||||
# Load the library
|
||||
def _load_shared_library(lib_base_name: str):
|
||||
# Construct the paths to the possible shared library names
|
||||
_base_path = pathlib.Path(os.path.abspath(os.path.dirname(__file__)))
|
||||
# Searching for the library in the current directory under the name "libllama" (default name
|
||||
# for llamacpp) and "llama" (default name for this repo)
|
||||
_lib_paths: List[pathlib.Path] = []
|
||||
# Determine the file extension based on the platform
|
||||
if sys.platform.startswith("linux"):
|
||||
_lib_paths += [
|
||||
_base_path / f"lib{lib_base_name}.so",
|
||||
]
|
||||
elif sys.platform == "darwin":
|
||||
_lib_paths += [
|
||||
_base_path / f"lib{lib_base_name}.so",
|
||||
_base_path / f"lib{lib_base_name}.dylib",
|
||||
]
|
||||
elif sys.platform == "win32":
|
||||
_lib_paths += [
|
||||
_base_path / f"{lib_base_name}.dll",
|
||||
_base_path / f"lib{lib_base_name}.dll",
|
||||
]
|
||||
else:
|
||||
raise RuntimeError("Unsupported platform")
|
||||
|
||||
if "LLAMA_CPP_LIB" in os.environ:
|
||||
lib_base_name = os.environ["LLAMA_CPP_LIB"]
|
||||
_lib = pathlib.Path(lib_base_name)
|
||||
_base_path = _lib.parent.resolve()
|
||||
_lib_paths = [_lib.resolve()]
|
||||
|
||||
cdll_args = dict() # type: ignore
|
||||
# Add the library directory to the DLL search path on Windows (if needed)
|
||||
if sys.platform == "win32" and sys.version_info >= (3, 8):
|
||||
os.add_dll_directory(str(_base_path))
|
||||
if "CUDA_PATH" in os.environ:
|
||||
os.add_dll_directory(os.path.join(os.environ["CUDA_PATH"], "bin"))
|
||||
os.add_dll_directory(os.path.join(os.environ["CUDA_PATH"], "lib"))
|
||||
cdll_args["winmode"] = ctypes.RTLD_GLOBAL
|
||||
|
||||
# Try to load the shared library, handling potential errors
|
||||
for _lib_path in _lib_paths:
|
||||
if _lib_path.exists():
|
||||
try:
|
||||
return ctypes.CDLL(str(_lib_path), **cdll_args)
|
||||
except Exception as e:
|
||||
raise RuntimeError(f"Failed to load shared library '{_lib_path}': {e}")
|
||||
|
||||
raise FileNotFoundError(
|
||||
f"Shared library with base name '{lib_base_name}' not found"
|
||||
)
|
||||
|
||||
|
||||
# Specify the base name of the shared library to load
|
||||
_libllava_base_name = "llava"
|
||||
|
||||
# Load the library
|
||||
_libllava = _load_shared_library(_libllava_base_name)
|
||||
|
||||
|
||||
################################################
|
||||
# llava.h
|
||||
################################################
|
||||
|
||||
# struct clip_ctx;
|
||||
clip_ctx_p = c_void_p
|
||||
|
||||
# struct llava_image_embed {
|
||||
# float * embed;
|
||||
# int n_image_pos;
|
||||
# };
|
||||
class llava_image_embed(Structure):
|
||||
_fields_ = [
|
||||
("embed", POINTER(c_float)),
|
||||
("n_image_pos", c_int),
|
||||
]
|
||||
|
||||
# /** sanity check for clip <-> llava embed size match */
|
||||
# LLAVA_API bool llava_validate_embed_size(const llama_context * ctx_llama, const clip_ctx * ctx_clip);
|
||||
def llava_validate_embed_size(ctx_llama: llama_cpp.llama_context_p, ctx_clip: clip_ctx_p) -> bool:
|
||||
return _libllava.llava_validate_embed_size(ctx_llama, ctx_clip)
|
||||
|
||||
_libllava.llava_validate_embed_size.argtypes = [llama_cpp.llama_context_p, clip_ctx_p]
|
||||
_libllava.llava_validate_embed_size.restype = c_bool
|
||||
|
||||
# /** build an image embed from image file bytes */
|
||||
# LLAVA_API struct llava_image_embed * llava_image_embed_make_with_bytes(struct clip_ctx * ctx_clip, int n_threads, const unsigned char * image_bytes, int image_bytes_length);
|
||||
def llava_image_embed_make_with_bytes(ctx_clip: clip_ctx_p, n_threads: Union[c_int, int], image_bytes: bytes, image_bytes_length: Union[c_int, int]) -> "_Pointer[llava_image_embed]":
|
||||
return _libllava.llava_image_embed_make_with_bytes(ctx_clip, n_threads, image_bytes, image_bytes_length)
|
||||
|
||||
_libllava.llava_image_embed_make_with_bytes.argtypes = [clip_ctx_p, c_int, POINTER(c_uint8), c_int]
|
||||
_libllava.llava_image_embed_make_with_bytes.restype = POINTER(llava_image_embed)
|
||||
|
||||
# /** build an image embed from a path to an image filename */
|
||||
# LLAVA_API struct llava_image_embed * llava_image_embed_make_with_filename(struct clip_ctx * ctx_clip, int n_threads, const char * image_path);
|
||||
def llava_image_embed_make_with_filename(ctx_clip: clip_ctx_p, n_threads: Union[c_int, int], image_path: bytes) -> "_Pointer[llava_image_embed]":
|
||||
return _libllava.llava_image_embed_make_with_filename(ctx_clip, n_threads, image_path)
|
||||
|
||||
_libllava.llava_image_embed_make_with_filename.argtypes = [clip_ctx_p, c_int, c_char_p]
|
||||
_libllava.llava_image_embed_make_with_filename.restype = POINTER(llava_image_embed)
|
||||
|
||||
# LLAVA_API void llava_image_embed_free(struct llava_image_embed * embed);
|
||||
# /** free an embedding made with llava_image_embed_make_* */
|
||||
def llava_image_embed_free(embed: "_Pointer[llava_image_embed]"):
|
||||
return _libllava.llava_image_embed_free(embed)
|
||||
|
||||
_libllava.llava_image_embed_free.argtypes = [POINTER(llava_image_embed)]
|
||||
_libllava.llava_image_embed_free.restype = None
|
||||
|
||||
# /** write the image represented by embed into the llama context with batch size n_batch, starting at context pos n_past. on completion, n_past points to the next position in the context after the image embed. */
|
||||
# LLAVA_API bool llava_eval_image_embed(struct llama_context * ctx_llama, const struct llava_image_embed * embed, int n_batch, int * n_past);
|
||||
def llava_eval_image_embed(ctx_llama: llama_cpp.llama_context_p, embed: "_Pointer[llava_image_embed]", n_batch: Union[c_int, int], n_past: "_Pointer[c_int]") -> bool:
|
||||
return _libllava.llava_eval_image_embed(ctx_llama, embed, n_batch, n_past)
|
||||
|
||||
_libllava.llava_eval_image_embed.argtypes = [llama_cpp.llama_context_p, POINTER(llava_image_embed), c_int, POINTER(c_int)]
|
||||
_libllava.llava_eval_image_embed.restype = c_bool
|
||||
|
||||
|
||||
################################################
|
||||
# clip.h
|
||||
################################################
|
||||
|
||||
|
||||
# struct clip_vision_hparams {
|
||||
# int32_t image_size;
|
||||
# int32_t patch_size;
|
||||
# int32_t hidden_size;
|
||||
# int32_t n_intermediate;
|
||||
# int32_t projection_dim;
|
||||
# int32_t n_head;
|
||||
# int32_t n_layer;
|
||||
# float eps;
|
||||
# };
|
||||
class clip_vision_hparams(Structure):
|
||||
_fields_ = [
|
||||
("image_size", c_int32),
|
||||
("patch_size", c_int32),
|
||||
("hidden_size", c_int32),
|
||||
("n_intermediate", c_int32),
|
||||
("projection_dim", c_int32),
|
||||
("n_head", c_int32),
|
||||
("n_layer", c_int32),
|
||||
("eps", c_float),
|
||||
]
|
||||
|
||||
# /** load mmproj model */
|
||||
# CLIP_API struct clip_ctx * clip_model_load(const char * fname, const int verbosity);
|
||||
def clip_model_load(fname: bytes, verbosity: Union[c_int, int]) -> clip_ctx_p:
|
||||
return _libllava.clip_model_load(fname, verbosity)
|
||||
|
||||
_libllava.clip_model_load.argtypes = [c_char_p, c_int]
|
||||
_libllava.clip_model_load.restype = clip_ctx_p
|
||||
|
||||
# /** free mmproj model */
|
||||
# CLIP_API void clip_free(struct clip_ctx * ctx);
|
||||
def clip_free(ctx: clip_ctx_p):
|
||||
return _libllava.clip_free(ctx)
|
||||
|
||||
_libllava.clip_free.argtypes = [clip_ctx_p]
|
||||
_libllava.clip_free.restype = None
|
||||
|
||||
# size_t clip_embd_nbytes(const struct clip_ctx * ctx);
|
||||
# int clip_n_patches(const struct clip_ctx * ctx);
|
||||
# int clip_n_mmproj_embd(const struct clip_ctx * ctx);
|
||||
|
||||
# // RGB uint8 image
|
||||
# struct clip_image_u8 {
|
||||
# int nx;
|
||||
# int ny;
|
||||
# uint8_t * data = NULL;
|
||||
# size_t size;
|
||||
# };
|
||||
|
||||
# // RGB float32 image (NHWC)
|
||||
# // Memory layout: RGBRGBRGB...
|
||||
# struct clip_image_f32 {
|
||||
# int nx;
|
||||
# int ny;
|
||||
# float * data = NULL;
|
||||
# size_t size;
|
||||
# };
|
||||
|
||||
# struct clip_image_u8_batch {
|
||||
# struct clip_image_u8 * data;
|
||||
# size_t size;
|
||||
# };
|
||||
|
||||
# struct clip_image_f32_batch {
|
||||
# struct clip_image_f32 * data;
|
||||
# size_t size;
|
||||
# };
|
||||
|
||||
# struct clip_image_u8 * make_clip_image_u8();
|
||||
# struct clip_image_f32 * make_clip_image_f32();
|
||||
# CLIP_API void clip_image_u8_free(clip_image_u8 * img);
|
||||
# CLIP_API void clip_image_f32_free(clip_image_f32 * img);
|
||||
# CLIP_API bool clip_image_load_from_file(const char * fname, struct clip_image_u8 * img);
|
||||
# /** interpret bytes as an image file with length bytes_length, and use the result to populate img */
|
||||
# CLIP_API bool clip_image_load_from_bytes(const unsigned char * bytes, size_t bytes_length, struct clip_image_u8 * img);
|
||||
|
||||
# bool clip_image_preprocess(const struct clip_ctx * ctx, const struct clip_image_u8 * img, struct clip_image_f32 * res, const bool pad2square);
|
||||
# bool clip_image_encode(const struct clip_ctx * ctx, const int n_threads, struct clip_image_f32 * img, float * vec);
|
||||
|
||||
# bool clip_image_batch_encode(const struct clip_ctx * ctx, const int n_threads, const struct clip_image_f32_batch * imgs,
|
||||
# float * vec);
|
||||
|
||||
# bool clip_model_quantize(const char * fname_inp, const char * fname_out, const int itype);
|
|
@ -138,6 +138,10 @@ class Settings(BaseSettings):
|
|||
default="llama-2",
|
||||
description="Chat format to use.",
|
||||
)
|
||||
clip_model_path: Optional[str] = Field(
|
||||
default=None,
|
||||
description="Path to a CLIP model to use for multi-modal chat completion.",
|
||||
)
|
||||
# Cache Params
|
||||
cache: bool = Field(
|
||||
default=False,
|
||||
|
@ -375,6 +379,14 @@ def create_app(settings: Optional[Settings] = None):
|
|||
)
|
||||
app.include_router(router)
|
||||
global llama
|
||||
|
||||
##
|
||||
chat_handler = None
|
||||
if settings.chat_format == "llava-1-5":
|
||||
assert settings.clip_model_path is not None
|
||||
chat_handler = llama_cpp.llama_chat_format.Llava15ChatHandler(clip_model_path=settings.clip_model_path)
|
||||
##
|
||||
|
||||
llama = llama_cpp.Llama(
|
||||
model_path=settings.model,
|
||||
# Model Params
|
||||
|
@ -411,6 +423,7 @@ def create_app(settings: Optional[Settings] = None):
|
|||
numa=settings.numa,
|
||||
# Chat Format Params
|
||||
chat_format=settings.chat_format,
|
||||
chat_handler=chat_handler,
|
||||
# Misc
|
||||
verbose=settings.verbose,
|
||||
)
|
||||
|
@ -580,10 +593,6 @@ class CreateCompletionRequest(BaseModel):
|
|||
max_tokens: int = max_tokens_field
|
||||
temperature: float = temperature_field
|
||||
top_p: float = top_p_field
|
||||
mirostat_mode: int = mirostat_mode_field
|
||||
mirostat_tau: float = mirostat_tau_field
|
||||
mirostat_eta: float = mirostat_eta_field
|
||||
grammar: Optional[str] = None
|
||||
echo: bool = Field(
|
||||
default=False,
|
||||
description="Whether to echo the prompt in the generated text. Useful for chatbots.",
|
||||
|
@ -610,6 +619,10 @@ class CreateCompletionRequest(BaseModel):
|
|||
top_k: int = top_k_field
|
||||
repeat_penalty: float = repeat_penalty_field
|
||||
logit_bias_type: Optional[Literal["input_ids", "tokens"]] = Field(None)
|
||||
mirostat_mode: int = mirostat_mode_field
|
||||
mirostat_tau: float = mirostat_tau_field
|
||||
mirostat_eta: float = mirostat_eta_field
|
||||
grammar: Optional[str] = None
|
||||
|
||||
model_config = {
|
||||
"json_schema_extra": {
|
||||
|
@ -688,7 +701,7 @@ async def create_completion(
|
|||
kwargs["grammar"] = llama_cpp.LlamaGrammar.from_string(body.grammar)
|
||||
|
||||
iterator_or_completion: Union[
|
||||
llama_cpp.Completion, Iterator[llama_cpp.CompletionChunk]
|
||||
llama_cpp.CreateCompletionResponse, Iterator[llama_cpp.CreateCompletionStreamResponse]
|
||||
] = await run_in_threadpool(llama, **kwargs)
|
||||
|
||||
if isinstance(iterator_or_completion, Iterator):
|
||||
|
@ -697,7 +710,7 @@ async def create_completion(
|
|||
|
||||
# If no exception was raised from first_response, we can assume that
|
||||
# the iterator is valid and we can use it to stream the response.
|
||||
def iterator() -> Iterator[llama_cpp.CompletionChunk]:
|
||||
def iterator() -> Iterator[llama_cpp.CreateCompletionStreamResponse]:
|
||||
yield first_response
|
||||
yield from iterator_or_completion
|
||||
|
||||
|
@ -748,27 +761,30 @@ class ChatCompletionRequestMessage(BaseModel):
|
|||
)
|
||||
content: Optional[str] = Field(default="", description="The content of the message.")
|
||||
|
||||
from typing import Any
|
||||
|
||||
class CreateChatCompletionRequest(BaseModel):
|
||||
messages: List[Any] = Field(
|
||||
messages: List[llama_cpp.ChatCompletionRequestMessage] = Field(
|
||||
default=[], description="A list of messages to generate completions for."
|
||||
)
|
||||
functions: Optional[List[llama_cpp.ChatCompletionFunction]] = Field(
|
||||
default=None,
|
||||
description="A list of functions to apply to the generated completions.",
|
||||
)
|
||||
function_call: Optional[Union[Literal["auto", "none"], llama_cpp.ChatCompletionFunctionCallOption]] = Field(
|
||||
function_call: Optional[llama_cpp.ChatCompletionRequestFunctionCall] = Field(
|
||||
default=None,
|
||||
description="A function to apply to the generated completions.",
|
||||
)
|
||||
tools: Optional[List[llama_cpp.ChatCompletionTool]] = Field(
|
||||
default=None,
|
||||
description="A list of tools to apply to the generated completions.",
|
||||
)
|
||||
tool_choice: Optional[llama_cpp.ChatCompletionToolChoiceOption] = Field(
|
||||
default=None,
|
||||
description="A tool to apply to the generated completions.",
|
||||
) # TODO: verify
|
||||
max_tokens: int = max_tokens_field
|
||||
temperature: float = temperature_field
|
||||
top_p: float = top_p_field
|
||||
mirostat_mode: int = mirostat_mode_field
|
||||
mirostat_tau: float = mirostat_tau_field
|
||||
mirostat_eta: float = mirostat_eta_field
|
||||
grammar: Optional[str] = None
|
||||
stop: Optional[List[str]] = stop_field
|
||||
stream: bool = stream_field
|
||||
presence_penalty: Optional[float] = presence_penalty_field
|
||||
|
@ -784,6 +800,10 @@ class CreateChatCompletionRequest(BaseModel):
|
|||
top_k: int = top_k_field
|
||||
repeat_penalty: float = repeat_penalty_field
|
||||
logit_bias_type: Optional[Literal["input_ids", "tokens"]] = Field(None)
|
||||
mirostat_mode: int = mirostat_mode_field
|
||||
mirostat_tau: float = mirostat_tau_field
|
||||
mirostat_eta: float = mirostat_eta_field
|
||||
grammar: Optional[str] = None
|
||||
|
||||
model_config = {
|
||||
"json_schema_extra": {
|
||||
|
|
2
vendor/llama.cpp
vendored
2
vendor/llama.cpp
vendored
|
@ -1 +1 @@
|
|||
Subproject commit 2833a6f63c1b87c7f4ac574bcf7a15a2f3bf3ede
|
||||
Subproject commit 381efbf480959bb6d1e247a8b0c2328f22e350f8
|
Loading…
Reference in a new issue