diff --git a/CHANGELOG.md b/CHANGELOG.md index 8a94ef5..435af43 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -7,6 +7,25 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 ## [Unreleased] +## [0.2.37] + +- feat: Update llama.cpp to ggerganov/llama.cpp@fea4fd4ba7f6b754ac795387b275e1a014a77bde +- feat: Automatically set chat format from gguf by @abetlen in #1110 + +## [0.2.36] + +- feat: Update llama.cpp to ggerganov/llama.cpp@2aed77eb06a329f0d82bb1c467f4244904d4073f +- feat: Add mistral instruct chat format as "mistral-instruct" by @Rafaelblsilva in #799 + +## [0.2.35] + +- feat: Update llama.cpp to ggerganov/llama.cpp@d2f650cb5b04ee2726663e79b47da5efe196ce00 + +## [0.2.34] + +- feat: Update llama.cpp to ggerganov/llama.cpp@6db2b41a76ee78d5efdd5c3cddd5d7ad3f646855 +- feat: Add json schema mode by @abetlen in #1122 + ## [0.2.33] - feat: Update llama.cpp to ggerganov/llama.cpp@faa3526a1eba458120987ed8269e5616385a76f4 diff --git a/Makefile b/Makefile index 5ed3fa2..ff1484c 100644 --- a/Makefile +++ b/Makefile @@ -27,6 +27,15 @@ build.blis: build.metal: CMAKE_ARGS="-DLLAMA_METAL=on" python3 -m pip install --verbose -e . +build.vulkan: + CMAKE_ARGS="-DLLAMA_VULKAN=on" python3 -m pip install --verbose -e . + +build.kompute: + CMAKE_ARGS="-DLLAMA_KOMPUTE=on" python3 -m pip install --verbose -e . + +build.sycl: + CMAKE_ARGS="-DLLAMA_SYCL=on" python3 -m pip install --verbose -e . + build.sdist: python3 -m build --sdist diff --git a/README.md b/README.md index 7813c96..0a77bbd 100644 --- a/README.md +++ b/README.md @@ -12,20 +12,17 @@ This package provides: - Low-level access to C API via `ctypes` interface. - High-level Python API for text completion - - OpenAI-like API - - [LangChain compatibility](https://python.langchain.com/docs/integrations/llms/llamacpp) - - [LlamaIndex compatibility](https://docs.llamaindex.ai/en/stable/examples/llm/llama_2_llama_cpp.html) + - OpenAI-like API + - [LangChain compatibility](https://python.langchain.com/docs/integrations/llms/llamacpp) + - [LlamaIndex compatibility](https://docs.llamaindex.ai/en/stable/examples/llm/llama_2_llama_cpp.html) - OpenAI compatible web server - - [Local Copilot replacement](https://llama-cpp-python.readthedocs.io/en/latest/server/#code-completion) - - [Function Calling support](https://llama-cpp-python.readthedocs.io/en/latest/server/#function-calling) - - [Vision API support](https://llama-cpp-python.readthedocs.io/en/latest/server/#multimodal-models) - - [Multiple Models](https://llama-cpp-python.readthedocs.io/en/latest/server/#configuration-and-multi-model-support) + - [Local Copilot replacement](https://llama-cpp-python.readthedocs.io/en/latest/server/#code-completion) + - [Function Calling support](https://llama-cpp-python.readthedocs.io/en/latest/server/#function-calling) + - [Vision API support](https://llama-cpp-python.readthedocs.io/en/latest/server/#multimodal-models) + - [Multiple Models](https://llama-cpp-python.readthedocs.io/en/latest/server/#configuration-and-multi-model-support) Documentation is available at [https://llama-cpp-python.readthedocs.io/en/latest](https://llama-cpp-python.readthedocs.io/en/latest). - - - ## Installation `llama-cpp-python` can be installed directly from PyPI as a source distribution by running: @@ -38,7 +35,6 @@ This will build `llama.cpp` from source using cmake and your system's c compiler If you run into issues during installation add the `--verbose` flag to the `pip install` command to see the full cmake build log. - ### Installation with Specific Hardware Acceleration (BLAS, CUDA, Metal, etc) The default pip install behaviour is to build `llama.cpp` for CPU only on Linux and Windows and use Metal on MacOS. @@ -71,7 +67,7 @@ CMAKE_ARGS="-DLLAMA_BLAS=ON -DLLAMA_BLAS_VENDOR=OpenBLAS" pip install llama-cpp- #### cuBLAS -To install with cuBLAS, set the `LLAMA_CUBLAS=1` environment variable before installing: +To install with cuBLAS, set the `LLAMA_CUBLAS=on` environment variable before installing: ```bash CMAKE_ARGS="-DLLAMA_CUBLAS=on" pip install llama-cpp-python @@ -87,7 +83,7 @@ CMAKE_ARGS="-DLLAMA_METAL=on" pip install llama-cpp-python #### CLBlast -To install with CLBlast, set the `LLAMA_CLBLAST=1` environment variable before installing: +To install with CLBlast, set the `LLAMA_CLBLAST=on` environment variable before installing: ```bash CMAKE_ARGS="-DLLAMA_CLBLAST=on" pip install llama-cpp-python @@ -101,13 +97,37 @@ To install with hipBLAS / ROCm support for AMD cards, set the `LLAMA_HIPBLAS=on` CMAKE_ARGS="-DLLAMA_HIPBLAS=on" pip install llama-cpp-python ``` +#### Vulkan + +To install with Vulkan support, set the `LLAMA_VULKAN=on` environment variable before installing: + +```bash +CMAKE_ARGS="-DLLAMA_VULKAN=on" pip install llama-cpp-python +``` + +#### Kompute + +To install with Kompute support, set the `LLAMA_KOMPUTE=on` environment variable before installing: + +```bash +CMAKE_ARGS="-DLLAMA_KOMPUTE=on" pip install llama-cpp-python +``` + +#### SYCL + +To install with SYCL support, set the `LLAMA_SYCL=on` environment variable before installing: + +```bash +CMAKE_ARGS="-DLLAMA_SYCL=on" pip install llama-cpp-python +``` + ### Windows Notes If you run into issues where it complains it can't find `'nmake'` `'?'` or CMAKE_C_COMPILER, you can extract w64devkit as [mentioned in llama.cpp repo](https://github.com/ggerganov/llama.cpp#openblas) and add those manually to CMAKE_ARGS before running `pip` install: ```ps $env:CMAKE_GENERATOR = "MinGW Makefiles" -$env:CMAKE_ARGS = "-DLLAMA_OPENBLAS=on -DCMAKE_C_COMPILER=C:/w64devkit/bin/gcc.exe -DCMAKE_CXX_COMPILER=C:/w64devkit/bin/g++.exe" +$env:CMAKE_ARGS = "-DLLAMA_OPENBLAS=on -DCMAKE_C_COMPILER=C:/w64devkit/bin/gcc.exe -DCMAKE_CXX_COMPILER=C:/w64devkit/bin/g++.exe" ``` See the above instructions and set `CMAKE_ARGS` to the BLAS backend you want to use. @@ -157,7 +177,7 @@ Below is a short example demonstrating how to use the high-level API to for basi >>> from llama_cpp import Llama >>> llm = Llama( model_path="./models/7B/llama-model.gguf", - # n_gpu_layers=-1, # Uncomment to use GPU acceleration + # n_gpu_layers=-1, # Uncomment to use GPU acceleration # seed=1337, # Uncomment to set a specific seed # n_ctx=2048, # Uncomment to increase the context window ) @@ -216,6 +236,59 @@ Note that `chat_format` option must be set for the particular model you are usin Chat completion is available through the [`create_chat_completion`](https://llama-cpp-python.readthedocs.io/en/latest/api-reference/#llama_cpp.Llama.create_chat_completion) method of the [`Llama`](https://llama-cpp-python.readthedocs.io/en/latest/api-reference/#llama_cpp.Llama) class. +### JSON and JSON Schema Mode + +If you want to constrain chat responses to only valid JSON or a specific JSON Schema you can use the `response_format` argument to the `create_chat_completion` method. + +#### JSON Mode + +The following example will constrain the response to be valid JSON. + +```python +>>> from llama_cpp import Llama +>>> llm = Llama(model_path="path/to/model.gguf", chat_format="chatml") +>>> llm.create_chat_completion( + messages=[ + { + "role": "system", + "content": "You are a helpful assistant that outputs in JSON.", + }, + {"role": "user", "content": "Who won the world series in 2020"}, + ], + response_format={ + "type": "json_object", + }, + temperature=0.7, +) +``` + +#### JSON Schema Mode + +To constrain the response to a specific JSON Schema, you can use the `schema` property of the `response_format` argument. + +```python +>>> from llama_cpp import Llama +>>> llm = Llama(model_path="path/to/model.gguf", chat_format="chatml") +>>> llm.create_chat_completion( + messages=[ + { + "role": "system", + "content": "You are a helpful assistant that outputs in JSON.", + }, + {"role": "user", "content": "Who won the world series in 2020"}, + ], + response_format={ + "type": "json_object", + "schema": { + "type": "object", + "properties": {"team_name": {"type": "string"}}, + "required": ["team_name"], + }, + }, + temperature=0.7, +) +``` + ### Function Calling The high-level API also provides a simple interface for function calling. @@ -223,7 +296,6 @@ The high-level API also provides a simple interface for function calling. Note that the only model that supports full function calling at this time is "functionary". The gguf-converted files for this model can be found here: [functionary-7b-v1](https://huggingface.co/abetlen/functionary-7b-v1-GGUF) - ```python >>> from llama_cpp import Llama >>> llm = Llama(model_path="path/to/functionary/llama-model.gguf", chat_format="functionary") @@ -232,7 +304,7 @@ The gguf-converted files for this model can be found here: [functionary-7b-v1](h { "role": "system", "content": "A chat between a curious user and an artificial intelligence assistant. The assistant gives helpful, detailed, and polite answers to the user's questions. The assistant calls functions with appropriate input when necessary" - + }, { "role": "user", @@ -271,7 +343,6 @@ The gguf-converted files for this model can be found here: [functionary-7b-v1](h ### Multi-modal Models - `llama-cpp-python` supports the llava1.5 family of multi-modal models which allow the language model to read information from both text and images. @@ -317,7 +388,6 @@ For instance, if you want to work with larger contexts, you can expand the conte llm = Llama(model_path="./models/7B/llama-model.gguf", n_ctx=2048) ``` - ## OpenAI Compatible Web Server `llama-cpp-python` offers a web server which aims to act as a drop-in replacement for the OpenAI API. @@ -365,7 +435,8 @@ A Docker image is available on [GHCR](https://ghcr.io/abetlen/llama-cpp-python). ```bash docker run --rm -it -p 8000:8000 -v /path/to/models:/models -e MODEL=/models/llama-model.gguf ghcr.io/abetlen/llama-cpp-python:latest ``` -[Docker on termux (requires root)](https://gist.github.com/FreddieOliveira/efe850df7ff3951cb62d74bd770dce27) is currently the only known way to run this on phones, see [termux support issue](https://github.com/abetlen/llama-cpp-python/issues/389) + +[Docker on termux (requires root)](https://gist.github.com/FreddieOliveira/efe850df7ff3951cb62d74bd770dce27) is currently the only known way to run this on phones, see [termux support issue](https://github.com/abetlen/llama-cpp-python/issues/389) ## Low-level API @@ -393,7 +464,6 @@ Below is a short example demonstrating how to use the low-level API to tokenize Check out the [examples folder](examples/low_level_api) for more examples of using the low-level API. - ## Documentation Documentation is available via [https://llama-cpp-python.readthedocs.io/](https://llama-cpp-python.readthedocs.io/). diff --git a/examples/high_level_api/fastapi_server.py b/examples/high_level_api/fastapi_server.py index 4b3189d..9421db5 100644 --- a/examples/high_level_api/fastapi_server.py +++ b/examples/high_level_api/fastapi_server.py @@ -9,7 +9,7 @@ export MODEL=../models/7B/... Then run: ``` -uvicorn llama_cpp.server.app:app --reload +uvicorn --factory llama_cpp.server.app:create_app --reload ``` or diff --git a/llama_cpp/__init__.py b/llama_cpp/__init__.py index 55f695e..4ce899c 100644 --- a/llama_cpp/__init__.py +++ b/llama_cpp/__init__.py @@ -1,4 +1,4 @@ from .llama_cpp import * from .llama import * -__version__ = "0.2.33" \ No newline at end of file +__version__ = "0.2.37" \ No newline at end of file diff --git a/llama_cpp/_internals.py b/llama_cpp/_internals.py index ec47c42..651cd4c 100644 --- a/llama_cpp/_internals.py +++ b/llama_cpp/_internals.py @@ -216,13 +216,13 @@ class _LlamaModel: for i in range(llama_cpp.llama_model_meta_count(self.model)): nbytes = llama_cpp.llama_model_meta_key_by_index(self.model, i, buffer, buffer_size) if nbytes > buffer_size: - buffer_size = nbytes + buffer_size = nbytes + 1 buffer = ctypes.create_string_buffer(buffer_size) nbytes = llama_cpp.llama_model_meta_key_by_index(self.model, i, buffer, buffer_size) key = buffer.value.decode("utf-8") nbytes = llama_cpp.llama_model_meta_val_str_by_index(self.model, i, buffer, buffer_size) if nbytes > buffer_size: - buffer_size = nbytes + buffer_size = nbytes + 1 buffer = ctypes.create_string_buffer(buffer_size) nbytes = llama_cpp.llama_model_meta_val_str_by_index(self.model, i, buffer, buffer_size) value = buffer.value.decode("utf-8") diff --git a/llama_cpp/llama.py b/llama_cpp/llama.py index 74739cb..b5618c1 100644 --- a/llama_cpp/llama.py +++ b/llama_cpp/llama.py @@ -87,7 +87,7 @@ class Llama: # Backend Params numa: bool = False, # Chat Format Params - chat_format: str = "llama-2", + chat_format: Optional[str] = None, chat_handler: Optional[llama_chat_format.LlamaChatCompletionHandler] = None, # Misc verbose: bool = True, @@ -343,6 +343,41 @@ class Llama: if self.verbose: print(f"Model metadata: {self.metadata}", file=sys.stderr) + if self.chat_format is None and self.chat_handler is None and "tokenizer.chat_template" in self.metadata: + chat_format = llama_chat_format.guess_chat_format_from_gguf_metadata(self.metadata) + + if chat_format is not None: + self.chat_format = chat_format + if self.verbose: + print(f"Guessed chat format: {chat_format}", file=sys.stderr) + else: + template = self.metadata["tokenizer.chat_template"] + try: + eos_token_id = int(self.metadata["tokenizer.ggml.eos_token_id"]) + except: + eos_token_id = self.token_eos() + try: + bos_token_id = int(self.metadata["tokenizer.ggml.bos_token_id"]) + except: + bos_token_id = self.token_bos() + + eos_token = self.detokenize([eos_token_id]).decode("utf-8") + bos_token = self.detokenize([bos_token_id]).decode("utf-8") + + if self.verbose: + print(f"Using chat template: {template}", file=sys.stderr) + print(f"Using chat eos_token: {eos_token}", file=sys.stderr) + print(f"Using chat bos_token: {bos_token}", file=sys.stderr) + + self.chat_handler = llama_chat_format.Jinja2ChatFormatter( + template=template, + eos_token=eos_token, + bos_token=bos_token + ).to_chat_handler() + + if self.chat_format is None and self.chat_handler is None: + self.chat_format = "llama-2" + @property def ctx(self) -> llama_cpp.llama_context_p: assert self._ctx.ctx is not None diff --git a/llama_cpp/llama_chat_format.py b/llama_cpp/llama_chat_format.py index 6c274aa..08f991b 100644 --- a/llama_cpp/llama_chat_format.py +++ b/llama_cpp/llama_chat_format.py @@ -14,6 +14,20 @@ import llama_cpp.llama_grammar as llama_grammar from ._utils import suppress_stdout_stderr, Singleton +### Common Chat Templates and Special Tokens ### + +# Source: https://huggingface.co/teknium/OpenHermes-2.5-Mistral-7B/blob/main/tokenizer_config.json +CHATML_CHAT_TEMPLATE = "{% for message in messages %}{{'<|im_start|>' + message['role'] + '\n' + message['content'] + '<|im_end|>' + '\n'}}{% endfor %}{% if add_generation_prompt %}{{ '<|im_start|>assistant\n' }}{% endif %}" +CHATML_BOS_TOKEN = "" +CHATML_EOS_TOKEN = "<|im_end|>" + +# Source: https://huggingface.co/mistralai/Mistral-7B-Instruct-v0.1/blob/main/tokenizer_config.json +MISTRAL_INSTRUCT_CHAT_TEMPLATE = "{{ bos_token }}{% for message in messages %}{% if (message['role'] == 'user') != (loop.index0 % 2 == 0) %}{{ raise_exception('Conversation roles must alternate user/assistant/user/assistant/...') }}{% endif %}{% if message['role'] == 'user' %}{{ '[INST] ' + message['content'] + ' [/INST]' }}{% elif message['role'] == 'assistant' %}{{ message['content'] + eos_token + ' ' }}{% else %}{{ raise_exception('Only user and assistant roles are supported!') }}{% endif %}{% endfor %}" +MISTRAL_INSTRUCT_BOS_TOKEN = "" +MISTRAL_INSTRUCT_EOS_TOKEN = "" + + +### Chat Completion Handler ### class LlamaChatCompletionHandler(Protocol): """Base Protocol for a llama chat completion handler. @@ -118,7 +132,6 @@ def register_chat_completion_handler(name: str): ### Chat Formatter ### - @dataclasses.dataclass class ChatFormatterResponse: """Dataclass that stores completion parameters for a given chat format and @@ -172,16 +185,17 @@ class Jinja2ChatFormatter(ChatFormatter): messages: List[llama_types.ChatCompletionRequestMessage], **kwargs: Any, ) -> ChatFormatterResponse: - if self.add_generation_prompt: - messages = [ - *messages, - llama_types.ChatCompletionRequestAssistantMessage( - role="assistant", content="" - ), - ] + def raise_exception(message: str): + raise ValueError(message) + prompt = self._environment.render( - messages=messages, eos_token=self.eos_token, bos_token=self.bos_token + messages=messages, + eos_token=self.eos_token, + bos_token=self.bos_token, + raise_exception=raise_exception, + add_generation_prompt=self.add_generation_prompt ) + return ChatFormatterResponse(prompt=prompt, stop=[self.eos_token]) def to_chat_handler(self) -> LlamaChatCompletionHandler: @@ -318,7 +332,14 @@ def chat_formatter_to_chat_completion_handler( stop = stop + rstop if response_format is not None and response_format["type"] == "json_object": - grammar = llama_grammar.LlamaGrammar.from_string(llama_grammar.JSON_GBNF) + try: + # create grammar from json schema + if "schema" in response_format: + grammar = llama_grammar.LlamaGrammar.from_json_schema( + json.dumps(response_format["schema"]) + ) + except Exception as e: + grammar = llama_grammar.LlamaGrammar.from_string(llama_grammar.JSON_GBNF) completion_or_chunks = llama.create_completion( prompt=prompt, @@ -433,7 +454,20 @@ def hf_tokenizer_config_to_chat_completion_handler( return chat_formatter_to_chat_completion_handler(chat_formatter) +def guess_chat_format_from_gguf_metadata(metadata: Dict[str, str]) -> Optional[str]: + if "tokenizer.chat_template" not in metadata: + return None + + if metadata["tokenizer.chat_template"] == CHATML_CHAT_TEMPLATE: + return "chatml" + + if metadata["tokenizer.chat_template"] == MISTRAL_INSTRUCT_CHAT_TEMPLATE: + return "mistral-instruct" + + return None + ### Utility functions for formatting chat prompts ### +# TODO: Replace these with jinja2 templates def _get_system_message( @@ -870,6 +904,24 @@ def format_chatml( return ChatFormatterResponse(prompt=_prompt, stop=_sep) +@register_chat_format("mistral-instruct") +def format_mistral_instruct( + messages: List[llama_types.ChatCompletionRequestMessage], + **kwargs: Any, +) -> ChatFormatterResponse: + bos = "" + eos = "" + stop = eos + prompt = bos + for message in messages: + if message["role"] == "user" and message["content"] is not None and isinstance(message["content"], str): + prompt += "[INST] " + message["content"] + elif message["role"] == "assistant" and message["content"] is not None and isinstance(message["content"], str): + prompt += " [/INST]" + message["content"] + eos + prompt += " [/INST]" + return ChatFormatterResponse(prompt=prompt, stop=stop) + + @register_chat_format("chatglm3") def format_chatglm3( messages: List[llama_types.ChatCompletionRequestMessage], @@ -904,7 +956,6 @@ def format_openchat( _prompt = _format_chatml(system_message, _messages, _sep) return ChatFormatterResponse(prompt=_prompt, stop=_sep) - # Chat format for Saiga models, see more details and available models: # https://huggingface.co/collections/IlyaGusev/saiga2-saigamistral-6505d4ccc3d1e53166b636cd @register_chat_format("saiga") @@ -926,6 +977,7 @@ def format_saiga( _prompt += "bot" return ChatFormatterResponse(prompt=_prompt.strip()) +# Tricky chat formats that require custom chat handlers @register_chat_completion_handler("functionary") def functionary_chat_handler( @@ -1434,10 +1486,14 @@ class Llava15ChatHandler: prompt = llama.input_ids[: llama.n_tokens].tolist() if response_format is not None and response_format["type"] == "json_object": - with suppress_stdout_stderr(disable=self.verbose): - grammar = llama_grammar.LlamaGrammar.from_string( - llama_grammar.JSON_GBNF - ) + try: + # create grammar from json schema + if "schema" in response_format: + grammar = llama_grammar.LlamaGrammar.from_json_schema( + json.dumps(response_format["schema"]) + ) + except Exception as e: + grammar = llama_grammar.LlamaGrammar.from_string(llama_grammar.JSON_GBNF) return _convert_completion_to_chat( llama.create_completion( diff --git a/llama_cpp/llama_cpp.py b/llama_cpp/llama_cpp.py index d31a5da..431a99f 100644 --- a/llama_cpp/llama_cpp.py +++ b/llama_cpp/llama_cpp.py @@ -93,14 +93,12 @@ c_size_t_p = POINTER(c_size_t) # from ggml-backend.h # typedef bool (*ggml_backend_sched_eval_callback)(struct ggml_tensor * t, bool ask, void * user_data); -ggml_backend_sched_eval_callback = ctypes.CFUNCTYPE( - c_bool, c_void_p, c_bool, c_void_p -) +ggml_backend_sched_eval_callback = ctypes.CFUNCTYPE(c_bool, c_void_p, c_bool, c_void_p) # llama.h bindings _lib.llama_max_devices.argtypes = [] -_lib.llama_max_devices.restype = ctypes.c_int32 +_lib.llama_max_devices.restype = ctypes.c_size_t LLAMA_MAX_DEVICES = _lib.llama_max_devices() @@ -189,6 +187,7 @@ LLAMA_TOKEN_TYPE_BYTE = 6 # LLAMA_FTYPE_MOSTLY_IQ2_XS = 20, // except 1d tensors # LLAMA_FTYPE_MOSTLY_Q2_K_S = 21, // except 1d tensors # LLAMA_FTYPE_MOSTLY_Q3_K_XS = 22, // except 1d tensors +# LLAMA_FTYPE_MOSTLY_IQ3_XXS = 23, // except 1d tensors # LLAMA_FTYPE_GUESSED = 1024, // not specified in the model file # }; @@ -213,6 +212,7 @@ LLAMA_FTYPE_MOSTLY_IQ2_XXS = 19 LLAMA_FTYPE_MOSTLY_IQ2_XS = 20 LLAMA_FTYPE_MOSTLY_Q2_K_S = 21 LLAMA_FTYPE_MOSTLY_Q3_K_XS = 22 +LLAMA_FTYPE_MOSTLY_IQ3_XXS = 23 LLAMA_FTYPE_GUESSED = 1024 # enum llama_rope_scaling_type { @@ -390,7 +390,7 @@ class llama_model_kv_override(Structure): # // LLAMA_SPLIT_LAYER: ignored # int32_t main_gpu; -# // proportion of the model (layers or rows) to offload to each GPU, size: LLAMA_MAX_DEVICES +# // proportion of the model (layers or rows) to offload to each GPU, size: llama_max_devices() # const float * tensor_split; # // Called with a progress value between 0.0 and 1.0. Pass NULL to disable. @@ -417,7 +417,7 @@ class llama_model_params(Structure): n_gpu_layers (int): number of layers to store in VRAM split_mode (int): how to split the model across multiple GPUs main_gpu (int): the GPU that is used for the entire model. main_gpu interpretation depends on split_mode: LLAMA_SPLIT_NONE: the GPU that is used for the entire model LLAMA_SPLIT_ROW: the GPU that is used for small tensors and intermediate results LLAMA_SPLIT_LAYER: ignored - tensor_split (ctypes.Array[ctypes.c_float]): proportion of the model (layers or rows) to offload to each GPU, size: LLAMA_MAX_DEVICES + tensor_split (ctypes.Array[ctypes.c_float]): proportion of the model (layers or rows) to offload to each GPU, size: llama_max_devices() progress_callback (llama_progress_callback): called with a progress value between 0.0 and 1.0. Pass NULL to disable. If the provided progress_callback returns true, model loading continues. If it returns false, model loading is immediately aborted. progress_callback_user_data (ctypes.c_void_p): context pointer passed to the progress callback kv_overrides (ctypes.Array[llama_model_kv_override]): override key-value pairs of the model meta data @@ -760,16 +760,43 @@ _lib.llama_time_us.argtypes = [] _lib.llama_time_us.restype = ctypes.c_int64 -# LLAMA_API int32_t llama_max_devices(void); +# LLAMA_API size_t llama_max_devices(void); def llama_max_devices() -> int: return _lib.llama_max_devices() _lib.llama_max_devices.argtypes = [] -_lib.llama_max_devices.restype = ctypes.c_int32 +_lib.llama_max_devices.restype = ctypes.c_size_t -# LLAMA_API bool llama_mmap_supported (void); +# LLAMA_API bool llama_supports_mmap (void); +def llama_supports_mmap() -> bool: + return _lib.llama_supports_mmap() + + +_lib.llama_supports_mmap.argtypes = [] +_lib.llama_supports_mmap.restype = c_bool + + +# LLAMA_API bool llama_supports_mlock (void); +def llama_supports_mlock() -> bool: + return _lib.llama_supports_mlock() + + +_lib.llama_supports_mlock.argtypes = [] +_lib.llama_supports_mlock.restype = c_bool + + +# LLAMA_API bool llama_supports_gpu_offload(void); +def llama_supports_gpu_offload() -> bool: + return _lib.llama_supports_gpu_offload() + + +_lib.llama_supports_gpu_offload.argtypes = [] +_lib.llama_supports_gpu_offload.restype = c_bool + + +# LLAMA_API DEPRECATED(bool llama_mmap_supported (void), "use llama_supports_mmap() instead"); def llama_mmap_supported() -> bool: return _lib.llama_mmap_supported() @@ -778,7 +805,7 @@ _lib.llama_mmap_supported.argtypes = [] _lib.llama_mmap_supported.restype = c_bool -# LLAMA_API bool llama_mlock_supported(void); +# LLAMA_API DEPRECATED(bool llama_mlock_supported(void), "use llama_supports_mlock() instead"); def llama_mlock_supported() -> bool: return _lib.llama_mlock_supported() @@ -2174,6 +2201,34 @@ _lib.llama_sample_typical.argtypes = [ _lib.llama_sample_typical.restype = None +# /// @details Dynamic temperature implementation described in the paper https://arxiv.org/abs/2309.02772. +# LLAMA_API void llama_sample_entropy( +# struct llama_context * ctx, +# llama_token_data_array * candidates_p, +# float min_temp, +# float max_temp, +# float exponent_val); +def llama_sample_entropy( + ctx: llama_context_p, + candidates, # type: _Pointer[llama_token_data_array] + min_temp: Union[c_float, float], + max_temp: Union[c_float, float], + exponent_val: Union[c_float, float], +): + """Dynamic temperature implementation described in the paper https://arxiv.org/abs/2309.02772.""" + return _lib.llama_sample_entropy(ctx, candidates, min_temp, max_temp, exponent_val) + + +_lib.llama_sample_entropy.argtypes = [ + llama_context_p, + llama_token_data_array_p, + c_float, + c_float, + c_float, +] +_lib.llama_sample_entropy.restype = None + + # LLAMA_API void llama_sample_temp( # struct llama_context * ctx, # llama_token_data_array * candidates, diff --git a/llama_cpp/llama_types.py b/llama_cpp/llama_types.py index 5b51e98..c3deba8 100644 --- a/llama_cpp/llama_types.py +++ b/llama_cpp/llama_types.py @@ -154,6 +154,7 @@ class ChatCompletionFunctionCallOption(TypedDict): class ChatCompletionRequestResponseFormat(TypedDict): type: Literal["text", "json_object"] + schema: NotRequired[JsonType] # https://docs.endpoints.anyscale.com/guides/json_mode/ class ChatCompletionRequestMessageContentPartText(TypedDict): diff --git a/llama_cpp/server/settings.py b/llama_cpp/server/settings.py index 9f0dc8a..9fe1a7b 100644 --- a/llama_cpp/server/settings.py +++ b/llama_cpp/server/settings.py @@ -113,8 +113,8 @@ class ModelSettings(BaseSettings): description="Enable NUMA support.", ) # Chat Format Params - chat_format: str = Field( - default="llama-2", + chat_format: Optional[str] = Field( + default=None, description="Chat format to use.", ) clip_model_path: Optional[str] = Field( diff --git a/tests/test_llama_chat_format.py b/tests/test_llama_chat_format.py index 1ef18d9..c10aee4 100644 --- a/tests/test_llama_chat_format.py +++ b/tests/test_llama_chat_format.py @@ -1,10 +1,33 @@ import json +import jinja2 + from llama_cpp import ( ChatCompletionRequestUserMessage, ) +import llama_cpp.llama_types as llama_types +import llama_cpp.llama_chat_format as llama_chat_format + from llama_cpp.llama_chat_format import hf_tokenizer_config_to_chat_formatter +def test_mistral_instruct(): + chat_template = "{{ bos_token }}{% for message in messages %}{% if (message['role'] == 'user') != (loop.index0 % 2 == 0) %}{{ raise_exception('Conversation roles must alternate user/assistant/user/assistant/...') }}{% endif %}{% if message['role'] == 'user' %}{{ '[INST] ' + message['content'] + ' [/INST]' }}{% elif message['role'] == 'assistant' %}{{ message['content'] + eos_token}}{% else %}{{ raise_exception('Only user and assistant roles are supported!') }}{% endif %}{% endfor %}" + chat_formatter = jinja2.Template(chat_template) + messages = [ + llama_types.ChatCompletionRequestUserMessage(role="user", content="Instruction"), + llama_types.ChatCompletionRequestAssistantMessage(role="assistant", content="Model answer"), + llama_types.ChatCompletionRequestUserMessage(role="user", content="Follow-up instruction"), + ] + response = llama_chat_format.format_mistral_instruct( + messages=messages, + ) + reference = chat_formatter.render( + messages=messages, + bos_token="", + eos_token="", + ) + assert response.prompt == reference + mistral_7b_tokenizer_config = """{ "add_bos_token": true, diff --git a/vendor/llama.cpp b/vendor/llama.cpp index faa3526..5cb04db 160000 --- a/vendor/llama.cpp +++ b/vendor/llama.cpp @@ -1 +1 @@ -Subproject commit faa3526a1eba458120987ed8269e5616385a76f4 +Subproject commit 5cb04dbc16d1da38c8fdcc0111b40e67d00dd1c3