llama.cpp/tests/test_llama_chat_format.py
Austin 6bfe98bd80
Integration of Jinja2 Templating (#875)
* feat: Add support for jinja templating

Signed-off-by: teleprint-me <77757836+teleprint-me@users.noreply.github.com>

* fix: Refactor chat formatter and update interface for jinja templates

- Simplify the `llama2_template` in `llama_jinja_format.py` by removing unnecessary line breaks for readability without affecting functionality.
- Update `ChatFormatterInterface` constructor to accept a more generic `Optional[object]` type for the template parameter, enhancing flexibility.
- Introduce a `template` property to `ChatFormatterInterface` for standardized access to the template string.
- Replace `MetaSingleton` metaclass with `Singleton` for the `ChatFormatterFactory` to streamline the singleton implementation.

These changes enhance code readability, maintain usability, and ensure consistency in the chat formatter's design pattern usage.

* Add outline for Jinja2 templating integration documentation

Signed-off-by: teleprint-me <77757836+teleprint-me@users.noreply.github.com>

* Add jinja2 as a dependency with version range for Hugging Face transformers compatibility

Signed-off-by: teleprint-me <77757836+teleprint-me@users.noreply.github.com>

* Update jinja2 version constraint for mkdocs-material compatibility

Signed-off-by: teleprint-me <77757836+teleprint-me@users.noreply.github.com>

* Fix attribute name in AutoChatFormatter

- Changed attribute name from `self._renderer` to `self._environment`

---------

Signed-off-by: teleprint-me <77757836+teleprint-me@users.noreply.github.com>
2024-01-17 09:47:52 -05:00

50 lines
1.8 KiB
Python

from typing import List
import pytest
from llama_cpp import ChatCompletionMessage
from llama_cpp.llama_jinja_format import Llama2Formatter
@pytest.fixture
def sequence_of_messages() -> List[ChatCompletionMessage]:
return [
ChatCompletionMessage(role="system", content="Welcome to CodeHelp Bot!"),
ChatCompletionMessage(
role="user", content="Hi there! I need some help with Python."
),
ChatCompletionMessage(
role="assistant", content="Of course! What do you need help with in Python?"
),
ChatCompletionMessage(
role="user",
content="I'm trying to write a function to find the factorial of a number, but I'm stuck.",
),
ChatCompletionMessage(
role="assistant",
content="I can help with that! Would you like a recursive or iterative solution?",
),
ChatCompletionMessage(
role="user", content="Let's go with a recursive solution."
),
]
def test_llama2_formatter(sequence_of_messages):
expected_prompt = (
"<<SYS>> Welcome to CodeHelp Bot! <</SYS>>\n"
"[INST] Hi there! I need some help with Python. [/INST]\n"
"Of course! What do you need help with in Python?\n"
"[INST] I'm trying to write a function to find the factorial of a number, but I'm stuck. [/INST]\n"
"I can help with that! Would you like a recursive or iterative solution?\n"
"[INST] Let's go with a recursive solution. [/INST]\n"
)
llama2_formatter_instance = Llama2Formatter()
formatter_response = llama2_formatter_instance(sequence_of_messages)
assert (
expected_prompt == formatter_response.prompt
), "The formatted prompt does not match the expected output."
# Optionally, include a test for the 'stop' if it's part of the functionality.