This commit is contained in:
baalajimaestro 2024-01-31 21:27:17 +05:30
commit cd66f3cfb4
Signed by: baalajimaestro
GPG key ID: F93C394FE9BBAFD5
13 changed files with 322 additions and 54 deletions

View file

@ -7,6 +7,25 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
## [Unreleased] ## [Unreleased]
## [0.2.37]
- feat: Update llama.cpp to ggerganov/llama.cpp@fea4fd4ba7f6b754ac795387b275e1a014a77bde
- feat: Automatically set chat format from gguf by @abetlen in #1110
## [0.2.36]
- feat: Update llama.cpp to ggerganov/llama.cpp@2aed77eb06a329f0d82bb1c467f4244904d4073f
- feat: Add mistral instruct chat format as "mistral-instruct" by @Rafaelblsilva in #799
## [0.2.35]
- feat: Update llama.cpp to ggerganov/llama.cpp@d2f650cb5b04ee2726663e79b47da5efe196ce00
## [0.2.34]
- feat: Update llama.cpp to ggerganov/llama.cpp@6db2b41a76ee78d5efdd5c3cddd5d7ad3f646855
- feat: Add json schema mode by @abetlen in #1122
## [0.2.33] ## [0.2.33]
- feat: Update llama.cpp to ggerganov/llama.cpp@faa3526a1eba458120987ed8269e5616385a76f4 - feat: Update llama.cpp to ggerganov/llama.cpp@faa3526a1eba458120987ed8269e5616385a76f4

View file

@ -27,6 +27,15 @@ build.blis:
build.metal: build.metal:
CMAKE_ARGS="-DLLAMA_METAL=on" python3 -m pip install --verbose -e . CMAKE_ARGS="-DLLAMA_METAL=on" python3 -m pip install --verbose -e .
build.vulkan:
CMAKE_ARGS="-DLLAMA_VULKAN=on" python3 -m pip install --verbose -e .
build.kompute:
CMAKE_ARGS="-DLLAMA_KOMPUTE=on" python3 -m pip install --verbose -e .
build.sycl:
CMAKE_ARGS="-DLLAMA_SYCL=on" python3 -m pip install --verbose -e .
build.sdist: build.sdist:
python3 -m build --sdist python3 -m build --sdist

112
README.md
View file

@ -12,20 +12,17 @@ This package provides:
- Low-level access to C API via `ctypes` interface. - Low-level access to C API via `ctypes` interface.
- High-level Python API for text completion - High-level Python API for text completion
- OpenAI-like API - OpenAI-like API
- [LangChain compatibility](https://python.langchain.com/docs/integrations/llms/llamacpp) - [LangChain compatibility](https://python.langchain.com/docs/integrations/llms/llamacpp)
- [LlamaIndex compatibility](https://docs.llamaindex.ai/en/stable/examples/llm/llama_2_llama_cpp.html) - [LlamaIndex compatibility](https://docs.llamaindex.ai/en/stable/examples/llm/llama_2_llama_cpp.html)
- OpenAI compatible web server - OpenAI compatible web server
- [Local Copilot replacement](https://llama-cpp-python.readthedocs.io/en/latest/server/#code-completion) - [Local Copilot replacement](https://llama-cpp-python.readthedocs.io/en/latest/server/#code-completion)
- [Function Calling support](https://llama-cpp-python.readthedocs.io/en/latest/server/#function-calling) - [Function Calling support](https://llama-cpp-python.readthedocs.io/en/latest/server/#function-calling)
- [Vision API support](https://llama-cpp-python.readthedocs.io/en/latest/server/#multimodal-models) - [Vision API support](https://llama-cpp-python.readthedocs.io/en/latest/server/#multimodal-models)
- [Multiple Models](https://llama-cpp-python.readthedocs.io/en/latest/server/#configuration-and-multi-model-support) - [Multiple Models](https://llama-cpp-python.readthedocs.io/en/latest/server/#configuration-and-multi-model-support)
Documentation is available at [https://llama-cpp-python.readthedocs.io/en/latest](https://llama-cpp-python.readthedocs.io/en/latest). Documentation is available at [https://llama-cpp-python.readthedocs.io/en/latest](https://llama-cpp-python.readthedocs.io/en/latest).
## Installation ## Installation
`llama-cpp-python` can be installed directly from PyPI as a source distribution by running: `llama-cpp-python` can be installed directly from PyPI as a source distribution by running:
@ -38,7 +35,6 @@ This will build `llama.cpp` from source using cmake and your system's c compiler
If you run into issues during installation add the `--verbose` flag to the `pip install` command to see the full cmake build log. If you run into issues during installation add the `--verbose` flag to the `pip install` command to see the full cmake build log.
### Installation with Specific Hardware Acceleration (BLAS, CUDA, Metal, etc) ### Installation with Specific Hardware Acceleration (BLAS, CUDA, Metal, etc)
The default pip install behaviour is to build `llama.cpp` for CPU only on Linux and Windows and use Metal on MacOS. The default pip install behaviour is to build `llama.cpp` for CPU only on Linux and Windows and use Metal on MacOS.
@ -71,7 +67,7 @@ CMAKE_ARGS="-DLLAMA_BLAS=ON -DLLAMA_BLAS_VENDOR=OpenBLAS" pip install llama-cpp-
#### cuBLAS #### cuBLAS
To install with cuBLAS, set the `LLAMA_CUBLAS=1` environment variable before installing: To install with cuBLAS, set the `LLAMA_CUBLAS=on` environment variable before installing:
```bash ```bash
CMAKE_ARGS="-DLLAMA_CUBLAS=on" pip install llama-cpp-python CMAKE_ARGS="-DLLAMA_CUBLAS=on" pip install llama-cpp-python
@ -87,7 +83,7 @@ CMAKE_ARGS="-DLLAMA_METAL=on" pip install llama-cpp-python
#### CLBlast #### CLBlast
To install with CLBlast, set the `LLAMA_CLBLAST=1` environment variable before installing: To install with CLBlast, set the `LLAMA_CLBLAST=on` environment variable before installing:
```bash ```bash
CMAKE_ARGS="-DLLAMA_CLBLAST=on" pip install llama-cpp-python CMAKE_ARGS="-DLLAMA_CLBLAST=on" pip install llama-cpp-python
@ -101,13 +97,37 @@ To install with hipBLAS / ROCm support for AMD cards, set the `LLAMA_HIPBLAS=on`
CMAKE_ARGS="-DLLAMA_HIPBLAS=on" pip install llama-cpp-python CMAKE_ARGS="-DLLAMA_HIPBLAS=on" pip install llama-cpp-python
``` ```
#### Vulkan
To install with Vulkan support, set the `LLAMA_VULKAN=on` environment variable before installing:
```bash
CMAKE_ARGS="-DLLAMA_VULKAN=on" pip install llama-cpp-python
```
#### Kompute
To install with Kompute support, set the `LLAMA_KOMPUTE=on` environment variable before installing:
```bash
CMAKE_ARGS="-DLLAMA_KOMPUTE=on" pip install llama-cpp-python
```
#### SYCL
To install with SYCL support, set the `LLAMA_SYCL=on` environment variable before installing:
```bash
CMAKE_ARGS="-DLLAMA_SYCL=on" pip install llama-cpp-python
```
### Windows Notes ### Windows Notes
If you run into issues where it complains it can't find `'nmake'` `'?'` or CMAKE_C_COMPILER, you can extract w64devkit as [mentioned in llama.cpp repo](https://github.com/ggerganov/llama.cpp#openblas) and add those manually to CMAKE_ARGS before running `pip` install: If you run into issues where it complains it can't find `'nmake'` `'?'` or CMAKE_C_COMPILER, you can extract w64devkit as [mentioned in llama.cpp repo](https://github.com/ggerganov/llama.cpp#openblas) and add those manually to CMAKE_ARGS before running `pip` install:
```ps ```ps
$env:CMAKE_GENERATOR = "MinGW Makefiles" $env:CMAKE_GENERATOR = "MinGW Makefiles"
$env:CMAKE_ARGS = "-DLLAMA_OPENBLAS=on -DCMAKE_C_COMPILER=C:/w64devkit/bin/gcc.exe -DCMAKE_CXX_COMPILER=C:/w64devkit/bin/g++.exe" $env:CMAKE_ARGS = "-DLLAMA_OPENBLAS=on -DCMAKE_C_COMPILER=C:/w64devkit/bin/gcc.exe -DCMAKE_CXX_COMPILER=C:/w64devkit/bin/g++.exe"
``` ```
See the above instructions and set `CMAKE_ARGS` to the BLAS backend you want to use. See the above instructions and set `CMAKE_ARGS` to the BLAS backend you want to use.
@ -157,7 +177,7 @@ Below is a short example demonstrating how to use the high-level API to for basi
>>> from llama_cpp import Llama >>> from llama_cpp import Llama
>>> llm = Llama( >>> llm = Llama(
model_path="./models/7B/llama-model.gguf", model_path="./models/7B/llama-model.gguf",
# n_gpu_layers=-1, # Uncomment to use GPU acceleration # n_gpu_layers=-1, # Uncomment to use GPU acceleration
# seed=1337, # Uncomment to set a specific seed # seed=1337, # Uncomment to set a specific seed
# n_ctx=2048, # Uncomment to increase the context window # n_ctx=2048, # Uncomment to increase the context window
) )
@ -216,6 +236,59 @@ Note that `chat_format` option must be set for the particular model you are usin
Chat completion is available through the [`create_chat_completion`](https://llama-cpp-python.readthedocs.io/en/latest/api-reference/#llama_cpp.Llama.create_chat_completion) method of the [`Llama`](https://llama-cpp-python.readthedocs.io/en/latest/api-reference/#llama_cpp.Llama) class. Chat completion is available through the [`create_chat_completion`](https://llama-cpp-python.readthedocs.io/en/latest/api-reference/#llama_cpp.Llama.create_chat_completion) method of the [`Llama`](https://llama-cpp-python.readthedocs.io/en/latest/api-reference/#llama_cpp.Llama) class.
### JSON and JSON Schema Mode
If you want to constrain chat responses to only valid JSON or a specific JSON Schema you can use the `response_format` argument to the `create_chat_completion` method.
#### JSON Mode
The following example will constrain the response to be valid JSON.
```python
>>> from llama_cpp import Llama
>>> llm = Llama(model_path="path/to/model.gguf", chat_format="chatml")
>>> llm.create_chat_completion(
messages=[
{
"role": "system",
"content": "You are a helpful assistant that outputs in JSON.",
},
{"role": "user", "content": "Who won the world series in 2020"},
],
response_format={
"type": "json_object",
},
temperature=0.7,
)
```
#### JSON Schema Mode
To constrain the response to a specific JSON Schema, you can use the `schema` property of the `response_format` argument.
```python
>>> from llama_cpp import Llama
>>> llm = Llama(model_path="path/to/model.gguf", chat_format="chatml")
>>> llm.create_chat_completion(
messages=[
{
"role": "system",
"content": "You are a helpful assistant that outputs in JSON.",
},
{"role": "user", "content": "Who won the world series in 2020"},
],
response_format={
"type": "json_object",
"schema": {
"type": "object",
"properties": {"team_name": {"type": "string"}},
"required": ["team_name"],
},
},
temperature=0.7,
)
```
### Function Calling ### Function Calling
The high-level API also provides a simple interface for function calling. The high-level API also provides a simple interface for function calling.
@ -223,7 +296,6 @@ The high-level API also provides a simple interface for function calling.
Note that the only model that supports full function calling at this time is "functionary". Note that the only model that supports full function calling at this time is "functionary".
The gguf-converted files for this model can be found here: [functionary-7b-v1](https://huggingface.co/abetlen/functionary-7b-v1-GGUF) The gguf-converted files for this model can be found here: [functionary-7b-v1](https://huggingface.co/abetlen/functionary-7b-v1-GGUF)
```python ```python
>>> from llama_cpp import Llama >>> from llama_cpp import Llama
>>> llm = Llama(model_path="path/to/functionary/llama-model.gguf", chat_format="functionary") >>> llm = Llama(model_path="path/to/functionary/llama-model.gguf", chat_format="functionary")
@ -232,7 +304,7 @@ The gguf-converted files for this model can be found here: [functionary-7b-v1](h
{ {
"role": "system", "role": "system",
"content": "A chat between a curious user and an artificial intelligence assistant. The assistant gives helpful, detailed, and polite answers to the user's questions. The assistant calls functions with appropriate input when necessary" "content": "A chat between a curious user and an artificial intelligence assistant. The assistant gives helpful, detailed, and polite answers to the user's questions. The assistant calls functions with appropriate input when necessary"
}, },
{ {
"role": "user", "role": "user",
@ -271,7 +343,6 @@ The gguf-converted files for this model can be found here: [functionary-7b-v1](h
### Multi-modal Models ### Multi-modal Models
`llama-cpp-python` supports the llava1.5 family of multi-modal models which allow the language model to `llama-cpp-python` supports the llava1.5 family of multi-modal models which allow the language model to
read information from both text and images. read information from both text and images.
@ -317,7 +388,6 @@ For instance, if you want to work with larger contexts, you can expand the conte
llm = Llama(model_path="./models/7B/llama-model.gguf", n_ctx=2048) llm = Llama(model_path="./models/7B/llama-model.gguf", n_ctx=2048)
``` ```
## OpenAI Compatible Web Server ## OpenAI Compatible Web Server
`llama-cpp-python` offers a web server which aims to act as a drop-in replacement for the OpenAI API. `llama-cpp-python` offers a web server which aims to act as a drop-in replacement for the OpenAI API.
@ -365,7 +435,8 @@ A Docker image is available on [GHCR](https://ghcr.io/abetlen/llama-cpp-python).
```bash ```bash
docker run --rm -it -p 8000:8000 -v /path/to/models:/models -e MODEL=/models/llama-model.gguf ghcr.io/abetlen/llama-cpp-python:latest docker run --rm -it -p 8000:8000 -v /path/to/models:/models -e MODEL=/models/llama-model.gguf ghcr.io/abetlen/llama-cpp-python:latest
``` ```
[Docker on termux (requires root)](https://gist.github.com/FreddieOliveira/efe850df7ff3951cb62d74bd770dce27) is currently the only known way to run this on phones, see [termux support issue](https://github.com/abetlen/llama-cpp-python/issues/389)
[Docker on termux (requires root)](https://gist.github.com/FreddieOliveira/efe850df7ff3951cb62d74bd770dce27) is currently the only known way to run this on phones, see [termux support issue](https://github.com/abetlen/llama-cpp-python/issues/389)
## Low-level API ## Low-level API
@ -393,7 +464,6 @@ Below is a short example demonstrating how to use the low-level API to tokenize
Check out the [examples folder](examples/low_level_api) for more examples of using the low-level API. Check out the [examples folder](examples/low_level_api) for more examples of using the low-level API.
## Documentation ## Documentation
Documentation is available via [https://llama-cpp-python.readthedocs.io/](https://llama-cpp-python.readthedocs.io/). Documentation is available via [https://llama-cpp-python.readthedocs.io/](https://llama-cpp-python.readthedocs.io/).

View file

@ -9,7 +9,7 @@ export MODEL=../models/7B/...
Then run: Then run:
``` ```
uvicorn llama_cpp.server.app:app --reload uvicorn --factory llama_cpp.server.app:create_app --reload
``` ```
or or

View file

@ -1,4 +1,4 @@
from .llama_cpp import * from .llama_cpp import *
from .llama import * from .llama import *
__version__ = "0.2.33" __version__ = "0.2.37"

View file

@ -216,13 +216,13 @@ class _LlamaModel:
for i in range(llama_cpp.llama_model_meta_count(self.model)): for i in range(llama_cpp.llama_model_meta_count(self.model)):
nbytes = llama_cpp.llama_model_meta_key_by_index(self.model, i, buffer, buffer_size) nbytes = llama_cpp.llama_model_meta_key_by_index(self.model, i, buffer, buffer_size)
if nbytes > buffer_size: if nbytes > buffer_size:
buffer_size = nbytes buffer_size = nbytes + 1
buffer = ctypes.create_string_buffer(buffer_size) buffer = ctypes.create_string_buffer(buffer_size)
nbytes = llama_cpp.llama_model_meta_key_by_index(self.model, i, buffer, buffer_size) nbytes = llama_cpp.llama_model_meta_key_by_index(self.model, i, buffer, buffer_size)
key = buffer.value.decode("utf-8") key = buffer.value.decode("utf-8")
nbytes = llama_cpp.llama_model_meta_val_str_by_index(self.model, i, buffer, buffer_size) nbytes = llama_cpp.llama_model_meta_val_str_by_index(self.model, i, buffer, buffer_size)
if nbytes > buffer_size: if nbytes > buffer_size:
buffer_size = nbytes buffer_size = nbytes + 1
buffer = ctypes.create_string_buffer(buffer_size) buffer = ctypes.create_string_buffer(buffer_size)
nbytes = llama_cpp.llama_model_meta_val_str_by_index(self.model, i, buffer, buffer_size) nbytes = llama_cpp.llama_model_meta_val_str_by_index(self.model, i, buffer, buffer_size)
value = buffer.value.decode("utf-8") value = buffer.value.decode("utf-8")

View file

@ -87,7 +87,7 @@ class Llama:
# Backend Params # Backend Params
numa: bool = False, numa: bool = False,
# Chat Format Params # Chat Format Params
chat_format: str = "llama-2", chat_format: Optional[str] = None,
chat_handler: Optional[llama_chat_format.LlamaChatCompletionHandler] = None, chat_handler: Optional[llama_chat_format.LlamaChatCompletionHandler] = None,
# Misc # Misc
verbose: bool = True, verbose: bool = True,
@ -343,6 +343,41 @@ class Llama:
if self.verbose: if self.verbose:
print(f"Model metadata: {self.metadata}", file=sys.stderr) print(f"Model metadata: {self.metadata}", file=sys.stderr)
if self.chat_format is None and self.chat_handler is None and "tokenizer.chat_template" in self.metadata:
chat_format = llama_chat_format.guess_chat_format_from_gguf_metadata(self.metadata)
if chat_format is not None:
self.chat_format = chat_format
if self.verbose:
print(f"Guessed chat format: {chat_format}", file=sys.stderr)
else:
template = self.metadata["tokenizer.chat_template"]
try:
eos_token_id = int(self.metadata["tokenizer.ggml.eos_token_id"])
except:
eos_token_id = self.token_eos()
try:
bos_token_id = int(self.metadata["tokenizer.ggml.bos_token_id"])
except:
bos_token_id = self.token_bos()
eos_token = self.detokenize([eos_token_id]).decode("utf-8")
bos_token = self.detokenize([bos_token_id]).decode("utf-8")
if self.verbose:
print(f"Using chat template: {template}", file=sys.stderr)
print(f"Using chat eos_token: {eos_token}", file=sys.stderr)
print(f"Using chat bos_token: {bos_token}", file=sys.stderr)
self.chat_handler = llama_chat_format.Jinja2ChatFormatter(
template=template,
eos_token=eos_token,
bos_token=bos_token
).to_chat_handler()
if self.chat_format is None and self.chat_handler is None:
self.chat_format = "llama-2"
@property @property
def ctx(self) -> llama_cpp.llama_context_p: def ctx(self) -> llama_cpp.llama_context_p:
assert self._ctx.ctx is not None assert self._ctx.ctx is not None

View file

@ -14,6 +14,20 @@ import llama_cpp.llama_grammar as llama_grammar
from ._utils import suppress_stdout_stderr, Singleton from ._utils import suppress_stdout_stderr, Singleton
### Common Chat Templates and Special Tokens ###
# Source: https://huggingface.co/teknium/OpenHermes-2.5-Mistral-7B/blob/main/tokenizer_config.json
CHATML_CHAT_TEMPLATE = "{% for message in messages %}{{'<|im_start|>' + message['role'] + '\n' + message['content'] + '<|im_end|>' + '\n'}}{% endfor %}{% if add_generation_prompt %}{{ '<|im_start|>assistant\n' }}{% endif %}"
CHATML_BOS_TOKEN = "<s>"
CHATML_EOS_TOKEN = "<|im_end|>"
# Source: https://huggingface.co/mistralai/Mistral-7B-Instruct-v0.1/blob/main/tokenizer_config.json
MISTRAL_INSTRUCT_CHAT_TEMPLATE = "{{ bos_token }}{% for message in messages %}{% if (message['role'] == 'user') != (loop.index0 % 2 == 0) %}{{ raise_exception('Conversation roles must alternate user/assistant/user/assistant/...') }}{% endif %}{% if message['role'] == 'user' %}{{ '[INST] ' + message['content'] + ' [/INST]' }}{% elif message['role'] == 'assistant' %}{{ message['content'] + eos_token + ' ' }}{% else %}{{ raise_exception('Only user and assistant roles are supported!') }}{% endif %}{% endfor %}"
MISTRAL_INSTRUCT_BOS_TOKEN = "<s>"
MISTRAL_INSTRUCT_EOS_TOKEN = "</s>"
### Chat Completion Handler ###
class LlamaChatCompletionHandler(Protocol): class LlamaChatCompletionHandler(Protocol):
"""Base Protocol for a llama chat completion handler. """Base Protocol for a llama chat completion handler.
@ -118,7 +132,6 @@ def register_chat_completion_handler(name: str):
### Chat Formatter ### ### Chat Formatter ###
@dataclasses.dataclass @dataclasses.dataclass
class ChatFormatterResponse: class ChatFormatterResponse:
"""Dataclass that stores completion parameters for a given chat format and """Dataclass that stores completion parameters for a given chat format and
@ -172,16 +185,17 @@ class Jinja2ChatFormatter(ChatFormatter):
messages: List[llama_types.ChatCompletionRequestMessage], messages: List[llama_types.ChatCompletionRequestMessage],
**kwargs: Any, **kwargs: Any,
) -> ChatFormatterResponse: ) -> ChatFormatterResponse:
if self.add_generation_prompt: def raise_exception(message: str):
messages = [ raise ValueError(message)
*messages,
llama_types.ChatCompletionRequestAssistantMessage(
role="assistant", content=""
),
]
prompt = self._environment.render( prompt = self._environment.render(
messages=messages, eos_token=self.eos_token, bos_token=self.bos_token messages=messages,
eos_token=self.eos_token,
bos_token=self.bos_token,
raise_exception=raise_exception,
add_generation_prompt=self.add_generation_prompt
) )
return ChatFormatterResponse(prompt=prompt, stop=[self.eos_token]) return ChatFormatterResponse(prompt=prompt, stop=[self.eos_token])
def to_chat_handler(self) -> LlamaChatCompletionHandler: def to_chat_handler(self) -> LlamaChatCompletionHandler:
@ -318,7 +332,14 @@ def chat_formatter_to_chat_completion_handler(
stop = stop + rstop stop = stop + rstop
if response_format is not None and response_format["type"] == "json_object": if response_format is not None and response_format["type"] == "json_object":
grammar = llama_grammar.LlamaGrammar.from_string(llama_grammar.JSON_GBNF) try:
# create grammar from json schema
if "schema" in response_format:
grammar = llama_grammar.LlamaGrammar.from_json_schema(
json.dumps(response_format["schema"])
)
except Exception as e:
grammar = llama_grammar.LlamaGrammar.from_string(llama_grammar.JSON_GBNF)
completion_or_chunks = llama.create_completion( completion_or_chunks = llama.create_completion(
prompt=prompt, prompt=prompt,
@ -433,7 +454,20 @@ def hf_tokenizer_config_to_chat_completion_handler(
return chat_formatter_to_chat_completion_handler(chat_formatter) return chat_formatter_to_chat_completion_handler(chat_formatter)
def guess_chat_format_from_gguf_metadata(metadata: Dict[str, str]) -> Optional[str]:
if "tokenizer.chat_template" not in metadata:
return None
if metadata["tokenizer.chat_template"] == CHATML_CHAT_TEMPLATE:
return "chatml"
if metadata["tokenizer.chat_template"] == MISTRAL_INSTRUCT_CHAT_TEMPLATE:
return "mistral-instruct"
return None
### Utility functions for formatting chat prompts ### ### Utility functions for formatting chat prompts ###
# TODO: Replace these with jinja2 templates
def _get_system_message( def _get_system_message(
@ -870,6 +904,24 @@ def format_chatml(
return ChatFormatterResponse(prompt=_prompt, stop=_sep) return ChatFormatterResponse(prompt=_prompt, stop=_sep)
@register_chat_format("mistral-instruct")
def format_mistral_instruct(
messages: List[llama_types.ChatCompletionRequestMessage],
**kwargs: Any,
) -> ChatFormatterResponse:
bos = "<s>"
eos = "</s>"
stop = eos
prompt = bos
for message in messages:
if message["role"] == "user" and message["content"] is not None and isinstance(message["content"], str):
prompt += "[INST] " + message["content"]
elif message["role"] == "assistant" and message["content"] is not None and isinstance(message["content"], str):
prompt += " [/INST]" + message["content"] + eos
prompt += " [/INST]"
return ChatFormatterResponse(prompt=prompt, stop=stop)
@register_chat_format("chatglm3") @register_chat_format("chatglm3")
def format_chatglm3( def format_chatglm3(
messages: List[llama_types.ChatCompletionRequestMessage], messages: List[llama_types.ChatCompletionRequestMessage],
@ -904,7 +956,6 @@ def format_openchat(
_prompt = _format_chatml(system_message, _messages, _sep) _prompt = _format_chatml(system_message, _messages, _sep)
return ChatFormatterResponse(prompt=_prompt, stop=_sep) return ChatFormatterResponse(prompt=_prompt, stop=_sep)
# Chat format for Saiga models, see more details and available models: # Chat format for Saiga models, see more details and available models:
# https://huggingface.co/collections/IlyaGusev/saiga2-saigamistral-6505d4ccc3d1e53166b636cd # https://huggingface.co/collections/IlyaGusev/saiga2-saigamistral-6505d4ccc3d1e53166b636cd
@register_chat_format("saiga") @register_chat_format("saiga")
@ -926,6 +977,7 @@ def format_saiga(
_prompt += "<s>bot" _prompt += "<s>bot"
return ChatFormatterResponse(prompt=_prompt.strip()) return ChatFormatterResponse(prompt=_prompt.strip())
# Tricky chat formats that require custom chat handlers
@register_chat_completion_handler("functionary") @register_chat_completion_handler("functionary")
def functionary_chat_handler( def functionary_chat_handler(
@ -1434,10 +1486,14 @@ class Llava15ChatHandler:
prompt = llama.input_ids[: llama.n_tokens].tolist() prompt = llama.input_ids[: llama.n_tokens].tolist()
if response_format is not None and response_format["type"] == "json_object": if response_format is not None and response_format["type"] == "json_object":
with suppress_stdout_stderr(disable=self.verbose): try:
grammar = llama_grammar.LlamaGrammar.from_string( # create grammar from json schema
llama_grammar.JSON_GBNF if "schema" in response_format:
) grammar = llama_grammar.LlamaGrammar.from_json_schema(
json.dumps(response_format["schema"])
)
except Exception as e:
grammar = llama_grammar.LlamaGrammar.from_string(llama_grammar.JSON_GBNF)
return _convert_completion_to_chat( return _convert_completion_to_chat(
llama.create_completion( llama.create_completion(

View file

@ -93,14 +93,12 @@ c_size_t_p = POINTER(c_size_t)
# from ggml-backend.h # from ggml-backend.h
# typedef bool (*ggml_backend_sched_eval_callback)(struct ggml_tensor * t, bool ask, void * user_data); # typedef bool (*ggml_backend_sched_eval_callback)(struct ggml_tensor * t, bool ask, void * user_data);
ggml_backend_sched_eval_callback = ctypes.CFUNCTYPE( ggml_backend_sched_eval_callback = ctypes.CFUNCTYPE(c_bool, c_void_p, c_bool, c_void_p)
c_bool, c_void_p, c_bool, c_void_p
)
# llama.h bindings # llama.h bindings
_lib.llama_max_devices.argtypes = [] _lib.llama_max_devices.argtypes = []
_lib.llama_max_devices.restype = ctypes.c_int32 _lib.llama_max_devices.restype = ctypes.c_size_t
LLAMA_MAX_DEVICES = _lib.llama_max_devices() LLAMA_MAX_DEVICES = _lib.llama_max_devices()
@ -189,6 +187,7 @@ LLAMA_TOKEN_TYPE_BYTE = 6
# LLAMA_FTYPE_MOSTLY_IQ2_XS = 20, // except 1d tensors # LLAMA_FTYPE_MOSTLY_IQ2_XS = 20, // except 1d tensors
# LLAMA_FTYPE_MOSTLY_Q2_K_S = 21, // except 1d tensors # LLAMA_FTYPE_MOSTLY_Q2_K_S = 21, // except 1d tensors
# LLAMA_FTYPE_MOSTLY_Q3_K_XS = 22, // except 1d tensors # LLAMA_FTYPE_MOSTLY_Q3_K_XS = 22, // except 1d tensors
# LLAMA_FTYPE_MOSTLY_IQ3_XXS = 23, // except 1d tensors
# LLAMA_FTYPE_GUESSED = 1024, // not specified in the model file # LLAMA_FTYPE_GUESSED = 1024, // not specified in the model file
# }; # };
@ -213,6 +212,7 @@ LLAMA_FTYPE_MOSTLY_IQ2_XXS = 19
LLAMA_FTYPE_MOSTLY_IQ2_XS = 20 LLAMA_FTYPE_MOSTLY_IQ2_XS = 20
LLAMA_FTYPE_MOSTLY_Q2_K_S = 21 LLAMA_FTYPE_MOSTLY_Q2_K_S = 21
LLAMA_FTYPE_MOSTLY_Q3_K_XS = 22 LLAMA_FTYPE_MOSTLY_Q3_K_XS = 22
LLAMA_FTYPE_MOSTLY_IQ3_XXS = 23
LLAMA_FTYPE_GUESSED = 1024 LLAMA_FTYPE_GUESSED = 1024
# enum llama_rope_scaling_type { # enum llama_rope_scaling_type {
@ -390,7 +390,7 @@ class llama_model_kv_override(Structure):
# // LLAMA_SPLIT_LAYER: ignored # // LLAMA_SPLIT_LAYER: ignored
# int32_t main_gpu; # int32_t main_gpu;
# // proportion of the model (layers or rows) to offload to each GPU, size: LLAMA_MAX_DEVICES # // proportion of the model (layers or rows) to offload to each GPU, size: llama_max_devices()
# const float * tensor_split; # const float * tensor_split;
# // Called with a progress value between 0.0 and 1.0. Pass NULL to disable. # // Called with a progress value between 0.0 and 1.0. Pass NULL to disable.
@ -417,7 +417,7 @@ class llama_model_params(Structure):
n_gpu_layers (int): number of layers to store in VRAM n_gpu_layers (int): number of layers to store in VRAM
split_mode (int): how to split the model across multiple GPUs split_mode (int): how to split the model across multiple GPUs
main_gpu (int): the GPU that is used for the entire model. main_gpu interpretation depends on split_mode: LLAMA_SPLIT_NONE: the GPU that is used for the entire model LLAMA_SPLIT_ROW: the GPU that is used for small tensors and intermediate results LLAMA_SPLIT_LAYER: ignored main_gpu (int): the GPU that is used for the entire model. main_gpu interpretation depends on split_mode: LLAMA_SPLIT_NONE: the GPU that is used for the entire model LLAMA_SPLIT_ROW: the GPU that is used for small tensors and intermediate results LLAMA_SPLIT_LAYER: ignored
tensor_split (ctypes.Array[ctypes.c_float]): proportion of the model (layers or rows) to offload to each GPU, size: LLAMA_MAX_DEVICES tensor_split (ctypes.Array[ctypes.c_float]): proportion of the model (layers or rows) to offload to each GPU, size: llama_max_devices()
progress_callback (llama_progress_callback): called with a progress value between 0.0 and 1.0. Pass NULL to disable. If the provided progress_callback returns true, model loading continues. If it returns false, model loading is immediately aborted. progress_callback (llama_progress_callback): called with a progress value between 0.0 and 1.0. Pass NULL to disable. If the provided progress_callback returns true, model loading continues. If it returns false, model loading is immediately aborted.
progress_callback_user_data (ctypes.c_void_p): context pointer passed to the progress callback progress_callback_user_data (ctypes.c_void_p): context pointer passed to the progress callback
kv_overrides (ctypes.Array[llama_model_kv_override]): override key-value pairs of the model meta data kv_overrides (ctypes.Array[llama_model_kv_override]): override key-value pairs of the model meta data
@ -760,16 +760,43 @@ _lib.llama_time_us.argtypes = []
_lib.llama_time_us.restype = ctypes.c_int64 _lib.llama_time_us.restype = ctypes.c_int64
# LLAMA_API int32_t llama_max_devices(void); # LLAMA_API size_t llama_max_devices(void);
def llama_max_devices() -> int: def llama_max_devices() -> int:
return _lib.llama_max_devices() return _lib.llama_max_devices()
_lib.llama_max_devices.argtypes = [] _lib.llama_max_devices.argtypes = []
_lib.llama_max_devices.restype = ctypes.c_int32 _lib.llama_max_devices.restype = ctypes.c_size_t
# LLAMA_API bool llama_mmap_supported (void); # LLAMA_API bool llama_supports_mmap (void);
def llama_supports_mmap() -> bool:
return _lib.llama_supports_mmap()
_lib.llama_supports_mmap.argtypes = []
_lib.llama_supports_mmap.restype = c_bool
# LLAMA_API bool llama_supports_mlock (void);
def llama_supports_mlock() -> bool:
return _lib.llama_supports_mlock()
_lib.llama_supports_mlock.argtypes = []
_lib.llama_supports_mlock.restype = c_bool
# LLAMA_API bool llama_supports_gpu_offload(void);
def llama_supports_gpu_offload() -> bool:
return _lib.llama_supports_gpu_offload()
_lib.llama_supports_gpu_offload.argtypes = []
_lib.llama_supports_gpu_offload.restype = c_bool
# LLAMA_API DEPRECATED(bool llama_mmap_supported (void), "use llama_supports_mmap() instead");
def llama_mmap_supported() -> bool: def llama_mmap_supported() -> bool:
return _lib.llama_mmap_supported() return _lib.llama_mmap_supported()
@ -778,7 +805,7 @@ _lib.llama_mmap_supported.argtypes = []
_lib.llama_mmap_supported.restype = c_bool _lib.llama_mmap_supported.restype = c_bool
# LLAMA_API bool llama_mlock_supported(void); # LLAMA_API DEPRECATED(bool llama_mlock_supported(void), "use llama_supports_mlock() instead");
def llama_mlock_supported() -> bool: def llama_mlock_supported() -> bool:
return _lib.llama_mlock_supported() return _lib.llama_mlock_supported()
@ -2174,6 +2201,34 @@ _lib.llama_sample_typical.argtypes = [
_lib.llama_sample_typical.restype = None _lib.llama_sample_typical.restype = None
# /// @details Dynamic temperature implementation described in the paper https://arxiv.org/abs/2309.02772.
# LLAMA_API void llama_sample_entropy(
# struct llama_context * ctx,
# llama_token_data_array * candidates_p,
# float min_temp,
# float max_temp,
# float exponent_val);
def llama_sample_entropy(
ctx: llama_context_p,
candidates, # type: _Pointer[llama_token_data_array]
min_temp: Union[c_float, float],
max_temp: Union[c_float, float],
exponent_val: Union[c_float, float],
):
"""Dynamic temperature implementation described in the paper https://arxiv.org/abs/2309.02772."""
return _lib.llama_sample_entropy(ctx, candidates, min_temp, max_temp, exponent_val)
_lib.llama_sample_entropy.argtypes = [
llama_context_p,
llama_token_data_array_p,
c_float,
c_float,
c_float,
]
_lib.llama_sample_entropy.restype = None
# LLAMA_API void llama_sample_temp( # LLAMA_API void llama_sample_temp(
# struct llama_context * ctx, # struct llama_context * ctx,
# llama_token_data_array * candidates, # llama_token_data_array * candidates,

View file

@ -154,6 +154,7 @@ class ChatCompletionFunctionCallOption(TypedDict):
class ChatCompletionRequestResponseFormat(TypedDict): class ChatCompletionRequestResponseFormat(TypedDict):
type: Literal["text", "json_object"] type: Literal["text", "json_object"]
schema: NotRequired[JsonType] # https://docs.endpoints.anyscale.com/guides/json_mode/
class ChatCompletionRequestMessageContentPartText(TypedDict): class ChatCompletionRequestMessageContentPartText(TypedDict):

View file

@ -113,8 +113,8 @@ class ModelSettings(BaseSettings):
description="Enable NUMA support.", description="Enable NUMA support.",
) )
# Chat Format Params # Chat Format Params
chat_format: str = Field( chat_format: Optional[str] = Field(
default="llama-2", default=None,
description="Chat format to use.", description="Chat format to use.",
) )
clip_model_path: Optional[str] = Field( clip_model_path: Optional[str] = Field(

View file

@ -1,10 +1,33 @@
import json import json
import jinja2
from llama_cpp import ( from llama_cpp import (
ChatCompletionRequestUserMessage, ChatCompletionRequestUserMessage,
) )
import llama_cpp.llama_types as llama_types
import llama_cpp.llama_chat_format as llama_chat_format
from llama_cpp.llama_chat_format import hf_tokenizer_config_to_chat_formatter from llama_cpp.llama_chat_format import hf_tokenizer_config_to_chat_formatter
def test_mistral_instruct():
chat_template = "{{ bos_token }}{% for message in messages %}{% if (message['role'] == 'user') != (loop.index0 % 2 == 0) %}{{ raise_exception('Conversation roles must alternate user/assistant/user/assistant/...') }}{% endif %}{% if message['role'] == 'user' %}{{ '[INST] ' + message['content'] + ' [/INST]' }}{% elif message['role'] == 'assistant' %}{{ message['content'] + eos_token}}{% else %}{{ raise_exception('Only user and assistant roles are supported!') }}{% endif %}{% endfor %}"
chat_formatter = jinja2.Template(chat_template)
messages = [
llama_types.ChatCompletionRequestUserMessage(role="user", content="Instruction"),
llama_types.ChatCompletionRequestAssistantMessage(role="assistant", content="Model answer"),
llama_types.ChatCompletionRequestUserMessage(role="user", content="Follow-up instruction"),
]
response = llama_chat_format.format_mistral_instruct(
messages=messages,
)
reference = chat_formatter.render(
messages=messages,
bos_token="<s>",
eos_token="</s>",
)
assert response.prompt == reference
mistral_7b_tokenizer_config = """{ mistral_7b_tokenizer_config = """{
"add_bos_token": true, "add_bos_token": true,

2
vendor/llama.cpp vendored

@ -1 +1 @@
Subproject commit faa3526a1eba458120987ed8269e5616385a76f4 Subproject commit 5cb04dbc16d1da38c8fdcc0111b40e67d00dd1c3