Automatically set chat format from gguf (#1110)
* Use jinja formatter to load chat format from gguf * Fix off-by-one error in metadata loader * Implement chat format auto-detection
This commit is contained in:
parent
059f6b3ac8
commit
da003d8768
4 changed files with 68 additions and 7 deletions
|
@ -216,13 +216,13 @@ class _LlamaModel:
|
||||||
for i in range(llama_cpp.llama_model_meta_count(self.model)):
|
for i in range(llama_cpp.llama_model_meta_count(self.model)):
|
||||||
nbytes = llama_cpp.llama_model_meta_key_by_index(self.model, i, buffer, buffer_size)
|
nbytes = llama_cpp.llama_model_meta_key_by_index(self.model, i, buffer, buffer_size)
|
||||||
if nbytes > buffer_size:
|
if nbytes > buffer_size:
|
||||||
buffer_size = nbytes
|
buffer_size = nbytes + 1
|
||||||
buffer = ctypes.create_string_buffer(buffer_size)
|
buffer = ctypes.create_string_buffer(buffer_size)
|
||||||
nbytes = llama_cpp.llama_model_meta_key_by_index(self.model, i, buffer, buffer_size)
|
nbytes = llama_cpp.llama_model_meta_key_by_index(self.model, i, buffer, buffer_size)
|
||||||
key = buffer.value.decode("utf-8")
|
key = buffer.value.decode("utf-8")
|
||||||
nbytes = llama_cpp.llama_model_meta_val_str_by_index(self.model, i, buffer, buffer_size)
|
nbytes = llama_cpp.llama_model_meta_val_str_by_index(self.model, i, buffer, buffer_size)
|
||||||
if nbytes > buffer_size:
|
if nbytes > buffer_size:
|
||||||
buffer_size = nbytes
|
buffer_size = nbytes + 1
|
||||||
buffer = ctypes.create_string_buffer(buffer_size)
|
buffer = ctypes.create_string_buffer(buffer_size)
|
||||||
nbytes = llama_cpp.llama_model_meta_val_str_by_index(self.model, i, buffer, buffer_size)
|
nbytes = llama_cpp.llama_model_meta_val_str_by_index(self.model, i, buffer, buffer_size)
|
||||||
value = buffer.value.decode("utf-8")
|
value = buffer.value.decode("utf-8")
|
||||||
|
|
|
@ -87,7 +87,7 @@ class Llama:
|
||||||
# Backend Params
|
# Backend Params
|
||||||
numa: bool = False,
|
numa: bool = False,
|
||||||
# Chat Format Params
|
# Chat Format Params
|
||||||
chat_format: str = "llama-2",
|
chat_format: Optional[str] = None,
|
||||||
chat_handler: Optional[llama_chat_format.LlamaChatCompletionHandler] = None,
|
chat_handler: Optional[llama_chat_format.LlamaChatCompletionHandler] = None,
|
||||||
# Misc
|
# Misc
|
||||||
verbose: bool = True,
|
verbose: bool = True,
|
||||||
|
@ -343,6 +343,41 @@ class Llama:
|
||||||
if self.verbose:
|
if self.verbose:
|
||||||
print(f"Model metadata: {self.metadata}", file=sys.stderr)
|
print(f"Model metadata: {self.metadata}", file=sys.stderr)
|
||||||
|
|
||||||
|
if self.chat_format is None and self.chat_handler is None and "tokenizer.chat_template" in self.metadata:
|
||||||
|
chat_format = llama_chat_format.guess_chat_format_from_gguf_metadata(self.metadata)
|
||||||
|
|
||||||
|
if chat_format is not None:
|
||||||
|
self.chat_format = chat_format
|
||||||
|
if self.verbose:
|
||||||
|
print(f"Guessed chat format: {chat_format}", file=sys.stderr)
|
||||||
|
else:
|
||||||
|
template = self.metadata["tokenizer.chat_template"]
|
||||||
|
try:
|
||||||
|
eos_token_id = int(self.metadata["tokenizer.ggml.eos_token_id"])
|
||||||
|
except:
|
||||||
|
eos_token_id = self.token_eos()
|
||||||
|
try:
|
||||||
|
bos_token_id = int(self.metadata["tokenizer.ggml.bos_token_id"])
|
||||||
|
except:
|
||||||
|
bos_token_id = self.token_bos()
|
||||||
|
|
||||||
|
eos_token = self.detokenize([eos_token_id]).decode("utf-8")
|
||||||
|
bos_token = self.detokenize([bos_token_id]).decode("utf-8")
|
||||||
|
|
||||||
|
if self.verbose:
|
||||||
|
print(f"Using chat template: {template}", file=sys.stderr)
|
||||||
|
print(f"Using chat eos_token: {eos_token}", file=sys.stderr)
|
||||||
|
print(f"Using chat bos_token: {bos_token}", file=sys.stderr)
|
||||||
|
|
||||||
|
self.chat_handler = llama_chat_format.Jinja2ChatFormatter(
|
||||||
|
template=template,
|
||||||
|
eos_token=eos_token,
|
||||||
|
bos_token=bos_token
|
||||||
|
).to_chat_handler()
|
||||||
|
|
||||||
|
if self.chat_format is None and self.chat_handler is None:
|
||||||
|
self.chat_format = "llama-2"
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def ctx(self) -> llama_cpp.llama_context_p:
|
def ctx(self) -> llama_cpp.llama_context_p:
|
||||||
assert self._ctx.ctx is not None
|
assert self._ctx.ctx is not None
|
||||||
|
|
|
@ -14,6 +14,20 @@ import llama_cpp.llama_grammar as llama_grammar
|
||||||
|
|
||||||
from ._utils import suppress_stdout_stderr, Singleton
|
from ._utils import suppress_stdout_stderr, Singleton
|
||||||
|
|
||||||
|
### Common Chat Templates and Special Tokens ###
|
||||||
|
|
||||||
|
# Source: https://huggingface.co/teknium/OpenHermes-2.5-Mistral-7B/blob/main/tokenizer_config.json
|
||||||
|
CHATML_CHAT_TEMPLATE = "{% for message in messages %}{{'<|im_start|>' + message['role'] + '\n' + message['content'] + '<|im_end|>' + '\n'}}{% endfor %}{% if add_generation_prompt %}{{ '<|im_start|>assistant\n' }}{% endif %}"
|
||||||
|
CHATML_BOS_TOKEN = "<s>"
|
||||||
|
CHATML_EOS_TOKEN = "<|im_end|>"
|
||||||
|
|
||||||
|
# Source: https://huggingface.co/mistralai/Mistral-7B-Instruct-v0.1/blob/main/tokenizer_config.json
|
||||||
|
MISTRAL_INSTRUCT_CHAT_TEMPLATE = "{{ bos_token }}{% for message in messages %}{% if (message['role'] == 'user') != (loop.index0 % 2 == 0) %}{{ raise_exception('Conversation roles must alternate user/assistant/user/assistant/...') }}{% endif %}{% if message['role'] == 'user' %}{{ '[INST] ' + message['content'] + ' [/INST]' }}{% elif message['role'] == 'assistant' %}{{ message['content'] + eos_token + ' ' }}{% else %}{{ raise_exception('Only user and assistant roles are supported!') }}{% endif %}{% endfor %}"
|
||||||
|
MISTRAL_INSTRUCT_BOS_TOKEN = "<s>"
|
||||||
|
MISTRAL_INSTRUCT_EOS_TOKEN = "</s>"
|
||||||
|
|
||||||
|
|
||||||
|
### Chat Completion Handler ###
|
||||||
|
|
||||||
class LlamaChatCompletionHandler(Protocol):
|
class LlamaChatCompletionHandler(Protocol):
|
||||||
"""Base Protocol for a llama chat completion handler.
|
"""Base Protocol for a llama chat completion handler.
|
||||||
|
@ -118,7 +132,6 @@ def register_chat_completion_handler(name: str):
|
||||||
|
|
||||||
### Chat Formatter ###
|
### Chat Formatter ###
|
||||||
|
|
||||||
|
|
||||||
@dataclasses.dataclass
|
@dataclasses.dataclass
|
||||||
class ChatFormatterResponse:
|
class ChatFormatterResponse:
|
||||||
"""Dataclass that stores completion parameters for a given chat format and
|
"""Dataclass that stores completion parameters for a given chat format and
|
||||||
|
@ -440,7 +453,20 @@ def hf_tokenizer_config_to_chat_completion_handler(
|
||||||
return chat_formatter_to_chat_completion_handler(chat_formatter)
|
return chat_formatter_to_chat_completion_handler(chat_formatter)
|
||||||
|
|
||||||
|
|
||||||
|
def guess_chat_format_from_gguf_metadata(metadata: Dict[str, str]) -> Optional[str]:
|
||||||
|
if "tokenizer.chat_template" not in metadata:
|
||||||
|
return None
|
||||||
|
|
||||||
|
if metadata["tokenizer.chat_template"] == CHATML_CHAT_TEMPLATE:
|
||||||
|
return "chatml"
|
||||||
|
|
||||||
|
if metadata["tokenizer.chat_template"] == MISTRAL_INSTRUCT_CHAT_TEMPLATE:
|
||||||
|
return "mistral-instruct"
|
||||||
|
|
||||||
|
return None
|
||||||
|
|
||||||
### Utility functions for formatting chat prompts ###
|
### Utility functions for formatting chat prompts ###
|
||||||
|
# TODO: Replace these with jinja2 templates
|
||||||
|
|
||||||
|
|
||||||
def _get_system_message(
|
def _get_system_message(
|
||||||
|
@ -929,7 +955,6 @@ def format_openchat(
|
||||||
_prompt = _format_chatml(system_message, _messages, _sep)
|
_prompt = _format_chatml(system_message, _messages, _sep)
|
||||||
return ChatFormatterResponse(prompt=_prompt, stop=_sep)
|
return ChatFormatterResponse(prompt=_prompt, stop=_sep)
|
||||||
|
|
||||||
|
|
||||||
# Chat format for Saiga models, see more details and available models:
|
# Chat format for Saiga models, see more details and available models:
|
||||||
# https://huggingface.co/collections/IlyaGusev/saiga2-saigamistral-6505d4ccc3d1e53166b636cd
|
# https://huggingface.co/collections/IlyaGusev/saiga2-saigamistral-6505d4ccc3d1e53166b636cd
|
||||||
@register_chat_format("saiga")
|
@register_chat_format("saiga")
|
||||||
|
@ -951,6 +976,7 @@ def format_saiga(
|
||||||
_prompt += "<s>bot"
|
_prompt += "<s>bot"
|
||||||
return ChatFormatterResponse(prompt=_prompt.strip())
|
return ChatFormatterResponse(prompt=_prompt.strip())
|
||||||
|
|
||||||
|
# Tricky chat formats that require custom chat handlers
|
||||||
|
|
||||||
@register_chat_completion_handler("functionary")
|
@register_chat_completion_handler("functionary")
|
||||||
def functionary_chat_handler(
|
def functionary_chat_handler(
|
||||||
|
|
|
@ -113,8 +113,8 @@ class ModelSettings(BaseSettings):
|
||||||
description="Enable NUMA support.",
|
description="Enable NUMA support.",
|
||||||
)
|
)
|
||||||
# Chat Format Params
|
# Chat Format Params
|
||||||
chat_format: str = Field(
|
chat_format: Optional[str] = Field(
|
||||||
default="llama-2",
|
default=None,
|
||||||
description="Chat format to use.",
|
description="Chat format to use.",
|
||||||
)
|
)
|
||||||
clip_model_path: Optional[str] = Field(
|
clip_model_path: Optional[str] = Field(
|
||||||
|
|
Loading…
Reference in a new issue