Automatically set chat format from gguf (#1110)

* Use jinja formatter to load chat format from gguf

* Fix off-by-one error in metadata loader

* Implement chat format auto-detection
This commit is contained in:
Andrei 2024-01-29 14:22:23 -05:00 committed by GitHub
parent 059f6b3ac8
commit da003d8768
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
4 changed files with 68 additions and 7 deletions

View file

@ -216,13 +216,13 @@ class _LlamaModel:
for i in range(llama_cpp.llama_model_meta_count(self.model)):
nbytes = llama_cpp.llama_model_meta_key_by_index(self.model, i, buffer, buffer_size)
if nbytes > buffer_size:
buffer_size = nbytes
buffer_size = nbytes + 1
buffer = ctypes.create_string_buffer(buffer_size)
nbytes = llama_cpp.llama_model_meta_key_by_index(self.model, i, buffer, buffer_size)
key = buffer.value.decode("utf-8")
nbytes = llama_cpp.llama_model_meta_val_str_by_index(self.model, i, buffer, buffer_size)
if nbytes > buffer_size:
buffer_size = nbytes
buffer_size = nbytes + 1
buffer = ctypes.create_string_buffer(buffer_size)
nbytes = llama_cpp.llama_model_meta_val_str_by_index(self.model, i, buffer, buffer_size)
value = buffer.value.decode("utf-8")

View file

@ -87,7 +87,7 @@ class Llama:
# Backend Params
numa: bool = False,
# Chat Format Params
chat_format: str = "llama-2",
chat_format: Optional[str] = None,
chat_handler: Optional[llama_chat_format.LlamaChatCompletionHandler] = None,
# Misc
verbose: bool = True,
@ -343,6 +343,41 @@ class Llama:
if self.verbose:
print(f"Model metadata: {self.metadata}", file=sys.stderr)
if self.chat_format is None and self.chat_handler is None and "tokenizer.chat_template" in self.metadata:
chat_format = llama_chat_format.guess_chat_format_from_gguf_metadata(self.metadata)
if chat_format is not None:
self.chat_format = chat_format
if self.verbose:
print(f"Guessed chat format: {chat_format}", file=sys.stderr)
else:
template = self.metadata["tokenizer.chat_template"]
try:
eos_token_id = int(self.metadata["tokenizer.ggml.eos_token_id"])
except:
eos_token_id = self.token_eos()
try:
bos_token_id = int(self.metadata["tokenizer.ggml.bos_token_id"])
except:
bos_token_id = self.token_bos()
eos_token = self.detokenize([eos_token_id]).decode("utf-8")
bos_token = self.detokenize([bos_token_id]).decode("utf-8")
if self.verbose:
print(f"Using chat template: {template}", file=sys.stderr)
print(f"Using chat eos_token: {eos_token}", file=sys.stderr)
print(f"Using chat bos_token: {bos_token}", file=sys.stderr)
self.chat_handler = llama_chat_format.Jinja2ChatFormatter(
template=template,
eos_token=eos_token,
bos_token=bos_token
).to_chat_handler()
if self.chat_format is None and self.chat_handler is None:
self.chat_format = "llama-2"
@property
def ctx(self) -> llama_cpp.llama_context_p:
assert self._ctx.ctx is not None

View file

@ -14,6 +14,20 @@ import llama_cpp.llama_grammar as llama_grammar
from ._utils import suppress_stdout_stderr, Singleton
### Common Chat Templates and Special Tokens ###
# Source: https://huggingface.co/teknium/OpenHermes-2.5-Mistral-7B/blob/main/tokenizer_config.json
CHATML_CHAT_TEMPLATE = "{% for message in messages %}{{'<|im_start|>' + message['role'] + '\n' + message['content'] + '<|im_end|>' + '\n'}}{% endfor %}{% if add_generation_prompt %}{{ '<|im_start|>assistant\n' }}{% endif %}"
CHATML_BOS_TOKEN = "<s>"
CHATML_EOS_TOKEN = "<|im_end|>"
# Source: https://huggingface.co/mistralai/Mistral-7B-Instruct-v0.1/blob/main/tokenizer_config.json
MISTRAL_INSTRUCT_CHAT_TEMPLATE = "{{ bos_token }}{% for message in messages %}{% if (message['role'] == 'user') != (loop.index0 % 2 == 0) %}{{ raise_exception('Conversation roles must alternate user/assistant/user/assistant/...') }}{% endif %}{% if message['role'] == 'user' %}{{ '[INST] ' + message['content'] + ' [/INST]' }}{% elif message['role'] == 'assistant' %}{{ message['content'] + eos_token + ' ' }}{% else %}{{ raise_exception('Only user and assistant roles are supported!') }}{% endif %}{% endfor %}"
MISTRAL_INSTRUCT_BOS_TOKEN = "<s>"
MISTRAL_INSTRUCT_EOS_TOKEN = "</s>"
### Chat Completion Handler ###
class LlamaChatCompletionHandler(Protocol):
"""Base Protocol for a llama chat completion handler.
@ -118,7 +132,6 @@ def register_chat_completion_handler(name: str):
### Chat Formatter ###
@dataclasses.dataclass
class ChatFormatterResponse:
"""Dataclass that stores completion parameters for a given chat format and
@ -440,7 +453,20 @@ def hf_tokenizer_config_to_chat_completion_handler(
return chat_formatter_to_chat_completion_handler(chat_formatter)
def guess_chat_format_from_gguf_metadata(metadata: Dict[str, str]) -> Optional[str]:
if "tokenizer.chat_template" not in metadata:
return None
if metadata["tokenizer.chat_template"] == CHATML_CHAT_TEMPLATE:
return "chatml"
if metadata["tokenizer.chat_template"] == MISTRAL_INSTRUCT_CHAT_TEMPLATE:
return "mistral-instruct"
return None
### Utility functions for formatting chat prompts ###
# TODO: Replace these with jinja2 templates
def _get_system_message(
@ -929,7 +955,6 @@ def format_openchat(
_prompt = _format_chatml(system_message, _messages, _sep)
return ChatFormatterResponse(prompt=_prompt, stop=_sep)
# Chat format for Saiga models, see more details and available models:
# https://huggingface.co/collections/IlyaGusev/saiga2-saigamistral-6505d4ccc3d1e53166b636cd
@register_chat_format("saiga")
@ -951,6 +976,7 @@ def format_saiga(
_prompt += "<s>bot"
return ChatFormatterResponse(prompt=_prompt.strip())
# Tricky chat formats that require custom chat handlers
@register_chat_completion_handler("functionary")
def functionary_chat_handler(

View file

@ -113,8 +113,8 @@ class ModelSettings(BaseSettings):
description="Enable NUMA support.",
)
# Chat Format Params
chat_format: str = Field(
default="llama-2",
chat_format: Optional[str] = Field(
default=None,
description="Chat format to use.",
)
clip_model_path: Optional[str] = Field(