feat: Add ability to load chat format from huggingface autotokenizer or tokenizer_config.json files.
This commit is contained in:
parent
48c3b77e6f
commit
b8fc1c7d83
6 changed files with 357 additions and 318 deletions
|
@ -1,7 +1,8 @@
|
|||
import os
|
||||
import sys
|
||||
|
||||
import sys, traceback
|
||||
import sys
|
||||
from typing import Any, Dict
|
||||
|
||||
# Avoid "LookupError: unknown encoding: ascii" when open() called in a destructor
|
||||
outnull_file = open(os.devnull, "w")
|
||||
|
@ -55,3 +56,25 @@ class suppress_stdout_stderr(object):
|
|||
|
||||
self.os.close(self.old_stdout_fileno)
|
||||
self.os.close(self.old_stderr_fileno)
|
||||
|
||||
|
||||
class MetaSingleton(type):
|
||||
"""
|
||||
Metaclass for implementing the Singleton pattern.
|
||||
"""
|
||||
|
||||
_instances: Dict[type, Any] = {}
|
||||
|
||||
def __call__(cls, *args: Any, **kwargs: Any) -> Any:
|
||||
if cls not in cls._instances:
|
||||
cls._instances[cls] = super(MetaSingleton, cls).__call__(*args, **kwargs)
|
||||
return cls._instances[cls]
|
||||
|
||||
|
||||
class Singleton(object, metaclass=MetaSingleton):
|
||||
"""
|
||||
Base class for implementing the Singleton pattern.
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
super(Singleton, self).__init__()
|
||||
|
|
|
@ -6,18 +6,28 @@ import ctypes
|
|||
import dataclasses
|
||||
from typing import Any, Dict, Iterator, List, Optional, Tuple, Union, Protocol
|
||||
|
||||
import jinja2
|
||||
|
||||
import llama_cpp.llama as llama
|
||||
import llama_cpp.llama_types as llama_types
|
||||
import llama_cpp.llama_grammar as llama_grammar
|
||||
|
||||
from ._utils import suppress_stdout_stderr
|
||||
from ._utils import suppress_stdout_stderr, Singleton
|
||||
|
||||
|
||||
class LlamaChatCompletionHandler(Protocol):
|
||||
"""Base Protocol for a llama chat completion handler.
|
||||
|
||||
Very generic protocol that can be used to implement any chat format.
|
||||
The only hard requirement is that it must return a ChatCompletion when
|
||||
stream=False and an iterator of ChatCompletionChunks when stream=True."""
|
||||
|
||||
def __call__(
|
||||
self,
|
||||
*,
|
||||
# llama.cpp instance
|
||||
llama: llama.Llama,
|
||||
# openai api parameters
|
||||
messages: List[llama_types.ChatCompletionRequestMessage],
|
||||
functions: Optional[List[llama_types.ChatCompletionFunction]] = None,
|
||||
function_call: Optional[llama_types.ChatCompletionRequestFunctionCall] = None,
|
||||
|
@ -26,8 +36,6 @@ class LlamaChatCompletionHandler(Protocol):
|
|||
temperature: float = 0.2,
|
||||
top_p: float = 0.95,
|
||||
top_k: int = 40,
|
||||
min_p: float = 0.05,
|
||||
typical_p: float = 1.0,
|
||||
stream: bool = False,
|
||||
stop: Optional[Union[str, List[str]]] = [],
|
||||
seed: Optional[int] = None,
|
||||
|
@ -38,14 +46,17 @@ class LlamaChatCompletionHandler(Protocol):
|
|||
presence_penalty: float = 0.0,
|
||||
frequency_penalty: float = 0.0,
|
||||
repeat_penalty: float = 1.1,
|
||||
model: Optional[str] = None,
|
||||
logit_bias: Optional[Dict[str, float]] = None,
|
||||
# llama.cpp parameters
|
||||
min_p: float = 0.05,
|
||||
typical_p: float = 1.0,
|
||||
tfs_z: float = 1.0,
|
||||
mirostat_mode: int = 0,
|
||||
mirostat_tau: float = 5.0,
|
||||
mirostat_eta: float = 0.1,
|
||||
model: Optional[str] = None,
|
||||
logits_processor: Optional[llama.LogitsProcessorList] = None,
|
||||
grammar: Optional[llama.LlamaGrammar] = None,
|
||||
logit_bias: Optional[Dict[str, float]] = None,
|
||||
**kwargs, # type: ignore
|
||||
) -> Union[
|
||||
llama_types.CreateChatCompletionResponse,
|
||||
|
@ -54,21 +65,83 @@ class LlamaChatCompletionHandler(Protocol):
|
|||
...
|
||||
|
||||
|
||||
CHAT_HANDLERS: Dict[str, LlamaChatCompletionHandler] = {}
|
||||
class LlamaChatCompletionHandlerNotFoundException(Exception):
|
||||
pass
|
||||
|
||||
|
||||
class LlamaChatCompletionHandlerRegistry(Singleton):
|
||||
_chat_handlers: Dict[str, LlamaChatCompletionHandler] = {}
|
||||
|
||||
def register_chat_completion_handler(
|
||||
self,
|
||||
name: str,
|
||||
chat_handler: LlamaChatCompletionHandler,
|
||||
overwrite: bool = False,
|
||||
):
|
||||
if not overwrite and name in self._chat_handlers:
|
||||
raise ValueError(
|
||||
f"Formatter with name '{name}' is already registered. Use `overwrite=True` to overwrite it."
|
||||
)
|
||||
self._chat_handlers[name] = chat_handler
|
||||
|
||||
def unregister_chat_handler(self, name: str):
|
||||
if name in self._chat_handlers:
|
||||
del self._chat_handlers[name]
|
||||
else:
|
||||
raise ValueError(f"No formatter registered under the name '{name}'.")
|
||||
|
||||
def get_chat_completion_handler_by_name(
|
||||
self, name: str
|
||||
) -> LlamaChatCompletionHandler:
|
||||
try:
|
||||
chat_handler = self._chat_handlers[name]
|
||||
return chat_handler
|
||||
except KeyError:
|
||||
raise LlamaChatCompletionHandlerNotFoundException(
|
||||
f"Invalid chat handler: {name} (valid formats: {list(self._chat_handlers.keys())})"
|
||||
)
|
||||
|
||||
|
||||
def get_chat_completion_handler(name: str) -> LlamaChatCompletionHandler:
|
||||
return CHAT_HANDLERS[name]
|
||||
return LlamaChatCompletionHandlerRegistry().get_chat_completion_handler_by_name(
|
||||
name
|
||||
)
|
||||
|
||||
|
||||
def register_chat_completion_handler(name: str):
|
||||
def decorator(f: LlamaChatCompletionHandler):
|
||||
CHAT_HANDLERS[name] = f
|
||||
LlamaChatCompletionHandlerRegistry().register_chat_completion_handler(name, f)
|
||||
return f
|
||||
|
||||
return decorator
|
||||
|
||||
|
||||
### Chat Formatter ###
|
||||
|
||||
|
||||
@dataclasses.dataclass
|
||||
class ChatFormatterResponse:
|
||||
prompt: str
|
||||
stop: Optional[Union[str, List[str]]] = None
|
||||
|
||||
|
||||
class ChatFormatter(Protocol):
|
||||
"""Base Protocol for a chat formatter. A chat formatter is a function that
|
||||
takes a list of messages and returns a formatted prompt. It can also return
|
||||
a stop token or list of stop tokens to use for the completion."""
|
||||
|
||||
def __call__(
|
||||
self,
|
||||
*,
|
||||
messages: List[llama_types.ChatCompletionRequestMessage],
|
||||
**kwargs: Any,
|
||||
) -> ChatFormatterResponse:
|
||||
...
|
||||
|
||||
|
||||
### Utility functions for formatting chat prompts ###
|
||||
|
||||
|
||||
def _get_system_message(
|
||||
messages: List[llama_types.ChatCompletionRequestMessage],
|
||||
) -> str:
|
||||
|
@ -80,14 +153,18 @@ def _get_system_message(
|
|||
|
||||
|
||||
def _map_roles(
|
||||
messages: List[llama_types.ChatCompletionRequestMessage], role_map: Dict[str, str]
|
||||
messages: List[llama_types.ChatCompletionRequestMessage],
|
||||
role_map: Dict[str, str],
|
||||
) -> List[Tuple[str, Optional[str]]]:
|
||||
"""Map the message roles."""
|
||||
output: List[Tuple[str, Optional[str]]] = []
|
||||
for message in messages:
|
||||
role = message["role"]
|
||||
if role in role_map:
|
||||
output.append((role_map[role], message["content"]))
|
||||
content: str | None = (
|
||||
message["content"] if isinstance(message["content"], str) else None
|
||||
)
|
||||
output.append((role_map[role], content))
|
||||
return output
|
||||
|
||||
|
||||
|
@ -99,7 +176,8 @@ def _format_llama2(
|
|||
ret = system_message + sep
|
||||
for i, (role, message) in enumerate(messages):
|
||||
if system_message and i == 0:
|
||||
ret += message + seps[i % 2]
|
||||
m = message or ""
|
||||
ret += m + seps[i % 2]
|
||||
elif message:
|
||||
ret += role + message + " " + seps[i % 2]
|
||||
else:
|
||||
|
@ -172,6 +250,7 @@ def _format_chatml(
|
|||
ret += role + "\n"
|
||||
return ret
|
||||
|
||||
|
||||
def _format_chatglm3(
|
||||
system_message: str, messages: List[Tuple[str, Optional[str]]], sep: str
|
||||
) -> str:
|
||||
|
@ -187,30 +266,10 @@ def _format_chatglm3(
|
|||
return ret
|
||||
|
||||
|
||||
@dataclasses.dataclass
|
||||
class ChatFormatterResponse:
|
||||
prompt: str
|
||||
stop: Optional[Union[str, List[str]]] = None
|
||||
|
||||
|
||||
class ChatFormatter(Protocol):
|
||||
def __call__(
|
||||
self,
|
||||
*,
|
||||
messages: List[llama_types.ChatCompletionRequestMessage],
|
||||
**kwargs: Any,
|
||||
) -> ChatFormatterResponse:
|
||||
...
|
||||
|
||||
|
||||
class BasicChatHandler:
|
||||
def __init__(self, chat_format: str):
|
||||
self.chat_format = chat_format
|
||||
|
||||
|
||||
def _convert_text_completion_to_chat(
|
||||
completion: llama_types.Completion,
|
||||
) -> llama_types.ChatCompletion:
|
||||
assert "usage" in completion
|
||||
return {
|
||||
"id": "chat" + completion["id"],
|
||||
"object": "chat.completion",
|
||||
|
@ -286,103 +345,95 @@ def _convert_completion_to_chat(
|
|||
return _convert_text_completion_to_chat(completion)
|
||||
|
||||
|
||||
_CHAT_FORMATS: Dict[str, ChatFormatter] = {}
|
||||
def chat_formatter_to_chat_completion_handler(
|
||||
chat_formatter: ChatFormatter,
|
||||
) -> LlamaChatCompletionHandler:
|
||||
def chat_completion_handler(
|
||||
*,
|
||||
llama: llama.Llama,
|
||||
messages: List[llama_types.ChatCompletionRequestMessage],
|
||||
functions: Optional[List[llama_types.ChatCompletionFunction]] = None,
|
||||
function_call: Optional[llama_types.ChatCompletionRequestFunctionCall] = None,
|
||||
tools: Optional[List[llama_types.ChatCompletionTool]] = None,
|
||||
tool_choice: Optional[llama_types.ChatCompletionToolChoiceOption] = None,
|
||||
temperature: float = 0.2,
|
||||
top_p: float = 0.95,
|
||||
top_k: int = 40,
|
||||
min_p: float = 0.05,
|
||||
typical_p: float = 1.0,
|
||||
stream: bool = False,
|
||||
stop: Optional[Union[str, List[str]]] = [],
|
||||
seed: Optional[int] = None,
|
||||
response_format: Optional[
|
||||
llama_types.ChatCompletionRequestResponseFormat
|
||||
] = None,
|
||||
max_tokens: Optional[int] = None,
|
||||
presence_penalty: float = 0.0,
|
||||
frequency_penalty: float = 0.0,
|
||||
repeat_penalty: float = 1.1,
|
||||
tfs_z: float = 1.0,
|
||||
mirostat_mode: int = 0,
|
||||
mirostat_tau: float = 5.0,
|
||||
mirostat_eta: float = 0.1,
|
||||
model: Optional[str] = None,
|
||||
logits_processor: Optional[llama.LogitsProcessorList] = None,
|
||||
grammar: Optional[llama.LlamaGrammar] = None,
|
||||
logit_bias: Optional[Dict[str, float]] = None,
|
||||
**kwargs, # type: ignore
|
||||
) -> Union[
|
||||
llama_types.CreateChatCompletionResponse,
|
||||
Iterator[llama_types.CreateChatCompletionStreamResponse],
|
||||
]:
|
||||
result = chat_formatter(
|
||||
messages=messages,
|
||||
functions=functions,
|
||||
function_call=function_call,
|
||||
)
|
||||
prompt = result.prompt
|
||||
if result.stop is not None:
|
||||
stop = [] if stop is None else [stop] if isinstance(stop, str) else stop
|
||||
rstop = result.stop if isinstance(result.stop, list) else [result.stop]
|
||||
stop = stop + rstop
|
||||
|
||||
if response_format is not None and response_format["type"] == "json_object":
|
||||
grammar = llama_grammar.LlamaGrammar.from_string(llama_grammar.JSON_GBNF)
|
||||
|
||||
completion_or_chunks = llama.create_completion(
|
||||
prompt=prompt,
|
||||
temperature=temperature,
|
||||
top_p=top_p,
|
||||
top_k=top_k,
|
||||
min_p=min_p,
|
||||
typical_p=typical_p,
|
||||
stream=stream,
|
||||
stop=stop,
|
||||
seed=seed,
|
||||
max_tokens=max_tokens,
|
||||
presence_penalty=presence_penalty,
|
||||
frequency_penalty=frequency_penalty,
|
||||
repeat_penalty=repeat_penalty,
|
||||
tfs_z=tfs_z,
|
||||
mirostat_mode=mirostat_mode,
|
||||
mirostat_tau=mirostat_tau,
|
||||
mirostat_eta=mirostat_eta,
|
||||
model=model,
|
||||
logits_processor=logits_processor,
|
||||
grammar=grammar,
|
||||
logit_bias=logit_bias,
|
||||
)
|
||||
return _convert_completion_to_chat(completion_or_chunks, stream=stream)
|
||||
|
||||
return chat_completion_handler
|
||||
|
||||
|
||||
def register_chat_format(name: str):
|
||||
def decorator(f: ChatFormatter):
|
||||
def basic_create_chat_completion(
|
||||
*,
|
||||
llama: llama.Llama,
|
||||
messages: List[llama_types.ChatCompletionRequestMessage],
|
||||
functions: Optional[List[llama_types.ChatCompletionFunction]] = None,
|
||||
function_call: Optional[
|
||||
llama_types.ChatCompletionRequestFunctionCall
|
||||
] = None,
|
||||
tools: Optional[List[llama_types.ChatCompletionTool]] = None,
|
||||
tool_choice: Optional[llama_types.ChatCompletionToolChoiceOption] = None,
|
||||
temperature: float = 0.2,
|
||||
top_p: float = 0.95,
|
||||
top_k: int = 40,
|
||||
min_p: float = 0.05,
|
||||
typical_p: float = 1.0,
|
||||
stream: bool = False,
|
||||
stop: Optional[Union[str, List[str]]] = [],
|
||||
seed: Optional[int] = None,
|
||||
response_format: Optional[
|
||||
llama_types.ChatCompletionRequestResponseFormat
|
||||
] = None,
|
||||
max_tokens: Optional[int] = None,
|
||||
presence_penalty: float = 0.0,
|
||||
frequency_penalty: float = 0.0,
|
||||
repeat_penalty: float = 1.1,
|
||||
tfs_z: float = 1.0,
|
||||
mirostat_mode: int = 0,
|
||||
mirostat_tau: float = 5.0,
|
||||
mirostat_eta: float = 0.1,
|
||||
model: Optional[str] = None,
|
||||
logits_processor: Optional[llama.LogitsProcessorList] = None,
|
||||
grammar: Optional[llama.LlamaGrammar] = None,
|
||||
logit_bias: Optional[Dict[str, float]] = None,
|
||||
**kwargs, # type: ignore
|
||||
) -> Union[
|
||||
llama_types.CreateChatCompletionResponse,
|
||||
Iterator[llama_types.CreateChatCompletionStreamResponse],
|
||||
]:
|
||||
result = f(
|
||||
messages=messages,
|
||||
functions=functions,
|
||||
function_call=function_call,
|
||||
)
|
||||
prompt = result.prompt
|
||||
if result.stop is not None:
|
||||
stop = [] if stop is None else [stop] if isinstance(stop, str) else stop
|
||||
rstop = result.stop if isinstance(result.stop, list) else [result.stop]
|
||||
stop = stop + rstop
|
||||
|
||||
if response_format is not None and response_format["type"] == "json_object":
|
||||
grammar = llama_grammar.LlamaGrammar.from_string(
|
||||
llama_grammar.JSON_GBNF
|
||||
)
|
||||
|
||||
completion_or_chunks = llama.create_completion(
|
||||
prompt=prompt,
|
||||
temperature=temperature,
|
||||
top_p=top_p,
|
||||
top_k=top_k,
|
||||
min_p=min_p,
|
||||
typical_p=typical_p,
|
||||
stream=stream,
|
||||
stop=stop,
|
||||
seed=seed,
|
||||
max_tokens=max_tokens,
|
||||
presence_penalty=presence_penalty,
|
||||
frequency_penalty=frequency_penalty,
|
||||
repeat_penalty=repeat_penalty,
|
||||
tfs_z=tfs_z,
|
||||
mirostat_mode=mirostat_mode,
|
||||
mirostat_tau=mirostat_tau,
|
||||
mirostat_eta=mirostat_eta,
|
||||
model=model,
|
||||
logits_processor=logits_processor,
|
||||
grammar=grammar,
|
||||
logit_bias=logit_bias,
|
||||
)
|
||||
return _convert_completion_to_chat(completion_or_chunks, stream=stream)
|
||||
|
||||
register_chat_completion_handler(name)(basic_create_chat_completion)
|
||||
return f
|
||||
|
||||
return decorator
|
||||
|
||||
|
||||
def get_chat_format(name: str):
|
||||
try:
|
||||
return _CHAT_FORMATS[name]
|
||||
except KeyError:
|
||||
raise ValueError(
|
||||
f"Invalid chat format: {name} (valid formats: {list(_CHAT_FORMATS.keys())})"
|
||||
chat_completion_handler = chat_formatter_to_chat_completion_handler(f)
|
||||
LlamaChatCompletionHandlerRegistry().register_chat_completion_handler(
|
||||
name, chat_completion_handler
|
||||
)
|
||||
return f
|
||||
return decorator
|
||||
|
||||
|
||||
def hf_autotokenizer_to_chat_formatter(
|
||||
|
@ -391,22 +442,78 @@ def hf_autotokenizer_to_chat_formatter(
|
|||
# https://huggingface.co/docs/transformers/main/chat_templating
|
||||
# https://huggingface.co/mistralai/Mistral-7B-Instruct-v0.1#instruction-format
|
||||
# https://huggingface.co/mistralai/Mistral-7B-Instruct-v0.1/blob/main/tokenizer_config.json
|
||||
from transformers import AutoTokenizer
|
||||
from transformers import AutoTokenizer # type: ignore
|
||||
|
||||
tokenizer = AutoTokenizer.from_pretrained(pretrained_model_name_or_path)
|
||||
tokenizer = AutoTokenizer.from_pretrained(pretrained_model_name_or_path) # type: ignore
|
||||
|
||||
def format_autotokenizer(
|
||||
messages: List[llama_types.ChatCompletionRequestMessage],
|
||||
**kwargs: Any,
|
||||
) -> ChatFormatterResponse:
|
||||
tokenizer.use_default_system_prompt = False
|
||||
_prompt = tokenizer.apply_chat_template(messages, tokenize=False)
|
||||
tokenizer.use_default_system_prompt = False # type: ignore
|
||||
prompt: str = tokenizer.apply_chat_template(messages, tokenize=False) # type: ignore
|
||||
assert isinstance(prompt, str)
|
||||
# Return formatted prompt and eos token by default
|
||||
return ChatFormatterResponse(prompt=_prompt, stop=tokenizer.eos_token)
|
||||
return ChatFormatterResponse(prompt=prompt, stop=tokenizer.eos_token)
|
||||
|
||||
return format_autotokenizer
|
||||
|
||||
|
||||
def hf_autotokenizer_to_chat_completion_handler(
|
||||
pretrained_model_name_or_path: Union[str, os.PathLike[str]]
|
||||
) -> LlamaChatCompletionHandler:
|
||||
chat_formatter = hf_autotokenizer_to_chat_formatter(pretrained_model_name_or_path)
|
||||
return chat_formatter_to_chat_completion_handler(chat_formatter)
|
||||
|
||||
|
||||
def hf_tokenizer_config_to_chat_formatter(tokenizer_config: Dict[str, Any]) -> ChatFormatter:
|
||||
assert isinstance(tokenizer_config, dict)
|
||||
|
||||
assert "chat_template" in tokenizer_config
|
||||
assert isinstance(tokenizer_config["chat_template"], str)
|
||||
chat_template = tokenizer_config["chat_template"]
|
||||
|
||||
assert "bos_token" in tokenizer_config
|
||||
assert isinstance(tokenizer_config["bos_token"], str)
|
||||
bos_token = tokenizer_config["bos_token"]
|
||||
|
||||
assert "eos_token" in tokenizer_config
|
||||
assert isinstance(tokenizer_config["eos_token"], str)
|
||||
eos_token = tokenizer_config["eos_token"]
|
||||
|
||||
env = jinja2.Environment(
|
||||
loader=jinja2.BaseLoader(),
|
||||
trim_blocks=True,
|
||||
lstrip_blocks=True,
|
||||
).from_string(chat_template)
|
||||
|
||||
def format_autotokenizer(
|
||||
messages: List[llama_types.ChatCompletionRequestMessage],
|
||||
**kwargs: Any,
|
||||
) -> ChatFormatterResponse:
|
||||
# TODO: veryify this is correct
|
||||
# Add a blank assistant message to the end of the messages to prompt the model to generate a response
|
||||
prompt = env.render(
|
||||
messages=[
|
||||
*messages,
|
||||
llama_types.ChatCompletionRequestAssistantMessage(
|
||||
role="assistant", content=""
|
||||
),
|
||||
],
|
||||
bos_token=bos_token,
|
||||
eos_token=eos_token,
|
||||
)
|
||||
return ChatFormatterResponse(prompt=prompt, stop=eos_token)
|
||||
return format_autotokenizer
|
||||
|
||||
|
||||
def hf_tokenizer_config_to_chat_completion_handler(
|
||||
tokenizer_config: Dict[str, Any],
|
||||
) -> LlamaChatCompletionHandler:
|
||||
chat_formatter = hf_tokenizer_config_to_chat_formatter(tokenizer_config)
|
||||
return chat_formatter_to_chat_completion_handler(chat_formatter)
|
||||
|
||||
|
||||
# see https://github.com/huggingface/transformers/blob/main/src/transformers/models/llama/tokenization_llama.py
|
||||
# system prompt is "embedded" in the first message
|
||||
@register_chat_format("llama-2")
|
||||
|
@ -437,21 +544,23 @@ def format_alpaca(
|
|||
_prompt = _format_add_colon_two(system_message, _messages, _sep, _sep2)
|
||||
return ChatFormatterResponse(prompt=_prompt)
|
||||
|
||||
|
||||
@register_chat_format("qwen")
|
||||
def format_qwen(
|
||||
messages: List[llama_types.ChatCompletionRequestMessage],
|
||||
**kwargs: Any,
|
||||
) -> ChatFormatterResponse:
|
||||
_roles = dict(user="<|im_start|>user", assistant="<|im_start|>assistant")
|
||||
system_message="You are a helpful assistant."
|
||||
system_template="<|im_start|>system\n{system_message}"
|
||||
system_message=system_template.format(system_message=system_message)
|
||||
system_message = "You are a helpful assistant."
|
||||
system_template = "<|im_start|>system\n{system_message}"
|
||||
system_message = system_template.format(system_message=system_message)
|
||||
_messages = _map_roles(messages, _roles)
|
||||
_messages.append((_roles["assistant"], None))
|
||||
_sep = "<|im_end|>"
|
||||
_prompt = _format_chatml(system_message, _messages, _sep)
|
||||
_sep2 = "<|endoftext|>"
|
||||
return ChatFormatterResponse(prompt=_prompt,stop=_sep2)
|
||||
return ChatFormatterResponse(prompt=_prompt, stop=_sep2)
|
||||
|
||||
|
||||
@register_chat_format("vicuna")
|
||||
def format(
|
||||
|
@ -650,6 +759,7 @@ def format_mistrallite(
|
|||
_prompt = _format_no_colon_single(system_message, _messages, _sep)
|
||||
return ChatFormatterResponse(prompt=_prompt)
|
||||
|
||||
|
||||
@register_chat_format("zephyr")
|
||||
def format_zephyr(
|
||||
messages: List[llama_types.ChatCompletionRequestMessage],
|
||||
|
@ -699,6 +809,7 @@ def format_chatml(
|
|||
_prompt = _format_chatml(system_message, _messages, _sep)
|
||||
return ChatFormatterResponse(prompt=_prompt, stop=_sep)
|
||||
|
||||
|
||||
@register_chat_format("chatglm3")
|
||||
def format_chatglm3(
|
||||
messages: List[llama_types.ChatCompletionRequestMessage],
|
||||
|
@ -739,7 +850,7 @@ def format_openchat(
|
|||
@register_chat_format("saiga")
|
||||
def format_saiga(
|
||||
messages: list[llama_types.ChatCompletionRequestMessage],
|
||||
**kwargs,
|
||||
**kwargs: Any,
|
||||
) -> ChatFormatterResponse:
|
||||
_message_template = "<s>{role}\n{content}</s>"
|
||||
_roles = dict(user="user", bot="bot", system="system")
|
||||
|
|
|
@ -1,138 +0,0 @@
|
|||
"""
|
||||
llama_cpp/llama_jinja_format.py
|
||||
"""
|
||||
import dataclasses
|
||||
from typing import Any, Callable, Dict, List, Optional, Protocol, Union
|
||||
|
||||
import jinja2
|
||||
from jinja2 import Template
|
||||
|
||||
# NOTE: We sacrifice readability for usability.
|
||||
# It will fail to work as expected if we attempt to format it in a readable way.
|
||||
llama2_template = """{% for message in messages %}{% if message['role'] == 'user' %}[INST] {{ message['content'] }} [/INST]\n{% elif message['role'] == 'assistant' %}{{ message['content'] }}\n{% elif message['role'] == 'system' %}<<SYS>> {{ message['content'] }} <</SYS>>\n{% endif %}{% endfor %}"""
|
||||
|
||||
|
||||
class MetaSingleton(type):
|
||||
"""
|
||||
Metaclass for implementing the Singleton pattern.
|
||||
"""
|
||||
|
||||
_instances = {}
|
||||
|
||||
def __call__(cls, *args, **kwargs):
|
||||
if cls not in cls._instances:
|
||||
cls._instances[cls] = super(MetaSingleton, cls).__call__(*args, **kwargs)
|
||||
return cls._instances[cls]
|
||||
|
||||
|
||||
class Singleton(object, metaclass=MetaSingleton):
|
||||
"""
|
||||
Base class for implementing the Singleton pattern.
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
super(Singleton, self).__init__()
|
||||
|
||||
|
||||
@dataclasses.dataclass
|
||||
class ChatFormatterResponse:
|
||||
prompt: str
|
||||
stop: Optional[Union[str, List[str]]] = None
|
||||
|
||||
|
||||
# Base Chat Formatter Protocol
|
||||
class ChatFormatterInterface(Protocol):
|
||||
def __init__(self, template: Optional[object] = None):
|
||||
...
|
||||
|
||||
def __call__(
|
||||
self,
|
||||
messages: List[Dict[str, str]],
|
||||
**kwargs,
|
||||
) -> ChatFormatterResponse:
|
||||
...
|
||||
|
||||
@property
|
||||
def template(self) -> str:
|
||||
...
|
||||
|
||||
|
||||
class AutoChatFormatter(ChatFormatterInterface):
|
||||
def __init__(
|
||||
self,
|
||||
template: Optional[str] = None,
|
||||
template_class: Optional[Template] = None,
|
||||
):
|
||||
if template is not None:
|
||||
self._template = template
|
||||
else:
|
||||
self._template = llama2_template # default template
|
||||
|
||||
self._environment = jinja2.Environment(
|
||||
loader=jinja2.BaseLoader(),
|
||||
trim_blocks=True,
|
||||
lstrip_blocks=True,
|
||||
).from_string(
|
||||
self._template,
|
||||
template_class=template_class,
|
||||
)
|
||||
|
||||
def __call__(
|
||||
self,
|
||||
messages: List[Dict[str, str]],
|
||||
**kwargs: Any,
|
||||
) -> ChatFormatterResponse:
|
||||
formatted_sequence = self._environment.render(messages=messages, **kwargs)
|
||||
return ChatFormatterResponse(prompt=formatted_sequence)
|
||||
|
||||
@property
|
||||
def template(self) -> str:
|
||||
return self._template
|
||||
|
||||
|
||||
class FormatterNotFoundException(Exception):
|
||||
pass
|
||||
|
||||
|
||||
class ChatFormatterFactory(Singleton):
|
||||
_chat_formatters: Dict[str, Callable[[], ChatFormatterInterface]] = {}
|
||||
|
||||
def register_formatter(
|
||||
self,
|
||||
name: str,
|
||||
formatter_callable: Callable[[], ChatFormatterInterface],
|
||||
overwrite=False,
|
||||
):
|
||||
if not overwrite and name in self._chat_formatters:
|
||||
raise ValueError(
|
||||
f"Formatter with name '{name}' is already registered. Use `overwrite=True` to overwrite it."
|
||||
)
|
||||
self._chat_formatters[name] = formatter_callable
|
||||
|
||||
def unregister_formatter(self, name: str):
|
||||
if name in self._chat_formatters:
|
||||
del self._chat_formatters[name]
|
||||
else:
|
||||
raise ValueError(f"No formatter registered under the name '{name}'.")
|
||||
|
||||
def get_formatter_by_name(self, name: str) -> ChatFormatterInterface:
|
||||
try:
|
||||
formatter_callable = self._chat_formatters[name]
|
||||
return formatter_callable()
|
||||
except KeyError:
|
||||
raise FormatterNotFoundException(
|
||||
f"Invalid chat format: {name} (valid formats: {list(self._chat_formatters.keys())})"
|
||||
)
|
||||
|
||||
|
||||
# Define a chat format class
|
||||
class Llama2Formatter(AutoChatFormatter):
|
||||
def __init__(self):
|
||||
super().__init__(llama2_template)
|
||||
|
||||
|
||||
# With the Singleton pattern applied, regardless of where or how many times
|
||||
# ChatFormatterFactory() is called, it will always return the same instance
|
||||
# of the factory, ensuring that the factory's state is consistent throughout
|
||||
# the application.
|
||||
ChatFormatterFactory().register_formatter("llama-2", Llama2Formatter)
|
|
@ -1,5 +1,7 @@
|
|||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
|
||||
from typing import Dict, Optional, Union, List
|
||||
|
||||
import llama_cpp
|
||||
|
@ -71,7 +73,25 @@ class LlamaProxy:
|
|||
chat_handler = llama_cpp.llama_chat_format.Llava15ChatHandler(
|
||||
clip_model_path=settings.clip_model_path, verbose=settings.verbose
|
||||
)
|
||||
|
||||
elif settings.chat_format == "hf-autotokenizer":
|
||||
assert (
|
||||
settings.hf_pretrained_model_name_or_path is not None
|
||||
), "hf_pretrained_model_name_or_path must be set for hf-autotokenizer"
|
||||
chat_handler = (
|
||||
llama_cpp.llama_chat_format.hf_autotokenizer_to_chat_formatter(
|
||||
settings.hf_pretrained_model_name_or_path
|
||||
)
|
||||
)
|
||||
elif settings.chat_format == "hf-tokenizer-config":
|
||||
assert (
|
||||
settings.hf_tokenizer_config_path is not None
|
||||
), "hf_tokenizer_config_path must be set for hf-tokenizer-config"
|
||||
chat_handler = (
|
||||
llama_cpp.llama_chat_format.hf_tokenizer_config_to_chat_formatter(
|
||||
json.load(open(settings.hf_tokenizer_config_path))
|
||||
)
|
||||
)
|
||||
|
||||
kv_overrides: Optional[Dict[str, Union[bool, int, float]]] = None
|
||||
if settings.kv_overrides is not None:
|
||||
assert isinstance(settings.kv_overrides, list)
|
||||
|
@ -141,4 +161,3 @@ class LlamaProxy:
|
|||
cache = llama_cpp.LlamaRAMCache(capacity_bytes=settings.cache_size)
|
||||
_model.set_cache(cache)
|
||||
return _model
|
||||
|
||||
|
|
|
@ -134,6 +134,15 @@ class ModelSettings(BaseSettings):
|
|||
default=2 << 30,
|
||||
description="The size of the cache in bytes. Only used if cache is True.",
|
||||
)
|
||||
# Tokenizer Options
|
||||
hf_tokenizer_config_path: Optional[str] = Field(
|
||||
default=None,
|
||||
description="The path to a HuggingFace tokenizer_config.json file.",
|
||||
)
|
||||
hf_pretrained_model_name_or_path: Optional[str] = Field(
|
||||
default=None,
|
||||
description="The model name or path to a pretrained HuggingFace tokenizer model. Same as you would pass to AutoTokenizer.from_pretrained().",
|
||||
)
|
||||
# Misc
|
||||
verbose: bool = Field(
|
||||
default=True, description="Whether to print debug information."
|
||||
|
|
|
@ -1,50 +1,65 @@
|
|||
from typing import List
|
||||
import json
|
||||
|
||||
import pytest
|
||||
|
||||
from llama_cpp import ChatCompletionMessage
|
||||
from llama_cpp.llama_jinja_format import Llama2Formatter
|
||||
from llama_cpp import (
|
||||
ChatCompletionRequestUserMessage,
|
||||
)
|
||||
from llama_cpp.llama_chat_format import hf_tokenizer_config_to_chat_formatter
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def sequence_of_messages() -> List[ChatCompletionMessage]:
|
||||
return [
|
||||
ChatCompletionMessage(role="system", content="Welcome to CodeHelp Bot!"),
|
||||
ChatCompletionMessage(
|
||||
role="user", content="Hi there! I need some help with Python."
|
||||
),
|
||||
ChatCompletionMessage(
|
||||
role="assistant", content="Of course! What do you need help with in Python?"
|
||||
),
|
||||
ChatCompletionMessage(
|
||||
role="user",
|
||||
content="I'm trying to write a function to find the factorial of a number, but I'm stuck.",
|
||||
),
|
||||
ChatCompletionMessage(
|
||||
role="assistant",
|
||||
content="I can help with that! Would you like a recursive or iterative solution?",
|
||||
),
|
||||
ChatCompletionMessage(
|
||||
role="user", content="Let's go with a recursive solution."
|
||||
),
|
||||
]
|
||||
mistral_7b_tokenizer_config = """{
|
||||
"add_bos_token": true,
|
||||
"add_eos_token": false,
|
||||
"added_tokens_decoder": {
|
||||
"0": {
|
||||
"content": "<unk>",
|
||||
"lstrip": false,
|
||||
"normalized": false,
|
||||
"rstrip": false,
|
||||
"single_word": false,
|
||||
"special": true
|
||||
},
|
||||
"1": {
|
||||
"content": "<s>",
|
||||
"lstrip": false,
|
||||
"normalized": false,
|
||||
"rstrip": false,
|
||||
"single_word": false,
|
||||
"special": true
|
||||
},
|
||||
"2": {
|
||||
"content": "</s>",
|
||||
"lstrip": false,
|
||||
"normalized": false,
|
||||
"rstrip": false,
|
||||
"single_word": false,
|
||||
"special": true
|
||||
}
|
||||
},
|
||||
"additional_special_tokens": [],
|
||||
"bos_token": "<s>",
|
||||
"clean_up_tokenization_spaces": false,
|
||||
"eos_token": "</s>",
|
||||
"legacy": true,
|
||||
"model_max_length": 1000000000000000019884624838656,
|
||||
"pad_token": null,
|
||||
"sp_model_kwargs": {},
|
||||
"spaces_between_special_tokens": false,
|
||||
"tokenizer_class": "LlamaTokenizer",
|
||||
"unk_token": "<unk>",
|
||||
"use_default_system_prompt": false,
|
||||
"chat_template": "{{ bos_token }}{% for message in messages %}{% if (message['role'] == 'user') != (loop.index0 % 2 == 0) %}{{ raise_exception('Conversation roles must alternate user/assistant/user/assistant/...') }}{% endif %}{% if message['role'] == 'user' %}{{ '[INST] ' + message['content'] + ' [/INST]' }}{% elif message['role'] == 'assistant' %}{{ message['content'] + eos_token}}{% else %}{{ raise_exception('Only user and assistant roles are supported!') }}{% endif %}{% endfor %}"
|
||||
}"""
|
||||
|
||||
|
||||
def test_llama2_formatter(sequence_of_messages):
|
||||
expected_prompt = (
|
||||
"<<SYS>> Welcome to CodeHelp Bot! <</SYS>>\n"
|
||||
"[INST] Hi there! I need some help with Python. [/INST]\n"
|
||||
"Of course! What do you need help with in Python?\n"
|
||||
"[INST] I'm trying to write a function to find the factorial of a number, but I'm stuck. [/INST]\n"
|
||||
"I can help with that! Would you like a recursive or iterative solution?\n"
|
||||
"[INST] Let's go with a recursive solution. [/INST]\n"
|
||||
def test_hf_tokenizer_config_str_to_chat_formatter():
|
||||
tokenizer_config = json.loads(mistral_7b_tokenizer_config)
|
||||
chat_formatter = hf_tokenizer_config_to_chat_formatter(
|
||||
tokenizer_config
|
||||
)
|
||||
chat_formatter_respoonse = chat_formatter(
|
||||
messages=[
|
||||
ChatCompletionRequestUserMessage(role="user", content="Hello, world!"),
|
||||
]
|
||||
)
|
||||
|
||||
llama2_formatter_instance = Llama2Formatter()
|
||||
formatter_response = llama2_formatter_instance(sequence_of_messages)
|
||||
assert (
|
||||
expected_prompt == formatter_response.prompt
|
||||
), "The formatted prompt does not match the expected output."
|
||||
|
||||
|
||||
# Optionally, include a test for the 'stop' if it's part of the functionality.
|
||||
assert chat_formatter_respoonse.prompt == ("<s>[INST] Hello, world! [/INST]</s>" "")
|
||||
|
|
Loading…
Reference in a new issue