Integration of Jinja2 Templating (#875)
* feat: Add support for jinja templating Signed-off-by: teleprint-me <77757836+teleprint-me@users.noreply.github.com> * fix: Refactor chat formatter and update interface for jinja templates - Simplify the `llama2_template` in `llama_jinja_format.py` by removing unnecessary line breaks for readability without affecting functionality. - Update `ChatFormatterInterface` constructor to accept a more generic `Optional[object]` type for the template parameter, enhancing flexibility. - Introduce a `template` property to `ChatFormatterInterface` for standardized access to the template string. - Replace `MetaSingleton` metaclass with `Singleton` for the `ChatFormatterFactory` to streamline the singleton implementation. These changes enhance code readability, maintain usability, and ensure consistency in the chat formatter's design pattern usage. * Add outline for Jinja2 templating integration documentation Signed-off-by: teleprint-me <77757836+teleprint-me@users.noreply.github.com> * Add jinja2 as a dependency with version range for Hugging Face transformers compatibility Signed-off-by: teleprint-me <77757836+teleprint-me@users.noreply.github.com> * Update jinja2 version constraint for mkdocs-material compatibility Signed-off-by: teleprint-me <77757836+teleprint-me@users.noreply.github.com> * Fix attribute name in AutoChatFormatter - Changed attribute name from `self._renderer` to `self._environment` --------- Signed-off-by: teleprint-me <77757836+teleprint-me@users.noreply.github.com>
This commit is contained in:
parent
52adc23115
commit
6bfe98bd80
4 changed files with 243 additions and 1 deletions
52
docs/templates.md
Normal file
52
docs/templates.md
Normal file
|
@ -0,0 +1,52 @@
|
|||
# Templates
|
||||
|
||||
This document provides a comprehensive guide to the integration of Jinja2 templating into the `llama-cpp-python` project, with a focus on enhancing the chat functionality of the `llama-2` model.
|
||||
|
||||
## Introduction
|
||||
|
||||
- Brief explanation of the `llama-cpp-python` project's need for a templating system.
|
||||
- Overview of the `llama-2` model's interaction with templating.
|
||||
|
||||
## Jinja2 Dependency Integration
|
||||
|
||||
- Rationale for choosing Jinja2 as the templating engine.
|
||||
- Compatibility with Hugging Face's `transformers`.
|
||||
- Desire for advanced templating features and simplicity.
|
||||
- Detailed steps for adding `jinja2` to `pyproject.toml` for dependency management.
|
||||
|
||||
## Template Management Refactor
|
||||
|
||||
- Summary of the refactor and the motivation behind it.
|
||||
- Description of the new chat handler selection logic:
|
||||
1. Preference for a user-specified `chat_handler`.
|
||||
2. Fallback to a user-specified `chat_format`.
|
||||
3. Defaulting to a chat format from a `.gguf` file if available.
|
||||
4. Utilizing the `llama2` default chat format as the final fallback.
|
||||
- Ensuring backward compatibility throughout the refactor.
|
||||
|
||||
## Implementation Details
|
||||
|
||||
- In-depth look at the new `AutoChatFormatter` class.
|
||||
- Example code snippets showing how to utilize the Jinja2 environment and templates.
|
||||
- Guidance on how to provide custom templates or use defaults.
|
||||
|
||||
## Testing and Validation
|
||||
|
||||
- Outline of the testing strategy to ensure seamless integration.
|
||||
- Steps for validating backward compatibility with existing implementations.
|
||||
|
||||
## Benefits and Impact
|
||||
|
||||
- Analysis of the expected benefits, including consistency, performance gains, and improved developer experience.
|
||||
- Discussion of the potential impact on current users and contributors.
|
||||
|
||||
## Future Work
|
||||
|
||||
- Exploration of how templating can evolve within the project.
|
||||
- Consideration of additional features or optimizations for the templating engine.
|
||||
- Mechanisms for community feedback on the templating system.
|
||||
|
||||
## Conclusion
|
||||
|
||||
- Final thoughts on the integration of Jinja2 templating.
|
||||
- Call to action for community involvement and feedback.
|
138
llama_cpp/llama_jinja_format.py
Normal file
138
llama_cpp/llama_jinja_format.py
Normal file
|
@ -0,0 +1,138 @@
|
|||
"""
|
||||
llama_cpp/llama_jinja_format.py
|
||||
"""
|
||||
import dataclasses
|
||||
from typing import Any, Callable, Dict, List, Optional, Protocol, Union
|
||||
|
||||
import jinja2
|
||||
from jinja2 import Template
|
||||
|
||||
# NOTE: We sacrifice readability for usability.
|
||||
# It will fail to work as expected if we attempt to format it in a readable way.
|
||||
llama2_template = """{% for message in messages %}{% if message['role'] == 'user' %}[INST] {{ message['content'] }} [/INST]\n{% elif message['role'] == 'assistant' %}{{ message['content'] }}\n{% elif message['role'] == 'system' %}<<SYS>> {{ message['content'] }} <</SYS>>\n{% endif %}{% endfor %}"""
|
||||
|
||||
|
||||
class MetaSingleton(type):
|
||||
"""
|
||||
Metaclass for implementing the Singleton pattern.
|
||||
"""
|
||||
|
||||
_instances = {}
|
||||
|
||||
def __call__(cls, *args, **kwargs):
|
||||
if cls not in cls._instances:
|
||||
cls._instances[cls] = super(MetaSingleton, cls).__call__(*args, **kwargs)
|
||||
return cls._instances[cls]
|
||||
|
||||
|
||||
class Singleton(object, metaclass=MetaSingleton):
|
||||
"""
|
||||
Base class for implementing the Singleton pattern.
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
super(Singleton, self).__init__()
|
||||
|
||||
|
||||
@dataclasses.dataclass
|
||||
class ChatFormatterResponse:
|
||||
prompt: str
|
||||
stop: Optional[Union[str, List[str]]] = None
|
||||
|
||||
|
||||
# Base Chat Formatter Protocol
|
||||
class ChatFormatterInterface(Protocol):
|
||||
def __init__(self, template: Optional[object] = None):
|
||||
...
|
||||
|
||||
def __call__(
|
||||
self,
|
||||
messages: List[Dict[str, str]],
|
||||
**kwargs,
|
||||
) -> ChatFormatterResponse:
|
||||
...
|
||||
|
||||
@property
|
||||
def template(self) -> str:
|
||||
...
|
||||
|
||||
|
||||
class AutoChatFormatter(ChatFormatterInterface):
|
||||
def __init__(
|
||||
self,
|
||||
template: Optional[str] = None,
|
||||
template_class: Optional[Template] = None,
|
||||
):
|
||||
if template is not None:
|
||||
self._template = template
|
||||
else:
|
||||
self._template = llama2_template # default template
|
||||
|
||||
self._environment = jinja2.Environment(
|
||||
loader=jinja2.BaseLoader(),
|
||||
trim_blocks=True,
|
||||
lstrip_blocks=True,
|
||||
).from_string(
|
||||
self._template,
|
||||
template_class=template_class,
|
||||
)
|
||||
|
||||
def __call__(
|
||||
self,
|
||||
messages: List[Dict[str, str]],
|
||||
**kwargs: Any,
|
||||
) -> ChatFormatterResponse:
|
||||
formatted_sequence = self._environment.render(messages=messages, **kwargs)
|
||||
return ChatFormatterResponse(prompt=formatted_sequence)
|
||||
|
||||
@property
|
||||
def template(self) -> str:
|
||||
return self._template
|
||||
|
||||
|
||||
class FormatterNotFoundException(Exception):
|
||||
pass
|
||||
|
||||
|
||||
class ChatFormatterFactory(Singleton):
|
||||
_chat_formatters: Dict[str, Callable[[], ChatFormatterInterface]] = {}
|
||||
|
||||
def register_formatter(
|
||||
self,
|
||||
name: str,
|
||||
formatter_callable: Callable[[], ChatFormatterInterface],
|
||||
overwrite=False,
|
||||
):
|
||||
if not overwrite and name in self._chat_formatters:
|
||||
raise ValueError(
|
||||
f"Formatter with name '{name}' is already registered. Use `overwrite=True` to overwrite it."
|
||||
)
|
||||
self._chat_formatters[name] = formatter_callable
|
||||
|
||||
def unregister_formatter(self, name: str):
|
||||
if name in self._chat_formatters:
|
||||
del self._chat_formatters[name]
|
||||
else:
|
||||
raise ValueError(f"No formatter registered under the name '{name}'.")
|
||||
|
||||
def get_formatter_by_name(self, name: str) -> ChatFormatterInterface:
|
||||
try:
|
||||
formatter_callable = self._chat_formatters[name]
|
||||
return formatter_callable()
|
||||
except KeyError:
|
||||
raise FormatterNotFoundException(
|
||||
f"Invalid chat format: {name} (valid formats: {list(self._chat_formatters.keys())})"
|
||||
)
|
||||
|
||||
|
||||
# Define a chat format class
|
||||
class Llama2Formatter(AutoChatFormatter):
|
||||
def __init__(self):
|
||||
super().__init__(llama2_template)
|
||||
|
||||
|
||||
# With the Singleton pattern applied, regardless of where or how many times
|
||||
# ChatFormatterFactory() is called, it will always return the same instance
|
||||
# of the factory, ensuring that the factory's state is consistent throughout
|
||||
# the application.
|
||||
ChatFormatterFactory().register_formatter("llama-2", Llama2Formatter)
|
|
@ -11,10 +11,13 @@ license = { text = "MIT" }
|
|||
authors = [
|
||||
{ name = "Andrei Betlen", email = "abetlen@gmail.com" },
|
||||
]
|
||||
# mkdocs-martiral requires "jinja2~=3.0"
|
||||
# transformers requires "jinja2>=2.11.3"
|
||||
dependencies = [
|
||||
"typing-extensions>=4.5.0",
|
||||
"numpy>=1.20.0",
|
||||
"diskcache>=5.6.1",
|
||||
"jinja2>=2.11.3",
|
||||
]
|
||||
requires-python = ">=3.8"
|
||||
classifiers = [
|
||||
|
@ -72,4 +75,3 @@ Changelog = "https://llama-cpp-python.readthedocs.io/en/latest/changelog/"
|
|||
|
||||
[tool.pytest.ini_options]
|
||||
addopts = "--ignore=vendor"
|
||||
|
||||
|
|
50
tests/test_llama_chat_format.py
Normal file
50
tests/test_llama_chat_format.py
Normal file
|
@ -0,0 +1,50 @@
|
|||
from typing import List
|
||||
|
||||
import pytest
|
||||
|
||||
from llama_cpp import ChatCompletionMessage
|
||||
from llama_cpp.llama_jinja_format import Llama2Formatter
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def sequence_of_messages() -> List[ChatCompletionMessage]:
|
||||
return [
|
||||
ChatCompletionMessage(role="system", content="Welcome to CodeHelp Bot!"),
|
||||
ChatCompletionMessage(
|
||||
role="user", content="Hi there! I need some help with Python."
|
||||
),
|
||||
ChatCompletionMessage(
|
||||
role="assistant", content="Of course! What do you need help with in Python?"
|
||||
),
|
||||
ChatCompletionMessage(
|
||||
role="user",
|
||||
content="I'm trying to write a function to find the factorial of a number, but I'm stuck.",
|
||||
),
|
||||
ChatCompletionMessage(
|
||||
role="assistant",
|
||||
content="I can help with that! Would you like a recursive or iterative solution?",
|
||||
),
|
||||
ChatCompletionMessage(
|
||||
role="user", content="Let's go with a recursive solution."
|
||||
),
|
||||
]
|
||||
|
||||
|
||||
def test_llama2_formatter(sequence_of_messages):
|
||||
expected_prompt = (
|
||||
"<<SYS>> Welcome to CodeHelp Bot! <</SYS>>\n"
|
||||
"[INST] Hi there! I need some help with Python. [/INST]\n"
|
||||
"Of course! What do you need help with in Python?\n"
|
||||
"[INST] I'm trying to write a function to find the factorial of a number, but I'm stuck. [/INST]\n"
|
||||
"I can help with that! Would you like a recursive or iterative solution?\n"
|
||||
"[INST] Let's go with a recursive solution. [/INST]\n"
|
||||
)
|
||||
|
||||
llama2_formatter_instance = Llama2Formatter()
|
||||
formatter_response = llama2_formatter_instance(sequence_of_messages)
|
||||
assert (
|
||||
expected_prompt == formatter_response.prompt
|
||||
), "The formatted prompt does not match the expected output."
|
||||
|
||||
|
||||
# Optionally, include a test for the 'stop' if it's part of the functionality.
|
Loading…
Reference in a new issue