diff --git a/llama_cpp/_utils.py b/llama_cpp/_utils.py index f7b6ba6..4a10647 100644 --- a/llama_cpp/_utils.py +++ b/llama_cpp/_utils.py @@ -1,7 +1,8 @@ import os import sys -import sys, traceback +import sys +from typing import Any, Dict # Avoid "LookupError: unknown encoding: ascii" when open() called in a destructor outnull_file = open(os.devnull, "w") @@ -55,3 +56,25 @@ class suppress_stdout_stderr(object): self.os.close(self.old_stdout_fileno) self.os.close(self.old_stderr_fileno) + + +class MetaSingleton(type): + """ + Metaclass for implementing the Singleton pattern. + """ + + _instances: Dict[type, Any] = {} + + def __call__(cls, *args: Any, **kwargs: Any) -> Any: + if cls not in cls._instances: + cls._instances[cls] = super(MetaSingleton, cls).__call__(*args, **kwargs) + return cls._instances[cls] + + +class Singleton(object, metaclass=MetaSingleton): + """ + Base class for implementing the Singleton pattern. + """ + + def __init__(self): + super(Singleton, self).__init__() diff --git a/llama_cpp/llama_chat_format.py b/llama_cpp/llama_chat_format.py index 0ef7bd4..3d18d90 100644 --- a/llama_cpp/llama_chat_format.py +++ b/llama_cpp/llama_chat_format.py @@ -6,18 +6,28 @@ import ctypes import dataclasses from typing import Any, Dict, Iterator, List, Optional, Tuple, Union, Protocol +import jinja2 + import llama_cpp.llama as llama import llama_cpp.llama_types as llama_types import llama_cpp.llama_grammar as llama_grammar -from ._utils import suppress_stdout_stderr +from ._utils import suppress_stdout_stderr, Singleton class LlamaChatCompletionHandler(Protocol): + """Base Protocol for a llama chat completion handler. + + Very generic protocol that can be used to implement any chat format. + The only hard requirement is that it must return a ChatCompletion when + stream=False and an iterator of ChatCompletionChunks when stream=True.""" + def __call__( self, *, + # llama.cpp instance llama: llama.Llama, + # openai api parameters messages: List[llama_types.ChatCompletionRequestMessage], functions: Optional[List[llama_types.ChatCompletionFunction]] = None, function_call: Optional[llama_types.ChatCompletionRequestFunctionCall] = None, @@ -26,8 +36,6 @@ class LlamaChatCompletionHandler(Protocol): temperature: float = 0.2, top_p: float = 0.95, top_k: int = 40, - min_p: float = 0.05, - typical_p: float = 1.0, stream: bool = False, stop: Optional[Union[str, List[str]]] = [], seed: Optional[int] = None, @@ -38,14 +46,17 @@ class LlamaChatCompletionHandler(Protocol): presence_penalty: float = 0.0, frequency_penalty: float = 0.0, repeat_penalty: float = 1.1, + model: Optional[str] = None, + logit_bias: Optional[Dict[str, float]] = None, + # llama.cpp parameters + min_p: float = 0.05, + typical_p: float = 1.0, tfs_z: float = 1.0, mirostat_mode: int = 0, mirostat_tau: float = 5.0, mirostat_eta: float = 0.1, - model: Optional[str] = None, logits_processor: Optional[llama.LogitsProcessorList] = None, grammar: Optional[llama.LlamaGrammar] = None, - logit_bias: Optional[Dict[str, float]] = None, **kwargs, # type: ignore ) -> Union[ llama_types.CreateChatCompletionResponse, @@ -54,21 +65,83 @@ class LlamaChatCompletionHandler(Protocol): ... -CHAT_HANDLERS: Dict[str, LlamaChatCompletionHandler] = {} +class LlamaChatCompletionHandlerNotFoundException(Exception): + pass + + +class LlamaChatCompletionHandlerRegistry(Singleton): + _chat_handlers: Dict[str, LlamaChatCompletionHandler] = {} + + def register_chat_completion_handler( + self, + name: str, + chat_handler: LlamaChatCompletionHandler, + overwrite: bool = False, + ): + if not overwrite and name in self._chat_handlers: + raise ValueError( + f"Formatter with name '{name}' is already registered. Use `overwrite=True` to overwrite it." + ) + self._chat_handlers[name] = chat_handler + + def unregister_chat_handler(self, name: str): + if name in self._chat_handlers: + del self._chat_handlers[name] + else: + raise ValueError(f"No formatter registered under the name '{name}'.") + + def get_chat_completion_handler_by_name( + self, name: str + ) -> LlamaChatCompletionHandler: + try: + chat_handler = self._chat_handlers[name] + return chat_handler + except KeyError: + raise LlamaChatCompletionHandlerNotFoundException( + f"Invalid chat handler: {name} (valid formats: {list(self._chat_handlers.keys())})" + ) def get_chat_completion_handler(name: str) -> LlamaChatCompletionHandler: - return CHAT_HANDLERS[name] + return LlamaChatCompletionHandlerRegistry().get_chat_completion_handler_by_name( + name + ) def register_chat_completion_handler(name: str): def decorator(f: LlamaChatCompletionHandler): - CHAT_HANDLERS[name] = f + LlamaChatCompletionHandlerRegistry().register_chat_completion_handler(name, f) return f return decorator +### Chat Formatter ### + + +@dataclasses.dataclass +class ChatFormatterResponse: + prompt: str + stop: Optional[Union[str, List[str]]] = None + + +class ChatFormatter(Protocol): + """Base Protocol for a chat formatter. A chat formatter is a function that + takes a list of messages and returns a formatted prompt. It can also return + a stop token or list of stop tokens to use for the completion.""" + + def __call__( + self, + *, + messages: List[llama_types.ChatCompletionRequestMessage], + **kwargs: Any, + ) -> ChatFormatterResponse: + ... + + +### Utility functions for formatting chat prompts ### + + def _get_system_message( messages: List[llama_types.ChatCompletionRequestMessage], ) -> str: @@ -80,14 +153,18 @@ def _get_system_message( def _map_roles( - messages: List[llama_types.ChatCompletionRequestMessage], role_map: Dict[str, str] + messages: List[llama_types.ChatCompletionRequestMessage], + role_map: Dict[str, str], ) -> List[Tuple[str, Optional[str]]]: """Map the message roles.""" output: List[Tuple[str, Optional[str]]] = [] for message in messages: role = message["role"] if role in role_map: - output.append((role_map[role], message["content"])) + content: str | None = ( + message["content"] if isinstance(message["content"], str) else None + ) + output.append((role_map[role], content)) return output @@ -99,7 +176,8 @@ def _format_llama2( ret = system_message + sep for i, (role, message) in enumerate(messages): if system_message and i == 0: - ret += message + seps[i % 2] + m = message or "" + ret += m + seps[i % 2] elif message: ret += role + message + " " + seps[i % 2] else: @@ -172,6 +250,7 @@ def _format_chatml( ret += role + "\n" return ret + def _format_chatglm3( system_message: str, messages: List[Tuple[str, Optional[str]]], sep: str ) -> str: @@ -187,30 +266,10 @@ def _format_chatglm3( return ret -@dataclasses.dataclass -class ChatFormatterResponse: - prompt: str - stop: Optional[Union[str, List[str]]] = None - - -class ChatFormatter(Protocol): - def __call__( - self, - *, - messages: List[llama_types.ChatCompletionRequestMessage], - **kwargs: Any, - ) -> ChatFormatterResponse: - ... - - -class BasicChatHandler: - def __init__(self, chat_format: str): - self.chat_format = chat_format - - def _convert_text_completion_to_chat( completion: llama_types.Completion, ) -> llama_types.ChatCompletion: + assert "usage" in completion return { "id": "chat" + completion["id"], "object": "chat.completion", @@ -286,103 +345,95 @@ def _convert_completion_to_chat( return _convert_text_completion_to_chat(completion) -_CHAT_FORMATS: Dict[str, ChatFormatter] = {} +def chat_formatter_to_chat_completion_handler( + chat_formatter: ChatFormatter, +) -> LlamaChatCompletionHandler: + def chat_completion_handler( + *, + llama: llama.Llama, + messages: List[llama_types.ChatCompletionRequestMessage], + functions: Optional[List[llama_types.ChatCompletionFunction]] = None, + function_call: Optional[llama_types.ChatCompletionRequestFunctionCall] = None, + tools: Optional[List[llama_types.ChatCompletionTool]] = None, + tool_choice: Optional[llama_types.ChatCompletionToolChoiceOption] = None, + temperature: float = 0.2, + top_p: float = 0.95, + top_k: int = 40, + min_p: float = 0.05, + typical_p: float = 1.0, + stream: bool = False, + stop: Optional[Union[str, List[str]]] = [], + seed: Optional[int] = None, + response_format: Optional[ + llama_types.ChatCompletionRequestResponseFormat + ] = None, + max_tokens: Optional[int] = None, + presence_penalty: float = 0.0, + frequency_penalty: float = 0.0, + repeat_penalty: float = 1.1, + tfs_z: float = 1.0, + mirostat_mode: int = 0, + mirostat_tau: float = 5.0, + mirostat_eta: float = 0.1, + model: Optional[str] = None, + logits_processor: Optional[llama.LogitsProcessorList] = None, + grammar: Optional[llama.LlamaGrammar] = None, + logit_bias: Optional[Dict[str, float]] = None, + **kwargs, # type: ignore + ) -> Union[ + llama_types.CreateChatCompletionResponse, + Iterator[llama_types.CreateChatCompletionStreamResponse], + ]: + result = chat_formatter( + messages=messages, + functions=functions, + function_call=function_call, + ) + prompt = result.prompt + if result.stop is not None: + stop = [] if stop is None else [stop] if isinstance(stop, str) else stop + rstop = result.stop if isinstance(result.stop, list) else [result.stop] + stop = stop + rstop + + if response_format is not None and response_format["type"] == "json_object": + grammar = llama_grammar.LlamaGrammar.from_string(llama_grammar.JSON_GBNF) + + completion_or_chunks = llama.create_completion( + prompt=prompt, + temperature=temperature, + top_p=top_p, + top_k=top_k, + min_p=min_p, + typical_p=typical_p, + stream=stream, + stop=stop, + seed=seed, + max_tokens=max_tokens, + presence_penalty=presence_penalty, + frequency_penalty=frequency_penalty, + repeat_penalty=repeat_penalty, + tfs_z=tfs_z, + mirostat_mode=mirostat_mode, + mirostat_tau=mirostat_tau, + mirostat_eta=mirostat_eta, + model=model, + logits_processor=logits_processor, + grammar=grammar, + logit_bias=logit_bias, + ) + return _convert_completion_to_chat(completion_or_chunks, stream=stream) + + return chat_completion_handler def register_chat_format(name: str): def decorator(f: ChatFormatter): - def basic_create_chat_completion( - *, - llama: llama.Llama, - messages: List[llama_types.ChatCompletionRequestMessage], - functions: Optional[List[llama_types.ChatCompletionFunction]] = None, - function_call: Optional[ - llama_types.ChatCompletionRequestFunctionCall - ] = None, - tools: Optional[List[llama_types.ChatCompletionTool]] = None, - tool_choice: Optional[llama_types.ChatCompletionToolChoiceOption] = None, - temperature: float = 0.2, - top_p: float = 0.95, - top_k: int = 40, - min_p: float = 0.05, - typical_p: float = 1.0, - stream: bool = False, - stop: Optional[Union[str, List[str]]] = [], - seed: Optional[int] = None, - response_format: Optional[ - llama_types.ChatCompletionRequestResponseFormat - ] = None, - max_tokens: Optional[int] = None, - presence_penalty: float = 0.0, - frequency_penalty: float = 0.0, - repeat_penalty: float = 1.1, - tfs_z: float = 1.0, - mirostat_mode: int = 0, - mirostat_tau: float = 5.0, - mirostat_eta: float = 0.1, - model: Optional[str] = None, - logits_processor: Optional[llama.LogitsProcessorList] = None, - grammar: Optional[llama.LlamaGrammar] = None, - logit_bias: Optional[Dict[str, float]] = None, - **kwargs, # type: ignore - ) -> Union[ - llama_types.CreateChatCompletionResponse, - Iterator[llama_types.CreateChatCompletionStreamResponse], - ]: - result = f( - messages=messages, - functions=functions, - function_call=function_call, - ) - prompt = result.prompt - if result.stop is not None: - stop = [] if stop is None else [stop] if isinstance(stop, str) else stop - rstop = result.stop if isinstance(result.stop, list) else [result.stop] - stop = stop + rstop - - if response_format is not None and response_format["type"] == "json_object": - grammar = llama_grammar.LlamaGrammar.from_string( - llama_grammar.JSON_GBNF - ) - - completion_or_chunks = llama.create_completion( - prompt=prompt, - temperature=temperature, - top_p=top_p, - top_k=top_k, - min_p=min_p, - typical_p=typical_p, - stream=stream, - stop=stop, - seed=seed, - max_tokens=max_tokens, - presence_penalty=presence_penalty, - frequency_penalty=frequency_penalty, - repeat_penalty=repeat_penalty, - tfs_z=tfs_z, - mirostat_mode=mirostat_mode, - mirostat_tau=mirostat_tau, - mirostat_eta=mirostat_eta, - model=model, - logits_processor=logits_processor, - grammar=grammar, - logit_bias=logit_bias, - ) - return _convert_completion_to_chat(completion_or_chunks, stream=stream) - - register_chat_completion_handler(name)(basic_create_chat_completion) - return f - - return decorator - - -def get_chat_format(name: str): - try: - return _CHAT_FORMATS[name] - except KeyError: - raise ValueError( - f"Invalid chat format: {name} (valid formats: {list(_CHAT_FORMATS.keys())})" + chat_completion_handler = chat_formatter_to_chat_completion_handler(f) + LlamaChatCompletionHandlerRegistry().register_chat_completion_handler( + name, chat_completion_handler ) + return f + return decorator def hf_autotokenizer_to_chat_formatter( @@ -391,22 +442,78 @@ def hf_autotokenizer_to_chat_formatter( # https://huggingface.co/docs/transformers/main/chat_templating # https://huggingface.co/mistralai/Mistral-7B-Instruct-v0.1#instruction-format # https://huggingface.co/mistralai/Mistral-7B-Instruct-v0.1/blob/main/tokenizer_config.json - from transformers import AutoTokenizer + from transformers import AutoTokenizer # type: ignore - tokenizer = AutoTokenizer.from_pretrained(pretrained_model_name_or_path) + tokenizer = AutoTokenizer.from_pretrained(pretrained_model_name_or_path) # type: ignore def format_autotokenizer( messages: List[llama_types.ChatCompletionRequestMessage], **kwargs: Any, ) -> ChatFormatterResponse: - tokenizer.use_default_system_prompt = False - _prompt = tokenizer.apply_chat_template(messages, tokenize=False) + tokenizer.use_default_system_prompt = False # type: ignore + prompt: str = tokenizer.apply_chat_template(messages, tokenize=False) # type: ignore + assert isinstance(prompt, str) # Return formatted prompt and eos token by default - return ChatFormatterResponse(prompt=_prompt, stop=tokenizer.eos_token) + return ChatFormatterResponse(prompt=prompt, stop=tokenizer.eos_token) return format_autotokenizer +def hf_autotokenizer_to_chat_completion_handler( + pretrained_model_name_or_path: Union[str, os.PathLike[str]] +) -> LlamaChatCompletionHandler: + chat_formatter = hf_autotokenizer_to_chat_formatter(pretrained_model_name_or_path) + return chat_formatter_to_chat_completion_handler(chat_formatter) + + +def hf_tokenizer_config_to_chat_formatter(tokenizer_config: Dict[str, Any]) -> ChatFormatter: + assert isinstance(tokenizer_config, dict) + + assert "chat_template" in tokenizer_config + assert isinstance(tokenizer_config["chat_template"], str) + chat_template = tokenizer_config["chat_template"] + + assert "bos_token" in tokenizer_config + assert isinstance(tokenizer_config["bos_token"], str) + bos_token = tokenizer_config["bos_token"] + + assert "eos_token" in tokenizer_config + assert isinstance(tokenizer_config["eos_token"], str) + eos_token = tokenizer_config["eos_token"] + + env = jinja2.Environment( + loader=jinja2.BaseLoader(), + trim_blocks=True, + lstrip_blocks=True, + ).from_string(chat_template) + + def format_autotokenizer( + messages: List[llama_types.ChatCompletionRequestMessage], + **kwargs: Any, + ) -> ChatFormatterResponse: + # TODO: veryify this is correct + # Add a blank assistant message to the end of the messages to prompt the model to generate a response + prompt = env.render( + messages=[ + *messages, + llama_types.ChatCompletionRequestAssistantMessage( + role="assistant", content="" + ), + ], + bos_token=bos_token, + eos_token=eos_token, + ) + return ChatFormatterResponse(prompt=prompt, stop=eos_token) + return format_autotokenizer + + +def hf_tokenizer_config_to_chat_completion_handler( + tokenizer_config: Dict[str, Any], +) -> LlamaChatCompletionHandler: + chat_formatter = hf_tokenizer_config_to_chat_formatter(tokenizer_config) + return chat_formatter_to_chat_completion_handler(chat_formatter) + + # see https://github.com/huggingface/transformers/blob/main/src/transformers/models/llama/tokenization_llama.py # system prompt is "embedded" in the first message @register_chat_format("llama-2") @@ -437,21 +544,23 @@ def format_alpaca( _prompt = _format_add_colon_two(system_message, _messages, _sep, _sep2) return ChatFormatterResponse(prompt=_prompt) + @register_chat_format("qwen") def format_qwen( messages: List[llama_types.ChatCompletionRequestMessage], **kwargs: Any, ) -> ChatFormatterResponse: _roles = dict(user="<|im_start|>user", assistant="<|im_start|>assistant") - system_message="You are a helpful assistant." - system_template="<|im_start|>system\n{system_message}" - system_message=system_template.format(system_message=system_message) + system_message = "You are a helpful assistant." + system_template = "<|im_start|>system\n{system_message}" + system_message = system_template.format(system_message=system_message) _messages = _map_roles(messages, _roles) _messages.append((_roles["assistant"], None)) _sep = "<|im_end|>" _prompt = _format_chatml(system_message, _messages, _sep) _sep2 = "<|endoftext|>" - return ChatFormatterResponse(prompt=_prompt,stop=_sep2) + return ChatFormatterResponse(prompt=_prompt, stop=_sep2) + @register_chat_format("vicuna") def format( @@ -650,6 +759,7 @@ def format_mistrallite( _prompt = _format_no_colon_single(system_message, _messages, _sep) return ChatFormatterResponse(prompt=_prompt) + @register_chat_format("zephyr") def format_zephyr( messages: List[llama_types.ChatCompletionRequestMessage], @@ -699,6 +809,7 @@ def format_chatml( _prompt = _format_chatml(system_message, _messages, _sep) return ChatFormatterResponse(prompt=_prompt, stop=_sep) + @register_chat_format("chatglm3") def format_chatglm3( messages: List[llama_types.ChatCompletionRequestMessage], @@ -739,7 +850,7 @@ def format_openchat( @register_chat_format("saiga") def format_saiga( messages: list[llama_types.ChatCompletionRequestMessage], - **kwargs, + **kwargs: Any, ) -> ChatFormatterResponse: _message_template = "{role}\n{content}" _roles = dict(user="user", bot="bot", system="system") diff --git a/llama_cpp/llama_jinja_format.py b/llama_cpp/llama_jinja_format.py deleted file mode 100644 index 68faaf6..0000000 --- a/llama_cpp/llama_jinja_format.py +++ /dev/null @@ -1,138 +0,0 @@ -""" -llama_cpp/llama_jinja_format.py -""" -import dataclasses -from typing import Any, Callable, Dict, List, Optional, Protocol, Union - -import jinja2 -from jinja2 import Template - -# NOTE: We sacrifice readability for usability. -# It will fail to work as expected if we attempt to format it in a readable way. -llama2_template = """{% for message in messages %}{% if message['role'] == 'user' %}[INST] {{ message['content'] }} [/INST]\n{% elif message['role'] == 'assistant' %}{{ message['content'] }}\n{% elif message['role'] == 'system' %}<> {{ message['content'] }} <>\n{% endif %}{% endfor %}""" - - -class MetaSingleton(type): - """ - Metaclass for implementing the Singleton pattern. - """ - - _instances = {} - - def __call__(cls, *args, **kwargs): - if cls not in cls._instances: - cls._instances[cls] = super(MetaSingleton, cls).__call__(*args, **kwargs) - return cls._instances[cls] - - -class Singleton(object, metaclass=MetaSingleton): - """ - Base class for implementing the Singleton pattern. - """ - - def __init__(self): - super(Singleton, self).__init__() - - -@dataclasses.dataclass -class ChatFormatterResponse: - prompt: str - stop: Optional[Union[str, List[str]]] = None - - -# Base Chat Formatter Protocol -class ChatFormatterInterface(Protocol): - def __init__(self, template: Optional[object] = None): - ... - - def __call__( - self, - messages: List[Dict[str, str]], - **kwargs, - ) -> ChatFormatterResponse: - ... - - @property - def template(self) -> str: - ... - - -class AutoChatFormatter(ChatFormatterInterface): - def __init__( - self, - template: Optional[str] = None, - template_class: Optional[Template] = None, - ): - if template is not None: - self._template = template - else: - self._template = llama2_template # default template - - self._environment = jinja2.Environment( - loader=jinja2.BaseLoader(), - trim_blocks=True, - lstrip_blocks=True, - ).from_string( - self._template, - template_class=template_class, - ) - - def __call__( - self, - messages: List[Dict[str, str]], - **kwargs: Any, - ) -> ChatFormatterResponse: - formatted_sequence = self._environment.render(messages=messages, **kwargs) - return ChatFormatterResponse(prompt=formatted_sequence) - - @property - def template(self) -> str: - return self._template - - -class FormatterNotFoundException(Exception): - pass - - -class ChatFormatterFactory(Singleton): - _chat_formatters: Dict[str, Callable[[], ChatFormatterInterface]] = {} - - def register_formatter( - self, - name: str, - formatter_callable: Callable[[], ChatFormatterInterface], - overwrite=False, - ): - if not overwrite and name in self._chat_formatters: - raise ValueError( - f"Formatter with name '{name}' is already registered. Use `overwrite=True` to overwrite it." - ) - self._chat_formatters[name] = formatter_callable - - def unregister_formatter(self, name: str): - if name in self._chat_formatters: - del self._chat_formatters[name] - else: - raise ValueError(f"No formatter registered under the name '{name}'.") - - def get_formatter_by_name(self, name: str) -> ChatFormatterInterface: - try: - formatter_callable = self._chat_formatters[name] - return formatter_callable() - except KeyError: - raise FormatterNotFoundException( - f"Invalid chat format: {name} (valid formats: {list(self._chat_formatters.keys())})" - ) - - -# Define a chat format class -class Llama2Formatter(AutoChatFormatter): - def __init__(self): - super().__init__(llama2_template) - - -# With the Singleton pattern applied, regardless of where or how many times -# ChatFormatterFactory() is called, it will always return the same instance -# of the factory, ensuring that the factory's state is consistent throughout -# the application. -ChatFormatterFactory().register_formatter("llama-2", Llama2Formatter) diff --git a/llama_cpp/server/model.py b/llama_cpp/server/model.py index f9be323..c2d6b6d 100644 --- a/llama_cpp/server/model.py +++ b/llama_cpp/server/model.py @@ -1,5 +1,7 @@ from __future__ import annotations +import json + from typing import Dict, Optional, Union, List import llama_cpp @@ -71,7 +73,25 @@ class LlamaProxy: chat_handler = llama_cpp.llama_chat_format.Llava15ChatHandler( clip_model_path=settings.clip_model_path, verbose=settings.verbose ) - + elif settings.chat_format == "hf-autotokenizer": + assert ( + settings.hf_pretrained_model_name_or_path is not None + ), "hf_pretrained_model_name_or_path must be set for hf-autotokenizer" + chat_handler = ( + llama_cpp.llama_chat_format.hf_autotokenizer_to_chat_formatter( + settings.hf_pretrained_model_name_or_path + ) + ) + elif settings.chat_format == "hf-tokenizer-config": + assert ( + settings.hf_tokenizer_config_path is not None + ), "hf_tokenizer_config_path must be set for hf-tokenizer-config" + chat_handler = ( + llama_cpp.llama_chat_format.hf_tokenizer_config_to_chat_formatter( + json.load(open(settings.hf_tokenizer_config_path)) + ) + ) + kv_overrides: Optional[Dict[str, Union[bool, int, float]]] = None if settings.kv_overrides is not None: assert isinstance(settings.kv_overrides, list) @@ -141,4 +161,3 @@ class LlamaProxy: cache = llama_cpp.LlamaRAMCache(capacity_bytes=settings.cache_size) _model.set_cache(cache) return _model - diff --git a/llama_cpp/server/settings.py b/llama_cpp/server/settings.py index dc5be20..9f0dc8a 100644 --- a/llama_cpp/server/settings.py +++ b/llama_cpp/server/settings.py @@ -134,6 +134,15 @@ class ModelSettings(BaseSettings): default=2 << 30, description="The size of the cache in bytes. Only used if cache is True.", ) + # Tokenizer Options + hf_tokenizer_config_path: Optional[str] = Field( + default=None, + description="The path to a HuggingFace tokenizer_config.json file.", + ) + hf_pretrained_model_name_or_path: Optional[str] = Field( + default=None, + description="The model name or path to a pretrained HuggingFace tokenizer model. Same as you would pass to AutoTokenizer.from_pretrained().", + ) # Misc verbose: bool = Field( default=True, description="Whether to print debug information." diff --git a/tests/test_llama_chat_format.py b/tests/test_llama_chat_format.py index 4eebcb6..1ef18d9 100644 --- a/tests/test_llama_chat_format.py +++ b/tests/test_llama_chat_format.py @@ -1,50 +1,65 @@ -from typing import List +import json -import pytest - -from llama_cpp import ChatCompletionMessage -from llama_cpp.llama_jinja_format import Llama2Formatter +from llama_cpp import ( + ChatCompletionRequestUserMessage, +) +from llama_cpp.llama_chat_format import hf_tokenizer_config_to_chat_formatter -@pytest.fixture -def sequence_of_messages() -> List[ChatCompletionMessage]: - return [ - ChatCompletionMessage(role="system", content="Welcome to CodeHelp Bot!"), - ChatCompletionMessage( - role="user", content="Hi there! I need some help with Python." - ), - ChatCompletionMessage( - role="assistant", content="Of course! What do you need help with in Python?" - ), - ChatCompletionMessage( - role="user", - content="I'm trying to write a function to find the factorial of a number, but I'm stuck.", - ), - ChatCompletionMessage( - role="assistant", - content="I can help with that! Would you like a recursive or iterative solution?", - ), - ChatCompletionMessage( - role="user", content="Let's go with a recursive solution." - ), - ] +mistral_7b_tokenizer_config = """{ + "add_bos_token": true, + "add_eos_token": false, + "added_tokens_decoder": { + "0": { + "content": "", + "lstrip": false, + "normalized": false, + "rstrip": false, + "single_word": false, + "special": true + }, + "1": { + "content": "", + "lstrip": false, + "normalized": false, + "rstrip": false, + "single_word": false, + "special": true + }, + "2": { + "content": "", + "lstrip": false, + "normalized": false, + "rstrip": false, + "single_word": false, + "special": true + } + }, + "additional_special_tokens": [], + "bos_token": "", + "clean_up_tokenization_spaces": false, + "eos_token": "", + "legacy": true, + "model_max_length": 1000000000000000019884624838656, + "pad_token": null, + "sp_model_kwargs": {}, + "spaces_between_special_tokens": false, + "tokenizer_class": "LlamaTokenizer", + "unk_token": "", + "use_default_system_prompt": false, + "chat_template": "{{ bos_token }}{% for message in messages %}{% if (message['role'] == 'user') != (loop.index0 % 2 == 0) %}{{ raise_exception('Conversation roles must alternate user/assistant/user/assistant/...') }}{% endif %}{% if message['role'] == 'user' %}{{ '[INST] ' + message['content'] + ' [/INST]' }}{% elif message['role'] == 'assistant' %}{{ message['content'] + eos_token}}{% else %}{{ raise_exception('Only user and assistant roles are supported!') }}{% endif %}{% endfor %}" +}""" -def test_llama2_formatter(sequence_of_messages): - expected_prompt = ( - "<> Welcome to CodeHelp Bot! <>\n" - "[INST] Hi there! I need some help with Python. [/INST]\n" - "Of course! What do you need help with in Python?\n" - "[INST] I'm trying to write a function to find the factorial of a number, but I'm stuck. [/INST]\n" - "I can help with that! Would you like a recursive or iterative solution?\n" - "[INST] Let's go with a recursive solution. [/INST]\n" +def test_hf_tokenizer_config_str_to_chat_formatter(): + tokenizer_config = json.loads(mistral_7b_tokenizer_config) + chat_formatter = hf_tokenizer_config_to_chat_formatter( + tokenizer_config + ) + chat_formatter_respoonse = chat_formatter( + messages=[ + ChatCompletionRequestUserMessage(role="user", content="Hello, world!"), + ] ) - llama2_formatter_instance = Llama2Formatter() - formatter_response = llama2_formatter_instance(sequence_of_messages) - assert ( - expected_prompt == formatter_response.prompt - ), "The formatted prompt does not match the expected output." - - -# Optionally, include a test for the 'stop' if it's part of the functionality. + assert chat_formatter_respoonse.prompt == ("[INST] Hello, world! [/INST]" "")