From da003d87681f02475eedb6937443e5f07db889b0 Mon Sep 17 00:00:00 2001 From: Andrei Date: Mon, 29 Jan 2024 14:22:23 -0500 Subject: [PATCH] Automatically set chat format from gguf (#1110) * Use jinja formatter to load chat format from gguf * Fix off-by-one error in metadata loader * Implement chat format auto-detection --- llama_cpp/_internals.py | 4 ++-- llama_cpp/llama.py | 37 +++++++++++++++++++++++++++++++++- llama_cpp/llama_chat_format.py | 30 +++++++++++++++++++++++++-- llama_cpp/server/settings.py | 4 ++-- 4 files changed, 68 insertions(+), 7 deletions(-) diff --git a/llama_cpp/_internals.py b/llama_cpp/_internals.py index ec47c42..651cd4c 100644 --- a/llama_cpp/_internals.py +++ b/llama_cpp/_internals.py @@ -216,13 +216,13 @@ class _LlamaModel: for i in range(llama_cpp.llama_model_meta_count(self.model)): nbytes = llama_cpp.llama_model_meta_key_by_index(self.model, i, buffer, buffer_size) if nbytes > buffer_size: - buffer_size = nbytes + buffer_size = nbytes + 1 buffer = ctypes.create_string_buffer(buffer_size) nbytes = llama_cpp.llama_model_meta_key_by_index(self.model, i, buffer, buffer_size) key = buffer.value.decode("utf-8") nbytes = llama_cpp.llama_model_meta_val_str_by_index(self.model, i, buffer, buffer_size) if nbytes > buffer_size: - buffer_size = nbytes + buffer_size = nbytes + 1 buffer = ctypes.create_string_buffer(buffer_size) nbytes = llama_cpp.llama_model_meta_val_str_by_index(self.model, i, buffer, buffer_size) value = buffer.value.decode("utf-8") diff --git a/llama_cpp/llama.py b/llama_cpp/llama.py index 74739cb..b5618c1 100644 --- a/llama_cpp/llama.py +++ b/llama_cpp/llama.py @@ -87,7 +87,7 @@ class Llama: # Backend Params numa: bool = False, # Chat Format Params - chat_format: str = "llama-2", + chat_format: Optional[str] = None, chat_handler: Optional[llama_chat_format.LlamaChatCompletionHandler] = None, # Misc verbose: bool = True, @@ -343,6 +343,41 @@ class Llama: if self.verbose: print(f"Model metadata: {self.metadata}", file=sys.stderr) + if self.chat_format is None and self.chat_handler is None and "tokenizer.chat_template" in self.metadata: + chat_format = llama_chat_format.guess_chat_format_from_gguf_metadata(self.metadata) + + if chat_format is not None: + self.chat_format = chat_format + if self.verbose: + print(f"Guessed chat format: {chat_format}", file=sys.stderr) + else: + template = self.metadata["tokenizer.chat_template"] + try: + eos_token_id = int(self.metadata["tokenizer.ggml.eos_token_id"]) + except: + eos_token_id = self.token_eos() + try: + bos_token_id = int(self.metadata["tokenizer.ggml.bos_token_id"]) + except: + bos_token_id = self.token_bos() + + eos_token = self.detokenize([eos_token_id]).decode("utf-8") + bos_token = self.detokenize([bos_token_id]).decode("utf-8") + + if self.verbose: + print(f"Using chat template: {template}", file=sys.stderr) + print(f"Using chat eos_token: {eos_token}", file=sys.stderr) + print(f"Using chat bos_token: {bos_token}", file=sys.stderr) + + self.chat_handler = llama_chat_format.Jinja2ChatFormatter( + template=template, + eos_token=eos_token, + bos_token=bos_token + ).to_chat_handler() + + if self.chat_format is None and self.chat_handler is None: + self.chat_format = "llama-2" + @property def ctx(self) -> llama_cpp.llama_context_p: assert self._ctx.ctx is not None diff --git a/llama_cpp/llama_chat_format.py b/llama_cpp/llama_chat_format.py index 5466de3..4bc4a6c 100644 --- a/llama_cpp/llama_chat_format.py +++ b/llama_cpp/llama_chat_format.py @@ -14,6 +14,20 @@ import llama_cpp.llama_grammar as llama_grammar from ._utils import suppress_stdout_stderr, Singleton +### Common Chat Templates and Special Tokens ### + +# Source: https://huggingface.co/teknium/OpenHermes-2.5-Mistral-7B/blob/main/tokenizer_config.json +CHATML_CHAT_TEMPLATE = "{% for message in messages %}{{'<|im_start|>' + message['role'] + '\n' + message['content'] + '<|im_end|>' + '\n'}}{% endfor %}{% if add_generation_prompt %}{{ '<|im_start|>assistant\n' }}{% endif %}" +CHATML_BOS_TOKEN = "" +CHATML_EOS_TOKEN = "<|im_end|>" + +# Source: https://huggingface.co/mistralai/Mistral-7B-Instruct-v0.1/blob/main/tokenizer_config.json +MISTRAL_INSTRUCT_CHAT_TEMPLATE = "{{ bos_token }}{% for message in messages %}{% if (message['role'] == 'user') != (loop.index0 % 2 == 0) %}{{ raise_exception('Conversation roles must alternate user/assistant/user/assistant/...') }}{% endif %}{% if message['role'] == 'user' %}{{ '[INST] ' + message['content'] + ' [/INST]' }}{% elif message['role'] == 'assistant' %}{{ message['content'] + eos_token + ' ' }}{% else %}{{ raise_exception('Only user and assistant roles are supported!') }}{% endif %}{% endfor %}" +MISTRAL_INSTRUCT_BOS_TOKEN = "" +MISTRAL_INSTRUCT_EOS_TOKEN = "" + + +### Chat Completion Handler ### class LlamaChatCompletionHandler(Protocol): """Base Protocol for a llama chat completion handler. @@ -118,7 +132,6 @@ def register_chat_completion_handler(name: str): ### Chat Formatter ### - @dataclasses.dataclass class ChatFormatterResponse: """Dataclass that stores completion parameters for a given chat format and @@ -440,7 +453,20 @@ def hf_tokenizer_config_to_chat_completion_handler( return chat_formatter_to_chat_completion_handler(chat_formatter) +def guess_chat_format_from_gguf_metadata(metadata: Dict[str, str]) -> Optional[str]: + if "tokenizer.chat_template" not in metadata: + return None + + if metadata["tokenizer.chat_template"] == CHATML_CHAT_TEMPLATE: + return "chatml" + + if metadata["tokenizer.chat_template"] == MISTRAL_INSTRUCT_CHAT_TEMPLATE: + return "mistral-instruct" + + return None + ### Utility functions for formatting chat prompts ### +# TODO: Replace these with jinja2 templates def _get_system_message( @@ -929,7 +955,6 @@ def format_openchat( _prompt = _format_chatml(system_message, _messages, _sep) return ChatFormatterResponse(prompt=_prompt, stop=_sep) - # Chat format for Saiga models, see more details and available models: # https://huggingface.co/collections/IlyaGusev/saiga2-saigamistral-6505d4ccc3d1e53166b636cd @register_chat_format("saiga") @@ -951,6 +976,7 @@ def format_saiga( _prompt += "bot" return ChatFormatterResponse(prompt=_prompt.strip()) +# Tricky chat formats that require custom chat handlers @register_chat_completion_handler("functionary") def functionary_chat_handler( diff --git a/llama_cpp/server/settings.py b/llama_cpp/server/settings.py index 9f0dc8a..9fe1a7b 100644 --- a/llama_cpp/server/settings.py +++ b/llama_cpp/server/settings.py @@ -113,8 +113,8 @@ class ModelSettings(BaseSettings): description="Enable NUMA support.", ) # Chat Format Params - chat_format: str = Field( - default="llama-2", + chat_format: Optional[str] = Field( + default=None, description="Chat format to use.", ) clip_model_path: Optional[str] = Field(