Integration of Jinja2 Templating (#875)

* feat: Add support for jinja templating

Signed-off-by: teleprint-me <77757836+teleprint-me@users.noreply.github.com>

* fix: Refactor chat formatter and update interface for jinja templates

- Simplify the `llama2_template` in `llama_jinja_format.py` by removing unnecessary line breaks for readability without affecting functionality.
- Update `ChatFormatterInterface` constructor to accept a more generic `Optional[object]` type for the template parameter, enhancing flexibility.
- Introduce a `template` property to `ChatFormatterInterface` for standardized access to the template string.
- Replace `MetaSingleton` metaclass with `Singleton` for the `ChatFormatterFactory` to streamline the singleton implementation.

These changes enhance code readability, maintain usability, and ensure consistency in the chat formatter's design pattern usage.

* Add outline for Jinja2 templating integration documentation

Signed-off-by: teleprint-me <77757836+teleprint-me@users.noreply.github.com>

* Add jinja2 as a dependency with version range for Hugging Face transformers compatibility

Signed-off-by: teleprint-me <77757836+teleprint-me@users.noreply.github.com>

* Update jinja2 version constraint for mkdocs-material compatibility

Signed-off-by: teleprint-me <77757836+teleprint-me@users.noreply.github.com>

* Fix attribute name in AutoChatFormatter

- Changed attribute name from `self._renderer` to `self._environment`

---------

Signed-off-by: teleprint-me <77757836+teleprint-me@users.noreply.github.com>
This commit is contained in:
Austin 2024-01-17 09:47:52 -05:00 committed by GitHub
parent 52adc23115
commit 6bfe98bd80
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
4 changed files with 243 additions and 1 deletions

52
docs/templates.md Normal file
View file

@ -0,0 +1,52 @@
# Templates
This document provides a comprehensive guide to the integration of Jinja2 templating into the `llama-cpp-python` project, with a focus on enhancing the chat functionality of the `llama-2` model.
## Introduction
- Brief explanation of the `llama-cpp-python` project's need for a templating system.
- Overview of the `llama-2` model's interaction with templating.
## Jinja2 Dependency Integration
- Rationale for choosing Jinja2 as the templating engine.
- Compatibility with Hugging Face's `transformers`.
- Desire for advanced templating features and simplicity.
- Detailed steps for adding `jinja2` to `pyproject.toml` for dependency management.
## Template Management Refactor
- Summary of the refactor and the motivation behind it.
- Description of the new chat handler selection logic:
1. Preference for a user-specified `chat_handler`.
2. Fallback to a user-specified `chat_format`.
3. Defaulting to a chat format from a `.gguf` file if available.
4. Utilizing the `llama2` default chat format as the final fallback.
- Ensuring backward compatibility throughout the refactor.
## Implementation Details
- In-depth look at the new `AutoChatFormatter` class.
- Example code snippets showing how to utilize the Jinja2 environment and templates.
- Guidance on how to provide custom templates or use defaults.
## Testing and Validation
- Outline of the testing strategy to ensure seamless integration.
- Steps for validating backward compatibility with existing implementations.
## Benefits and Impact
- Analysis of the expected benefits, including consistency, performance gains, and improved developer experience.
- Discussion of the potential impact on current users and contributors.
## Future Work
- Exploration of how templating can evolve within the project.
- Consideration of additional features or optimizations for the templating engine.
- Mechanisms for community feedback on the templating system.
## Conclusion
- Final thoughts on the integration of Jinja2 templating.
- Call to action for community involvement and feedback.

View file

@ -0,0 +1,138 @@
"""
llama_cpp/llama_jinja_format.py
"""
import dataclasses
from typing import Any, Callable, Dict, List, Optional, Protocol, Union
import jinja2
from jinja2 import Template
# NOTE: We sacrifice readability for usability.
# It will fail to work as expected if we attempt to format it in a readable way.
llama2_template = """{% for message in messages %}{% if message['role'] == 'user' %}[INST] {{ message['content'] }} [/INST]\n{% elif message['role'] == 'assistant' %}{{ message['content'] }}\n{% elif message['role'] == 'system' %}<<SYS>> {{ message['content'] }} <</SYS>>\n{% endif %}{% endfor %}"""
class MetaSingleton(type):
"""
Metaclass for implementing the Singleton pattern.
"""
_instances = {}
def __call__(cls, *args, **kwargs):
if cls not in cls._instances:
cls._instances[cls] = super(MetaSingleton, cls).__call__(*args, **kwargs)
return cls._instances[cls]
class Singleton(object, metaclass=MetaSingleton):
"""
Base class for implementing the Singleton pattern.
"""
def __init__(self):
super(Singleton, self).__init__()
@dataclasses.dataclass
class ChatFormatterResponse:
prompt: str
stop: Optional[Union[str, List[str]]] = None
# Base Chat Formatter Protocol
class ChatFormatterInterface(Protocol):
def __init__(self, template: Optional[object] = None):
...
def __call__(
self,
messages: List[Dict[str, str]],
**kwargs,
) -> ChatFormatterResponse:
...
@property
def template(self) -> str:
...
class AutoChatFormatter(ChatFormatterInterface):
def __init__(
self,
template: Optional[str] = None,
template_class: Optional[Template] = None,
):
if template is not None:
self._template = template
else:
self._template = llama2_template # default template
self._environment = jinja2.Environment(
loader=jinja2.BaseLoader(),
trim_blocks=True,
lstrip_blocks=True,
).from_string(
self._template,
template_class=template_class,
)
def __call__(
self,
messages: List[Dict[str, str]],
**kwargs: Any,
) -> ChatFormatterResponse:
formatted_sequence = self._environment.render(messages=messages, **kwargs)
return ChatFormatterResponse(prompt=formatted_sequence)
@property
def template(self) -> str:
return self._template
class FormatterNotFoundException(Exception):
pass
class ChatFormatterFactory(Singleton):
_chat_formatters: Dict[str, Callable[[], ChatFormatterInterface]] = {}
def register_formatter(
self,
name: str,
formatter_callable: Callable[[], ChatFormatterInterface],
overwrite=False,
):
if not overwrite and name in self._chat_formatters:
raise ValueError(
f"Formatter with name '{name}' is already registered. Use `overwrite=True` to overwrite it."
)
self._chat_formatters[name] = formatter_callable
def unregister_formatter(self, name: str):
if name in self._chat_formatters:
del self._chat_formatters[name]
else:
raise ValueError(f"No formatter registered under the name '{name}'.")
def get_formatter_by_name(self, name: str) -> ChatFormatterInterface:
try:
formatter_callable = self._chat_formatters[name]
return formatter_callable()
except KeyError:
raise FormatterNotFoundException(
f"Invalid chat format: {name} (valid formats: {list(self._chat_formatters.keys())})"
)
# Define a chat format class
class Llama2Formatter(AutoChatFormatter):
def __init__(self):
super().__init__(llama2_template)
# With the Singleton pattern applied, regardless of where or how many times
# ChatFormatterFactory() is called, it will always return the same instance
# of the factory, ensuring that the factory's state is consistent throughout
# the application.
ChatFormatterFactory().register_formatter("llama-2", Llama2Formatter)

View file

@ -11,10 +11,13 @@ license = { text = "MIT" }
authors = [
{ name = "Andrei Betlen", email = "abetlen@gmail.com" },
]
# mkdocs-martiral requires "jinja2~=3.0"
# transformers requires "jinja2>=2.11.3"
dependencies = [
"typing-extensions>=4.5.0",
"numpy>=1.20.0",
"diskcache>=5.6.1",
"jinja2>=2.11.3",
]
requires-python = ">=3.8"
classifiers = [
@ -72,4 +75,3 @@ Changelog = "https://llama-cpp-python.readthedocs.io/en/latest/changelog/"
[tool.pytest.ini_options]
addopts = "--ignore=vendor"

View file

@ -0,0 +1,50 @@
from typing import List
import pytest
from llama_cpp import ChatCompletionMessage
from llama_cpp.llama_jinja_format import Llama2Formatter
@pytest.fixture
def sequence_of_messages() -> List[ChatCompletionMessage]:
return [
ChatCompletionMessage(role="system", content="Welcome to CodeHelp Bot!"),
ChatCompletionMessage(
role="user", content="Hi there! I need some help with Python."
),
ChatCompletionMessage(
role="assistant", content="Of course! What do you need help with in Python?"
),
ChatCompletionMessage(
role="user",
content="I'm trying to write a function to find the factorial of a number, but I'm stuck.",
),
ChatCompletionMessage(
role="assistant",
content="I can help with that! Would you like a recursive or iterative solution?",
),
ChatCompletionMessage(
role="user", content="Let's go with a recursive solution."
),
]
def test_llama2_formatter(sequence_of_messages):
expected_prompt = (
"<<SYS>> Welcome to CodeHelp Bot! <</SYS>>\n"
"[INST] Hi there! I need some help with Python. [/INST]\n"
"Of course! What do you need help with in Python?\n"
"[INST] I'm trying to write a function to find the factorial of a number, but I'm stuck. [/INST]\n"
"I can help with that! Would you like a recursive or iterative solution?\n"
"[INST] Let's go with a recursive solution. [/INST]\n"
)
llama2_formatter_instance = Llama2Formatter()
formatter_response = llama2_formatter_instance(sequence_of_messages)
assert (
expected_prompt == formatter_response.prompt
), "The formatted prompt does not match the expected output."
# Optionally, include a test for the 'stop' if it's part of the functionality.