6bfe98bd80
* feat: Add support for jinja templating Signed-off-by: teleprint-me <77757836+teleprint-me@users.noreply.github.com> * fix: Refactor chat formatter and update interface for jinja templates - Simplify the `llama2_template` in `llama_jinja_format.py` by removing unnecessary line breaks for readability without affecting functionality. - Update `ChatFormatterInterface` constructor to accept a more generic `Optional[object]` type for the template parameter, enhancing flexibility. - Introduce a `template` property to `ChatFormatterInterface` for standardized access to the template string. - Replace `MetaSingleton` metaclass with `Singleton` for the `ChatFormatterFactory` to streamline the singleton implementation. These changes enhance code readability, maintain usability, and ensure consistency in the chat formatter's design pattern usage. * Add outline for Jinja2 templating integration documentation Signed-off-by: teleprint-me <77757836+teleprint-me@users.noreply.github.com> * Add jinja2 as a dependency with version range for Hugging Face transformers compatibility Signed-off-by: teleprint-me <77757836+teleprint-me@users.noreply.github.com> * Update jinja2 version constraint for mkdocs-material compatibility Signed-off-by: teleprint-me <77757836+teleprint-me@users.noreply.github.com> * Fix attribute name in AutoChatFormatter - Changed attribute name from `self._renderer` to `self._environment` --------- Signed-off-by: teleprint-me <77757836+teleprint-me@users.noreply.github.com>
138 lines
4.1 KiB
Python
138 lines
4.1 KiB
Python
"""
|
|
llama_cpp/llama_jinja_format.py
|
|
"""
|
|
import dataclasses
|
|
from typing import Any, Callable, Dict, List, Optional, Protocol, Union
|
|
|
|
import jinja2
|
|
from jinja2 import Template
|
|
|
|
# NOTE: We sacrifice readability for usability.
|
|
# It will fail to work as expected if we attempt to format it in a readable way.
|
|
llama2_template = """{% for message in messages %}{% if message['role'] == 'user' %}[INST] {{ message['content'] }} [/INST]\n{% elif message['role'] == 'assistant' %}{{ message['content'] }}\n{% elif message['role'] == 'system' %}<<SYS>> {{ message['content'] }} <</SYS>>\n{% endif %}{% endfor %}"""
|
|
|
|
|
|
class MetaSingleton(type):
|
|
"""
|
|
Metaclass for implementing the Singleton pattern.
|
|
"""
|
|
|
|
_instances = {}
|
|
|
|
def __call__(cls, *args, **kwargs):
|
|
if cls not in cls._instances:
|
|
cls._instances[cls] = super(MetaSingleton, cls).__call__(*args, **kwargs)
|
|
return cls._instances[cls]
|
|
|
|
|
|
class Singleton(object, metaclass=MetaSingleton):
|
|
"""
|
|
Base class for implementing the Singleton pattern.
|
|
"""
|
|
|
|
def __init__(self):
|
|
super(Singleton, self).__init__()
|
|
|
|
|
|
@dataclasses.dataclass
|
|
class ChatFormatterResponse:
|
|
prompt: str
|
|
stop: Optional[Union[str, List[str]]] = None
|
|
|
|
|
|
# Base Chat Formatter Protocol
|
|
class ChatFormatterInterface(Protocol):
|
|
def __init__(self, template: Optional[object] = None):
|
|
...
|
|
|
|
def __call__(
|
|
self,
|
|
messages: List[Dict[str, str]],
|
|
**kwargs,
|
|
) -> ChatFormatterResponse:
|
|
...
|
|
|
|
@property
|
|
def template(self) -> str:
|
|
...
|
|
|
|
|
|
class AutoChatFormatter(ChatFormatterInterface):
|
|
def __init__(
|
|
self,
|
|
template: Optional[str] = None,
|
|
template_class: Optional[Template] = None,
|
|
):
|
|
if template is not None:
|
|
self._template = template
|
|
else:
|
|
self._template = llama2_template # default template
|
|
|
|
self._environment = jinja2.Environment(
|
|
loader=jinja2.BaseLoader(),
|
|
trim_blocks=True,
|
|
lstrip_blocks=True,
|
|
).from_string(
|
|
self._template,
|
|
template_class=template_class,
|
|
)
|
|
|
|
def __call__(
|
|
self,
|
|
messages: List[Dict[str, str]],
|
|
**kwargs: Any,
|
|
) -> ChatFormatterResponse:
|
|
formatted_sequence = self._environment.render(messages=messages, **kwargs)
|
|
return ChatFormatterResponse(prompt=formatted_sequence)
|
|
|
|
@property
|
|
def template(self) -> str:
|
|
return self._template
|
|
|
|
|
|
class FormatterNotFoundException(Exception):
|
|
pass
|
|
|
|
|
|
class ChatFormatterFactory(Singleton):
|
|
_chat_formatters: Dict[str, Callable[[], ChatFormatterInterface]] = {}
|
|
|
|
def register_formatter(
|
|
self,
|
|
name: str,
|
|
formatter_callable: Callable[[], ChatFormatterInterface],
|
|
overwrite=False,
|
|
):
|
|
if not overwrite and name in self._chat_formatters:
|
|
raise ValueError(
|
|
f"Formatter with name '{name}' is already registered. Use `overwrite=True` to overwrite it."
|
|
)
|
|
self._chat_formatters[name] = formatter_callable
|
|
|
|
def unregister_formatter(self, name: str):
|
|
if name in self._chat_formatters:
|
|
del self._chat_formatters[name]
|
|
else:
|
|
raise ValueError(f"No formatter registered under the name '{name}'.")
|
|
|
|
def get_formatter_by_name(self, name: str) -> ChatFormatterInterface:
|
|
try:
|
|
formatter_callable = self._chat_formatters[name]
|
|
return formatter_callable()
|
|
except KeyError:
|
|
raise FormatterNotFoundException(
|
|
f"Invalid chat format: {name} (valid formats: {list(self._chat_formatters.keys())})"
|
|
)
|
|
|
|
|
|
# Define a chat format class
|
|
class Llama2Formatter(AutoChatFormatter):
|
|
def __init__(self):
|
|
super().__init__(llama2_template)
|
|
|
|
|
|
# With the Singleton pattern applied, regardless of where or how many times
|
|
# ChatFormatterFactory() is called, it will always return the same instance
|
|
# of the factory, ensuring that the factory's state is consistent throughout
|
|
# the application.
|
|
ChatFormatterFactory().register_formatter("llama-2", Llama2Formatter)
|