""" llama_cpp/llama_jinja_format.py """ import dataclasses from typing import Any, Callable, Dict, List, Optional, Protocol, Union import jinja2 from jinja2 import Template # NOTE: We sacrifice readability for usability. # It will fail to work as expected if we attempt to format it in a readable way. llama2_template = """{% for message in messages %}{% if message['role'] == 'user' %}[INST] {{ message['content'] }} [/INST]\n{% elif message['role'] == 'assistant' %}{{ message['content'] }}\n{% elif message['role'] == 'system' %}<> {{ message['content'] }} <>\n{% endif %}{% endfor %}""" class MetaSingleton(type): """ Metaclass for implementing the Singleton pattern. """ _instances = {} def __call__(cls, *args, **kwargs): if cls not in cls._instances: cls._instances[cls] = super(MetaSingleton, cls).__call__(*args, **kwargs) return cls._instances[cls] class Singleton(object, metaclass=MetaSingleton): """ Base class for implementing the Singleton pattern. """ def __init__(self): super(Singleton, self).__init__() @dataclasses.dataclass class ChatFormatterResponse: prompt: str stop: Optional[Union[str, List[str]]] = None # Base Chat Formatter Protocol class ChatFormatterInterface(Protocol): def __init__(self, template: Optional[object] = None): ... def __call__( self, messages: List[Dict[str, str]], **kwargs, ) -> ChatFormatterResponse: ... @property def template(self) -> str: ... class AutoChatFormatter(ChatFormatterInterface): def __init__( self, template: Optional[str] = None, template_class: Optional[Template] = None, ): if template is not None: self._template = template else: self._template = llama2_template # default template self._environment = jinja2.Environment( loader=jinja2.BaseLoader(), trim_blocks=True, lstrip_blocks=True, ).from_string( self._template, template_class=template_class, ) def __call__( self, messages: List[Dict[str, str]], **kwargs: Any, ) -> ChatFormatterResponse: formatted_sequence = self._environment.render(messages=messages, **kwargs) return ChatFormatterResponse(prompt=formatted_sequence) @property def template(self) -> str: return self._template class FormatterNotFoundException(Exception): pass class ChatFormatterFactory(Singleton): _chat_formatters: Dict[str, Callable[[], ChatFormatterInterface]] = {} def register_formatter( self, name: str, formatter_callable: Callable[[], ChatFormatterInterface], overwrite=False, ): if not overwrite and name in self._chat_formatters: raise ValueError( f"Formatter with name '{name}' is already registered. Use `overwrite=True` to overwrite it." ) self._chat_formatters[name] = formatter_callable def unregister_formatter(self, name: str): if name in self._chat_formatters: del self._chat_formatters[name] else: raise ValueError(f"No formatter registered under the name '{name}'.") def get_formatter_by_name(self, name: str) -> ChatFormatterInterface: try: formatter_callable = self._chat_formatters[name] return formatter_callable() except KeyError: raise FormatterNotFoundException( f"Invalid chat format: {name} (valid formats: {list(self._chat_formatters.keys())})" ) # Define a chat format class class Llama2Formatter(AutoChatFormatter): def __init__(self): super().__init__(llama2_template) # With the Singleton pattern applied, regardless of where or how many times # ChatFormatterFactory() is called, it will always return the same instance # of the factory, ensuring that the factory's state is consistent throughout # the application. ChatFormatterFactory().register_formatter("llama-2", Llama2Formatter)