From 6bfe98bd801262b68c7cb4761e67a626d91534c0 Mon Sep 17 00:00:00 2001
From: Austin <77757836+teleprint-me@users.noreply.github.com>
Date: Wed, 17 Jan 2024 09:47:52 -0500
Subject: [PATCH] Integration of Jinja2 Templating (#875)
* feat: Add support for jinja templating
Signed-off-by: teleprint-me <77757836+teleprint-me@users.noreply.github.com>
* fix: Refactor chat formatter and update interface for jinja templates
- Simplify the `llama2_template` in `llama_jinja_format.py` by removing unnecessary line breaks for readability without affecting functionality.
- Update `ChatFormatterInterface` constructor to accept a more generic `Optional[object]` type for the template parameter, enhancing flexibility.
- Introduce a `template` property to `ChatFormatterInterface` for standardized access to the template string.
- Replace `MetaSingleton` metaclass with `Singleton` for the `ChatFormatterFactory` to streamline the singleton implementation.
These changes enhance code readability, maintain usability, and ensure consistency in the chat formatter's design pattern usage.
* Add outline for Jinja2 templating integration documentation
Signed-off-by: teleprint-me <77757836+teleprint-me@users.noreply.github.com>
* Add jinja2 as a dependency with version range for Hugging Face transformers compatibility
Signed-off-by: teleprint-me <77757836+teleprint-me@users.noreply.github.com>
* Update jinja2 version constraint for mkdocs-material compatibility
Signed-off-by: teleprint-me <77757836+teleprint-me@users.noreply.github.com>
* Fix attribute name in AutoChatFormatter
- Changed attribute name from `self._renderer` to `self._environment`
---------
Signed-off-by: teleprint-me <77757836+teleprint-me@users.noreply.github.com>
---
docs/templates.md | 52 ++++++++++++
llama_cpp/llama_jinja_format.py | 138 ++++++++++++++++++++++++++++++++
pyproject.toml | 4 +-
tests/test_llama_chat_format.py | 50 ++++++++++++
4 files changed, 243 insertions(+), 1 deletion(-)
create mode 100644 docs/templates.md
create mode 100644 llama_cpp/llama_jinja_format.py
create mode 100644 tests/test_llama_chat_format.py
diff --git a/docs/templates.md b/docs/templates.md
new file mode 100644
index 0000000..5acdaa1
--- /dev/null
+++ b/docs/templates.md
@@ -0,0 +1,52 @@
+# Templates
+
+This document provides a comprehensive guide to the integration of Jinja2 templating into the `llama-cpp-python` project, with a focus on enhancing the chat functionality of the `llama-2` model.
+
+## Introduction
+
+- Brief explanation of the `llama-cpp-python` project's need for a templating system.
+- Overview of the `llama-2` model's interaction with templating.
+
+## Jinja2 Dependency Integration
+
+- Rationale for choosing Jinja2 as the templating engine.
+ - Compatibility with Hugging Face's `transformers`.
+ - Desire for advanced templating features and simplicity.
+- Detailed steps for adding `jinja2` to `pyproject.toml` for dependency management.
+
+## Template Management Refactor
+
+- Summary of the refactor and the motivation behind it.
+- Description of the new chat handler selection logic:
+ 1. Preference for a user-specified `chat_handler`.
+ 2. Fallback to a user-specified `chat_format`.
+ 3. Defaulting to a chat format from a `.gguf` file if available.
+ 4. Utilizing the `llama2` default chat format as the final fallback.
+- Ensuring backward compatibility throughout the refactor.
+
+## Implementation Details
+
+- In-depth look at the new `AutoChatFormatter` class.
+- Example code snippets showing how to utilize the Jinja2 environment and templates.
+- Guidance on how to provide custom templates or use defaults.
+
+## Testing and Validation
+
+- Outline of the testing strategy to ensure seamless integration.
+- Steps for validating backward compatibility with existing implementations.
+
+## Benefits and Impact
+
+- Analysis of the expected benefits, including consistency, performance gains, and improved developer experience.
+- Discussion of the potential impact on current users and contributors.
+
+## Future Work
+
+- Exploration of how templating can evolve within the project.
+- Consideration of additional features or optimizations for the templating engine.
+- Mechanisms for community feedback on the templating system.
+
+## Conclusion
+
+- Final thoughts on the integration of Jinja2 templating.
+- Call to action for community involvement and feedback.
diff --git a/llama_cpp/llama_jinja_format.py b/llama_cpp/llama_jinja_format.py
new file mode 100644
index 0000000..68faaf6
--- /dev/null
+++ b/llama_cpp/llama_jinja_format.py
@@ -0,0 +1,138 @@
+"""
+llama_cpp/llama_jinja_format.py
+"""
+import dataclasses
+from typing import Any, Callable, Dict, List, Optional, Protocol, Union
+
+import jinja2
+from jinja2 import Template
+
+# NOTE: We sacrifice readability for usability.
+# It will fail to work as expected if we attempt to format it in a readable way.
+llama2_template = """{% for message in messages %}{% if message['role'] == 'user' %}[INST] {{ message['content'] }} [/INST]\n{% elif message['role'] == 'assistant' %}{{ message['content'] }}\n{% elif message['role'] == 'system' %}<> {{ message['content'] }} <>\n{% endif %}{% endfor %}"""
+
+
+class MetaSingleton(type):
+ """
+ Metaclass for implementing the Singleton pattern.
+ """
+
+ _instances = {}
+
+ def __call__(cls, *args, **kwargs):
+ if cls not in cls._instances:
+ cls._instances[cls] = super(MetaSingleton, cls).__call__(*args, **kwargs)
+ return cls._instances[cls]
+
+
+class Singleton(object, metaclass=MetaSingleton):
+ """
+ Base class for implementing the Singleton pattern.
+ """
+
+ def __init__(self):
+ super(Singleton, self).__init__()
+
+
+@dataclasses.dataclass
+class ChatFormatterResponse:
+ prompt: str
+ stop: Optional[Union[str, List[str]]] = None
+
+
+# Base Chat Formatter Protocol
+class ChatFormatterInterface(Protocol):
+ def __init__(self, template: Optional[object] = None):
+ ...
+
+ def __call__(
+ self,
+ messages: List[Dict[str, str]],
+ **kwargs,
+ ) -> ChatFormatterResponse:
+ ...
+
+ @property
+ def template(self) -> str:
+ ...
+
+
+class AutoChatFormatter(ChatFormatterInterface):
+ def __init__(
+ self,
+ template: Optional[str] = None,
+ template_class: Optional[Template] = None,
+ ):
+ if template is not None:
+ self._template = template
+ else:
+ self._template = llama2_template # default template
+
+ self._environment = jinja2.Environment(
+ loader=jinja2.BaseLoader(),
+ trim_blocks=True,
+ lstrip_blocks=True,
+ ).from_string(
+ self._template,
+ template_class=template_class,
+ )
+
+ def __call__(
+ self,
+ messages: List[Dict[str, str]],
+ **kwargs: Any,
+ ) -> ChatFormatterResponse:
+ formatted_sequence = self._environment.render(messages=messages, **kwargs)
+ return ChatFormatterResponse(prompt=formatted_sequence)
+
+ @property
+ def template(self) -> str:
+ return self._template
+
+
+class FormatterNotFoundException(Exception):
+ pass
+
+
+class ChatFormatterFactory(Singleton):
+ _chat_formatters: Dict[str, Callable[[], ChatFormatterInterface]] = {}
+
+ def register_formatter(
+ self,
+ name: str,
+ formatter_callable: Callable[[], ChatFormatterInterface],
+ overwrite=False,
+ ):
+ if not overwrite and name in self._chat_formatters:
+ raise ValueError(
+ f"Formatter with name '{name}' is already registered. Use `overwrite=True` to overwrite it."
+ )
+ self._chat_formatters[name] = formatter_callable
+
+ def unregister_formatter(self, name: str):
+ if name in self._chat_formatters:
+ del self._chat_formatters[name]
+ else:
+ raise ValueError(f"No formatter registered under the name '{name}'.")
+
+ def get_formatter_by_name(self, name: str) -> ChatFormatterInterface:
+ try:
+ formatter_callable = self._chat_formatters[name]
+ return formatter_callable()
+ except KeyError:
+ raise FormatterNotFoundException(
+ f"Invalid chat format: {name} (valid formats: {list(self._chat_formatters.keys())})"
+ )
+
+
+# Define a chat format class
+class Llama2Formatter(AutoChatFormatter):
+ def __init__(self):
+ super().__init__(llama2_template)
+
+
+# With the Singleton pattern applied, regardless of where or how many times
+# ChatFormatterFactory() is called, it will always return the same instance
+# of the factory, ensuring that the factory's state is consistent throughout
+# the application.
+ChatFormatterFactory().register_formatter("llama-2", Llama2Formatter)
diff --git a/pyproject.toml b/pyproject.toml
index b5affaa..806127d 100644
--- a/pyproject.toml
+++ b/pyproject.toml
@@ -11,10 +11,13 @@ license = { text = "MIT" }
authors = [
{ name = "Andrei Betlen", email = "abetlen@gmail.com" },
]
+# mkdocs-martiral requires "jinja2~=3.0"
+# transformers requires "jinja2>=2.11.3"
dependencies = [
"typing-extensions>=4.5.0",
"numpy>=1.20.0",
"diskcache>=5.6.1",
+ "jinja2>=2.11.3",
]
requires-python = ">=3.8"
classifiers = [
@@ -72,4 +75,3 @@ Changelog = "https://llama-cpp-python.readthedocs.io/en/latest/changelog/"
[tool.pytest.ini_options]
addopts = "--ignore=vendor"
-
diff --git a/tests/test_llama_chat_format.py b/tests/test_llama_chat_format.py
new file mode 100644
index 0000000..4eebcb6
--- /dev/null
+++ b/tests/test_llama_chat_format.py
@@ -0,0 +1,50 @@
+from typing import List
+
+import pytest
+
+from llama_cpp import ChatCompletionMessage
+from llama_cpp.llama_jinja_format import Llama2Formatter
+
+
+@pytest.fixture
+def sequence_of_messages() -> List[ChatCompletionMessage]:
+ return [
+ ChatCompletionMessage(role="system", content="Welcome to CodeHelp Bot!"),
+ ChatCompletionMessage(
+ role="user", content="Hi there! I need some help with Python."
+ ),
+ ChatCompletionMessage(
+ role="assistant", content="Of course! What do you need help with in Python?"
+ ),
+ ChatCompletionMessage(
+ role="user",
+ content="I'm trying to write a function to find the factorial of a number, but I'm stuck.",
+ ),
+ ChatCompletionMessage(
+ role="assistant",
+ content="I can help with that! Would you like a recursive or iterative solution?",
+ ),
+ ChatCompletionMessage(
+ role="user", content="Let's go with a recursive solution."
+ ),
+ ]
+
+
+def test_llama2_formatter(sequence_of_messages):
+ expected_prompt = (
+ "<> Welcome to CodeHelp Bot! <>\n"
+ "[INST] Hi there! I need some help with Python. [/INST]\n"
+ "Of course! What do you need help with in Python?\n"
+ "[INST] I'm trying to write a function to find the factorial of a number, but I'm stuck. [/INST]\n"
+ "I can help with that! Would you like a recursive or iterative solution?\n"
+ "[INST] Let's go with a recursive solution. [/INST]\n"
+ )
+
+ llama2_formatter_instance = Llama2Formatter()
+ formatter_response = llama2_formatter_instance(sequence_of_messages)
+ assert (
+ expected_prompt == formatter_response.prompt
+ ), "The formatted prompt does not match the expected output."
+
+
+# Optionally, include a test for the 'stop' if it's part of the functionality.