Automatically set chat format from gguf (#1110)
* Use jinja formatter to load chat format from gguf * Fix off-by-one error in metadata loader * Implement chat format auto-detection
This commit is contained in:
parent
059f6b3ac8
commit
da003d8768
4 changed files with 68 additions and 7 deletions
|
@ -216,13 +216,13 @@ class _LlamaModel:
|
|||
for i in range(llama_cpp.llama_model_meta_count(self.model)):
|
||||
nbytes = llama_cpp.llama_model_meta_key_by_index(self.model, i, buffer, buffer_size)
|
||||
if nbytes > buffer_size:
|
||||
buffer_size = nbytes
|
||||
buffer_size = nbytes + 1
|
||||
buffer = ctypes.create_string_buffer(buffer_size)
|
||||
nbytes = llama_cpp.llama_model_meta_key_by_index(self.model, i, buffer, buffer_size)
|
||||
key = buffer.value.decode("utf-8")
|
||||
nbytes = llama_cpp.llama_model_meta_val_str_by_index(self.model, i, buffer, buffer_size)
|
||||
if nbytes > buffer_size:
|
||||
buffer_size = nbytes
|
||||
buffer_size = nbytes + 1
|
||||
buffer = ctypes.create_string_buffer(buffer_size)
|
||||
nbytes = llama_cpp.llama_model_meta_val_str_by_index(self.model, i, buffer, buffer_size)
|
||||
value = buffer.value.decode("utf-8")
|
||||
|
|
|
@ -87,7 +87,7 @@ class Llama:
|
|||
# Backend Params
|
||||
numa: bool = False,
|
||||
# Chat Format Params
|
||||
chat_format: str = "llama-2",
|
||||
chat_format: Optional[str] = None,
|
||||
chat_handler: Optional[llama_chat_format.LlamaChatCompletionHandler] = None,
|
||||
# Misc
|
||||
verbose: bool = True,
|
||||
|
@ -343,6 +343,41 @@ class Llama:
|
|||
if self.verbose:
|
||||
print(f"Model metadata: {self.metadata}", file=sys.stderr)
|
||||
|
||||
if self.chat_format is None and self.chat_handler is None and "tokenizer.chat_template" in self.metadata:
|
||||
chat_format = llama_chat_format.guess_chat_format_from_gguf_metadata(self.metadata)
|
||||
|
||||
if chat_format is not None:
|
||||
self.chat_format = chat_format
|
||||
if self.verbose:
|
||||
print(f"Guessed chat format: {chat_format}", file=sys.stderr)
|
||||
else:
|
||||
template = self.metadata["tokenizer.chat_template"]
|
||||
try:
|
||||
eos_token_id = int(self.metadata["tokenizer.ggml.eos_token_id"])
|
||||
except:
|
||||
eos_token_id = self.token_eos()
|
||||
try:
|
||||
bos_token_id = int(self.metadata["tokenizer.ggml.bos_token_id"])
|
||||
except:
|
||||
bos_token_id = self.token_bos()
|
||||
|
||||
eos_token = self.detokenize([eos_token_id]).decode("utf-8")
|
||||
bos_token = self.detokenize([bos_token_id]).decode("utf-8")
|
||||
|
||||
if self.verbose:
|
||||
print(f"Using chat template: {template}", file=sys.stderr)
|
||||
print(f"Using chat eos_token: {eos_token}", file=sys.stderr)
|
||||
print(f"Using chat bos_token: {bos_token}", file=sys.stderr)
|
||||
|
||||
self.chat_handler = llama_chat_format.Jinja2ChatFormatter(
|
||||
template=template,
|
||||
eos_token=eos_token,
|
||||
bos_token=bos_token
|
||||
).to_chat_handler()
|
||||
|
||||
if self.chat_format is None and self.chat_handler is None:
|
||||
self.chat_format = "llama-2"
|
||||
|
||||
@property
|
||||
def ctx(self) -> llama_cpp.llama_context_p:
|
||||
assert self._ctx.ctx is not None
|
||||
|
|
|
@ -14,6 +14,20 @@ import llama_cpp.llama_grammar as llama_grammar
|
|||
|
||||
from ._utils import suppress_stdout_stderr, Singleton
|
||||
|
||||
### Common Chat Templates and Special Tokens ###
|
||||
|
||||
# Source: https://huggingface.co/teknium/OpenHermes-2.5-Mistral-7B/blob/main/tokenizer_config.json
|
||||
CHATML_CHAT_TEMPLATE = "{% for message in messages %}{{'<|im_start|>' + message['role'] + '\n' + message['content'] + '<|im_end|>' + '\n'}}{% endfor %}{% if add_generation_prompt %}{{ '<|im_start|>assistant\n' }}{% endif %}"
|
||||
CHATML_BOS_TOKEN = "<s>"
|
||||
CHATML_EOS_TOKEN = "<|im_end|>"
|
||||
|
||||
# Source: https://huggingface.co/mistralai/Mistral-7B-Instruct-v0.1/blob/main/tokenizer_config.json
|
||||
MISTRAL_INSTRUCT_CHAT_TEMPLATE = "{{ bos_token }}{% for message in messages %}{% if (message['role'] == 'user') != (loop.index0 % 2 == 0) %}{{ raise_exception('Conversation roles must alternate user/assistant/user/assistant/...') }}{% endif %}{% if message['role'] == 'user' %}{{ '[INST] ' + message['content'] + ' [/INST]' }}{% elif message['role'] == 'assistant' %}{{ message['content'] + eos_token + ' ' }}{% else %}{{ raise_exception('Only user and assistant roles are supported!') }}{% endif %}{% endfor %}"
|
||||
MISTRAL_INSTRUCT_BOS_TOKEN = "<s>"
|
||||
MISTRAL_INSTRUCT_EOS_TOKEN = "</s>"
|
||||
|
||||
|
||||
### Chat Completion Handler ###
|
||||
|
||||
class LlamaChatCompletionHandler(Protocol):
|
||||
"""Base Protocol for a llama chat completion handler.
|
||||
|
@ -118,7 +132,6 @@ def register_chat_completion_handler(name: str):
|
|||
|
||||
### Chat Formatter ###
|
||||
|
||||
|
||||
@dataclasses.dataclass
|
||||
class ChatFormatterResponse:
|
||||
"""Dataclass that stores completion parameters for a given chat format and
|
||||
|
@ -440,7 +453,20 @@ def hf_tokenizer_config_to_chat_completion_handler(
|
|||
return chat_formatter_to_chat_completion_handler(chat_formatter)
|
||||
|
||||
|
||||
def guess_chat_format_from_gguf_metadata(metadata: Dict[str, str]) -> Optional[str]:
|
||||
if "tokenizer.chat_template" not in metadata:
|
||||
return None
|
||||
|
||||
if metadata["tokenizer.chat_template"] == CHATML_CHAT_TEMPLATE:
|
||||
return "chatml"
|
||||
|
||||
if metadata["tokenizer.chat_template"] == MISTRAL_INSTRUCT_CHAT_TEMPLATE:
|
||||
return "mistral-instruct"
|
||||
|
||||
return None
|
||||
|
||||
### Utility functions for formatting chat prompts ###
|
||||
# TODO: Replace these with jinja2 templates
|
||||
|
||||
|
||||
def _get_system_message(
|
||||
|
@ -929,7 +955,6 @@ def format_openchat(
|
|||
_prompt = _format_chatml(system_message, _messages, _sep)
|
||||
return ChatFormatterResponse(prompt=_prompt, stop=_sep)
|
||||
|
||||
|
||||
# Chat format for Saiga models, see more details and available models:
|
||||
# https://huggingface.co/collections/IlyaGusev/saiga2-saigamistral-6505d4ccc3d1e53166b636cd
|
||||
@register_chat_format("saiga")
|
||||
|
@ -951,6 +976,7 @@ def format_saiga(
|
|||
_prompt += "<s>bot"
|
||||
return ChatFormatterResponse(prompt=_prompt.strip())
|
||||
|
||||
# Tricky chat formats that require custom chat handlers
|
||||
|
||||
@register_chat_completion_handler("functionary")
|
||||
def functionary_chat_handler(
|
||||
|
|
|
@ -113,8 +113,8 @@ class ModelSettings(BaseSettings):
|
|||
description="Enable NUMA support.",
|
||||
)
|
||||
# Chat Format Params
|
||||
chat_format: str = Field(
|
||||
default="llama-2",
|
||||
chat_format: Optional[str] = Field(
|
||||
default=None,
|
||||
description="Chat format to use.",
|
||||
)
|
||||
clip_model_path: Optional[str] = Field(
|
||||
|
|
Loading…
Reference in a new issue