feat: Add ability to load chat format from huggingface autotokenizer or tokenizer_config.json files.
This commit is contained in:
parent
48c3b77e6f
commit
b8fc1c7d83
6 changed files with 357 additions and 318 deletions
|
@ -1,7 +1,8 @@
|
||||||
import os
|
import os
|
||||||
import sys
|
import sys
|
||||||
|
|
||||||
import sys, traceback
|
import sys
|
||||||
|
from typing import Any, Dict
|
||||||
|
|
||||||
# Avoid "LookupError: unknown encoding: ascii" when open() called in a destructor
|
# Avoid "LookupError: unknown encoding: ascii" when open() called in a destructor
|
||||||
outnull_file = open(os.devnull, "w")
|
outnull_file = open(os.devnull, "w")
|
||||||
|
@ -55,3 +56,25 @@ class suppress_stdout_stderr(object):
|
||||||
|
|
||||||
self.os.close(self.old_stdout_fileno)
|
self.os.close(self.old_stdout_fileno)
|
||||||
self.os.close(self.old_stderr_fileno)
|
self.os.close(self.old_stderr_fileno)
|
||||||
|
|
||||||
|
|
||||||
|
class MetaSingleton(type):
|
||||||
|
"""
|
||||||
|
Metaclass for implementing the Singleton pattern.
|
||||||
|
"""
|
||||||
|
|
||||||
|
_instances: Dict[type, Any] = {}
|
||||||
|
|
||||||
|
def __call__(cls, *args: Any, **kwargs: Any) -> Any:
|
||||||
|
if cls not in cls._instances:
|
||||||
|
cls._instances[cls] = super(MetaSingleton, cls).__call__(*args, **kwargs)
|
||||||
|
return cls._instances[cls]
|
||||||
|
|
||||||
|
|
||||||
|
class Singleton(object, metaclass=MetaSingleton):
|
||||||
|
"""
|
||||||
|
Base class for implementing the Singleton pattern.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self):
|
||||||
|
super(Singleton, self).__init__()
|
||||||
|
|
|
@ -6,18 +6,28 @@ import ctypes
|
||||||
import dataclasses
|
import dataclasses
|
||||||
from typing import Any, Dict, Iterator, List, Optional, Tuple, Union, Protocol
|
from typing import Any, Dict, Iterator, List, Optional, Tuple, Union, Protocol
|
||||||
|
|
||||||
|
import jinja2
|
||||||
|
|
||||||
import llama_cpp.llama as llama
|
import llama_cpp.llama as llama
|
||||||
import llama_cpp.llama_types as llama_types
|
import llama_cpp.llama_types as llama_types
|
||||||
import llama_cpp.llama_grammar as llama_grammar
|
import llama_cpp.llama_grammar as llama_grammar
|
||||||
|
|
||||||
from ._utils import suppress_stdout_stderr
|
from ._utils import suppress_stdout_stderr, Singleton
|
||||||
|
|
||||||
|
|
||||||
class LlamaChatCompletionHandler(Protocol):
|
class LlamaChatCompletionHandler(Protocol):
|
||||||
|
"""Base Protocol for a llama chat completion handler.
|
||||||
|
|
||||||
|
Very generic protocol that can be used to implement any chat format.
|
||||||
|
The only hard requirement is that it must return a ChatCompletion when
|
||||||
|
stream=False and an iterator of ChatCompletionChunks when stream=True."""
|
||||||
|
|
||||||
def __call__(
|
def __call__(
|
||||||
self,
|
self,
|
||||||
*,
|
*,
|
||||||
|
# llama.cpp instance
|
||||||
llama: llama.Llama,
|
llama: llama.Llama,
|
||||||
|
# openai api parameters
|
||||||
messages: List[llama_types.ChatCompletionRequestMessage],
|
messages: List[llama_types.ChatCompletionRequestMessage],
|
||||||
functions: Optional[List[llama_types.ChatCompletionFunction]] = None,
|
functions: Optional[List[llama_types.ChatCompletionFunction]] = None,
|
||||||
function_call: Optional[llama_types.ChatCompletionRequestFunctionCall] = None,
|
function_call: Optional[llama_types.ChatCompletionRequestFunctionCall] = None,
|
||||||
|
@ -26,8 +36,6 @@ class LlamaChatCompletionHandler(Protocol):
|
||||||
temperature: float = 0.2,
|
temperature: float = 0.2,
|
||||||
top_p: float = 0.95,
|
top_p: float = 0.95,
|
||||||
top_k: int = 40,
|
top_k: int = 40,
|
||||||
min_p: float = 0.05,
|
|
||||||
typical_p: float = 1.0,
|
|
||||||
stream: bool = False,
|
stream: bool = False,
|
||||||
stop: Optional[Union[str, List[str]]] = [],
|
stop: Optional[Union[str, List[str]]] = [],
|
||||||
seed: Optional[int] = None,
|
seed: Optional[int] = None,
|
||||||
|
@ -38,14 +46,17 @@ class LlamaChatCompletionHandler(Protocol):
|
||||||
presence_penalty: float = 0.0,
|
presence_penalty: float = 0.0,
|
||||||
frequency_penalty: float = 0.0,
|
frequency_penalty: float = 0.0,
|
||||||
repeat_penalty: float = 1.1,
|
repeat_penalty: float = 1.1,
|
||||||
|
model: Optional[str] = None,
|
||||||
|
logit_bias: Optional[Dict[str, float]] = None,
|
||||||
|
# llama.cpp parameters
|
||||||
|
min_p: float = 0.05,
|
||||||
|
typical_p: float = 1.0,
|
||||||
tfs_z: float = 1.0,
|
tfs_z: float = 1.0,
|
||||||
mirostat_mode: int = 0,
|
mirostat_mode: int = 0,
|
||||||
mirostat_tau: float = 5.0,
|
mirostat_tau: float = 5.0,
|
||||||
mirostat_eta: float = 0.1,
|
mirostat_eta: float = 0.1,
|
||||||
model: Optional[str] = None,
|
|
||||||
logits_processor: Optional[llama.LogitsProcessorList] = None,
|
logits_processor: Optional[llama.LogitsProcessorList] = None,
|
||||||
grammar: Optional[llama.LlamaGrammar] = None,
|
grammar: Optional[llama.LlamaGrammar] = None,
|
||||||
logit_bias: Optional[Dict[str, float]] = None,
|
|
||||||
**kwargs, # type: ignore
|
**kwargs, # type: ignore
|
||||||
) -> Union[
|
) -> Union[
|
||||||
llama_types.CreateChatCompletionResponse,
|
llama_types.CreateChatCompletionResponse,
|
||||||
|
@ -54,21 +65,83 @@ class LlamaChatCompletionHandler(Protocol):
|
||||||
...
|
...
|
||||||
|
|
||||||
|
|
||||||
CHAT_HANDLERS: Dict[str, LlamaChatCompletionHandler] = {}
|
class LlamaChatCompletionHandlerNotFoundException(Exception):
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
class LlamaChatCompletionHandlerRegistry(Singleton):
|
||||||
|
_chat_handlers: Dict[str, LlamaChatCompletionHandler] = {}
|
||||||
|
|
||||||
|
def register_chat_completion_handler(
|
||||||
|
self,
|
||||||
|
name: str,
|
||||||
|
chat_handler: LlamaChatCompletionHandler,
|
||||||
|
overwrite: bool = False,
|
||||||
|
):
|
||||||
|
if not overwrite and name in self._chat_handlers:
|
||||||
|
raise ValueError(
|
||||||
|
f"Formatter with name '{name}' is already registered. Use `overwrite=True` to overwrite it."
|
||||||
|
)
|
||||||
|
self._chat_handlers[name] = chat_handler
|
||||||
|
|
||||||
|
def unregister_chat_handler(self, name: str):
|
||||||
|
if name in self._chat_handlers:
|
||||||
|
del self._chat_handlers[name]
|
||||||
|
else:
|
||||||
|
raise ValueError(f"No formatter registered under the name '{name}'.")
|
||||||
|
|
||||||
|
def get_chat_completion_handler_by_name(
|
||||||
|
self, name: str
|
||||||
|
) -> LlamaChatCompletionHandler:
|
||||||
|
try:
|
||||||
|
chat_handler = self._chat_handlers[name]
|
||||||
|
return chat_handler
|
||||||
|
except KeyError:
|
||||||
|
raise LlamaChatCompletionHandlerNotFoundException(
|
||||||
|
f"Invalid chat handler: {name} (valid formats: {list(self._chat_handlers.keys())})"
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
def get_chat_completion_handler(name: str) -> LlamaChatCompletionHandler:
|
def get_chat_completion_handler(name: str) -> LlamaChatCompletionHandler:
|
||||||
return CHAT_HANDLERS[name]
|
return LlamaChatCompletionHandlerRegistry().get_chat_completion_handler_by_name(
|
||||||
|
name
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
def register_chat_completion_handler(name: str):
|
def register_chat_completion_handler(name: str):
|
||||||
def decorator(f: LlamaChatCompletionHandler):
|
def decorator(f: LlamaChatCompletionHandler):
|
||||||
CHAT_HANDLERS[name] = f
|
LlamaChatCompletionHandlerRegistry().register_chat_completion_handler(name, f)
|
||||||
return f
|
return f
|
||||||
|
|
||||||
return decorator
|
return decorator
|
||||||
|
|
||||||
|
|
||||||
|
### Chat Formatter ###
|
||||||
|
|
||||||
|
|
||||||
|
@dataclasses.dataclass
|
||||||
|
class ChatFormatterResponse:
|
||||||
|
prompt: str
|
||||||
|
stop: Optional[Union[str, List[str]]] = None
|
||||||
|
|
||||||
|
|
||||||
|
class ChatFormatter(Protocol):
|
||||||
|
"""Base Protocol for a chat formatter. A chat formatter is a function that
|
||||||
|
takes a list of messages and returns a formatted prompt. It can also return
|
||||||
|
a stop token or list of stop tokens to use for the completion."""
|
||||||
|
|
||||||
|
def __call__(
|
||||||
|
self,
|
||||||
|
*,
|
||||||
|
messages: List[llama_types.ChatCompletionRequestMessage],
|
||||||
|
**kwargs: Any,
|
||||||
|
) -> ChatFormatterResponse:
|
||||||
|
...
|
||||||
|
|
||||||
|
|
||||||
|
### Utility functions for formatting chat prompts ###
|
||||||
|
|
||||||
|
|
||||||
def _get_system_message(
|
def _get_system_message(
|
||||||
messages: List[llama_types.ChatCompletionRequestMessage],
|
messages: List[llama_types.ChatCompletionRequestMessage],
|
||||||
) -> str:
|
) -> str:
|
||||||
|
@ -80,14 +153,18 @@ def _get_system_message(
|
||||||
|
|
||||||
|
|
||||||
def _map_roles(
|
def _map_roles(
|
||||||
messages: List[llama_types.ChatCompletionRequestMessage], role_map: Dict[str, str]
|
messages: List[llama_types.ChatCompletionRequestMessage],
|
||||||
|
role_map: Dict[str, str],
|
||||||
) -> List[Tuple[str, Optional[str]]]:
|
) -> List[Tuple[str, Optional[str]]]:
|
||||||
"""Map the message roles."""
|
"""Map the message roles."""
|
||||||
output: List[Tuple[str, Optional[str]]] = []
|
output: List[Tuple[str, Optional[str]]] = []
|
||||||
for message in messages:
|
for message in messages:
|
||||||
role = message["role"]
|
role = message["role"]
|
||||||
if role in role_map:
|
if role in role_map:
|
||||||
output.append((role_map[role], message["content"]))
|
content: str | None = (
|
||||||
|
message["content"] if isinstance(message["content"], str) else None
|
||||||
|
)
|
||||||
|
output.append((role_map[role], content))
|
||||||
return output
|
return output
|
||||||
|
|
||||||
|
|
||||||
|
@ -99,7 +176,8 @@ def _format_llama2(
|
||||||
ret = system_message + sep
|
ret = system_message + sep
|
||||||
for i, (role, message) in enumerate(messages):
|
for i, (role, message) in enumerate(messages):
|
||||||
if system_message and i == 0:
|
if system_message and i == 0:
|
||||||
ret += message + seps[i % 2]
|
m = message or ""
|
||||||
|
ret += m + seps[i % 2]
|
||||||
elif message:
|
elif message:
|
||||||
ret += role + message + " " + seps[i % 2]
|
ret += role + message + " " + seps[i % 2]
|
||||||
else:
|
else:
|
||||||
|
@ -172,6 +250,7 @@ def _format_chatml(
|
||||||
ret += role + "\n"
|
ret += role + "\n"
|
||||||
return ret
|
return ret
|
||||||
|
|
||||||
|
|
||||||
def _format_chatglm3(
|
def _format_chatglm3(
|
||||||
system_message: str, messages: List[Tuple[str, Optional[str]]], sep: str
|
system_message: str, messages: List[Tuple[str, Optional[str]]], sep: str
|
||||||
) -> str:
|
) -> str:
|
||||||
|
@ -187,30 +266,10 @@ def _format_chatglm3(
|
||||||
return ret
|
return ret
|
||||||
|
|
||||||
|
|
||||||
@dataclasses.dataclass
|
|
||||||
class ChatFormatterResponse:
|
|
||||||
prompt: str
|
|
||||||
stop: Optional[Union[str, List[str]]] = None
|
|
||||||
|
|
||||||
|
|
||||||
class ChatFormatter(Protocol):
|
|
||||||
def __call__(
|
|
||||||
self,
|
|
||||||
*,
|
|
||||||
messages: List[llama_types.ChatCompletionRequestMessage],
|
|
||||||
**kwargs: Any,
|
|
||||||
) -> ChatFormatterResponse:
|
|
||||||
...
|
|
||||||
|
|
||||||
|
|
||||||
class BasicChatHandler:
|
|
||||||
def __init__(self, chat_format: str):
|
|
||||||
self.chat_format = chat_format
|
|
||||||
|
|
||||||
|
|
||||||
def _convert_text_completion_to_chat(
|
def _convert_text_completion_to_chat(
|
||||||
completion: llama_types.Completion,
|
completion: llama_types.Completion,
|
||||||
) -> llama_types.ChatCompletion:
|
) -> llama_types.ChatCompletion:
|
||||||
|
assert "usage" in completion
|
||||||
return {
|
return {
|
||||||
"id": "chat" + completion["id"],
|
"id": "chat" + completion["id"],
|
||||||
"object": "chat.completion",
|
"object": "chat.completion",
|
||||||
|
@ -286,19 +345,15 @@ def _convert_completion_to_chat(
|
||||||
return _convert_text_completion_to_chat(completion)
|
return _convert_text_completion_to_chat(completion)
|
||||||
|
|
||||||
|
|
||||||
_CHAT_FORMATS: Dict[str, ChatFormatter] = {}
|
def chat_formatter_to_chat_completion_handler(
|
||||||
|
chat_formatter: ChatFormatter,
|
||||||
|
) -> LlamaChatCompletionHandler:
|
||||||
def register_chat_format(name: str):
|
def chat_completion_handler(
|
||||||
def decorator(f: ChatFormatter):
|
|
||||||
def basic_create_chat_completion(
|
|
||||||
*,
|
*,
|
||||||
llama: llama.Llama,
|
llama: llama.Llama,
|
||||||
messages: List[llama_types.ChatCompletionRequestMessage],
|
messages: List[llama_types.ChatCompletionRequestMessage],
|
||||||
functions: Optional[List[llama_types.ChatCompletionFunction]] = None,
|
functions: Optional[List[llama_types.ChatCompletionFunction]] = None,
|
||||||
function_call: Optional[
|
function_call: Optional[llama_types.ChatCompletionRequestFunctionCall] = None,
|
||||||
llama_types.ChatCompletionRequestFunctionCall
|
|
||||||
] = None,
|
|
||||||
tools: Optional[List[llama_types.ChatCompletionTool]] = None,
|
tools: Optional[List[llama_types.ChatCompletionTool]] = None,
|
||||||
tool_choice: Optional[llama_types.ChatCompletionToolChoiceOption] = None,
|
tool_choice: Optional[llama_types.ChatCompletionToolChoiceOption] = None,
|
||||||
temperature: float = 0.2,
|
temperature: float = 0.2,
|
||||||
|
@ -329,7 +384,7 @@ def register_chat_format(name: str):
|
||||||
llama_types.CreateChatCompletionResponse,
|
llama_types.CreateChatCompletionResponse,
|
||||||
Iterator[llama_types.CreateChatCompletionStreamResponse],
|
Iterator[llama_types.CreateChatCompletionStreamResponse],
|
||||||
]:
|
]:
|
||||||
result = f(
|
result = chat_formatter(
|
||||||
messages=messages,
|
messages=messages,
|
||||||
functions=functions,
|
functions=functions,
|
||||||
function_call=function_call,
|
function_call=function_call,
|
||||||
|
@ -341,9 +396,7 @@ def register_chat_format(name: str):
|
||||||
stop = stop + rstop
|
stop = stop + rstop
|
||||||
|
|
||||||
if response_format is not None and response_format["type"] == "json_object":
|
if response_format is not None and response_format["type"] == "json_object":
|
||||||
grammar = llama_grammar.LlamaGrammar.from_string(
|
grammar = llama_grammar.LlamaGrammar.from_string(llama_grammar.JSON_GBNF)
|
||||||
llama_grammar.JSON_GBNF
|
|
||||||
)
|
|
||||||
|
|
||||||
completion_or_chunks = llama.create_completion(
|
completion_or_chunks = llama.create_completion(
|
||||||
prompt=prompt,
|
prompt=prompt,
|
||||||
|
@ -370,19 +423,17 @@ def register_chat_format(name: str):
|
||||||
)
|
)
|
||||||
return _convert_completion_to_chat(completion_or_chunks, stream=stream)
|
return _convert_completion_to_chat(completion_or_chunks, stream=stream)
|
||||||
|
|
||||||
register_chat_completion_handler(name)(basic_create_chat_completion)
|
return chat_completion_handler
|
||||||
return f
|
|
||||||
|
|
||||||
return decorator
|
|
||||||
|
|
||||||
|
|
||||||
def get_chat_format(name: str):
|
def register_chat_format(name: str):
|
||||||
try:
|
def decorator(f: ChatFormatter):
|
||||||
return _CHAT_FORMATS[name]
|
chat_completion_handler = chat_formatter_to_chat_completion_handler(f)
|
||||||
except KeyError:
|
LlamaChatCompletionHandlerRegistry().register_chat_completion_handler(
|
||||||
raise ValueError(
|
name, chat_completion_handler
|
||||||
f"Invalid chat format: {name} (valid formats: {list(_CHAT_FORMATS.keys())})"
|
|
||||||
)
|
)
|
||||||
|
return f
|
||||||
|
return decorator
|
||||||
|
|
||||||
|
|
||||||
def hf_autotokenizer_to_chat_formatter(
|
def hf_autotokenizer_to_chat_formatter(
|
||||||
|
@ -391,22 +442,78 @@ def hf_autotokenizer_to_chat_formatter(
|
||||||
# https://huggingface.co/docs/transformers/main/chat_templating
|
# https://huggingface.co/docs/transformers/main/chat_templating
|
||||||
# https://huggingface.co/mistralai/Mistral-7B-Instruct-v0.1#instruction-format
|
# https://huggingface.co/mistralai/Mistral-7B-Instruct-v0.1#instruction-format
|
||||||
# https://huggingface.co/mistralai/Mistral-7B-Instruct-v0.1/blob/main/tokenizer_config.json
|
# https://huggingface.co/mistralai/Mistral-7B-Instruct-v0.1/blob/main/tokenizer_config.json
|
||||||
from transformers import AutoTokenizer
|
from transformers import AutoTokenizer # type: ignore
|
||||||
|
|
||||||
tokenizer = AutoTokenizer.from_pretrained(pretrained_model_name_or_path)
|
tokenizer = AutoTokenizer.from_pretrained(pretrained_model_name_or_path) # type: ignore
|
||||||
|
|
||||||
def format_autotokenizer(
|
def format_autotokenizer(
|
||||||
messages: List[llama_types.ChatCompletionRequestMessage],
|
messages: List[llama_types.ChatCompletionRequestMessage],
|
||||||
**kwargs: Any,
|
**kwargs: Any,
|
||||||
) -> ChatFormatterResponse:
|
) -> ChatFormatterResponse:
|
||||||
tokenizer.use_default_system_prompt = False
|
tokenizer.use_default_system_prompt = False # type: ignore
|
||||||
_prompt = tokenizer.apply_chat_template(messages, tokenize=False)
|
prompt: str = tokenizer.apply_chat_template(messages, tokenize=False) # type: ignore
|
||||||
|
assert isinstance(prompt, str)
|
||||||
# Return formatted prompt and eos token by default
|
# Return formatted prompt and eos token by default
|
||||||
return ChatFormatterResponse(prompt=_prompt, stop=tokenizer.eos_token)
|
return ChatFormatterResponse(prompt=prompt, stop=tokenizer.eos_token)
|
||||||
|
|
||||||
return format_autotokenizer
|
return format_autotokenizer
|
||||||
|
|
||||||
|
|
||||||
|
def hf_autotokenizer_to_chat_completion_handler(
|
||||||
|
pretrained_model_name_or_path: Union[str, os.PathLike[str]]
|
||||||
|
) -> LlamaChatCompletionHandler:
|
||||||
|
chat_formatter = hf_autotokenizer_to_chat_formatter(pretrained_model_name_or_path)
|
||||||
|
return chat_formatter_to_chat_completion_handler(chat_formatter)
|
||||||
|
|
||||||
|
|
||||||
|
def hf_tokenizer_config_to_chat_formatter(tokenizer_config: Dict[str, Any]) -> ChatFormatter:
|
||||||
|
assert isinstance(tokenizer_config, dict)
|
||||||
|
|
||||||
|
assert "chat_template" in tokenizer_config
|
||||||
|
assert isinstance(tokenizer_config["chat_template"], str)
|
||||||
|
chat_template = tokenizer_config["chat_template"]
|
||||||
|
|
||||||
|
assert "bos_token" in tokenizer_config
|
||||||
|
assert isinstance(tokenizer_config["bos_token"], str)
|
||||||
|
bos_token = tokenizer_config["bos_token"]
|
||||||
|
|
||||||
|
assert "eos_token" in tokenizer_config
|
||||||
|
assert isinstance(tokenizer_config["eos_token"], str)
|
||||||
|
eos_token = tokenizer_config["eos_token"]
|
||||||
|
|
||||||
|
env = jinja2.Environment(
|
||||||
|
loader=jinja2.BaseLoader(),
|
||||||
|
trim_blocks=True,
|
||||||
|
lstrip_blocks=True,
|
||||||
|
).from_string(chat_template)
|
||||||
|
|
||||||
|
def format_autotokenizer(
|
||||||
|
messages: List[llama_types.ChatCompletionRequestMessage],
|
||||||
|
**kwargs: Any,
|
||||||
|
) -> ChatFormatterResponse:
|
||||||
|
# TODO: veryify this is correct
|
||||||
|
# Add a blank assistant message to the end of the messages to prompt the model to generate a response
|
||||||
|
prompt = env.render(
|
||||||
|
messages=[
|
||||||
|
*messages,
|
||||||
|
llama_types.ChatCompletionRequestAssistantMessage(
|
||||||
|
role="assistant", content=""
|
||||||
|
),
|
||||||
|
],
|
||||||
|
bos_token=bos_token,
|
||||||
|
eos_token=eos_token,
|
||||||
|
)
|
||||||
|
return ChatFormatterResponse(prompt=prompt, stop=eos_token)
|
||||||
|
return format_autotokenizer
|
||||||
|
|
||||||
|
|
||||||
|
def hf_tokenizer_config_to_chat_completion_handler(
|
||||||
|
tokenizer_config: Dict[str, Any],
|
||||||
|
) -> LlamaChatCompletionHandler:
|
||||||
|
chat_formatter = hf_tokenizer_config_to_chat_formatter(tokenizer_config)
|
||||||
|
return chat_formatter_to_chat_completion_handler(chat_formatter)
|
||||||
|
|
||||||
|
|
||||||
# see https://github.com/huggingface/transformers/blob/main/src/transformers/models/llama/tokenization_llama.py
|
# see https://github.com/huggingface/transformers/blob/main/src/transformers/models/llama/tokenization_llama.py
|
||||||
# system prompt is "embedded" in the first message
|
# system prompt is "embedded" in the first message
|
||||||
@register_chat_format("llama-2")
|
@register_chat_format("llama-2")
|
||||||
|
@ -437,6 +544,7 @@ def format_alpaca(
|
||||||
_prompt = _format_add_colon_two(system_message, _messages, _sep, _sep2)
|
_prompt = _format_add_colon_two(system_message, _messages, _sep, _sep2)
|
||||||
return ChatFormatterResponse(prompt=_prompt)
|
return ChatFormatterResponse(prompt=_prompt)
|
||||||
|
|
||||||
|
|
||||||
@register_chat_format("qwen")
|
@register_chat_format("qwen")
|
||||||
def format_qwen(
|
def format_qwen(
|
||||||
messages: List[llama_types.ChatCompletionRequestMessage],
|
messages: List[llama_types.ChatCompletionRequestMessage],
|
||||||
|
@ -453,6 +561,7 @@ def format_qwen(
|
||||||
_sep2 = "<|endoftext|>"
|
_sep2 = "<|endoftext|>"
|
||||||
return ChatFormatterResponse(prompt=_prompt, stop=_sep2)
|
return ChatFormatterResponse(prompt=_prompt, stop=_sep2)
|
||||||
|
|
||||||
|
|
||||||
@register_chat_format("vicuna")
|
@register_chat_format("vicuna")
|
||||||
def format(
|
def format(
|
||||||
messages: List[llama_types.ChatCompletionRequestMessage],
|
messages: List[llama_types.ChatCompletionRequestMessage],
|
||||||
|
@ -650,6 +759,7 @@ def format_mistrallite(
|
||||||
_prompt = _format_no_colon_single(system_message, _messages, _sep)
|
_prompt = _format_no_colon_single(system_message, _messages, _sep)
|
||||||
return ChatFormatterResponse(prompt=_prompt)
|
return ChatFormatterResponse(prompt=_prompt)
|
||||||
|
|
||||||
|
|
||||||
@register_chat_format("zephyr")
|
@register_chat_format("zephyr")
|
||||||
def format_zephyr(
|
def format_zephyr(
|
||||||
messages: List[llama_types.ChatCompletionRequestMessage],
|
messages: List[llama_types.ChatCompletionRequestMessage],
|
||||||
|
@ -699,6 +809,7 @@ def format_chatml(
|
||||||
_prompt = _format_chatml(system_message, _messages, _sep)
|
_prompt = _format_chatml(system_message, _messages, _sep)
|
||||||
return ChatFormatterResponse(prompt=_prompt, stop=_sep)
|
return ChatFormatterResponse(prompt=_prompt, stop=_sep)
|
||||||
|
|
||||||
|
|
||||||
@register_chat_format("chatglm3")
|
@register_chat_format("chatglm3")
|
||||||
def format_chatglm3(
|
def format_chatglm3(
|
||||||
messages: List[llama_types.ChatCompletionRequestMessage],
|
messages: List[llama_types.ChatCompletionRequestMessage],
|
||||||
|
@ -739,7 +850,7 @@ def format_openchat(
|
||||||
@register_chat_format("saiga")
|
@register_chat_format("saiga")
|
||||||
def format_saiga(
|
def format_saiga(
|
||||||
messages: list[llama_types.ChatCompletionRequestMessage],
|
messages: list[llama_types.ChatCompletionRequestMessage],
|
||||||
**kwargs,
|
**kwargs: Any,
|
||||||
) -> ChatFormatterResponse:
|
) -> ChatFormatterResponse:
|
||||||
_message_template = "<s>{role}\n{content}</s>"
|
_message_template = "<s>{role}\n{content}</s>"
|
||||||
_roles = dict(user="user", bot="bot", system="system")
|
_roles = dict(user="user", bot="bot", system="system")
|
||||||
|
|
|
@ -1,138 +0,0 @@
|
||||||
"""
|
|
||||||
llama_cpp/llama_jinja_format.py
|
|
||||||
"""
|
|
||||||
import dataclasses
|
|
||||||
from typing import Any, Callable, Dict, List, Optional, Protocol, Union
|
|
||||||
|
|
||||||
import jinja2
|
|
||||||
from jinja2 import Template
|
|
||||||
|
|
||||||
# NOTE: We sacrifice readability for usability.
|
|
||||||
# It will fail to work as expected if we attempt to format it in a readable way.
|
|
||||||
llama2_template = """{% for message in messages %}{% if message['role'] == 'user' %}[INST] {{ message['content'] }} [/INST]\n{% elif message['role'] == 'assistant' %}{{ message['content'] }}\n{% elif message['role'] == 'system' %}<<SYS>> {{ message['content'] }} <</SYS>>\n{% endif %}{% endfor %}"""
|
|
||||||
|
|
||||||
|
|
||||||
class MetaSingleton(type):
|
|
||||||
"""
|
|
||||||
Metaclass for implementing the Singleton pattern.
|
|
||||||
"""
|
|
||||||
|
|
||||||
_instances = {}
|
|
||||||
|
|
||||||
def __call__(cls, *args, **kwargs):
|
|
||||||
if cls not in cls._instances:
|
|
||||||
cls._instances[cls] = super(MetaSingleton, cls).__call__(*args, **kwargs)
|
|
||||||
return cls._instances[cls]
|
|
||||||
|
|
||||||
|
|
||||||
class Singleton(object, metaclass=MetaSingleton):
|
|
||||||
"""
|
|
||||||
Base class for implementing the Singleton pattern.
|
|
||||||
"""
|
|
||||||
|
|
||||||
def __init__(self):
|
|
||||||
super(Singleton, self).__init__()
|
|
||||||
|
|
||||||
|
|
||||||
@dataclasses.dataclass
|
|
||||||
class ChatFormatterResponse:
|
|
||||||
prompt: str
|
|
||||||
stop: Optional[Union[str, List[str]]] = None
|
|
||||||
|
|
||||||
|
|
||||||
# Base Chat Formatter Protocol
|
|
||||||
class ChatFormatterInterface(Protocol):
|
|
||||||
def __init__(self, template: Optional[object] = None):
|
|
||||||
...
|
|
||||||
|
|
||||||
def __call__(
|
|
||||||
self,
|
|
||||||
messages: List[Dict[str, str]],
|
|
||||||
**kwargs,
|
|
||||||
) -> ChatFormatterResponse:
|
|
||||||
...
|
|
||||||
|
|
||||||
@property
|
|
||||||
def template(self) -> str:
|
|
||||||
...
|
|
||||||
|
|
||||||
|
|
||||||
class AutoChatFormatter(ChatFormatterInterface):
|
|
||||||
def __init__(
|
|
||||||
self,
|
|
||||||
template: Optional[str] = None,
|
|
||||||
template_class: Optional[Template] = None,
|
|
||||||
):
|
|
||||||
if template is not None:
|
|
||||||
self._template = template
|
|
||||||
else:
|
|
||||||
self._template = llama2_template # default template
|
|
||||||
|
|
||||||
self._environment = jinja2.Environment(
|
|
||||||
loader=jinja2.BaseLoader(),
|
|
||||||
trim_blocks=True,
|
|
||||||
lstrip_blocks=True,
|
|
||||||
).from_string(
|
|
||||||
self._template,
|
|
||||||
template_class=template_class,
|
|
||||||
)
|
|
||||||
|
|
||||||
def __call__(
|
|
||||||
self,
|
|
||||||
messages: List[Dict[str, str]],
|
|
||||||
**kwargs: Any,
|
|
||||||
) -> ChatFormatterResponse:
|
|
||||||
formatted_sequence = self._environment.render(messages=messages, **kwargs)
|
|
||||||
return ChatFormatterResponse(prompt=formatted_sequence)
|
|
||||||
|
|
||||||
@property
|
|
||||||
def template(self) -> str:
|
|
||||||
return self._template
|
|
||||||
|
|
||||||
|
|
||||||
class FormatterNotFoundException(Exception):
|
|
||||||
pass
|
|
||||||
|
|
||||||
|
|
||||||
class ChatFormatterFactory(Singleton):
|
|
||||||
_chat_formatters: Dict[str, Callable[[], ChatFormatterInterface]] = {}
|
|
||||||
|
|
||||||
def register_formatter(
|
|
||||||
self,
|
|
||||||
name: str,
|
|
||||||
formatter_callable: Callable[[], ChatFormatterInterface],
|
|
||||||
overwrite=False,
|
|
||||||
):
|
|
||||||
if not overwrite and name in self._chat_formatters:
|
|
||||||
raise ValueError(
|
|
||||||
f"Formatter with name '{name}' is already registered. Use `overwrite=True` to overwrite it."
|
|
||||||
)
|
|
||||||
self._chat_formatters[name] = formatter_callable
|
|
||||||
|
|
||||||
def unregister_formatter(self, name: str):
|
|
||||||
if name in self._chat_formatters:
|
|
||||||
del self._chat_formatters[name]
|
|
||||||
else:
|
|
||||||
raise ValueError(f"No formatter registered under the name '{name}'.")
|
|
||||||
|
|
||||||
def get_formatter_by_name(self, name: str) -> ChatFormatterInterface:
|
|
||||||
try:
|
|
||||||
formatter_callable = self._chat_formatters[name]
|
|
||||||
return formatter_callable()
|
|
||||||
except KeyError:
|
|
||||||
raise FormatterNotFoundException(
|
|
||||||
f"Invalid chat format: {name} (valid formats: {list(self._chat_formatters.keys())})"
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
# Define a chat format class
|
|
||||||
class Llama2Formatter(AutoChatFormatter):
|
|
||||||
def __init__(self):
|
|
||||||
super().__init__(llama2_template)
|
|
||||||
|
|
||||||
|
|
||||||
# With the Singleton pattern applied, regardless of where or how many times
|
|
||||||
# ChatFormatterFactory() is called, it will always return the same instance
|
|
||||||
# of the factory, ensuring that the factory's state is consistent throughout
|
|
||||||
# the application.
|
|
||||||
ChatFormatterFactory().register_formatter("llama-2", Llama2Formatter)
|
|
|
@ -1,5 +1,7 @@
|
||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import json
|
||||||
|
|
||||||
from typing import Dict, Optional, Union, List
|
from typing import Dict, Optional, Union, List
|
||||||
|
|
||||||
import llama_cpp
|
import llama_cpp
|
||||||
|
@ -71,6 +73,24 @@ class LlamaProxy:
|
||||||
chat_handler = llama_cpp.llama_chat_format.Llava15ChatHandler(
|
chat_handler = llama_cpp.llama_chat_format.Llava15ChatHandler(
|
||||||
clip_model_path=settings.clip_model_path, verbose=settings.verbose
|
clip_model_path=settings.clip_model_path, verbose=settings.verbose
|
||||||
)
|
)
|
||||||
|
elif settings.chat_format == "hf-autotokenizer":
|
||||||
|
assert (
|
||||||
|
settings.hf_pretrained_model_name_or_path is not None
|
||||||
|
), "hf_pretrained_model_name_or_path must be set for hf-autotokenizer"
|
||||||
|
chat_handler = (
|
||||||
|
llama_cpp.llama_chat_format.hf_autotokenizer_to_chat_formatter(
|
||||||
|
settings.hf_pretrained_model_name_or_path
|
||||||
|
)
|
||||||
|
)
|
||||||
|
elif settings.chat_format == "hf-tokenizer-config":
|
||||||
|
assert (
|
||||||
|
settings.hf_tokenizer_config_path is not None
|
||||||
|
), "hf_tokenizer_config_path must be set for hf-tokenizer-config"
|
||||||
|
chat_handler = (
|
||||||
|
llama_cpp.llama_chat_format.hf_tokenizer_config_to_chat_formatter(
|
||||||
|
json.load(open(settings.hf_tokenizer_config_path))
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
kv_overrides: Optional[Dict[str, Union[bool, int, float]]] = None
|
kv_overrides: Optional[Dict[str, Union[bool, int, float]]] = None
|
||||||
if settings.kv_overrides is not None:
|
if settings.kv_overrides is not None:
|
||||||
|
@ -141,4 +161,3 @@ class LlamaProxy:
|
||||||
cache = llama_cpp.LlamaRAMCache(capacity_bytes=settings.cache_size)
|
cache = llama_cpp.LlamaRAMCache(capacity_bytes=settings.cache_size)
|
||||||
_model.set_cache(cache)
|
_model.set_cache(cache)
|
||||||
return _model
|
return _model
|
||||||
|
|
||||||
|
|
|
@ -134,6 +134,15 @@ class ModelSettings(BaseSettings):
|
||||||
default=2 << 30,
|
default=2 << 30,
|
||||||
description="The size of the cache in bytes. Only used if cache is True.",
|
description="The size of the cache in bytes. Only used if cache is True.",
|
||||||
)
|
)
|
||||||
|
# Tokenizer Options
|
||||||
|
hf_tokenizer_config_path: Optional[str] = Field(
|
||||||
|
default=None,
|
||||||
|
description="The path to a HuggingFace tokenizer_config.json file.",
|
||||||
|
)
|
||||||
|
hf_pretrained_model_name_or_path: Optional[str] = Field(
|
||||||
|
default=None,
|
||||||
|
description="The model name or path to a pretrained HuggingFace tokenizer model. Same as you would pass to AutoTokenizer.from_pretrained().",
|
||||||
|
)
|
||||||
# Misc
|
# Misc
|
||||||
verbose: bool = Field(
|
verbose: bool = Field(
|
||||||
default=True, description="Whether to print debug information."
|
default=True, description="Whether to print debug information."
|
||||||
|
|
|
@ -1,50 +1,65 @@
|
||||||
from typing import List
|
import json
|
||||||
|
|
||||||
import pytest
|
from llama_cpp import (
|
||||||
|
ChatCompletionRequestUserMessage,
|
||||||
from llama_cpp import ChatCompletionMessage
|
)
|
||||||
from llama_cpp.llama_jinja_format import Llama2Formatter
|
from llama_cpp.llama_chat_format import hf_tokenizer_config_to_chat_formatter
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture
|
mistral_7b_tokenizer_config = """{
|
||||||
def sequence_of_messages() -> List[ChatCompletionMessage]:
|
"add_bos_token": true,
|
||||||
return [
|
"add_eos_token": false,
|
||||||
ChatCompletionMessage(role="system", content="Welcome to CodeHelp Bot!"),
|
"added_tokens_decoder": {
|
||||||
ChatCompletionMessage(
|
"0": {
|
||||||
role="user", content="Hi there! I need some help with Python."
|
"content": "<unk>",
|
||||||
),
|
"lstrip": false,
|
||||||
ChatCompletionMessage(
|
"normalized": false,
|
||||||
role="assistant", content="Of course! What do you need help with in Python?"
|
"rstrip": false,
|
||||||
),
|
"single_word": false,
|
||||||
ChatCompletionMessage(
|
"special": true
|
||||||
role="user",
|
},
|
||||||
content="I'm trying to write a function to find the factorial of a number, but I'm stuck.",
|
"1": {
|
||||||
),
|
"content": "<s>",
|
||||||
ChatCompletionMessage(
|
"lstrip": false,
|
||||||
role="assistant",
|
"normalized": false,
|
||||||
content="I can help with that! Would you like a recursive or iterative solution?",
|
"rstrip": false,
|
||||||
),
|
"single_word": false,
|
||||||
ChatCompletionMessage(
|
"special": true
|
||||||
role="user", content="Let's go with a recursive solution."
|
},
|
||||||
),
|
"2": {
|
||||||
|
"content": "</s>",
|
||||||
|
"lstrip": false,
|
||||||
|
"normalized": false,
|
||||||
|
"rstrip": false,
|
||||||
|
"single_word": false,
|
||||||
|
"special": true
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"additional_special_tokens": [],
|
||||||
|
"bos_token": "<s>",
|
||||||
|
"clean_up_tokenization_spaces": false,
|
||||||
|
"eos_token": "</s>",
|
||||||
|
"legacy": true,
|
||||||
|
"model_max_length": 1000000000000000019884624838656,
|
||||||
|
"pad_token": null,
|
||||||
|
"sp_model_kwargs": {},
|
||||||
|
"spaces_between_special_tokens": false,
|
||||||
|
"tokenizer_class": "LlamaTokenizer",
|
||||||
|
"unk_token": "<unk>",
|
||||||
|
"use_default_system_prompt": false,
|
||||||
|
"chat_template": "{{ bos_token }}{% for message in messages %}{% if (message['role'] == 'user') != (loop.index0 % 2 == 0) %}{{ raise_exception('Conversation roles must alternate user/assistant/user/assistant/...') }}{% endif %}{% if message['role'] == 'user' %}{{ '[INST] ' + message['content'] + ' [/INST]' }}{% elif message['role'] == 'assistant' %}{{ message['content'] + eos_token}}{% else %}{{ raise_exception('Only user and assistant roles are supported!') }}{% endif %}{% endfor %}"
|
||||||
|
}"""
|
||||||
|
|
||||||
|
|
||||||
|
def test_hf_tokenizer_config_str_to_chat_formatter():
|
||||||
|
tokenizer_config = json.loads(mistral_7b_tokenizer_config)
|
||||||
|
chat_formatter = hf_tokenizer_config_to_chat_formatter(
|
||||||
|
tokenizer_config
|
||||||
|
)
|
||||||
|
chat_formatter_respoonse = chat_formatter(
|
||||||
|
messages=[
|
||||||
|
ChatCompletionRequestUserMessage(role="user", content="Hello, world!"),
|
||||||
]
|
]
|
||||||
|
|
||||||
|
|
||||||
def test_llama2_formatter(sequence_of_messages):
|
|
||||||
expected_prompt = (
|
|
||||||
"<<SYS>> Welcome to CodeHelp Bot! <</SYS>>\n"
|
|
||||||
"[INST] Hi there! I need some help with Python. [/INST]\n"
|
|
||||||
"Of course! What do you need help with in Python?\n"
|
|
||||||
"[INST] I'm trying to write a function to find the factorial of a number, but I'm stuck. [/INST]\n"
|
|
||||||
"I can help with that! Would you like a recursive or iterative solution?\n"
|
|
||||||
"[INST] Let's go with a recursive solution. [/INST]\n"
|
|
||||||
)
|
)
|
||||||
|
|
||||||
llama2_formatter_instance = Llama2Formatter()
|
assert chat_formatter_respoonse.prompt == ("<s>[INST] Hello, world! [/INST]</s>" "")
|
||||||
formatter_response = llama2_formatter_instance(sequence_of_messages)
|
|
||||||
assert (
|
|
||||||
expected_prompt == formatter_response.prompt
|
|
||||||
), "The formatted prompt does not match the expected output."
|
|
||||||
|
|
||||||
|
|
||||||
# Optionally, include a test for the 'stop' if it's part of the functionality.
|
|
||||||
|
|
Loading…
Reference in a new issue