Integration of Jinja2 Templating (#875)
* feat: Add support for jinja templating Signed-off-by: teleprint-me <77757836+teleprint-me@users.noreply.github.com> * fix: Refactor chat formatter and update interface for jinja templates - Simplify the `llama2_template` in `llama_jinja_format.py` by removing unnecessary line breaks for readability without affecting functionality. - Update `ChatFormatterInterface` constructor to accept a more generic `Optional[object]` type for the template parameter, enhancing flexibility. - Introduce a `template` property to `ChatFormatterInterface` for standardized access to the template string. - Replace `MetaSingleton` metaclass with `Singleton` for the `ChatFormatterFactory` to streamline the singleton implementation. These changes enhance code readability, maintain usability, and ensure consistency in the chat formatter's design pattern usage. * Add outline for Jinja2 templating integration documentation Signed-off-by: teleprint-me <77757836+teleprint-me@users.noreply.github.com> * Add jinja2 as a dependency with version range for Hugging Face transformers compatibility Signed-off-by: teleprint-me <77757836+teleprint-me@users.noreply.github.com> * Update jinja2 version constraint for mkdocs-material compatibility Signed-off-by: teleprint-me <77757836+teleprint-me@users.noreply.github.com> * Fix attribute name in AutoChatFormatter - Changed attribute name from `self._renderer` to `self._environment` --------- Signed-off-by: teleprint-me <77757836+teleprint-me@users.noreply.github.com>
This commit is contained in:
parent
52adc23115
commit
6bfe98bd80
4 changed files with 243 additions and 1 deletions
52
docs/templates.md
Normal file
52
docs/templates.md
Normal file
|
@ -0,0 +1,52 @@
|
||||||
|
# Templates
|
||||||
|
|
||||||
|
This document provides a comprehensive guide to the integration of Jinja2 templating into the `llama-cpp-python` project, with a focus on enhancing the chat functionality of the `llama-2` model.
|
||||||
|
|
||||||
|
## Introduction
|
||||||
|
|
||||||
|
- Brief explanation of the `llama-cpp-python` project's need for a templating system.
|
||||||
|
- Overview of the `llama-2` model's interaction with templating.
|
||||||
|
|
||||||
|
## Jinja2 Dependency Integration
|
||||||
|
|
||||||
|
- Rationale for choosing Jinja2 as the templating engine.
|
||||||
|
- Compatibility with Hugging Face's `transformers`.
|
||||||
|
- Desire for advanced templating features and simplicity.
|
||||||
|
- Detailed steps for adding `jinja2` to `pyproject.toml` for dependency management.
|
||||||
|
|
||||||
|
## Template Management Refactor
|
||||||
|
|
||||||
|
- Summary of the refactor and the motivation behind it.
|
||||||
|
- Description of the new chat handler selection logic:
|
||||||
|
1. Preference for a user-specified `chat_handler`.
|
||||||
|
2. Fallback to a user-specified `chat_format`.
|
||||||
|
3. Defaulting to a chat format from a `.gguf` file if available.
|
||||||
|
4. Utilizing the `llama2` default chat format as the final fallback.
|
||||||
|
- Ensuring backward compatibility throughout the refactor.
|
||||||
|
|
||||||
|
## Implementation Details
|
||||||
|
|
||||||
|
- In-depth look at the new `AutoChatFormatter` class.
|
||||||
|
- Example code snippets showing how to utilize the Jinja2 environment and templates.
|
||||||
|
- Guidance on how to provide custom templates or use defaults.
|
||||||
|
|
||||||
|
## Testing and Validation
|
||||||
|
|
||||||
|
- Outline of the testing strategy to ensure seamless integration.
|
||||||
|
- Steps for validating backward compatibility with existing implementations.
|
||||||
|
|
||||||
|
## Benefits and Impact
|
||||||
|
|
||||||
|
- Analysis of the expected benefits, including consistency, performance gains, and improved developer experience.
|
||||||
|
- Discussion of the potential impact on current users and contributors.
|
||||||
|
|
||||||
|
## Future Work
|
||||||
|
|
||||||
|
- Exploration of how templating can evolve within the project.
|
||||||
|
- Consideration of additional features or optimizations for the templating engine.
|
||||||
|
- Mechanisms for community feedback on the templating system.
|
||||||
|
|
||||||
|
## Conclusion
|
||||||
|
|
||||||
|
- Final thoughts on the integration of Jinja2 templating.
|
||||||
|
- Call to action for community involvement and feedback.
|
138
llama_cpp/llama_jinja_format.py
Normal file
138
llama_cpp/llama_jinja_format.py
Normal file
|
@ -0,0 +1,138 @@
|
||||||
|
"""
|
||||||
|
llama_cpp/llama_jinja_format.py
|
||||||
|
"""
|
||||||
|
import dataclasses
|
||||||
|
from typing import Any, Callable, Dict, List, Optional, Protocol, Union
|
||||||
|
|
||||||
|
import jinja2
|
||||||
|
from jinja2 import Template
|
||||||
|
|
||||||
|
# NOTE: We sacrifice readability for usability.
|
||||||
|
# It will fail to work as expected if we attempt to format it in a readable way.
|
||||||
|
llama2_template = """{% for message in messages %}{% if message['role'] == 'user' %}[INST] {{ message['content'] }} [/INST]\n{% elif message['role'] == 'assistant' %}{{ message['content'] }}\n{% elif message['role'] == 'system' %}<<SYS>> {{ message['content'] }} <</SYS>>\n{% endif %}{% endfor %}"""
|
||||||
|
|
||||||
|
|
||||||
|
class MetaSingleton(type):
|
||||||
|
"""
|
||||||
|
Metaclass for implementing the Singleton pattern.
|
||||||
|
"""
|
||||||
|
|
||||||
|
_instances = {}
|
||||||
|
|
||||||
|
def __call__(cls, *args, **kwargs):
|
||||||
|
if cls not in cls._instances:
|
||||||
|
cls._instances[cls] = super(MetaSingleton, cls).__call__(*args, **kwargs)
|
||||||
|
return cls._instances[cls]
|
||||||
|
|
||||||
|
|
||||||
|
class Singleton(object, metaclass=MetaSingleton):
|
||||||
|
"""
|
||||||
|
Base class for implementing the Singleton pattern.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self):
|
||||||
|
super(Singleton, self).__init__()
|
||||||
|
|
||||||
|
|
||||||
|
@dataclasses.dataclass
|
||||||
|
class ChatFormatterResponse:
|
||||||
|
prompt: str
|
||||||
|
stop: Optional[Union[str, List[str]]] = None
|
||||||
|
|
||||||
|
|
||||||
|
# Base Chat Formatter Protocol
|
||||||
|
class ChatFormatterInterface(Protocol):
|
||||||
|
def __init__(self, template: Optional[object] = None):
|
||||||
|
...
|
||||||
|
|
||||||
|
def __call__(
|
||||||
|
self,
|
||||||
|
messages: List[Dict[str, str]],
|
||||||
|
**kwargs,
|
||||||
|
) -> ChatFormatterResponse:
|
||||||
|
...
|
||||||
|
|
||||||
|
@property
|
||||||
|
def template(self) -> str:
|
||||||
|
...
|
||||||
|
|
||||||
|
|
||||||
|
class AutoChatFormatter(ChatFormatterInterface):
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
template: Optional[str] = None,
|
||||||
|
template_class: Optional[Template] = None,
|
||||||
|
):
|
||||||
|
if template is not None:
|
||||||
|
self._template = template
|
||||||
|
else:
|
||||||
|
self._template = llama2_template # default template
|
||||||
|
|
||||||
|
self._environment = jinja2.Environment(
|
||||||
|
loader=jinja2.BaseLoader(),
|
||||||
|
trim_blocks=True,
|
||||||
|
lstrip_blocks=True,
|
||||||
|
).from_string(
|
||||||
|
self._template,
|
||||||
|
template_class=template_class,
|
||||||
|
)
|
||||||
|
|
||||||
|
def __call__(
|
||||||
|
self,
|
||||||
|
messages: List[Dict[str, str]],
|
||||||
|
**kwargs: Any,
|
||||||
|
) -> ChatFormatterResponse:
|
||||||
|
formatted_sequence = self._environment.render(messages=messages, **kwargs)
|
||||||
|
return ChatFormatterResponse(prompt=formatted_sequence)
|
||||||
|
|
||||||
|
@property
|
||||||
|
def template(self) -> str:
|
||||||
|
return self._template
|
||||||
|
|
||||||
|
|
||||||
|
class FormatterNotFoundException(Exception):
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
class ChatFormatterFactory(Singleton):
|
||||||
|
_chat_formatters: Dict[str, Callable[[], ChatFormatterInterface]] = {}
|
||||||
|
|
||||||
|
def register_formatter(
|
||||||
|
self,
|
||||||
|
name: str,
|
||||||
|
formatter_callable: Callable[[], ChatFormatterInterface],
|
||||||
|
overwrite=False,
|
||||||
|
):
|
||||||
|
if not overwrite and name in self._chat_formatters:
|
||||||
|
raise ValueError(
|
||||||
|
f"Formatter with name '{name}' is already registered. Use `overwrite=True` to overwrite it."
|
||||||
|
)
|
||||||
|
self._chat_formatters[name] = formatter_callable
|
||||||
|
|
||||||
|
def unregister_formatter(self, name: str):
|
||||||
|
if name in self._chat_formatters:
|
||||||
|
del self._chat_formatters[name]
|
||||||
|
else:
|
||||||
|
raise ValueError(f"No formatter registered under the name '{name}'.")
|
||||||
|
|
||||||
|
def get_formatter_by_name(self, name: str) -> ChatFormatterInterface:
|
||||||
|
try:
|
||||||
|
formatter_callable = self._chat_formatters[name]
|
||||||
|
return formatter_callable()
|
||||||
|
except KeyError:
|
||||||
|
raise FormatterNotFoundException(
|
||||||
|
f"Invalid chat format: {name} (valid formats: {list(self._chat_formatters.keys())})"
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
# Define a chat format class
|
||||||
|
class Llama2Formatter(AutoChatFormatter):
|
||||||
|
def __init__(self):
|
||||||
|
super().__init__(llama2_template)
|
||||||
|
|
||||||
|
|
||||||
|
# With the Singleton pattern applied, regardless of where or how many times
|
||||||
|
# ChatFormatterFactory() is called, it will always return the same instance
|
||||||
|
# of the factory, ensuring that the factory's state is consistent throughout
|
||||||
|
# the application.
|
||||||
|
ChatFormatterFactory().register_formatter("llama-2", Llama2Formatter)
|
|
@ -11,10 +11,13 @@ license = { text = "MIT" }
|
||||||
authors = [
|
authors = [
|
||||||
{ name = "Andrei Betlen", email = "abetlen@gmail.com" },
|
{ name = "Andrei Betlen", email = "abetlen@gmail.com" },
|
||||||
]
|
]
|
||||||
|
# mkdocs-martiral requires "jinja2~=3.0"
|
||||||
|
# transformers requires "jinja2>=2.11.3"
|
||||||
dependencies = [
|
dependencies = [
|
||||||
"typing-extensions>=4.5.0",
|
"typing-extensions>=4.5.0",
|
||||||
"numpy>=1.20.0",
|
"numpy>=1.20.0",
|
||||||
"diskcache>=5.6.1",
|
"diskcache>=5.6.1",
|
||||||
|
"jinja2>=2.11.3",
|
||||||
]
|
]
|
||||||
requires-python = ">=3.8"
|
requires-python = ">=3.8"
|
||||||
classifiers = [
|
classifiers = [
|
||||||
|
@ -72,4 +75,3 @@ Changelog = "https://llama-cpp-python.readthedocs.io/en/latest/changelog/"
|
||||||
|
|
||||||
[tool.pytest.ini_options]
|
[tool.pytest.ini_options]
|
||||||
addopts = "--ignore=vendor"
|
addopts = "--ignore=vendor"
|
||||||
|
|
||||||
|
|
50
tests/test_llama_chat_format.py
Normal file
50
tests/test_llama_chat_format.py
Normal file
|
@ -0,0 +1,50 @@
|
||||||
|
from typing import List
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
|
||||||
|
from llama_cpp import ChatCompletionMessage
|
||||||
|
from llama_cpp.llama_jinja_format import Llama2Formatter
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def sequence_of_messages() -> List[ChatCompletionMessage]:
|
||||||
|
return [
|
||||||
|
ChatCompletionMessage(role="system", content="Welcome to CodeHelp Bot!"),
|
||||||
|
ChatCompletionMessage(
|
||||||
|
role="user", content="Hi there! I need some help with Python."
|
||||||
|
),
|
||||||
|
ChatCompletionMessage(
|
||||||
|
role="assistant", content="Of course! What do you need help with in Python?"
|
||||||
|
),
|
||||||
|
ChatCompletionMessage(
|
||||||
|
role="user",
|
||||||
|
content="I'm trying to write a function to find the factorial of a number, but I'm stuck.",
|
||||||
|
),
|
||||||
|
ChatCompletionMessage(
|
||||||
|
role="assistant",
|
||||||
|
content="I can help with that! Would you like a recursive or iterative solution?",
|
||||||
|
),
|
||||||
|
ChatCompletionMessage(
|
||||||
|
role="user", content="Let's go with a recursive solution."
|
||||||
|
),
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
|
def test_llama2_formatter(sequence_of_messages):
|
||||||
|
expected_prompt = (
|
||||||
|
"<<SYS>> Welcome to CodeHelp Bot! <</SYS>>\n"
|
||||||
|
"[INST] Hi there! I need some help with Python. [/INST]\n"
|
||||||
|
"Of course! What do you need help with in Python?\n"
|
||||||
|
"[INST] I'm trying to write a function to find the factorial of a number, but I'm stuck. [/INST]\n"
|
||||||
|
"I can help with that! Would you like a recursive or iterative solution?\n"
|
||||||
|
"[INST] Let's go with a recursive solution. [/INST]\n"
|
||||||
|
)
|
||||||
|
|
||||||
|
llama2_formatter_instance = Llama2Formatter()
|
||||||
|
formatter_response = llama2_formatter_instance(sequence_of_messages)
|
||||||
|
assert (
|
||||||
|
expected_prompt == formatter_response.prompt
|
||||||
|
), "The formatted prompt does not match the expected output."
|
||||||
|
|
||||||
|
|
||||||
|
# Optionally, include a test for the 'stop' if it's part of the functionality.
|
Loading…
Add table
Reference in a new issue