feat: Add ability to load chat format from huggingface autotokenizer or tokenizer_config.json files.

This commit is contained in:
Andrei Betlen 2024-01-18 21:21:37 -05:00
parent 48c3b77e6f
commit b8fc1c7d83
6 changed files with 357 additions and 318 deletions

View file

@ -1,7 +1,8 @@
import os import os
import sys import sys
import sys, traceback import sys
from typing import Any, Dict
# Avoid "LookupError: unknown encoding: ascii" when open() called in a destructor # Avoid "LookupError: unknown encoding: ascii" when open() called in a destructor
outnull_file = open(os.devnull, "w") outnull_file = open(os.devnull, "w")
@ -55,3 +56,25 @@ class suppress_stdout_stderr(object):
self.os.close(self.old_stdout_fileno) self.os.close(self.old_stdout_fileno)
self.os.close(self.old_stderr_fileno) self.os.close(self.old_stderr_fileno)
class MetaSingleton(type):
"""
Metaclass for implementing the Singleton pattern.
"""
_instances: Dict[type, Any] = {}
def __call__(cls, *args: Any, **kwargs: Any) -> Any:
if cls not in cls._instances:
cls._instances[cls] = super(MetaSingleton, cls).__call__(*args, **kwargs)
return cls._instances[cls]
class Singleton(object, metaclass=MetaSingleton):
"""
Base class for implementing the Singleton pattern.
"""
def __init__(self):
super(Singleton, self).__init__()

View file

@ -6,18 +6,28 @@ import ctypes
import dataclasses import dataclasses
from typing import Any, Dict, Iterator, List, Optional, Tuple, Union, Protocol from typing import Any, Dict, Iterator, List, Optional, Tuple, Union, Protocol
import jinja2
import llama_cpp.llama as llama import llama_cpp.llama as llama
import llama_cpp.llama_types as llama_types import llama_cpp.llama_types as llama_types
import llama_cpp.llama_grammar as llama_grammar import llama_cpp.llama_grammar as llama_grammar
from ._utils import suppress_stdout_stderr from ._utils import suppress_stdout_stderr, Singleton
class LlamaChatCompletionHandler(Protocol): class LlamaChatCompletionHandler(Protocol):
"""Base Protocol for a llama chat completion handler.
Very generic protocol that can be used to implement any chat format.
The only hard requirement is that it must return a ChatCompletion when
stream=False and an iterator of ChatCompletionChunks when stream=True."""
def __call__( def __call__(
self, self,
*, *,
# llama.cpp instance
llama: llama.Llama, llama: llama.Llama,
# openai api parameters
messages: List[llama_types.ChatCompletionRequestMessage], messages: List[llama_types.ChatCompletionRequestMessage],
functions: Optional[List[llama_types.ChatCompletionFunction]] = None, functions: Optional[List[llama_types.ChatCompletionFunction]] = None,
function_call: Optional[llama_types.ChatCompletionRequestFunctionCall] = None, function_call: Optional[llama_types.ChatCompletionRequestFunctionCall] = None,
@ -26,8 +36,6 @@ class LlamaChatCompletionHandler(Protocol):
temperature: float = 0.2, temperature: float = 0.2,
top_p: float = 0.95, top_p: float = 0.95,
top_k: int = 40, top_k: int = 40,
min_p: float = 0.05,
typical_p: float = 1.0,
stream: bool = False, stream: bool = False,
stop: Optional[Union[str, List[str]]] = [], stop: Optional[Union[str, List[str]]] = [],
seed: Optional[int] = None, seed: Optional[int] = None,
@ -38,14 +46,17 @@ class LlamaChatCompletionHandler(Protocol):
presence_penalty: float = 0.0, presence_penalty: float = 0.0,
frequency_penalty: float = 0.0, frequency_penalty: float = 0.0,
repeat_penalty: float = 1.1, repeat_penalty: float = 1.1,
model: Optional[str] = None,
logit_bias: Optional[Dict[str, float]] = None,
# llama.cpp parameters
min_p: float = 0.05,
typical_p: float = 1.0,
tfs_z: float = 1.0, tfs_z: float = 1.0,
mirostat_mode: int = 0, mirostat_mode: int = 0,
mirostat_tau: float = 5.0, mirostat_tau: float = 5.0,
mirostat_eta: float = 0.1, mirostat_eta: float = 0.1,
model: Optional[str] = None,
logits_processor: Optional[llama.LogitsProcessorList] = None, logits_processor: Optional[llama.LogitsProcessorList] = None,
grammar: Optional[llama.LlamaGrammar] = None, grammar: Optional[llama.LlamaGrammar] = None,
logit_bias: Optional[Dict[str, float]] = None,
**kwargs, # type: ignore **kwargs, # type: ignore
) -> Union[ ) -> Union[
llama_types.CreateChatCompletionResponse, llama_types.CreateChatCompletionResponse,
@ -54,21 +65,83 @@ class LlamaChatCompletionHandler(Protocol):
... ...
CHAT_HANDLERS: Dict[str, LlamaChatCompletionHandler] = {} class LlamaChatCompletionHandlerNotFoundException(Exception):
pass
class LlamaChatCompletionHandlerRegistry(Singleton):
_chat_handlers: Dict[str, LlamaChatCompletionHandler] = {}
def register_chat_completion_handler(
self,
name: str,
chat_handler: LlamaChatCompletionHandler,
overwrite: bool = False,
):
if not overwrite and name in self._chat_handlers:
raise ValueError(
f"Formatter with name '{name}' is already registered. Use `overwrite=True` to overwrite it."
)
self._chat_handlers[name] = chat_handler
def unregister_chat_handler(self, name: str):
if name in self._chat_handlers:
del self._chat_handlers[name]
else:
raise ValueError(f"No formatter registered under the name '{name}'.")
def get_chat_completion_handler_by_name(
self, name: str
) -> LlamaChatCompletionHandler:
try:
chat_handler = self._chat_handlers[name]
return chat_handler
except KeyError:
raise LlamaChatCompletionHandlerNotFoundException(
f"Invalid chat handler: {name} (valid formats: {list(self._chat_handlers.keys())})"
)
def get_chat_completion_handler(name: str) -> LlamaChatCompletionHandler: def get_chat_completion_handler(name: str) -> LlamaChatCompletionHandler:
return CHAT_HANDLERS[name] return LlamaChatCompletionHandlerRegistry().get_chat_completion_handler_by_name(
name
)
def register_chat_completion_handler(name: str): def register_chat_completion_handler(name: str):
def decorator(f: LlamaChatCompletionHandler): def decorator(f: LlamaChatCompletionHandler):
CHAT_HANDLERS[name] = f LlamaChatCompletionHandlerRegistry().register_chat_completion_handler(name, f)
return f return f
return decorator return decorator
### Chat Formatter ###
@dataclasses.dataclass
class ChatFormatterResponse:
prompt: str
stop: Optional[Union[str, List[str]]] = None
class ChatFormatter(Protocol):
"""Base Protocol for a chat formatter. A chat formatter is a function that
takes a list of messages and returns a formatted prompt. It can also return
a stop token or list of stop tokens to use for the completion."""
def __call__(
self,
*,
messages: List[llama_types.ChatCompletionRequestMessage],
**kwargs: Any,
) -> ChatFormatterResponse:
...
### Utility functions for formatting chat prompts ###
def _get_system_message( def _get_system_message(
messages: List[llama_types.ChatCompletionRequestMessage], messages: List[llama_types.ChatCompletionRequestMessage],
) -> str: ) -> str:
@ -80,14 +153,18 @@ def _get_system_message(
def _map_roles( def _map_roles(
messages: List[llama_types.ChatCompletionRequestMessage], role_map: Dict[str, str] messages: List[llama_types.ChatCompletionRequestMessage],
role_map: Dict[str, str],
) -> List[Tuple[str, Optional[str]]]: ) -> List[Tuple[str, Optional[str]]]:
"""Map the message roles.""" """Map the message roles."""
output: List[Tuple[str, Optional[str]]] = [] output: List[Tuple[str, Optional[str]]] = []
for message in messages: for message in messages:
role = message["role"] role = message["role"]
if role in role_map: if role in role_map:
output.append((role_map[role], message["content"])) content: str | None = (
message["content"] if isinstance(message["content"], str) else None
)
output.append((role_map[role], content))
return output return output
@ -99,7 +176,8 @@ def _format_llama2(
ret = system_message + sep ret = system_message + sep
for i, (role, message) in enumerate(messages): for i, (role, message) in enumerate(messages):
if system_message and i == 0: if system_message and i == 0:
ret += message + seps[i % 2] m = message or ""
ret += m + seps[i % 2]
elif message: elif message:
ret += role + message + " " + seps[i % 2] ret += role + message + " " + seps[i % 2]
else: else:
@ -172,6 +250,7 @@ def _format_chatml(
ret += role + "\n" ret += role + "\n"
return ret return ret
def _format_chatglm3( def _format_chatglm3(
system_message: str, messages: List[Tuple[str, Optional[str]]], sep: str system_message: str, messages: List[Tuple[str, Optional[str]]], sep: str
) -> str: ) -> str:
@ -187,30 +266,10 @@ def _format_chatglm3(
return ret return ret
@dataclasses.dataclass
class ChatFormatterResponse:
prompt: str
stop: Optional[Union[str, List[str]]] = None
class ChatFormatter(Protocol):
def __call__(
self,
*,
messages: List[llama_types.ChatCompletionRequestMessage],
**kwargs: Any,
) -> ChatFormatterResponse:
...
class BasicChatHandler:
def __init__(self, chat_format: str):
self.chat_format = chat_format
def _convert_text_completion_to_chat( def _convert_text_completion_to_chat(
completion: llama_types.Completion, completion: llama_types.Completion,
) -> llama_types.ChatCompletion: ) -> llama_types.ChatCompletion:
assert "usage" in completion
return { return {
"id": "chat" + completion["id"], "id": "chat" + completion["id"],
"object": "chat.completion", "object": "chat.completion",
@ -286,103 +345,95 @@ def _convert_completion_to_chat(
return _convert_text_completion_to_chat(completion) return _convert_text_completion_to_chat(completion)
_CHAT_FORMATS: Dict[str, ChatFormatter] = {} def chat_formatter_to_chat_completion_handler(
chat_formatter: ChatFormatter,
) -> LlamaChatCompletionHandler:
def chat_completion_handler(
*,
llama: llama.Llama,
messages: List[llama_types.ChatCompletionRequestMessage],
functions: Optional[List[llama_types.ChatCompletionFunction]] = None,
function_call: Optional[llama_types.ChatCompletionRequestFunctionCall] = None,
tools: Optional[List[llama_types.ChatCompletionTool]] = None,
tool_choice: Optional[llama_types.ChatCompletionToolChoiceOption] = None,
temperature: float = 0.2,
top_p: float = 0.95,
top_k: int = 40,
min_p: float = 0.05,
typical_p: float = 1.0,
stream: bool = False,
stop: Optional[Union[str, List[str]]] = [],
seed: Optional[int] = None,
response_format: Optional[
llama_types.ChatCompletionRequestResponseFormat
] = None,
max_tokens: Optional[int] = None,
presence_penalty: float = 0.0,
frequency_penalty: float = 0.0,
repeat_penalty: float = 1.1,
tfs_z: float = 1.0,
mirostat_mode: int = 0,
mirostat_tau: float = 5.0,
mirostat_eta: float = 0.1,
model: Optional[str] = None,
logits_processor: Optional[llama.LogitsProcessorList] = None,
grammar: Optional[llama.LlamaGrammar] = None,
logit_bias: Optional[Dict[str, float]] = None,
**kwargs, # type: ignore
) -> Union[
llama_types.CreateChatCompletionResponse,
Iterator[llama_types.CreateChatCompletionStreamResponse],
]:
result = chat_formatter(
messages=messages,
functions=functions,
function_call=function_call,
)
prompt = result.prompt
if result.stop is not None:
stop = [] if stop is None else [stop] if isinstance(stop, str) else stop
rstop = result.stop if isinstance(result.stop, list) else [result.stop]
stop = stop + rstop
if response_format is not None and response_format["type"] == "json_object":
grammar = llama_grammar.LlamaGrammar.from_string(llama_grammar.JSON_GBNF)
completion_or_chunks = llama.create_completion(
prompt=prompt,
temperature=temperature,
top_p=top_p,
top_k=top_k,
min_p=min_p,
typical_p=typical_p,
stream=stream,
stop=stop,
seed=seed,
max_tokens=max_tokens,
presence_penalty=presence_penalty,
frequency_penalty=frequency_penalty,
repeat_penalty=repeat_penalty,
tfs_z=tfs_z,
mirostat_mode=mirostat_mode,
mirostat_tau=mirostat_tau,
mirostat_eta=mirostat_eta,
model=model,
logits_processor=logits_processor,
grammar=grammar,
logit_bias=logit_bias,
)
return _convert_completion_to_chat(completion_or_chunks, stream=stream)
return chat_completion_handler
def register_chat_format(name: str): def register_chat_format(name: str):
def decorator(f: ChatFormatter): def decorator(f: ChatFormatter):
def basic_create_chat_completion( chat_completion_handler = chat_formatter_to_chat_completion_handler(f)
*, LlamaChatCompletionHandlerRegistry().register_chat_completion_handler(
llama: llama.Llama, name, chat_completion_handler
messages: List[llama_types.ChatCompletionRequestMessage],
functions: Optional[List[llama_types.ChatCompletionFunction]] = None,
function_call: Optional[
llama_types.ChatCompletionRequestFunctionCall
] = None,
tools: Optional[List[llama_types.ChatCompletionTool]] = None,
tool_choice: Optional[llama_types.ChatCompletionToolChoiceOption] = None,
temperature: float = 0.2,
top_p: float = 0.95,
top_k: int = 40,
min_p: float = 0.05,
typical_p: float = 1.0,
stream: bool = False,
stop: Optional[Union[str, List[str]]] = [],
seed: Optional[int] = None,
response_format: Optional[
llama_types.ChatCompletionRequestResponseFormat
] = None,
max_tokens: Optional[int] = None,
presence_penalty: float = 0.0,
frequency_penalty: float = 0.0,
repeat_penalty: float = 1.1,
tfs_z: float = 1.0,
mirostat_mode: int = 0,
mirostat_tau: float = 5.0,
mirostat_eta: float = 0.1,
model: Optional[str] = None,
logits_processor: Optional[llama.LogitsProcessorList] = None,
grammar: Optional[llama.LlamaGrammar] = None,
logit_bias: Optional[Dict[str, float]] = None,
**kwargs, # type: ignore
) -> Union[
llama_types.CreateChatCompletionResponse,
Iterator[llama_types.CreateChatCompletionStreamResponse],
]:
result = f(
messages=messages,
functions=functions,
function_call=function_call,
)
prompt = result.prompt
if result.stop is not None:
stop = [] if stop is None else [stop] if isinstance(stop, str) else stop
rstop = result.stop if isinstance(result.stop, list) else [result.stop]
stop = stop + rstop
if response_format is not None and response_format["type"] == "json_object":
grammar = llama_grammar.LlamaGrammar.from_string(
llama_grammar.JSON_GBNF
)
completion_or_chunks = llama.create_completion(
prompt=prompt,
temperature=temperature,
top_p=top_p,
top_k=top_k,
min_p=min_p,
typical_p=typical_p,
stream=stream,
stop=stop,
seed=seed,
max_tokens=max_tokens,
presence_penalty=presence_penalty,
frequency_penalty=frequency_penalty,
repeat_penalty=repeat_penalty,
tfs_z=tfs_z,
mirostat_mode=mirostat_mode,
mirostat_tau=mirostat_tau,
mirostat_eta=mirostat_eta,
model=model,
logits_processor=logits_processor,
grammar=grammar,
logit_bias=logit_bias,
)
return _convert_completion_to_chat(completion_or_chunks, stream=stream)
register_chat_completion_handler(name)(basic_create_chat_completion)
return f
return decorator
def get_chat_format(name: str):
try:
return _CHAT_FORMATS[name]
except KeyError:
raise ValueError(
f"Invalid chat format: {name} (valid formats: {list(_CHAT_FORMATS.keys())})"
) )
return f
return decorator
def hf_autotokenizer_to_chat_formatter( def hf_autotokenizer_to_chat_formatter(
@ -391,22 +442,78 @@ def hf_autotokenizer_to_chat_formatter(
# https://huggingface.co/docs/transformers/main/chat_templating # https://huggingface.co/docs/transformers/main/chat_templating
# https://huggingface.co/mistralai/Mistral-7B-Instruct-v0.1#instruction-format # https://huggingface.co/mistralai/Mistral-7B-Instruct-v0.1#instruction-format
# https://huggingface.co/mistralai/Mistral-7B-Instruct-v0.1/blob/main/tokenizer_config.json # https://huggingface.co/mistralai/Mistral-7B-Instruct-v0.1/blob/main/tokenizer_config.json
from transformers import AutoTokenizer from transformers import AutoTokenizer # type: ignore
tokenizer = AutoTokenizer.from_pretrained(pretrained_model_name_or_path) tokenizer = AutoTokenizer.from_pretrained(pretrained_model_name_or_path) # type: ignore
def format_autotokenizer( def format_autotokenizer(
messages: List[llama_types.ChatCompletionRequestMessage], messages: List[llama_types.ChatCompletionRequestMessage],
**kwargs: Any, **kwargs: Any,
) -> ChatFormatterResponse: ) -> ChatFormatterResponse:
tokenizer.use_default_system_prompt = False tokenizer.use_default_system_prompt = False # type: ignore
_prompt = tokenizer.apply_chat_template(messages, tokenize=False) prompt: str = tokenizer.apply_chat_template(messages, tokenize=False) # type: ignore
assert isinstance(prompt, str)
# Return formatted prompt and eos token by default # Return formatted prompt and eos token by default
return ChatFormatterResponse(prompt=_prompt, stop=tokenizer.eos_token) return ChatFormatterResponse(prompt=prompt, stop=tokenizer.eos_token)
return format_autotokenizer return format_autotokenizer
def hf_autotokenizer_to_chat_completion_handler(
pretrained_model_name_or_path: Union[str, os.PathLike[str]]
) -> LlamaChatCompletionHandler:
chat_formatter = hf_autotokenizer_to_chat_formatter(pretrained_model_name_or_path)
return chat_formatter_to_chat_completion_handler(chat_formatter)
def hf_tokenizer_config_to_chat_formatter(tokenizer_config: Dict[str, Any]) -> ChatFormatter:
assert isinstance(tokenizer_config, dict)
assert "chat_template" in tokenizer_config
assert isinstance(tokenizer_config["chat_template"], str)
chat_template = tokenizer_config["chat_template"]
assert "bos_token" in tokenizer_config
assert isinstance(tokenizer_config["bos_token"], str)
bos_token = tokenizer_config["bos_token"]
assert "eos_token" in tokenizer_config
assert isinstance(tokenizer_config["eos_token"], str)
eos_token = tokenizer_config["eos_token"]
env = jinja2.Environment(
loader=jinja2.BaseLoader(),
trim_blocks=True,
lstrip_blocks=True,
).from_string(chat_template)
def format_autotokenizer(
messages: List[llama_types.ChatCompletionRequestMessage],
**kwargs: Any,
) -> ChatFormatterResponse:
# TODO: veryify this is correct
# Add a blank assistant message to the end of the messages to prompt the model to generate a response
prompt = env.render(
messages=[
*messages,
llama_types.ChatCompletionRequestAssistantMessage(
role="assistant", content=""
),
],
bos_token=bos_token,
eos_token=eos_token,
)
return ChatFormatterResponse(prompt=prompt, stop=eos_token)
return format_autotokenizer
def hf_tokenizer_config_to_chat_completion_handler(
tokenizer_config: Dict[str, Any],
) -> LlamaChatCompletionHandler:
chat_formatter = hf_tokenizer_config_to_chat_formatter(tokenizer_config)
return chat_formatter_to_chat_completion_handler(chat_formatter)
# see https://github.com/huggingface/transformers/blob/main/src/transformers/models/llama/tokenization_llama.py # see https://github.com/huggingface/transformers/blob/main/src/transformers/models/llama/tokenization_llama.py
# system prompt is "embedded" in the first message # system prompt is "embedded" in the first message
@register_chat_format("llama-2") @register_chat_format("llama-2")
@ -437,21 +544,23 @@ def format_alpaca(
_prompt = _format_add_colon_two(system_message, _messages, _sep, _sep2) _prompt = _format_add_colon_two(system_message, _messages, _sep, _sep2)
return ChatFormatterResponse(prompt=_prompt) return ChatFormatterResponse(prompt=_prompt)
@register_chat_format("qwen") @register_chat_format("qwen")
def format_qwen( def format_qwen(
messages: List[llama_types.ChatCompletionRequestMessage], messages: List[llama_types.ChatCompletionRequestMessage],
**kwargs: Any, **kwargs: Any,
) -> ChatFormatterResponse: ) -> ChatFormatterResponse:
_roles = dict(user="<|im_start|>user", assistant="<|im_start|>assistant") _roles = dict(user="<|im_start|>user", assistant="<|im_start|>assistant")
system_message="You are a helpful assistant." system_message = "You are a helpful assistant."
system_template="<|im_start|>system\n{system_message}" system_template = "<|im_start|>system\n{system_message}"
system_message=system_template.format(system_message=system_message) system_message = system_template.format(system_message=system_message)
_messages = _map_roles(messages, _roles) _messages = _map_roles(messages, _roles)
_messages.append((_roles["assistant"], None)) _messages.append((_roles["assistant"], None))
_sep = "<|im_end|>" _sep = "<|im_end|>"
_prompt = _format_chatml(system_message, _messages, _sep) _prompt = _format_chatml(system_message, _messages, _sep)
_sep2 = "<|endoftext|>" _sep2 = "<|endoftext|>"
return ChatFormatterResponse(prompt=_prompt,stop=_sep2) return ChatFormatterResponse(prompt=_prompt, stop=_sep2)
@register_chat_format("vicuna") @register_chat_format("vicuna")
def format( def format(
@ -650,6 +759,7 @@ def format_mistrallite(
_prompt = _format_no_colon_single(system_message, _messages, _sep) _prompt = _format_no_colon_single(system_message, _messages, _sep)
return ChatFormatterResponse(prompt=_prompt) return ChatFormatterResponse(prompt=_prompt)
@register_chat_format("zephyr") @register_chat_format("zephyr")
def format_zephyr( def format_zephyr(
messages: List[llama_types.ChatCompletionRequestMessage], messages: List[llama_types.ChatCompletionRequestMessage],
@ -699,6 +809,7 @@ def format_chatml(
_prompt = _format_chatml(system_message, _messages, _sep) _prompt = _format_chatml(system_message, _messages, _sep)
return ChatFormatterResponse(prompt=_prompt, stop=_sep) return ChatFormatterResponse(prompt=_prompt, stop=_sep)
@register_chat_format("chatglm3") @register_chat_format("chatglm3")
def format_chatglm3( def format_chatglm3(
messages: List[llama_types.ChatCompletionRequestMessage], messages: List[llama_types.ChatCompletionRequestMessage],
@ -739,7 +850,7 @@ def format_openchat(
@register_chat_format("saiga") @register_chat_format("saiga")
def format_saiga( def format_saiga(
messages: list[llama_types.ChatCompletionRequestMessage], messages: list[llama_types.ChatCompletionRequestMessage],
**kwargs, **kwargs: Any,
) -> ChatFormatterResponse: ) -> ChatFormatterResponse:
_message_template = "<s>{role}\n{content}</s>" _message_template = "<s>{role}\n{content}</s>"
_roles = dict(user="user", bot="bot", system="system") _roles = dict(user="user", bot="bot", system="system")

View file

@ -1,138 +0,0 @@
"""
llama_cpp/llama_jinja_format.py
"""
import dataclasses
from typing import Any, Callable, Dict, List, Optional, Protocol, Union
import jinja2
from jinja2 import Template
# NOTE: We sacrifice readability for usability.
# It will fail to work as expected if we attempt to format it in a readable way.
llama2_template = """{% for message in messages %}{% if message['role'] == 'user' %}[INST] {{ message['content'] }} [/INST]\n{% elif message['role'] == 'assistant' %}{{ message['content'] }}\n{% elif message['role'] == 'system' %}<<SYS>> {{ message['content'] }} <</SYS>>\n{% endif %}{% endfor %}"""
class MetaSingleton(type):
"""
Metaclass for implementing the Singleton pattern.
"""
_instances = {}
def __call__(cls, *args, **kwargs):
if cls not in cls._instances:
cls._instances[cls] = super(MetaSingleton, cls).__call__(*args, **kwargs)
return cls._instances[cls]
class Singleton(object, metaclass=MetaSingleton):
"""
Base class for implementing the Singleton pattern.
"""
def __init__(self):
super(Singleton, self).__init__()
@dataclasses.dataclass
class ChatFormatterResponse:
prompt: str
stop: Optional[Union[str, List[str]]] = None
# Base Chat Formatter Protocol
class ChatFormatterInterface(Protocol):
def __init__(self, template: Optional[object] = None):
...
def __call__(
self,
messages: List[Dict[str, str]],
**kwargs,
) -> ChatFormatterResponse:
...
@property
def template(self) -> str:
...
class AutoChatFormatter(ChatFormatterInterface):
def __init__(
self,
template: Optional[str] = None,
template_class: Optional[Template] = None,
):
if template is not None:
self._template = template
else:
self._template = llama2_template # default template
self._environment = jinja2.Environment(
loader=jinja2.BaseLoader(),
trim_blocks=True,
lstrip_blocks=True,
).from_string(
self._template,
template_class=template_class,
)
def __call__(
self,
messages: List[Dict[str, str]],
**kwargs: Any,
) -> ChatFormatterResponse:
formatted_sequence = self._environment.render(messages=messages, **kwargs)
return ChatFormatterResponse(prompt=formatted_sequence)
@property
def template(self) -> str:
return self._template
class FormatterNotFoundException(Exception):
pass
class ChatFormatterFactory(Singleton):
_chat_formatters: Dict[str, Callable[[], ChatFormatterInterface]] = {}
def register_formatter(
self,
name: str,
formatter_callable: Callable[[], ChatFormatterInterface],
overwrite=False,
):
if not overwrite and name in self._chat_formatters:
raise ValueError(
f"Formatter with name '{name}' is already registered. Use `overwrite=True` to overwrite it."
)
self._chat_formatters[name] = formatter_callable
def unregister_formatter(self, name: str):
if name in self._chat_formatters:
del self._chat_formatters[name]
else:
raise ValueError(f"No formatter registered under the name '{name}'.")
def get_formatter_by_name(self, name: str) -> ChatFormatterInterface:
try:
formatter_callable = self._chat_formatters[name]
return formatter_callable()
except KeyError:
raise FormatterNotFoundException(
f"Invalid chat format: {name} (valid formats: {list(self._chat_formatters.keys())})"
)
# Define a chat format class
class Llama2Formatter(AutoChatFormatter):
def __init__(self):
super().__init__(llama2_template)
# With the Singleton pattern applied, regardless of where or how many times
# ChatFormatterFactory() is called, it will always return the same instance
# of the factory, ensuring that the factory's state is consistent throughout
# the application.
ChatFormatterFactory().register_formatter("llama-2", Llama2Formatter)

View file

@ -1,5 +1,7 @@
from __future__ import annotations from __future__ import annotations
import json
from typing import Dict, Optional, Union, List from typing import Dict, Optional, Union, List
import llama_cpp import llama_cpp
@ -71,7 +73,25 @@ class LlamaProxy:
chat_handler = llama_cpp.llama_chat_format.Llava15ChatHandler( chat_handler = llama_cpp.llama_chat_format.Llava15ChatHandler(
clip_model_path=settings.clip_model_path, verbose=settings.verbose clip_model_path=settings.clip_model_path, verbose=settings.verbose
) )
elif settings.chat_format == "hf-autotokenizer":
assert (
settings.hf_pretrained_model_name_or_path is not None
), "hf_pretrained_model_name_or_path must be set for hf-autotokenizer"
chat_handler = (
llama_cpp.llama_chat_format.hf_autotokenizer_to_chat_formatter(
settings.hf_pretrained_model_name_or_path
)
)
elif settings.chat_format == "hf-tokenizer-config":
assert (
settings.hf_tokenizer_config_path is not None
), "hf_tokenizer_config_path must be set for hf-tokenizer-config"
chat_handler = (
llama_cpp.llama_chat_format.hf_tokenizer_config_to_chat_formatter(
json.load(open(settings.hf_tokenizer_config_path))
)
)
kv_overrides: Optional[Dict[str, Union[bool, int, float]]] = None kv_overrides: Optional[Dict[str, Union[bool, int, float]]] = None
if settings.kv_overrides is not None: if settings.kv_overrides is not None:
assert isinstance(settings.kv_overrides, list) assert isinstance(settings.kv_overrides, list)
@ -141,4 +161,3 @@ class LlamaProxy:
cache = llama_cpp.LlamaRAMCache(capacity_bytes=settings.cache_size) cache = llama_cpp.LlamaRAMCache(capacity_bytes=settings.cache_size)
_model.set_cache(cache) _model.set_cache(cache)
return _model return _model

View file

@ -134,6 +134,15 @@ class ModelSettings(BaseSettings):
default=2 << 30, default=2 << 30,
description="The size of the cache in bytes. Only used if cache is True.", description="The size of the cache in bytes. Only used if cache is True.",
) )
# Tokenizer Options
hf_tokenizer_config_path: Optional[str] = Field(
default=None,
description="The path to a HuggingFace tokenizer_config.json file.",
)
hf_pretrained_model_name_or_path: Optional[str] = Field(
default=None,
description="The model name or path to a pretrained HuggingFace tokenizer model. Same as you would pass to AutoTokenizer.from_pretrained().",
)
# Misc # Misc
verbose: bool = Field( verbose: bool = Field(
default=True, description="Whether to print debug information." default=True, description="Whether to print debug information."

View file

@ -1,50 +1,65 @@
from typing import List import json
import pytest from llama_cpp import (
ChatCompletionRequestUserMessage,
from llama_cpp import ChatCompletionMessage )
from llama_cpp.llama_jinja_format import Llama2Formatter from llama_cpp.llama_chat_format import hf_tokenizer_config_to_chat_formatter
@pytest.fixture mistral_7b_tokenizer_config = """{
def sequence_of_messages() -> List[ChatCompletionMessage]: "add_bos_token": true,
return [ "add_eos_token": false,
ChatCompletionMessage(role="system", content="Welcome to CodeHelp Bot!"), "added_tokens_decoder": {
ChatCompletionMessage( "0": {
role="user", content="Hi there! I need some help with Python." "content": "<unk>",
), "lstrip": false,
ChatCompletionMessage( "normalized": false,
role="assistant", content="Of course! What do you need help with in Python?" "rstrip": false,
), "single_word": false,
ChatCompletionMessage( "special": true
role="user", },
content="I'm trying to write a function to find the factorial of a number, but I'm stuck.", "1": {
), "content": "<s>",
ChatCompletionMessage( "lstrip": false,
role="assistant", "normalized": false,
content="I can help with that! Would you like a recursive or iterative solution?", "rstrip": false,
), "single_word": false,
ChatCompletionMessage( "special": true
role="user", content="Let's go with a recursive solution." },
), "2": {
] "content": "</s>",
"lstrip": false,
"normalized": false,
"rstrip": false,
"single_word": false,
"special": true
}
},
"additional_special_tokens": [],
"bos_token": "<s>",
"clean_up_tokenization_spaces": false,
"eos_token": "</s>",
"legacy": true,
"model_max_length": 1000000000000000019884624838656,
"pad_token": null,
"sp_model_kwargs": {},
"spaces_between_special_tokens": false,
"tokenizer_class": "LlamaTokenizer",
"unk_token": "<unk>",
"use_default_system_prompt": false,
"chat_template": "{{ bos_token }}{% for message in messages %}{% if (message['role'] == 'user') != (loop.index0 % 2 == 0) %}{{ raise_exception('Conversation roles must alternate user/assistant/user/assistant/...') }}{% endif %}{% if message['role'] == 'user' %}{{ '[INST] ' + message['content'] + ' [/INST]' }}{% elif message['role'] == 'assistant' %}{{ message['content'] + eos_token}}{% else %}{{ raise_exception('Only user and assistant roles are supported!') }}{% endif %}{% endfor %}"
}"""
def test_llama2_formatter(sequence_of_messages): def test_hf_tokenizer_config_str_to_chat_formatter():
expected_prompt = ( tokenizer_config = json.loads(mistral_7b_tokenizer_config)
"<<SYS>> Welcome to CodeHelp Bot! <</SYS>>\n" chat_formatter = hf_tokenizer_config_to_chat_formatter(
"[INST] Hi there! I need some help with Python. [/INST]\n" tokenizer_config
"Of course! What do you need help with in Python?\n" )
"[INST] I'm trying to write a function to find the factorial of a number, but I'm stuck. [/INST]\n" chat_formatter_respoonse = chat_formatter(
"I can help with that! Would you like a recursive or iterative solution?\n" messages=[
"[INST] Let's go with a recursive solution. [/INST]\n" ChatCompletionRequestUserMessage(role="user", content="Hello, world!"),
]
) )
llama2_formatter_instance = Llama2Formatter() assert chat_formatter_respoonse.prompt == ("<s>[INST] Hello, world! [/INST]</s>" "")
formatter_response = llama2_formatter_instance(sequence_of_messages)
assert (
expected_prompt == formatter_response.prompt
), "The formatted prompt does not match the expected output."
# Optionally, include a test for the 'stop' if it's part of the functionality.