diff --git a/llama_cpp/_utils.py b/llama_cpp/_utils.py
index f7b6ba6..4a10647 100644
--- a/llama_cpp/_utils.py
+++ b/llama_cpp/_utils.py
@@ -1,7 +1,8 @@
import os
import sys
-import sys, traceback
+import sys
+from typing import Any, Dict
# Avoid "LookupError: unknown encoding: ascii" when open() called in a destructor
outnull_file = open(os.devnull, "w")
@@ -55,3 +56,25 @@ class suppress_stdout_stderr(object):
self.os.close(self.old_stdout_fileno)
self.os.close(self.old_stderr_fileno)
+
+
+class MetaSingleton(type):
+ """
+ Metaclass for implementing the Singleton pattern.
+ """
+
+ _instances: Dict[type, Any] = {}
+
+ def __call__(cls, *args: Any, **kwargs: Any) -> Any:
+ if cls not in cls._instances:
+ cls._instances[cls] = super(MetaSingleton, cls).__call__(*args, **kwargs)
+ return cls._instances[cls]
+
+
+class Singleton(object, metaclass=MetaSingleton):
+ """
+ Base class for implementing the Singleton pattern.
+ """
+
+ def __init__(self):
+ super(Singleton, self).__init__()
diff --git a/llama_cpp/llama_chat_format.py b/llama_cpp/llama_chat_format.py
index 0ef7bd4..3d18d90 100644
--- a/llama_cpp/llama_chat_format.py
+++ b/llama_cpp/llama_chat_format.py
@@ -6,18 +6,28 @@ import ctypes
import dataclasses
from typing import Any, Dict, Iterator, List, Optional, Tuple, Union, Protocol
+import jinja2
+
import llama_cpp.llama as llama
import llama_cpp.llama_types as llama_types
import llama_cpp.llama_grammar as llama_grammar
-from ._utils import suppress_stdout_stderr
+from ._utils import suppress_stdout_stderr, Singleton
class LlamaChatCompletionHandler(Protocol):
+ """Base Protocol for a llama chat completion handler.
+
+ Very generic protocol that can be used to implement any chat format.
+ The only hard requirement is that it must return a ChatCompletion when
+ stream=False and an iterator of ChatCompletionChunks when stream=True."""
+
def __call__(
self,
*,
+ # llama.cpp instance
llama: llama.Llama,
+ # openai api parameters
messages: List[llama_types.ChatCompletionRequestMessage],
functions: Optional[List[llama_types.ChatCompletionFunction]] = None,
function_call: Optional[llama_types.ChatCompletionRequestFunctionCall] = None,
@@ -26,8 +36,6 @@ class LlamaChatCompletionHandler(Protocol):
temperature: float = 0.2,
top_p: float = 0.95,
top_k: int = 40,
- min_p: float = 0.05,
- typical_p: float = 1.0,
stream: bool = False,
stop: Optional[Union[str, List[str]]] = [],
seed: Optional[int] = None,
@@ -38,14 +46,17 @@ class LlamaChatCompletionHandler(Protocol):
presence_penalty: float = 0.0,
frequency_penalty: float = 0.0,
repeat_penalty: float = 1.1,
+ model: Optional[str] = None,
+ logit_bias: Optional[Dict[str, float]] = None,
+ # llama.cpp parameters
+ min_p: float = 0.05,
+ typical_p: float = 1.0,
tfs_z: float = 1.0,
mirostat_mode: int = 0,
mirostat_tau: float = 5.0,
mirostat_eta: float = 0.1,
- model: Optional[str] = None,
logits_processor: Optional[llama.LogitsProcessorList] = None,
grammar: Optional[llama.LlamaGrammar] = None,
- logit_bias: Optional[Dict[str, float]] = None,
**kwargs, # type: ignore
) -> Union[
llama_types.CreateChatCompletionResponse,
@@ -54,21 +65,83 @@ class LlamaChatCompletionHandler(Protocol):
...
-CHAT_HANDLERS: Dict[str, LlamaChatCompletionHandler] = {}
+class LlamaChatCompletionHandlerNotFoundException(Exception):
+ pass
+
+
+class LlamaChatCompletionHandlerRegistry(Singleton):
+ _chat_handlers: Dict[str, LlamaChatCompletionHandler] = {}
+
+ def register_chat_completion_handler(
+ self,
+ name: str,
+ chat_handler: LlamaChatCompletionHandler,
+ overwrite: bool = False,
+ ):
+ if not overwrite and name in self._chat_handlers:
+ raise ValueError(
+ f"Formatter with name '{name}' is already registered. Use `overwrite=True` to overwrite it."
+ )
+ self._chat_handlers[name] = chat_handler
+
+ def unregister_chat_handler(self, name: str):
+ if name in self._chat_handlers:
+ del self._chat_handlers[name]
+ else:
+ raise ValueError(f"No formatter registered under the name '{name}'.")
+
+ def get_chat_completion_handler_by_name(
+ self, name: str
+ ) -> LlamaChatCompletionHandler:
+ try:
+ chat_handler = self._chat_handlers[name]
+ return chat_handler
+ except KeyError:
+ raise LlamaChatCompletionHandlerNotFoundException(
+ f"Invalid chat handler: {name} (valid formats: {list(self._chat_handlers.keys())})"
+ )
def get_chat_completion_handler(name: str) -> LlamaChatCompletionHandler:
- return CHAT_HANDLERS[name]
+ return LlamaChatCompletionHandlerRegistry().get_chat_completion_handler_by_name(
+ name
+ )
def register_chat_completion_handler(name: str):
def decorator(f: LlamaChatCompletionHandler):
- CHAT_HANDLERS[name] = f
+ LlamaChatCompletionHandlerRegistry().register_chat_completion_handler(name, f)
return f
return decorator
+### Chat Formatter ###
+
+
+@dataclasses.dataclass
+class ChatFormatterResponse:
+ prompt: str
+ stop: Optional[Union[str, List[str]]] = None
+
+
+class ChatFormatter(Protocol):
+ """Base Protocol for a chat formatter. A chat formatter is a function that
+ takes a list of messages and returns a formatted prompt. It can also return
+ a stop token or list of stop tokens to use for the completion."""
+
+ def __call__(
+ self,
+ *,
+ messages: List[llama_types.ChatCompletionRequestMessage],
+ **kwargs: Any,
+ ) -> ChatFormatterResponse:
+ ...
+
+
+### Utility functions for formatting chat prompts ###
+
+
def _get_system_message(
messages: List[llama_types.ChatCompletionRequestMessage],
) -> str:
@@ -80,14 +153,18 @@ def _get_system_message(
def _map_roles(
- messages: List[llama_types.ChatCompletionRequestMessage], role_map: Dict[str, str]
+ messages: List[llama_types.ChatCompletionRequestMessage],
+ role_map: Dict[str, str],
) -> List[Tuple[str, Optional[str]]]:
"""Map the message roles."""
output: List[Tuple[str, Optional[str]]] = []
for message in messages:
role = message["role"]
if role in role_map:
- output.append((role_map[role], message["content"]))
+ content: str | None = (
+ message["content"] if isinstance(message["content"], str) else None
+ )
+ output.append((role_map[role], content))
return output
@@ -99,7 +176,8 @@ def _format_llama2(
ret = system_message + sep
for i, (role, message) in enumerate(messages):
if system_message and i == 0:
- ret += message + seps[i % 2]
+ m = message or ""
+ ret += m + seps[i % 2]
elif message:
ret += role + message + " " + seps[i % 2]
else:
@@ -172,6 +250,7 @@ def _format_chatml(
ret += role + "\n"
return ret
+
def _format_chatglm3(
system_message: str, messages: List[Tuple[str, Optional[str]]], sep: str
) -> str:
@@ -187,30 +266,10 @@ def _format_chatglm3(
return ret
-@dataclasses.dataclass
-class ChatFormatterResponse:
- prompt: str
- stop: Optional[Union[str, List[str]]] = None
-
-
-class ChatFormatter(Protocol):
- def __call__(
- self,
- *,
- messages: List[llama_types.ChatCompletionRequestMessage],
- **kwargs: Any,
- ) -> ChatFormatterResponse:
- ...
-
-
-class BasicChatHandler:
- def __init__(self, chat_format: str):
- self.chat_format = chat_format
-
-
def _convert_text_completion_to_chat(
completion: llama_types.Completion,
) -> llama_types.ChatCompletion:
+ assert "usage" in completion
return {
"id": "chat" + completion["id"],
"object": "chat.completion",
@@ -286,103 +345,95 @@ def _convert_completion_to_chat(
return _convert_text_completion_to_chat(completion)
-_CHAT_FORMATS: Dict[str, ChatFormatter] = {}
+def chat_formatter_to_chat_completion_handler(
+ chat_formatter: ChatFormatter,
+) -> LlamaChatCompletionHandler:
+ def chat_completion_handler(
+ *,
+ llama: llama.Llama,
+ messages: List[llama_types.ChatCompletionRequestMessage],
+ functions: Optional[List[llama_types.ChatCompletionFunction]] = None,
+ function_call: Optional[llama_types.ChatCompletionRequestFunctionCall] = None,
+ tools: Optional[List[llama_types.ChatCompletionTool]] = None,
+ tool_choice: Optional[llama_types.ChatCompletionToolChoiceOption] = None,
+ temperature: float = 0.2,
+ top_p: float = 0.95,
+ top_k: int = 40,
+ min_p: float = 0.05,
+ typical_p: float = 1.0,
+ stream: bool = False,
+ stop: Optional[Union[str, List[str]]] = [],
+ seed: Optional[int] = None,
+ response_format: Optional[
+ llama_types.ChatCompletionRequestResponseFormat
+ ] = None,
+ max_tokens: Optional[int] = None,
+ presence_penalty: float = 0.0,
+ frequency_penalty: float = 0.0,
+ repeat_penalty: float = 1.1,
+ tfs_z: float = 1.0,
+ mirostat_mode: int = 0,
+ mirostat_tau: float = 5.0,
+ mirostat_eta: float = 0.1,
+ model: Optional[str] = None,
+ logits_processor: Optional[llama.LogitsProcessorList] = None,
+ grammar: Optional[llama.LlamaGrammar] = None,
+ logit_bias: Optional[Dict[str, float]] = None,
+ **kwargs, # type: ignore
+ ) -> Union[
+ llama_types.CreateChatCompletionResponse,
+ Iterator[llama_types.CreateChatCompletionStreamResponse],
+ ]:
+ result = chat_formatter(
+ messages=messages,
+ functions=functions,
+ function_call=function_call,
+ )
+ prompt = result.prompt
+ if result.stop is not None:
+ stop = [] if stop is None else [stop] if isinstance(stop, str) else stop
+ rstop = result.stop if isinstance(result.stop, list) else [result.stop]
+ stop = stop + rstop
+
+ if response_format is not None and response_format["type"] == "json_object":
+ grammar = llama_grammar.LlamaGrammar.from_string(llama_grammar.JSON_GBNF)
+
+ completion_or_chunks = llama.create_completion(
+ prompt=prompt,
+ temperature=temperature,
+ top_p=top_p,
+ top_k=top_k,
+ min_p=min_p,
+ typical_p=typical_p,
+ stream=stream,
+ stop=stop,
+ seed=seed,
+ max_tokens=max_tokens,
+ presence_penalty=presence_penalty,
+ frequency_penalty=frequency_penalty,
+ repeat_penalty=repeat_penalty,
+ tfs_z=tfs_z,
+ mirostat_mode=mirostat_mode,
+ mirostat_tau=mirostat_tau,
+ mirostat_eta=mirostat_eta,
+ model=model,
+ logits_processor=logits_processor,
+ grammar=grammar,
+ logit_bias=logit_bias,
+ )
+ return _convert_completion_to_chat(completion_or_chunks, stream=stream)
+
+ return chat_completion_handler
def register_chat_format(name: str):
def decorator(f: ChatFormatter):
- def basic_create_chat_completion(
- *,
- llama: llama.Llama,
- messages: List[llama_types.ChatCompletionRequestMessage],
- functions: Optional[List[llama_types.ChatCompletionFunction]] = None,
- function_call: Optional[
- llama_types.ChatCompletionRequestFunctionCall
- ] = None,
- tools: Optional[List[llama_types.ChatCompletionTool]] = None,
- tool_choice: Optional[llama_types.ChatCompletionToolChoiceOption] = None,
- temperature: float = 0.2,
- top_p: float = 0.95,
- top_k: int = 40,
- min_p: float = 0.05,
- typical_p: float = 1.0,
- stream: bool = False,
- stop: Optional[Union[str, List[str]]] = [],
- seed: Optional[int] = None,
- response_format: Optional[
- llama_types.ChatCompletionRequestResponseFormat
- ] = None,
- max_tokens: Optional[int] = None,
- presence_penalty: float = 0.0,
- frequency_penalty: float = 0.0,
- repeat_penalty: float = 1.1,
- tfs_z: float = 1.0,
- mirostat_mode: int = 0,
- mirostat_tau: float = 5.0,
- mirostat_eta: float = 0.1,
- model: Optional[str] = None,
- logits_processor: Optional[llama.LogitsProcessorList] = None,
- grammar: Optional[llama.LlamaGrammar] = None,
- logit_bias: Optional[Dict[str, float]] = None,
- **kwargs, # type: ignore
- ) -> Union[
- llama_types.CreateChatCompletionResponse,
- Iterator[llama_types.CreateChatCompletionStreamResponse],
- ]:
- result = f(
- messages=messages,
- functions=functions,
- function_call=function_call,
- )
- prompt = result.prompt
- if result.stop is not None:
- stop = [] if stop is None else [stop] if isinstance(stop, str) else stop
- rstop = result.stop if isinstance(result.stop, list) else [result.stop]
- stop = stop + rstop
-
- if response_format is not None and response_format["type"] == "json_object":
- grammar = llama_grammar.LlamaGrammar.from_string(
- llama_grammar.JSON_GBNF
- )
-
- completion_or_chunks = llama.create_completion(
- prompt=prompt,
- temperature=temperature,
- top_p=top_p,
- top_k=top_k,
- min_p=min_p,
- typical_p=typical_p,
- stream=stream,
- stop=stop,
- seed=seed,
- max_tokens=max_tokens,
- presence_penalty=presence_penalty,
- frequency_penalty=frequency_penalty,
- repeat_penalty=repeat_penalty,
- tfs_z=tfs_z,
- mirostat_mode=mirostat_mode,
- mirostat_tau=mirostat_tau,
- mirostat_eta=mirostat_eta,
- model=model,
- logits_processor=logits_processor,
- grammar=grammar,
- logit_bias=logit_bias,
- )
- return _convert_completion_to_chat(completion_or_chunks, stream=stream)
-
- register_chat_completion_handler(name)(basic_create_chat_completion)
- return f
-
- return decorator
-
-
-def get_chat_format(name: str):
- try:
- return _CHAT_FORMATS[name]
- except KeyError:
- raise ValueError(
- f"Invalid chat format: {name} (valid formats: {list(_CHAT_FORMATS.keys())})"
+ chat_completion_handler = chat_formatter_to_chat_completion_handler(f)
+ LlamaChatCompletionHandlerRegistry().register_chat_completion_handler(
+ name, chat_completion_handler
)
+ return f
+ return decorator
def hf_autotokenizer_to_chat_formatter(
@@ -391,22 +442,78 @@ def hf_autotokenizer_to_chat_formatter(
# https://huggingface.co/docs/transformers/main/chat_templating
# https://huggingface.co/mistralai/Mistral-7B-Instruct-v0.1#instruction-format
# https://huggingface.co/mistralai/Mistral-7B-Instruct-v0.1/blob/main/tokenizer_config.json
- from transformers import AutoTokenizer
+ from transformers import AutoTokenizer # type: ignore
- tokenizer = AutoTokenizer.from_pretrained(pretrained_model_name_or_path)
+ tokenizer = AutoTokenizer.from_pretrained(pretrained_model_name_or_path) # type: ignore
def format_autotokenizer(
messages: List[llama_types.ChatCompletionRequestMessage],
**kwargs: Any,
) -> ChatFormatterResponse:
- tokenizer.use_default_system_prompt = False
- _prompt = tokenizer.apply_chat_template(messages, tokenize=False)
+ tokenizer.use_default_system_prompt = False # type: ignore
+ prompt: str = tokenizer.apply_chat_template(messages, tokenize=False) # type: ignore
+ assert isinstance(prompt, str)
# Return formatted prompt and eos token by default
- return ChatFormatterResponse(prompt=_prompt, stop=tokenizer.eos_token)
+ return ChatFormatterResponse(prompt=prompt, stop=tokenizer.eos_token)
return format_autotokenizer
+def hf_autotokenizer_to_chat_completion_handler(
+ pretrained_model_name_or_path: Union[str, os.PathLike[str]]
+) -> LlamaChatCompletionHandler:
+ chat_formatter = hf_autotokenizer_to_chat_formatter(pretrained_model_name_or_path)
+ return chat_formatter_to_chat_completion_handler(chat_formatter)
+
+
+def hf_tokenizer_config_to_chat_formatter(tokenizer_config: Dict[str, Any]) -> ChatFormatter:
+ assert isinstance(tokenizer_config, dict)
+
+ assert "chat_template" in tokenizer_config
+ assert isinstance(tokenizer_config["chat_template"], str)
+ chat_template = tokenizer_config["chat_template"]
+
+ assert "bos_token" in tokenizer_config
+ assert isinstance(tokenizer_config["bos_token"], str)
+ bos_token = tokenizer_config["bos_token"]
+
+ assert "eos_token" in tokenizer_config
+ assert isinstance(tokenizer_config["eos_token"], str)
+ eos_token = tokenizer_config["eos_token"]
+
+ env = jinja2.Environment(
+ loader=jinja2.BaseLoader(),
+ trim_blocks=True,
+ lstrip_blocks=True,
+ ).from_string(chat_template)
+
+ def format_autotokenizer(
+ messages: List[llama_types.ChatCompletionRequestMessage],
+ **kwargs: Any,
+ ) -> ChatFormatterResponse:
+ # TODO: veryify this is correct
+ # Add a blank assistant message to the end of the messages to prompt the model to generate a response
+ prompt = env.render(
+ messages=[
+ *messages,
+ llama_types.ChatCompletionRequestAssistantMessage(
+ role="assistant", content=""
+ ),
+ ],
+ bos_token=bos_token,
+ eos_token=eos_token,
+ )
+ return ChatFormatterResponse(prompt=prompt, stop=eos_token)
+ return format_autotokenizer
+
+
+def hf_tokenizer_config_to_chat_completion_handler(
+ tokenizer_config: Dict[str, Any],
+) -> LlamaChatCompletionHandler:
+ chat_formatter = hf_tokenizer_config_to_chat_formatter(tokenizer_config)
+ return chat_formatter_to_chat_completion_handler(chat_formatter)
+
+
# see https://github.com/huggingface/transformers/blob/main/src/transformers/models/llama/tokenization_llama.py
# system prompt is "embedded" in the first message
@register_chat_format("llama-2")
@@ -437,21 +544,23 @@ def format_alpaca(
_prompt = _format_add_colon_two(system_message, _messages, _sep, _sep2)
return ChatFormatterResponse(prompt=_prompt)
+
@register_chat_format("qwen")
def format_qwen(
messages: List[llama_types.ChatCompletionRequestMessage],
**kwargs: Any,
) -> ChatFormatterResponse:
_roles = dict(user="<|im_start|>user", assistant="<|im_start|>assistant")
- system_message="You are a helpful assistant."
- system_template="<|im_start|>system\n{system_message}"
- system_message=system_template.format(system_message=system_message)
+ system_message = "You are a helpful assistant."
+ system_template = "<|im_start|>system\n{system_message}"
+ system_message = system_template.format(system_message=system_message)
_messages = _map_roles(messages, _roles)
_messages.append((_roles["assistant"], None))
_sep = "<|im_end|>"
_prompt = _format_chatml(system_message, _messages, _sep)
_sep2 = "<|endoftext|>"
- return ChatFormatterResponse(prompt=_prompt,stop=_sep2)
+ return ChatFormatterResponse(prompt=_prompt, stop=_sep2)
+
@register_chat_format("vicuna")
def format(
@@ -650,6 +759,7 @@ def format_mistrallite(
_prompt = _format_no_colon_single(system_message, _messages, _sep)
return ChatFormatterResponse(prompt=_prompt)
+
@register_chat_format("zephyr")
def format_zephyr(
messages: List[llama_types.ChatCompletionRequestMessage],
@@ -699,6 +809,7 @@ def format_chatml(
_prompt = _format_chatml(system_message, _messages, _sep)
return ChatFormatterResponse(prompt=_prompt, stop=_sep)
+
@register_chat_format("chatglm3")
def format_chatglm3(
messages: List[llama_types.ChatCompletionRequestMessage],
@@ -739,7 +850,7 @@ def format_openchat(
@register_chat_format("saiga")
def format_saiga(
messages: list[llama_types.ChatCompletionRequestMessage],
- **kwargs,
+ **kwargs: Any,
) -> ChatFormatterResponse:
_message_template = "{role}\n{content}"
_roles = dict(user="user", bot="bot", system="system")
diff --git a/llama_cpp/llama_jinja_format.py b/llama_cpp/llama_jinja_format.py
deleted file mode 100644
index 68faaf6..0000000
--- a/llama_cpp/llama_jinja_format.py
+++ /dev/null
@@ -1,138 +0,0 @@
-"""
-llama_cpp/llama_jinja_format.py
-"""
-import dataclasses
-from typing import Any, Callable, Dict, List, Optional, Protocol, Union
-
-import jinja2
-from jinja2 import Template
-
-# NOTE: We sacrifice readability for usability.
-# It will fail to work as expected if we attempt to format it in a readable way.
-llama2_template = """{% for message in messages %}{% if message['role'] == 'user' %}[INST] {{ message['content'] }} [/INST]\n{% elif message['role'] == 'assistant' %}{{ message['content'] }}\n{% elif message['role'] == 'system' %}<> {{ message['content'] }} <>\n{% endif %}{% endfor %}"""
-
-
-class MetaSingleton(type):
- """
- Metaclass for implementing the Singleton pattern.
- """
-
- _instances = {}
-
- def __call__(cls, *args, **kwargs):
- if cls not in cls._instances:
- cls._instances[cls] = super(MetaSingleton, cls).__call__(*args, **kwargs)
- return cls._instances[cls]
-
-
-class Singleton(object, metaclass=MetaSingleton):
- """
- Base class for implementing the Singleton pattern.
- """
-
- def __init__(self):
- super(Singleton, self).__init__()
-
-
-@dataclasses.dataclass
-class ChatFormatterResponse:
- prompt: str
- stop: Optional[Union[str, List[str]]] = None
-
-
-# Base Chat Formatter Protocol
-class ChatFormatterInterface(Protocol):
- def __init__(self, template: Optional[object] = None):
- ...
-
- def __call__(
- self,
- messages: List[Dict[str, str]],
- **kwargs,
- ) -> ChatFormatterResponse:
- ...
-
- @property
- def template(self) -> str:
- ...
-
-
-class AutoChatFormatter(ChatFormatterInterface):
- def __init__(
- self,
- template: Optional[str] = None,
- template_class: Optional[Template] = None,
- ):
- if template is not None:
- self._template = template
- else:
- self._template = llama2_template # default template
-
- self._environment = jinja2.Environment(
- loader=jinja2.BaseLoader(),
- trim_blocks=True,
- lstrip_blocks=True,
- ).from_string(
- self._template,
- template_class=template_class,
- )
-
- def __call__(
- self,
- messages: List[Dict[str, str]],
- **kwargs: Any,
- ) -> ChatFormatterResponse:
- formatted_sequence = self._environment.render(messages=messages, **kwargs)
- return ChatFormatterResponse(prompt=formatted_sequence)
-
- @property
- def template(self) -> str:
- return self._template
-
-
-class FormatterNotFoundException(Exception):
- pass
-
-
-class ChatFormatterFactory(Singleton):
- _chat_formatters: Dict[str, Callable[[], ChatFormatterInterface]] = {}
-
- def register_formatter(
- self,
- name: str,
- formatter_callable: Callable[[], ChatFormatterInterface],
- overwrite=False,
- ):
- if not overwrite and name in self._chat_formatters:
- raise ValueError(
- f"Formatter with name '{name}' is already registered. Use `overwrite=True` to overwrite it."
- )
- self._chat_formatters[name] = formatter_callable
-
- def unregister_formatter(self, name: str):
- if name in self._chat_formatters:
- del self._chat_formatters[name]
- else:
- raise ValueError(f"No formatter registered under the name '{name}'.")
-
- def get_formatter_by_name(self, name: str) -> ChatFormatterInterface:
- try:
- formatter_callable = self._chat_formatters[name]
- return formatter_callable()
- except KeyError:
- raise FormatterNotFoundException(
- f"Invalid chat format: {name} (valid formats: {list(self._chat_formatters.keys())})"
- )
-
-
-# Define a chat format class
-class Llama2Formatter(AutoChatFormatter):
- def __init__(self):
- super().__init__(llama2_template)
-
-
-# With the Singleton pattern applied, regardless of where or how many times
-# ChatFormatterFactory() is called, it will always return the same instance
-# of the factory, ensuring that the factory's state is consistent throughout
-# the application.
-ChatFormatterFactory().register_formatter("llama-2", Llama2Formatter)
diff --git a/llama_cpp/server/model.py b/llama_cpp/server/model.py
index f9be323..c2d6b6d 100644
--- a/llama_cpp/server/model.py
+++ b/llama_cpp/server/model.py
@@ -1,5 +1,7 @@
from __future__ import annotations
+import json
+
from typing import Dict, Optional, Union, List
import llama_cpp
@@ -71,7 +73,25 @@ class LlamaProxy:
chat_handler = llama_cpp.llama_chat_format.Llava15ChatHandler(
clip_model_path=settings.clip_model_path, verbose=settings.verbose
)
-
+ elif settings.chat_format == "hf-autotokenizer":
+ assert (
+ settings.hf_pretrained_model_name_or_path is not None
+ ), "hf_pretrained_model_name_or_path must be set for hf-autotokenizer"
+ chat_handler = (
+ llama_cpp.llama_chat_format.hf_autotokenizer_to_chat_formatter(
+ settings.hf_pretrained_model_name_or_path
+ )
+ )
+ elif settings.chat_format == "hf-tokenizer-config":
+ assert (
+ settings.hf_tokenizer_config_path is not None
+ ), "hf_tokenizer_config_path must be set for hf-tokenizer-config"
+ chat_handler = (
+ llama_cpp.llama_chat_format.hf_tokenizer_config_to_chat_formatter(
+ json.load(open(settings.hf_tokenizer_config_path))
+ )
+ )
+
kv_overrides: Optional[Dict[str, Union[bool, int, float]]] = None
if settings.kv_overrides is not None:
assert isinstance(settings.kv_overrides, list)
@@ -141,4 +161,3 @@ class LlamaProxy:
cache = llama_cpp.LlamaRAMCache(capacity_bytes=settings.cache_size)
_model.set_cache(cache)
return _model
-
diff --git a/llama_cpp/server/settings.py b/llama_cpp/server/settings.py
index dc5be20..9f0dc8a 100644
--- a/llama_cpp/server/settings.py
+++ b/llama_cpp/server/settings.py
@@ -134,6 +134,15 @@ class ModelSettings(BaseSettings):
default=2 << 30,
description="The size of the cache in bytes. Only used if cache is True.",
)
+ # Tokenizer Options
+ hf_tokenizer_config_path: Optional[str] = Field(
+ default=None,
+ description="The path to a HuggingFace tokenizer_config.json file.",
+ )
+ hf_pretrained_model_name_or_path: Optional[str] = Field(
+ default=None,
+ description="The model name or path to a pretrained HuggingFace tokenizer model. Same as you would pass to AutoTokenizer.from_pretrained().",
+ )
# Misc
verbose: bool = Field(
default=True, description="Whether to print debug information."
diff --git a/tests/test_llama_chat_format.py b/tests/test_llama_chat_format.py
index 4eebcb6..1ef18d9 100644
--- a/tests/test_llama_chat_format.py
+++ b/tests/test_llama_chat_format.py
@@ -1,50 +1,65 @@
-from typing import List
+import json
-import pytest
-
-from llama_cpp import ChatCompletionMessage
-from llama_cpp.llama_jinja_format import Llama2Formatter
+from llama_cpp import (
+ ChatCompletionRequestUserMessage,
+)
+from llama_cpp.llama_chat_format import hf_tokenizer_config_to_chat_formatter
-@pytest.fixture
-def sequence_of_messages() -> List[ChatCompletionMessage]:
- return [
- ChatCompletionMessage(role="system", content="Welcome to CodeHelp Bot!"),
- ChatCompletionMessage(
- role="user", content="Hi there! I need some help with Python."
- ),
- ChatCompletionMessage(
- role="assistant", content="Of course! What do you need help with in Python?"
- ),
- ChatCompletionMessage(
- role="user",
- content="I'm trying to write a function to find the factorial of a number, but I'm stuck.",
- ),
- ChatCompletionMessage(
- role="assistant",
- content="I can help with that! Would you like a recursive or iterative solution?",
- ),
- ChatCompletionMessage(
- role="user", content="Let's go with a recursive solution."
- ),
- ]
+mistral_7b_tokenizer_config = """{
+ "add_bos_token": true,
+ "add_eos_token": false,
+ "added_tokens_decoder": {
+ "0": {
+ "content": "",
+ "lstrip": false,
+ "normalized": false,
+ "rstrip": false,
+ "single_word": false,
+ "special": true
+ },
+ "1": {
+ "content": "",
+ "lstrip": false,
+ "normalized": false,
+ "rstrip": false,
+ "single_word": false,
+ "special": true
+ },
+ "2": {
+ "content": "",
+ "lstrip": false,
+ "normalized": false,
+ "rstrip": false,
+ "single_word": false,
+ "special": true
+ }
+ },
+ "additional_special_tokens": [],
+ "bos_token": "",
+ "clean_up_tokenization_spaces": false,
+ "eos_token": "",
+ "legacy": true,
+ "model_max_length": 1000000000000000019884624838656,
+ "pad_token": null,
+ "sp_model_kwargs": {},
+ "spaces_between_special_tokens": false,
+ "tokenizer_class": "LlamaTokenizer",
+ "unk_token": "",
+ "use_default_system_prompt": false,
+ "chat_template": "{{ bos_token }}{% for message in messages %}{% if (message['role'] == 'user') != (loop.index0 % 2 == 0) %}{{ raise_exception('Conversation roles must alternate user/assistant/user/assistant/...') }}{% endif %}{% if message['role'] == 'user' %}{{ '[INST] ' + message['content'] + ' [/INST]' }}{% elif message['role'] == 'assistant' %}{{ message['content'] + eos_token}}{% else %}{{ raise_exception('Only user and assistant roles are supported!') }}{% endif %}{% endfor %}"
+}"""
-def test_llama2_formatter(sequence_of_messages):
- expected_prompt = (
- "<> Welcome to CodeHelp Bot! <>\n"
- "[INST] Hi there! I need some help with Python. [/INST]\n"
- "Of course! What do you need help with in Python?\n"
- "[INST] I'm trying to write a function to find the factorial of a number, but I'm stuck. [/INST]\n"
- "I can help with that! Would you like a recursive or iterative solution?\n"
- "[INST] Let's go with a recursive solution. [/INST]\n"
+def test_hf_tokenizer_config_str_to_chat_formatter():
+ tokenizer_config = json.loads(mistral_7b_tokenizer_config)
+ chat_formatter = hf_tokenizer_config_to_chat_formatter(
+ tokenizer_config
+ )
+ chat_formatter_respoonse = chat_formatter(
+ messages=[
+ ChatCompletionRequestUserMessage(role="user", content="Hello, world!"),
+ ]
)
- llama2_formatter_instance = Llama2Formatter()
- formatter_response = llama2_formatter_instance(sequence_of_messages)
- assert (
- expected_prompt == formatter_response.prompt
- ), "The formatted prompt does not match the expected output."
-
-
-# Optionally, include a test for the 'stop' if it's part of the functionality.
+ assert chat_formatter_respoonse.prompt == ("[INST] Hello, world! [/INST]" "")