diff --git a/docs/templates.md b/docs/templates.md new file mode 100644 index 0000000..5acdaa1 --- /dev/null +++ b/docs/templates.md @@ -0,0 +1,52 @@ +# Templates + +This document provides a comprehensive guide to the integration of Jinja2 templating into the `llama-cpp-python` project, with a focus on enhancing the chat functionality of the `llama-2` model. + +## Introduction + +- Brief explanation of the `llama-cpp-python` project's need for a templating system. +- Overview of the `llama-2` model's interaction with templating. + +## Jinja2 Dependency Integration + +- Rationale for choosing Jinja2 as the templating engine. + - Compatibility with Hugging Face's `transformers`. + - Desire for advanced templating features and simplicity. +- Detailed steps for adding `jinja2` to `pyproject.toml` for dependency management. + +## Template Management Refactor + +- Summary of the refactor and the motivation behind it. +- Description of the new chat handler selection logic: + 1. Preference for a user-specified `chat_handler`. + 2. Fallback to a user-specified `chat_format`. + 3. Defaulting to a chat format from a `.gguf` file if available. + 4. Utilizing the `llama2` default chat format as the final fallback. +- Ensuring backward compatibility throughout the refactor. + +## Implementation Details + +- In-depth look at the new `AutoChatFormatter` class. +- Example code snippets showing how to utilize the Jinja2 environment and templates. +- Guidance on how to provide custom templates or use defaults. + +## Testing and Validation + +- Outline of the testing strategy to ensure seamless integration. +- Steps for validating backward compatibility with existing implementations. + +## Benefits and Impact + +- Analysis of the expected benefits, including consistency, performance gains, and improved developer experience. +- Discussion of the potential impact on current users and contributors. + +## Future Work + +- Exploration of how templating can evolve within the project. +- Consideration of additional features or optimizations for the templating engine. +- Mechanisms for community feedback on the templating system. + +## Conclusion + +- Final thoughts on the integration of Jinja2 templating. +- Call to action for community involvement and feedback. diff --git a/llama_cpp/llama_jinja_format.py b/llama_cpp/llama_jinja_format.py new file mode 100644 index 0000000..68faaf6 --- /dev/null +++ b/llama_cpp/llama_jinja_format.py @@ -0,0 +1,138 @@ +""" +llama_cpp/llama_jinja_format.py +""" +import dataclasses +from typing import Any, Callable, Dict, List, Optional, Protocol, Union + +import jinja2 +from jinja2 import Template + +# NOTE: We sacrifice readability for usability. +# It will fail to work as expected if we attempt to format it in a readable way. +llama2_template = """{% for message in messages %}{% if message['role'] == 'user' %}[INST] {{ message['content'] }} [/INST]\n{% elif message['role'] == 'assistant' %}{{ message['content'] }}\n{% elif message['role'] == 'system' %}<> {{ message['content'] }} <>\n{% endif %}{% endfor %}""" + + +class MetaSingleton(type): + """ + Metaclass for implementing the Singleton pattern. + """ + + _instances = {} + + def __call__(cls, *args, **kwargs): + if cls not in cls._instances: + cls._instances[cls] = super(MetaSingleton, cls).__call__(*args, **kwargs) + return cls._instances[cls] + + +class Singleton(object, metaclass=MetaSingleton): + """ + Base class for implementing the Singleton pattern. + """ + + def __init__(self): + super(Singleton, self).__init__() + + +@dataclasses.dataclass +class ChatFormatterResponse: + prompt: str + stop: Optional[Union[str, List[str]]] = None + + +# Base Chat Formatter Protocol +class ChatFormatterInterface(Protocol): + def __init__(self, template: Optional[object] = None): + ... + + def __call__( + self, + messages: List[Dict[str, str]], + **kwargs, + ) -> ChatFormatterResponse: + ... + + @property + def template(self) -> str: + ... + + +class AutoChatFormatter(ChatFormatterInterface): + def __init__( + self, + template: Optional[str] = None, + template_class: Optional[Template] = None, + ): + if template is not None: + self._template = template + else: + self._template = llama2_template # default template + + self._environment = jinja2.Environment( + loader=jinja2.BaseLoader(), + trim_blocks=True, + lstrip_blocks=True, + ).from_string( + self._template, + template_class=template_class, + ) + + def __call__( + self, + messages: List[Dict[str, str]], + **kwargs: Any, + ) -> ChatFormatterResponse: + formatted_sequence = self._environment.render(messages=messages, **kwargs) + return ChatFormatterResponse(prompt=formatted_sequence) + + @property + def template(self) -> str: + return self._template + + +class FormatterNotFoundException(Exception): + pass + + +class ChatFormatterFactory(Singleton): + _chat_formatters: Dict[str, Callable[[], ChatFormatterInterface]] = {} + + def register_formatter( + self, + name: str, + formatter_callable: Callable[[], ChatFormatterInterface], + overwrite=False, + ): + if not overwrite and name in self._chat_formatters: + raise ValueError( + f"Formatter with name '{name}' is already registered. Use `overwrite=True` to overwrite it." + ) + self._chat_formatters[name] = formatter_callable + + def unregister_formatter(self, name: str): + if name in self._chat_formatters: + del self._chat_formatters[name] + else: + raise ValueError(f"No formatter registered under the name '{name}'.") + + def get_formatter_by_name(self, name: str) -> ChatFormatterInterface: + try: + formatter_callable = self._chat_formatters[name] + return formatter_callable() + except KeyError: + raise FormatterNotFoundException( + f"Invalid chat format: {name} (valid formats: {list(self._chat_formatters.keys())})" + ) + + +# Define a chat format class +class Llama2Formatter(AutoChatFormatter): + def __init__(self): + super().__init__(llama2_template) + + +# With the Singleton pattern applied, regardless of where or how many times +# ChatFormatterFactory() is called, it will always return the same instance +# of the factory, ensuring that the factory's state is consistent throughout +# the application. +ChatFormatterFactory().register_formatter("llama-2", Llama2Formatter) diff --git a/pyproject.toml b/pyproject.toml index b5affaa..806127d 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -11,10 +11,13 @@ license = { text = "MIT" } authors = [ { name = "Andrei Betlen", email = "abetlen@gmail.com" }, ] +# mkdocs-martiral requires "jinja2~=3.0" +# transformers requires "jinja2>=2.11.3" dependencies = [ "typing-extensions>=4.5.0", "numpy>=1.20.0", "diskcache>=5.6.1", + "jinja2>=2.11.3", ] requires-python = ">=3.8" classifiers = [ @@ -72,4 +75,3 @@ Changelog = "https://llama-cpp-python.readthedocs.io/en/latest/changelog/" [tool.pytest.ini_options] addopts = "--ignore=vendor" - diff --git a/tests/test_llama_chat_format.py b/tests/test_llama_chat_format.py new file mode 100644 index 0000000..4eebcb6 --- /dev/null +++ b/tests/test_llama_chat_format.py @@ -0,0 +1,50 @@ +from typing import List + +import pytest + +from llama_cpp import ChatCompletionMessage +from llama_cpp.llama_jinja_format import Llama2Formatter + + +@pytest.fixture +def sequence_of_messages() -> List[ChatCompletionMessage]: + return [ + ChatCompletionMessage(role="system", content="Welcome to CodeHelp Bot!"), + ChatCompletionMessage( + role="user", content="Hi there! I need some help with Python." + ), + ChatCompletionMessage( + role="assistant", content="Of course! What do you need help with in Python?" + ), + ChatCompletionMessage( + role="user", + content="I'm trying to write a function to find the factorial of a number, but I'm stuck.", + ), + ChatCompletionMessage( + role="assistant", + content="I can help with that! Would you like a recursive or iterative solution?", + ), + ChatCompletionMessage( + role="user", content="Let's go with a recursive solution." + ), + ] + + +def test_llama2_formatter(sequence_of_messages): + expected_prompt = ( + "<> Welcome to CodeHelp Bot! <>\n" + "[INST] Hi there! I need some help with Python. [/INST]\n" + "Of course! What do you need help with in Python?\n" + "[INST] I'm trying to write a function to find the factorial of a number, but I'm stuck. [/INST]\n" + "I can help with that! Would you like a recursive or iterative solution?\n" + "[INST] Let's go with a recursive solution. [/INST]\n" + ) + + llama2_formatter_instance = Llama2Formatter() + formatter_response = llama2_formatter_instance(sequence_of_messages) + assert ( + expected_prompt == formatter_response.prompt + ), "The formatted prompt does not match the expected output." + + +# Optionally, include a test for the 'stop' if it's part of the functionality.