diff --git a/llama_cpp/llama_chat_format.py b/llama_cpp/llama_chat_format.py index 037f96a..6f402e0 100644 --- a/llama_cpp/llama_chat_format.py +++ b/llama_cpp/llama_chat_format.py @@ -172,6 +172,20 @@ def _format_chatml( ret += role + "\n" return ret +def _format_chatglm3( + system_message: str, messages: List[Tuple[str, Optional[str]]], sep: str +) -> str: + """Format the prompt with the chatglm3 style.""" + ret = "" + if system_message: + ret += system_message + for role, message in messages: + if message: + ret += role + "\n" + " " + message + else: + ret += role + return ret + @dataclasses.dataclass class ChatFormatterResponse: @@ -685,6 +699,22 @@ def format_chatml( _prompt = _format_chatml(system_message, _messages, _sep) return ChatFormatterResponse(prompt=_prompt, stop=_sep) +@register_chat_format("chatglm3") +def format_chatglm3( + messages: List[llama_types.ChatCompletionRequestMessage], + **kwargs: Any, +) -> ChatFormatterResponse: + system_template = """<|system|> +{system_message}""" + system_message = _get_system_message(messages) + system_message = system_template.format(system_message=system_message) + _roles = dict(user="<|user|>", assistant="<|assistant|>") + _sep = "" + _messages = _map_roles(messages, _roles) + _messages.append((_roles["assistant"], None)) + _prompt = _format_chatglm3(system_message, _messages, _sep) + return ChatFormatterResponse(prompt=_prompt, stop=_sep) + @register_chat_format("openchat") def format_openchat(