3580e2c5df
* Update llama_chat_format.py properly formal llama2 with first-message prompt embedded * Update llama_chat_format.py
714 lines
26 KiB
Python
714 lines
26 KiB
Python
from __future__ import annotations
|
|
|
|
import dataclasses
|
|
from typing import Any, Dict, Iterator, List, Optional, Tuple, Union, Protocol
|
|
from . import llama_types
|
|
from . import llama
|
|
|
|
|
|
class LlamaChatCompletionHandler(Protocol):
|
|
def __call__(
|
|
self,
|
|
llama: llama.Llama,
|
|
messages: List[llama_types.ChatCompletionRequestMessage],
|
|
functions: Optional[List[llama_types.ChatCompletionFunction]] = None,
|
|
function_call: Optional[
|
|
Union[str, llama_types.ChatCompletionFunctionCall]
|
|
] = None,
|
|
temperature: float = 0.2,
|
|
top_p: float = 0.95,
|
|
top_k: int = 40,
|
|
stream: bool = False,
|
|
stop: Optional[Union[str, List[str]]] = [],
|
|
max_tokens: int = 256,
|
|
presence_penalty: float = 0.0,
|
|
frequency_penalty: float = 0.0,
|
|
repeat_penalty: float = 1.1,
|
|
tfs_z: float = 1.0,
|
|
mirostat_mode: int = 0,
|
|
mirostat_tau: float = 5.0,
|
|
mirostat_eta: float = 0.1,
|
|
model: Optional[str] = None,
|
|
logits_processor: Optional[llama.LogitsProcessorList] = None,
|
|
grammar: Optional[llama.LlamaGrammar] = None,
|
|
) -> Union[llama_types.ChatCompletion, Iterator[llama_types.ChatCompletionChunk]]:
|
|
...
|
|
|
|
|
|
CHAT_HANDLERS: Dict[str, LlamaChatCompletionHandler] = {}
|
|
|
|
|
|
def get_chat_completion_handler(name: str) -> LlamaChatCompletionHandler:
|
|
return CHAT_HANDLERS[name]
|
|
|
|
|
|
def register_chat_completion_handler(name: str):
|
|
def decorator(f: LlamaChatCompletionHandler):
|
|
CHAT_HANDLERS[name] = f
|
|
return f
|
|
|
|
return decorator
|
|
|
|
|
|
def _get_system_message(
|
|
messages: List[llama_types.ChatCompletionRequestMessage],
|
|
) -> str:
|
|
"""Get the first system message."""
|
|
for message in messages:
|
|
if message["role"] == "system":
|
|
return message["content"] or ""
|
|
return ""
|
|
|
|
|
|
def _map_roles(
|
|
messages: List[llama_types.ChatCompletionRequestMessage], role_map: Dict[str, str]
|
|
) -> List[Tuple[str, Optional[str]]]:
|
|
"""Map the message roles."""
|
|
output: List[Tuple[str, Optional[str]]] = []
|
|
for message in messages:
|
|
role = message["role"]
|
|
if role in role_map:
|
|
output.append((role_map[role], message["content"]))
|
|
return output
|
|
|
|
|
|
def _format_llama2(
|
|
system_message: str, messages: List[Tuple[str, Optional[str]]], sep: str, sep2: str
|
|
) -> str:
|
|
"""Format the prompt with the llama2 style."""
|
|
seps = [sep, sep2]
|
|
ret = system_message + sep
|
|
for i, (role, message) in enumerate(messages):
|
|
if system_message and i == 0:
|
|
ret += message + seps[i % 2]
|
|
elif message:
|
|
ret += role + message + " " + seps[i % 2]
|
|
else:
|
|
ret += role + " "
|
|
return ret
|
|
|
|
|
|
def _format_add_colon_single(
|
|
system_message: str, messages: List[Tuple[str, Optional[str]]], sep: str
|
|
) -> str:
|
|
"""Format the prompt with the add-colon-single style."""
|
|
ret = system_message + sep
|
|
for role, message in messages:
|
|
if message:
|
|
ret += role + ": " + message + sep
|
|
else:
|
|
ret += role + ":"
|
|
return ret
|
|
|
|
|
|
def _format_add_colon_two(
|
|
system_message: str, messages: List[Tuple[str, Optional[str]]], sep: str, sep2: str
|
|
) -> str:
|
|
"""Format the prompt with the add-colon-two style."""
|
|
seps = [sep, sep2]
|
|
ret = system_message + seps[0]
|
|
for i, (role, message) in enumerate(messages):
|
|
if message:
|
|
ret += role + ": " + message + seps[i % 2]
|
|
else:
|
|
ret += role + ":"
|
|
return ret
|
|
|
|
|
|
def _format_no_colon_single(
|
|
system_message: str, messages: List[Tuple[str, Optional[str]]], sep: str
|
|
) -> str:
|
|
"""Format the prompt with the no-colon-single style."""
|
|
ret = system_message
|
|
for role, message in messages:
|
|
if message:
|
|
ret += role + message + sep
|
|
else:
|
|
ret += role
|
|
return ret
|
|
|
|
|
|
def _format_add_colon_space_single(
|
|
system_message: str, messages: List[Tuple[str, Optional[str]]], sep: str
|
|
) -> str:
|
|
"""Format the prompt with the add-colon-space-single style."""
|
|
ret = system_message + sep
|
|
for role, message in messages:
|
|
if message:
|
|
ret += role + ": " + message + sep
|
|
else:
|
|
ret += role + ": " # must be end with a space
|
|
return ret
|
|
|
|
|
|
def _format_chatml(
|
|
system_message: str, messages: List[Tuple[str, Optional[str]]], sep: str
|
|
) -> str:
|
|
"""Format the prompt with the chatml style."""
|
|
ret = "" if system_message == "" else system_message + sep + "\n"
|
|
for role, message in messages:
|
|
if message:
|
|
ret += role + "\n" + message + sep + "\n"
|
|
else:
|
|
ret += role + "\n"
|
|
return ret
|
|
|
|
|
|
@dataclasses.dataclass
|
|
class ChatFormatterResponse:
|
|
prompt: str
|
|
stop: Optional[Union[str, List[str]]] = None
|
|
|
|
|
|
class ChatFormatter(Protocol):
|
|
def __call__(
|
|
self,
|
|
messages: List[llama_types.ChatCompletionRequestMessage],
|
|
**kwargs: Any,
|
|
) -> ChatFormatterResponse:
|
|
...
|
|
|
|
|
|
class BasicChatHandler:
|
|
def __init__(self, chat_format: str):
|
|
self.chat_format = chat_format
|
|
|
|
|
|
def _convert_text_completion_to_chat(
|
|
completion: llama_types.Completion,
|
|
) -> llama_types.ChatCompletion:
|
|
return {
|
|
"id": "chat" + completion["id"],
|
|
"object": "chat.completion",
|
|
"created": completion["created"],
|
|
"model": completion["model"],
|
|
"choices": [
|
|
{
|
|
"index": 0,
|
|
"message": {
|
|
"role": "assistant",
|
|
"content": completion["choices"][0]["text"],
|
|
},
|
|
"finish_reason": completion["choices"][0]["finish_reason"],
|
|
}
|
|
],
|
|
"usage": completion["usage"],
|
|
}
|
|
|
|
|
|
def _convert_text_completion_chunks_to_chat(
|
|
chunks: Iterator[llama_types.CompletionChunk],
|
|
) -> Iterator[llama_types.ChatCompletionChunk]:
|
|
for i, chunk in enumerate(chunks):
|
|
if i == 0:
|
|
yield {
|
|
"id": "chat" + chunk["id"],
|
|
"model": chunk["model"],
|
|
"created": chunk["created"],
|
|
"object": "chat.completion.chunk",
|
|
"choices": [
|
|
{
|
|
"index": 0,
|
|
"delta": {
|
|
"role": "assistant",
|
|
},
|
|
"finish_reason": None,
|
|
}
|
|
],
|
|
}
|
|
yield {
|
|
"id": "chat" + chunk["id"],
|
|
"model": chunk["model"],
|
|
"created": chunk["created"],
|
|
"object": "chat.completion.chunk",
|
|
"choices": [
|
|
{
|
|
"index": 0,
|
|
"delta": {
|
|
"content": chunk["choices"][0]["text"],
|
|
}
|
|
if chunk["choices"][0]["finish_reason"] is None
|
|
else {},
|
|
"finish_reason": chunk["choices"][0]["finish_reason"],
|
|
}
|
|
],
|
|
}
|
|
|
|
|
|
def _convert_completion_to_chat(
|
|
completion_or_chunks: Union[
|
|
llama_types.Completion, Iterator[llama_types.CompletionChunk]
|
|
],
|
|
stream: bool = False,
|
|
) -> Union[llama_types.ChatCompletion, Iterator[llama_types.ChatCompletionChunk]]:
|
|
if stream:
|
|
chunks: Iterator[llama_types.CompletionChunk] = completion_or_chunks # type: ignore
|
|
return _convert_text_completion_chunks_to_chat(chunks)
|
|
else:
|
|
completion: llama_types.Completion = completion_or_chunks # type: ignore
|
|
return _convert_text_completion_to_chat(completion)
|
|
|
|
|
|
_CHAT_FORMATS: Dict[str, ChatFormatter] = {}
|
|
|
|
|
|
def register_chat_format(name: str):
|
|
def decorator(f: ChatFormatter):
|
|
def basic_create_chat_completion(
|
|
llama: llama.Llama,
|
|
messages: List[llama_types.ChatCompletionRequestMessage],
|
|
functions: Optional[List[llama_types.ChatCompletionFunction]] = None,
|
|
function_call: Optional[
|
|
Union[str, llama_types.ChatCompletionFunctionCall]
|
|
] = None,
|
|
temperature: float = 0.2,
|
|
top_p: float = 0.95,
|
|
top_k: int = 40,
|
|
stream: bool = False,
|
|
stop: Optional[Union[str, List[str]]] = [],
|
|
max_tokens: int = 256,
|
|
presence_penalty: float = 0.0,
|
|
frequency_penalty: float = 0.0,
|
|
repeat_penalty: float = 1.1,
|
|
tfs_z: float = 1.0,
|
|
mirostat_mode: int = 0,
|
|
mirostat_tau: float = 5.0,
|
|
mirostat_eta: float = 0.1,
|
|
model: Optional[str] = None,
|
|
logits_processor: Optional[llama.LogitsProcessorList] = None,
|
|
grammar: Optional[llama.LlamaGrammar] = None,
|
|
) -> Union[
|
|
llama_types.ChatCompletion, Iterator[llama_types.ChatCompletionChunk]
|
|
]:
|
|
result = f(
|
|
messages=messages,
|
|
functions=functions,
|
|
function_call=function_call,
|
|
)
|
|
prompt = result.prompt
|
|
if result.stop is not None:
|
|
stop = [] if stop is None else [stop] if isinstance(stop, str) else stop
|
|
rstop = result.stop if isinstance(result.stop, list) else [result.stop]
|
|
stop = stop + rstop
|
|
|
|
completion_or_chunks = llama.create_completion(
|
|
prompt=prompt,
|
|
temperature=temperature,
|
|
top_p=top_p,
|
|
top_k=top_k,
|
|
stream=stream,
|
|
stop=stop,
|
|
max_tokens=max_tokens,
|
|
presence_penalty=presence_penalty,
|
|
frequency_penalty=frequency_penalty,
|
|
repeat_penalty=repeat_penalty,
|
|
tfs_z=tfs_z,
|
|
mirostat_mode=mirostat_mode,
|
|
mirostat_tau=mirostat_tau,
|
|
mirostat_eta=mirostat_eta,
|
|
model=model,
|
|
logits_processor=logits_processor,
|
|
grammar=grammar,
|
|
)
|
|
return _convert_completion_to_chat(completion_or_chunks, stream=stream) # type: ignore
|
|
|
|
register_chat_completion_handler(name)(basic_create_chat_completion)
|
|
return f
|
|
|
|
return decorator
|
|
|
|
|
|
def get_chat_format(name: str):
|
|
try:
|
|
return _CHAT_FORMATS[name]
|
|
except KeyError:
|
|
raise ValueError(
|
|
f"Invalid chat format: {name} (valid formats: {list(_CHAT_FORMATS.keys())})"
|
|
)
|
|
|
|
|
|
# see https://github.com/huggingface/transformers/blob/main/src/transformers/models/llama/tokenization_llama.py
|
|
# system prompt is "embedded" in the first message
|
|
@register_chat_format("llama-2")
|
|
def format_llama2(
|
|
messages: List[llama_types.ChatCompletionRequestMessage],
|
|
**kwargs: Any,
|
|
) -> ChatFormatterResponse:
|
|
_system_template = "<s>[INST] <<SYS>>\n{system_message}\n<</SYS>>"
|
|
_roles = dict(user="<s>[INST]", assistant="[/INST]")
|
|
_messages = _map_roles(messages, _roles)
|
|
system_message = _get_system_message(messages)
|
|
if system_message:
|
|
system_message = _system_template.format(system_message=system_message)
|
|
_prompt = _format_llama2(system_message, _messages, " ", "</s>") + "[/INST]"
|
|
return ChatFormatterResponse(prompt=_prompt)
|
|
|
|
|
|
@register_chat_format("alpaca")
|
|
def format_alpaca(
|
|
messages: List[llama_types.ChatCompletionRequestMessage],
|
|
**kwargs: Any,
|
|
) -> ChatFormatterResponse:
|
|
_roles = dict(user="### Instruction", assistant="### Response")
|
|
_sep = "\n\n"
|
|
_sep2 = "</s>"
|
|
system_message = _get_system_message(messages)
|
|
_messages = _map_roles(messages, _roles)
|
|
_prompt = _format_add_colon_two(system_message, _messages, _sep, _sep2)
|
|
return ChatFormatterResponse(prompt=_prompt)
|
|
|
|
|
|
@register_chat_format("vicuna")
|
|
def format(
|
|
messages: List[llama_types.ChatCompletionRequestMessage],
|
|
**kwargs: Any,
|
|
) -> ChatFormatterResponse:
|
|
_system_message = "A chat between a curious user and an artificial intelligence assistant. The assistant gives helpful, detailed, and polite answers to the user's questions."
|
|
_roles = dict(user="USER", assistant="ASSISTANT")
|
|
_sep = " "
|
|
_sep2 = "</s>"
|
|
system_message = _system_message
|
|
_messages = _map_roles(messages, _roles)
|
|
_messages.append((_roles["assistant"], None))
|
|
_prompt = _format_add_colon_two(system_message, _messages, _sep, _sep2)
|
|
return ChatFormatterResponse(prompt=_prompt)
|
|
|
|
|
|
@register_chat_format("oasst_llama")
|
|
def format_oasst_llama(
|
|
messages: List[llama_types.ChatCompletionRequestMessage],
|
|
**kwargs: Any,
|
|
) -> ChatFormatterResponse:
|
|
_system_template = "[INST] <<SYS>>\n{system_message}\n<</SYS>>\n\n"
|
|
_roles = dict(user="<|prompter|>", assistant="<|assistant|>")
|
|
_sep = "</s>"
|
|
system_message = _get_system_message(messages)
|
|
system_message = _system_template.format(system_message=system_message)
|
|
_messages = _map_roles(messages, _roles)
|
|
_messages.append((_roles["assistant"], None))
|
|
_prompt = _format_no_colon_single(system_message, _messages, _sep)
|
|
return ChatFormatterResponse(prompt=_prompt)
|
|
|
|
|
|
@register_chat_format("openbuddy")
|
|
def format_openbuddy(
|
|
messages: List[llama_types.ChatCompletionRequestMessage],
|
|
**kwargs: Any,
|
|
) -> ChatFormatterResponse:
|
|
_system_message = """Consider a conversation between User (a human) and Assistant (named Buddy).
|
|
Buddy is an INTP-T, a friendly, intelligent and multilingual AI assistant, by OpenBuddy team. GitHub: https://github.com/OpenBuddy/OpenBuddy
|
|
Buddy cannot access the Internet.
|
|
Buddy can fluently speak the user's language (e.g. English, Chinese).
|
|
Buddy can generate poems, stories, code, essays, songs, parodies, and more.
|
|
Buddy possesses vast knowledge about the world, history, and culture.
|
|
Buddy's responses are always safe, creative, high-quality, human-like, and interesting.
|
|
Buddy strictly refuses to discuss political, NSFW, or other unsafe topics.
|
|
|
|
User: Hi.
|
|
Assistant: Hi, I'm Buddy, your AI assistant. How can I help you today?"""
|
|
_roles = dict(user="User", assistant="Assistant")
|
|
_sep = "\n"
|
|
system_message = _system_message
|
|
_messages = _map_roles(messages, _roles)
|
|
_messages.append((_roles["assistant"], None))
|
|
_prompt = _format_add_colon_single(system_message, _messages, _sep)
|
|
return ChatFormatterResponse(prompt=_prompt)
|
|
|
|
|
|
@register_chat_format("redpajama-incite")
|
|
def format_redpajama_incite(
|
|
messages: List[llama_types.ChatCompletionRequestMessage],
|
|
**kwargs: Any,
|
|
) -> ChatFormatterResponse:
|
|
_system_message = _get_system_message(messages)
|
|
_roles = dict(user="<human>", assistant="<bot>")
|
|
_sep = "\n"
|
|
_stop = "<human>"
|
|
system_message = _system_message
|
|
_messages = _map_roles(messages, _roles)
|
|
_messages.append((_roles["assistant"], None))
|
|
_prompt = _format_add_colon_single(system_message, _messages, _sep)
|
|
return ChatFormatterResponse(prompt=_prompt, stop=_stop)
|
|
|
|
|
|
@register_chat_format("snoozy")
|
|
def format_snoozy(
|
|
messages: List[llama_types.ChatCompletionRequestMessage],
|
|
**kwargs: Any,
|
|
) -> ChatFormatterResponse:
|
|
system_template = "### Instruction:\n{system_message}"
|
|
default_system_message = "The prompt below is a question to answer, a task to complete, or a conversation to respond to; decide which and write an appropriate response."
|
|
_system_message = _get_system_message(messages)
|
|
_system_message = (
|
|
_system_message if _system_message != "" else default_system_message
|
|
)
|
|
system_message = system_template.format(system_message=_system_message)
|
|
_roles = dict(user="### Prompt", assistant="### Response")
|
|
_sep = "\n"
|
|
_stop = "###"
|
|
system_message = _system_message
|
|
_messages = _map_roles(messages, _roles)
|
|
_messages.append((_roles["assistant"], None))
|
|
_prompt = _format_add_colon_single(system_message, _messages, _sep)
|
|
return ChatFormatterResponse(prompt=_prompt, stop=_stop)
|
|
|
|
|
|
@register_chat_format("phind")
|
|
def format_phind(
|
|
messages: List[llama_types.ChatCompletionRequestMessage],
|
|
**kwargs: Any,
|
|
) -> ChatFormatterResponse:
|
|
_roles = dict(user="### User Message", assistant="### Assistant")
|
|
_sep = "\n\n"
|
|
_system_message = "### System Prompt\nYou are an intelligent programming assistant."
|
|
_messages = _map_roles(messages, _roles)
|
|
_messages.append((_roles["assistant"], None))
|
|
_prompt = _format_add_colon_single(_system_message, _messages, _sep)
|
|
return ChatFormatterResponse(prompt=_prompt)
|
|
|
|
|
|
@register_chat_format("open-orca")
|
|
def format_open_orca(
|
|
messages: List[llama_types.ChatCompletionRequestMessage],
|
|
**kwargs: Any,
|
|
) -> ChatFormatterResponse:
|
|
system_template = "{system_message}"
|
|
system_message = (
|
|
"You are a helpful assistant. Please answer truthfully and write out your "
|
|
)
|
|
"thinking step by step to be sure you get the right answer. If you make a mistake or encounter "
|
|
"an error in your thinking, say so out loud and attempt to correct it. If you don't know or "
|
|
"aren't sure about something, say so clearly. You will act as a professional logician, mathematician, "
|
|
"and physicist. You will also act as the most appropriate type of expert to answer any particular "
|
|
"question or solve the relevant problem; state which expert type your are, if so. Also think of "
|
|
"any particular named expert that would be ideal to answer the relevant question or solve the "
|
|
"relevant problem; name and act as them, if appropriate."
|
|
roles = ("User", "Assistant")
|
|
sep = "<|end_of_turn|>\n"
|
|
# stop_token_ids=[32000, 32001], # "<|end_of_turn|>"
|
|
stop_str = "User"
|
|
system_message = system_template.format(system_message=system_message)
|
|
_messages = _map_roles(messages, dict(zip(roles, roles)))
|
|
_messages.append((roles[1], None))
|
|
_prompt = _format_add_colon_space_single(system_message, _messages, sep)
|
|
return ChatFormatterResponse(prompt=_prompt, stop=stop_str)
|
|
|
|
|
|
@register_chat_format("chatml")
|
|
def format_chatml(
|
|
messages: List[llama_types.ChatCompletionRequestMessage],
|
|
**kwargs: Any,
|
|
) -> ChatFormatterResponse:
|
|
system_template = """<|im_start|>system
|
|
{system_message}"""
|
|
system_message = _get_system_message(messages)
|
|
system_message = system_template.format(system_message=system_message)
|
|
_roles = dict(user="<|im_start|>user", assistant="<|im_start|>assistant")
|
|
_sep = "<|im_end|>"
|
|
_messages = _map_roles(messages, _roles)
|
|
_messages.append((_roles["assistant"], None))
|
|
_prompt = _format_chatml(system_message, _messages, _sep)
|
|
return ChatFormatterResponse(prompt=_prompt)
|
|
|
|
|
|
@register_chat_completion_handler("functionary")
|
|
def functionary_chat_handler(
|
|
llama: llama.Llama,
|
|
messages: List[llama_types.ChatCompletionRequestMessage],
|
|
functions: Optional[List[llama_types.ChatCompletionFunction]] = None,
|
|
function_call: Optional[Union[str, llama_types.ChatCompletionFunctionCall]] = None,
|
|
temperature: float = 0.2,
|
|
top_p: float = 0.95,
|
|
top_k: int = 40,
|
|
stream: bool = False,
|
|
stop: Optional[Union[str, List[str]]] = [],
|
|
max_tokens: int = 256,
|
|
presence_penalty: float = 0.0,
|
|
frequency_penalty: float = 0.0,
|
|
repeat_penalty: float = 1.1,
|
|
tfs_z: float = 1.0,
|
|
mirostat_mode: int = 0,
|
|
mirostat_tau: float = 5.0,
|
|
mirostat_eta: float = 0.1,
|
|
model: Optional[str] = None,
|
|
logits_processor: Optional[llama.LogitsProcessorList] = None,
|
|
grammar: Optional[llama.LlamaGrammar] = None,
|
|
) -> Union[llama_types.ChatCompletion, Iterator[llama_types.ChatCompletionChunk]]:
|
|
SYSTEM_MESSAGE = """A chat between a curious user and an artificial intelligence assistant. The assistant gives helpful, detailed, and polite answers to the user's questions. The assistant calls functions with appropriate input when necessary"""
|
|
|
|
def generate_schema_from_functions(
|
|
functions: List[llama_types.ChatCompletionFunctions],
|
|
namespace: str = "functions",
|
|
):
|
|
"""
|
|
Convert functions schema to a schema that language models can understand.
|
|
"""
|
|
|
|
schema = (
|
|
"// Supported function definitions that should be called when necessary.\n"
|
|
)
|
|
schema += f"namespace {namespace} {{\n\n"
|
|
|
|
for function in functions:
|
|
# Convert a Function object to dict, if necessary
|
|
function_name = function["name"]
|
|
description = function.get("description", "")
|
|
schema += f"// {description}\n"
|
|
schema += f"type {function_name}"
|
|
|
|
parameters = function.get("parameters", None)
|
|
schema += " = (_: {\n"
|
|
required_params = parameters.get("required", [])
|
|
for param_name, param in parameters.get("properties", {}).items():
|
|
# Param Description
|
|
description = param.get("description")
|
|
if description is not None:
|
|
schema += f"// {description}\n"
|
|
|
|
# Param Name
|
|
schema += f"{param_name}"
|
|
if param_name not in required_params:
|
|
schema += "?"
|
|
|
|
# Param Type
|
|
param_type = param.get("type", "any")
|
|
if param_type == "integer":
|
|
param_type = "number"
|
|
if "enum" in param:
|
|
param_type = " | ".join([f'"{v}"' for v in param["enum"]])
|
|
schema += f": {param_type},\n"
|
|
|
|
schema += "}) => any;\n\n"
|
|
|
|
schema += f"}} // namespace {namespace}"
|
|
|
|
return schema
|
|
|
|
def prepare_messages_for_inference(
|
|
messages: List[llama_types.ChatCompletionRequestMessage],
|
|
functions: Optional[List[llama_types.ChatCompletionFunctions]] = None,
|
|
):
|
|
all_messages: List[llama_types.ChatCompletionRequestMessage] = []
|
|
if functions is not None:
|
|
all_messages.append(
|
|
llama_types.ChatCompletionRequestMessage(
|
|
role="system", content=generate_schema_from_functions(functions)
|
|
)
|
|
)
|
|
|
|
all_messages.append(
|
|
llama_types.ChatCompletionRequestMessage(
|
|
role="system", content=SYSTEM_MESSAGE
|
|
)
|
|
)
|
|
|
|
for message in messages:
|
|
# Function call responses
|
|
if message["role"] == "function" and "name" in message:
|
|
message["name"] = f"functions.{message['name']}"
|
|
# Function call requests by assistant
|
|
if "function_call" in message:
|
|
message["function_call"][
|
|
"name"
|
|
] = f"functions.{message['function_call']['name']}"
|
|
all_messages.append(message)
|
|
|
|
all_messages.append(
|
|
llama_types.ChatCompletionRequestMessage(role="assistant", content=None)
|
|
)
|
|
|
|
def message_to_str(msg: llama_types.ChatCompletionRequestMessage):
|
|
if msg["role"] == "system":
|
|
return f"system:\n{msg['content']}\n"
|
|
|
|
elif msg["role"] == "function" and "name" in msg:
|
|
return f"function name={msg['name']}:\n{msg['content']}\n"
|
|
elif msg["role"] == "function" and "function_call" in msg:
|
|
return f"function name={msg['function_call']['name']}:\n{msg['function_call']['arguments']}\n"
|
|
elif msg["role"] == "user":
|
|
if msg["content"] is None:
|
|
return "user:\n</s>"
|
|
else:
|
|
return f"user:\n</s>{msg['content']}\n"
|
|
elif msg["role"] == "assistant":
|
|
if msg["content"] is not None and "function_call" in msg:
|
|
return f"assistant:\n{msg['content']}\nassistant to={msg['function_call']['name']}:\n{msg['function_call']['arguments']}</s>"
|
|
elif "function_call" in msg:
|
|
return f"assistant to={msg['function_call']['name']}:\n{msg['function_call']['arguments']}</s>"
|
|
elif msg["content"] is None:
|
|
return "assistant"
|
|
else:
|
|
return f"assistant:\n{msg['content']}\n"
|
|
else:
|
|
raise ValueError(f"Unsupported role: {msg['role']}")
|
|
|
|
return "".join([message_to_str(msg) for msg in all_messages])
|
|
|
|
prompt = prepare_messages_for_inference(messages, functions)
|
|
|
|
if function_call is None and (functions is None or len(functions) == 0):
|
|
completion_or_completion_chunks = llama.create_completion(
|
|
prompt=prompt + ":\n",
|
|
temperature=temperature,
|
|
top_p=top_p,
|
|
top_k=top_k,
|
|
stream=stream,
|
|
stop=["user:", "</s>"],
|
|
max_tokens=max_tokens,
|
|
presence_penalty=presence_penalty,
|
|
frequency_penalty=frequency_penalty,
|
|
repeat_penalty=repeat_penalty,
|
|
tfs_z=tfs_z,
|
|
mirostat_mode=mirostat_mode,
|
|
mirostat_tau=mirostat_tau,
|
|
mirostat_eta=mirostat_eta,
|
|
model=model,
|
|
logits_processor=logits_processor,
|
|
grammar=grammar,
|
|
)
|
|
return _convert_completion_to_chat(completion_or_completion_chunks, stream=stream) # type: ignore
|
|
|
|
if function_call is None or (
|
|
isinstance(function_call, str) and function_call == "auto"
|
|
):
|
|
stop = "\n"
|
|
completion: llama_types.Completion = llama.create_completion(
|
|
prompt=prompt, stop=stop, stream=False
|
|
) # type: ignore
|
|
completion_text = completion["choices"][0]["text"]
|
|
# strip " to=functions." and ending ":"
|
|
function_call = completion_text[14:-1]
|
|
new_prompt = prompt + completion_text + stop
|
|
elif isinstance(function_call, str) and function_call != "none":
|
|
new_prompt = prompt + f"assistant:\n"
|
|
elif isinstance(function_call, dict):
|
|
new_prompt = prompt + f"assistant to={function_call['name']}:\n"
|
|
function_call = function_call["name"]
|
|
else:
|
|
new_prompt = prompt + f"assistant:\n"
|
|
|
|
completion: llama_types.Completion = llama.create_completion(
|
|
prompt=new_prompt, stop=["user:", "</s>"], stream=False
|
|
) # type: ignore
|
|
|
|
return llama_types.CreateChatCompletionResponse(
|
|
id="chat" + completion["id"],
|
|
object="chat.completion",
|
|
created=completion["created"],
|
|
model=completion["model"],
|
|
choices=[
|
|
{
|
|
"index": 0,
|
|
"message": {
|
|
"role": "function",
|
|
"content": None,
|
|
"function_call": {
|
|
"name": function_call,
|
|
"arguments": completion["choices"][0]["text"],
|
|
},
|
|
},
|
|
"finish_reason": "function_call",
|
|
}
|
|
],
|
|
usage=completion["usage"],
|
|
)
|