2023-11-03 06:12:14 +00:00
from __future__ import annotations
2023-11-06 14:07:27 +00:00
import os
2023-11-10 07:51:58 +00:00
import json
2023-11-08 03:48:51 +00:00
import ctypes
2023-09-29 23:52:04 +00:00
import dataclasses
2023-11-03 06:12:14 +00:00
from typing import Any , Dict , Iterator , List , Optional , Tuple , Union , Protocol
2023-11-06 14:07:27 +00:00
2024-01-19 02:21:37 +00:00
import jinja2
2023-11-08 03:48:51 +00:00
import llama_cpp . llama as llama
2023-11-08 05:07:16 +00:00
import llama_cpp . llama_types as llama_types
import llama_cpp . llama_grammar as llama_grammar
2023-11-03 06:12:14 +00:00
2024-01-19 02:21:37 +00:00
from . _utils import suppress_stdout_stderr , Singleton
2023-11-08 16:05:45 +00:00
2023-11-03 06:12:14 +00:00
class LlamaChatCompletionHandler ( Protocol ) :
2024-01-19 02:21:37 +00:00
""" Base Protocol for a llama chat completion handler.
Very generic protocol that can be used to implement any chat format .
The only hard requirement is that it must return a ChatCompletion when
stream = False and an iterator of ChatCompletionChunks when stream = True . """
2023-11-03 06:12:14 +00:00
def __call__ (
self ,
2023-11-08 03:48:51 +00:00
* ,
2024-01-19 02:21:37 +00:00
# llama.cpp instance
2023-11-03 06:12:14 +00:00
llama : llama . Llama ,
2024-01-19 02:21:37 +00:00
# openai api parameters
2023-11-03 06:12:14 +00:00
messages : List [ llama_types . ChatCompletionRequestMessage ] ,
functions : Optional [ List [ llama_types . ChatCompletionFunction ] ] = None ,
2023-11-08 03:48:51 +00:00
function_call : Optional [ llama_types . ChatCompletionRequestFunctionCall ] = None ,
tools : Optional [ List [ llama_types . ChatCompletionTool ] ] = None ,
tool_choice : Optional [ llama_types . ChatCompletionToolChoiceOption ] = None ,
2023-11-03 06:12:14 +00:00
temperature : float = 0.2 ,
top_p : float = 0.95 ,
top_k : int = 40 ,
stream : bool = False ,
stop : Optional [ Union [ str , List [ str ] ] ] = [ ] ,
2023-11-08 04:41:29 +00:00
seed : Optional [ int ] = None ,
2023-11-08 05:07:16 +00:00
response_format : Optional [
llama_types . ChatCompletionRequestResponseFormat
] = None ,
2023-11-10 07:51:58 +00:00
max_tokens : Optional [ int ] = None ,
2023-11-03 06:12:14 +00:00
presence_penalty : float = 0.0 ,
frequency_penalty : float = 0.0 ,
repeat_penalty : float = 1.1 ,
2024-01-19 02:21:37 +00:00
model : Optional [ str ] = None ,
logit_bias : Optional [ Dict [ str , float ] ] = None ,
# llama.cpp parameters
min_p : float = 0.05 ,
typical_p : float = 1.0 ,
2023-11-03 06:12:14 +00:00
tfs_z : float = 1.0 ,
mirostat_mode : int = 0 ,
mirostat_tau : float = 5.0 ,
mirostat_eta : float = 0.1 ,
logits_processor : Optional [ llama . LogitsProcessorList ] = None ,
grammar : Optional [ llama . LlamaGrammar ] = None ,
2023-11-08 03:48:51 +00:00
* * kwargs , # type: ignore
2023-11-08 05:07:16 +00:00
) - > Union [
llama_types . CreateChatCompletionResponse ,
Iterator [ llama_types . CreateChatCompletionStreamResponse ] ,
] :
2023-11-03 06:12:14 +00:00
. . .
2024-01-19 02:21:37 +00:00
class LlamaChatCompletionHandlerNotFoundException ( Exception ) :
pass
class LlamaChatCompletionHandlerRegistry ( Singleton ) :
_chat_handlers : Dict [ str , LlamaChatCompletionHandler ] = { }
def register_chat_completion_handler (
self ,
name : str ,
chat_handler : LlamaChatCompletionHandler ,
overwrite : bool = False ,
) :
if not overwrite and name in self . _chat_handlers :
raise ValueError (
f " Formatter with name ' { name } ' is already registered. Use `overwrite=True` to overwrite it. "
)
self . _chat_handlers [ name ] = chat_handler
def unregister_chat_handler ( self , name : str ) :
if name in self . _chat_handlers :
del self . _chat_handlers [ name ]
else :
raise ValueError ( f " No formatter registered under the name ' { name } ' . " )
def get_chat_completion_handler_by_name (
self , name : str
) - > LlamaChatCompletionHandler :
try :
chat_handler = self . _chat_handlers [ name ]
return chat_handler
except KeyError :
raise LlamaChatCompletionHandlerNotFoundException (
f " Invalid chat handler: { name } (valid formats: { list ( self . _chat_handlers . keys ( ) ) } ) "
)
2023-11-03 06:12:14 +00:00
def get_chat_completion_handler ( name : str ) - > LlamaChatCompletionHandler :
2024-01-19 02:21:37 +00:00
return LlamaChatCompletionHandlerRegistry ( ) . get_chat_completion_handler_by_name (
name
)
2023-11-03 06:12:14 +00:00
def register_chat_completion_handler ( name : str ) :
def decorator ( f : LlamaChatCompletionHandler ) :
2024-01-19 02:21:37 +00:00
LlamaChatCompletionHandlerRegistry ( ) . register_chat_completion_handler ( name , f )
2023-11-03 06:12:14 +00:00
return f
return decorator
2023-09-29 23:52:04 +00:00
2024-01-19 02:21:37 +00:00
### Chat Formatter ###
@dataclasses.dataclass
class ChatFormatterResponse :
2024-01-19 20:04:42 +00:00
""" Dataclass that stores completion parameters for a given chat format and
create_chat_completion request .
prompt contains the formatted prompt generated from the chat format and messages .
stop contains the stop token or list of stop tokens to use for the chat format . """
2024-01-19 02:21:37 +00:00
prompt : str
stop : Optional [ Union [ str , List [ str ] ] ] = None
class ChatFormatter ( Protocol ) :
""" Base Protocol for a chat formatter. A chat formatter is a function that
2024-01-19 20:04:42 +00:00
takes a list of messages and returns a chat format response which can be used
to generate a completion . The response can also include a stop token or list
of stop tokens to use for the completion . """
2024-01-19 02:21:37 +00:00
def __call__ (
self ,
* ,
messages : List [ llama_types . ChatCompletionRequestMessage ] ,
* * kwargs : Any ,
) - > ChatFormatterResponse :
. . .
2024-01-19 20:04:42 +00:00
class Jinja2ChatFormatter ( ChatFormatter ) :
def __init__ (
self ,
template : str ,
eos_token : str ,
bos_token : str ,
2024-01-21 23:37:24 +00:00
add_generation_prompt : bool = True ,
2024-01-19 20:04:42 +00:00
) :
""" A chat formatter that uses jinja2 templates to format the prompt. """
self . template = template
self . eos_token = eos_token
self . bos_token = bos_token
2024-01-21 23:37:24 +00:00
self . add_generation_prompt = add_generation_prompt
2023-09-29 23:52:04 +00:00
2024-01-19 20:04:42 +00:00
self . _environment = jinja2 . Environment (
loader = jinja2 . BaseLoader ( ) ,
trim_blocks = True ,
lstrip_blocks = True ,
) . from_string ( self . template )
2023-09-29 23:52:04 +00:00
2024-01-19 20:04:42 +00:00
def __call__ (
self ,
* ,
messages : List [ llama_types . ChatCompletionRequestMessage ] ,
* * kwargs : Any ,
) - > ChatFormatterResponse :
2024-01-21 23:37:24 +00:00
if self . add_generation_prompt :
messages = [
* messages ,
llama_types . ChatCompletionRequestAssistantMessage (
role = " assistant " , content = " "
) ,
]
2024-01-19 20:04:42 +00:00
prompt = self . _environment . render (
messages = messages , eos_token = self . eos_token , bos_token = self . bos_token
)
return ChatFormatterResponse ( prompt = prompt , stop = [ self . eos_token ] )
2023-10-01 01:01:34 +00:00
2024-01-19 20:04:42 +00:00
def to_chat_handler ( self ) - > LlamaChatCompletionHandler :
return chat_formatter_to_chat_completion_handler ( self )
2024-01-04 23:12:02 +00:00
2023-10-01 01:01:34 +00:00
2023-11-03 06:12:14 +00:00
def _convert_text_completion_to_chat (
completion : llama_types . Completion ,
) - > llama_types . ChatCompletion :
2024-01-19 02:21:37 +00:00
assert " usage " in completion
2023-11-03 06:12:14 +00:00
return {
" id " : " chat " + completion [ " id " ] ,
" object " : " chat.completion " ,
" created " : completion [ " created " ] ,
" model " : completion [ " model " ] ,
" choices " : [
{
" index " : 0 ,
" message " : {
" role " : " assistant " ,
" content " : completion [ " choices " ] [ 0 ] [ " text " ] ,
} ,
" finish_reason " : completion [ " choices " ] [ 0 ] [ " finish_reason " ] ,
}
] ,
" usage " : completion [ " usage " ] ,
}
def _convert_text_completion_chunks_to_chat (
2023-11-08 03:48:51 +00:00
chunks : Iterator [ llama_types . CreateCompletionStreamResponse ] ,
2023-11-03 06:12:14 +00:00
) - > Iterator [ llama_types . ChatCompletionChunk ] :
for i , chunk in enumerate ( chunks ) :
if i == 0 :
yield {
" id " : " chat " + chunk [ " id " ] ,
" model " : chunk [ " model " ] ,
" created " : chunk [ " created " ] ,
" object " : " chat.completion.chunk " ,
" choices " : [
{
" index " : 0 ,
" delta " : {
" role " : " assistant " ,
} ,
" finish_reason " : None ,
}
] ,
}
yield {
" id " : " chat " + chunk [ " id " ] ,
" model " : chunk [ " model " ] ,
" created " : chunk [ " created " ] ,
" object " : " chat.completion.chunk " ,
" choices " : [
{
" index " : 0 ,
" delta " : {
" content " : chunk [ " choices " ] [ 0 ] [ " text " ] ,
}
if chunk [ " choices " ] [ 0 ] [ " finish_reason " ] is None
else { } ,
" finish_reason " : chunk [ " choices " ] [ 0 ] [ " finish_reason " ] ,
}
] ,
}
def _convert_completion_to_chat (
completion_or_chunks : Union [
2023-11-08 03:48:51 +00:00
llama_types . CreateCompletionResponse ,
Iterator [ llama_types . CreateCompletionStreamResponse ] ,
2023-11-03 06:12:14 +00:00
] ,
stream : bool = False ,
2023-11-08 03:48:51 +00:00
) - > Union [
llama_types . CreateChatCompletionResponse , Iterator [ llama_types . ChatCompletionChunk ]
] :
2023-11-03 06:12:14 +00:00
if stream :
2023-11-08 03:48:51 +00:00
chunks : Iterator [ llama_types . CreateCompletionStreamResponse ] = completion_or_chunks # type: ignore
2023-11-03 06:12:14 +00:00
return _convert_text_completion_chunks_to_chat ( chunks )
else :
completion : llama_types . Completion = completion_or_chunks # type: ignore
return _convert_text_completion_to_chat ( completion )
2024-01-19 02:21:37 +00:00
def chat_formatter_to_chat_completion_handler (
chat_formatter : ChatFormatter ,
) - > LlamaChatCompletionHandler :
def chat_completion_handler (
* ,
llama : llama . Llama ,
messages : List [ llama_types . ChatCompletionRequestMessage ] ,
functions : Optional [ List [ llama_types . ChatCompletionFunction ] ] = None ,
function_call : Optional [ llama_types . ChatCompletionRequestFunctionCall ] = None ,
tools : Optional [ List [ llama_types . ChatCompletionTool ] ] = None ,
tool_choice : Optional [ llama_types . ChatCompletionToolChoiceOption ] = None ,
temperature : float = 0.2 ,
top_p : float = 0.95 ,
top_k : int = 40 ,
min_p : float = 0.05 ,
typical_p : float = 1.0 ,
stream : bool = False ,
stop : Optional [ Union [ str , List [ str ] ] ] = [ ] ,
seed : Optional [ int ] = None ,
response_format : Optional [
llama_types . ChatCompletionRequestResponseFormat
] = None ,
max_tokens : Optional [ int ] = None ,
presence_penalty : float = 0.0 ,
frequency_penalty : float = 0.0 ,
repeat_penalty : float = 1.1 ,
tfs_z : float = 1.0 ,
mirostat_mode : int = 0 ,
mirostat_tau : float = 5.0 ,
mirostat_eta : float = 0.1 ,
model : Optional [ str ] = None ,
logits_processor : Optional [ llama . LogitsProcessorList ] = None ,
grammar : Optional [ llama . LlamaGrammar ] = None ,
logit_bias : Optional [ Dict [ str , float ] ] = None ,
* * kwargs , # type: ignore
) - > Union [
llama_types . CreateChatCompletionResponse ,
Iterator [ llama_types . CreateChatCompletionStreamResponse ] ,
] :
result = chat_formatter (
messages = messages ,
functions = functions ,
function_call = function_call ,
)
prompt = result . prompt
if result . stop is not None :
stop = [ ] if stop is None else [ stop ] if isinstance ( stop , str ) else stop
rstop = result . stop if isinstance ( result . stop , list ) else [ result . stop ]
stop = stop + rstop
if response_format is not None and response_format [ " type " ] == " json_object " :
grammar = llama_grammar . LlamaGrammar . from_string ( llama_grammar . JSON_GBNF )
completion_or_chunks = llama . create_completion (
prompt = prompt ,
temperature = temperature ,
top_p = top_p ,
top_k = top_k ,
min_p = min_p ,
typical_p = typical_p ,
stream = stream ,
stop = stop ,
seed = seed ,
max_tokens = max_tokens ,
presence_penalty = presence_penalty ,
frequency_penalty = frequency_penalty ,
repeat_penalty = repeat_penalty ,
tfs_z = tfs_z ,
mirostat_mode = mirostat_mode ,
mirostat_tau = mirostat_tau ,
mirostat_eta = mirostat_eta ,
model = model ,
logits_processor = logits_processor ,
grammar = grammar ,
logit_bias = logit_bias ,
)
return _convert_completion_to_chat ( completion_or_chunks , stream = stream )
return chat_completion_handler
2023-09-29 23:52:04 +00:00
2023-11-08 03:48:51 +00:00
def hf_autotokenizer_to_chat_formatter (
pretrained_model_name_or_path : Union [ str , os . PathLike [ str ] ]
) - > ChatFormatter :
2023-11-06 14:07:27 +00:00
# https://huggingface.co/docs/transformers/main/chat_templating
# https://huggingface.co/mistralai/Mistral-7B-Instruct-v0.1#instruction-format
# https://huggingface.co/mistralai/Mistral-7B-Instruct-v0.1/blob/main/tokenizer_config.json
2024-01-19 02:21:37 +00:00
from transformers import AutoTokenizer # type: ignore
2023-11-06 14:07:27 +00:00
2024-01-19 02:21:37 +00:00
tokenizer = AutoTokenizer . from_pretrained ( pretrained_model_name_or_path ) # type: ignore
2023-11-06 14:07:27 +00:00
def format_autotokenizer (
messages : List [ llama_types . ChatCompletionRequestMessage ] ,
* * kwargs : Any ,
) - > ChatFormatterResponse :
2024-01-19 02:21:37 +00:00
tokenizer . use_default_system_prompt = False # type: ignore
prompt : str = tokenizer . apply_chat_template ( messages , tokenize = False ) # type: ignore
assert isinstance ( prompt , str )
2023-11-06 14:07:27 +00:00
# Return formatted prompt and eos token by default
2024-01-19 02:21:37 +00:00
return ChatFormatterResponse ( prompt = prompt , stop = tokenizer . eos_token )
2023-11-06 14:07:27 +00:00
return format_autotokenizer
2024-01-19 02:21:37 +00:00
def hf_autotokenizer_to_chat_completion_handler (
pretrained_model_name_or_path : Union [ str , os . PathLike [ str ] ]
) - > LlamaChatCompletionHandler :
chat_formatter = hf_autotokenizer_to_chat_formatter ( pretrained_model_name_or_path )
return chat_formatter_to_chat_completion_handler ( chat_formatter )
2024-01-19 20:04:42 +00:00
def hf_tokenizer_config_to_chat_formatter (
tokenizer_config : Dict [ str , Any ]
) - > ChatFormatter :
2024-01-19 02:21:37 +00:00
assert isinstance ( tokenizer_config , dict )
assert " chat_template " in tokenizer_config
assert isinstance ( tokenizer_config [ " chat_template " ] , str )
chat_template = tokenizer_config [ " chat_template " ]
assert " bos_token " in tokenizer_config
assert isinstance ( tokenizer_config [ " bos_token " ] , str )
bos_token = tokenizer_config [ " bos_token " ]
assert " eos_token " in tokenizer_config
assert isinstance ( tokenizer_config [ " eos_token " ] , str )
eos_token = tokenizer_config [ " eos_token " ]
env = jinja2 . Environment (
loader = jinja2 . BaseLoader ( ) ,
trim_blocks = True ,
lstrip_blocks = True ,
) . from_string ( chat_template )
def format_autotokenizer (
messages : List [ llama_types . ChatCompletionRequestMessage ] ,
* * kwargs : Any ,
) - > ChatFormatterResponse :
# TODO: veryify this is correct
# Add a blank assistant message to the end of the messages to prompt the model to generate a response
prompt = env . render (
messages = [
* messages ,
llama_types . ChatCompletionRequestAssistantMessage (
role = " assistant " , content = " "
) ,
] ,
bos_token = bos_token ,
eos_token = eos_token ,
)
return ChatFormatterResponse ( prompt = prompt , stop = eos_token )
2024-01-19 20:04:42 +00:00
2024-01-19 02:21:37 +00:00
return format_autotokenizer
def hf_tokenizer_config_to_chat_completion_handler (
tokenizer_config : Dict [ str , Any ] ,
) - > LlamaChatCompletionHandler :
chat_formatter = hf_tokenizer_config_to_chat_formatter ( tokenizer_config )
return chat_formatter_to_chat_completion_handler ( chat_formatter )
2024-01-19 20:04:42 +00:00
### Utility functions for formatting chat prompts ###
def _get_system_message (
messages : List [ llama_types . ChatCompletionRequestMessage ] ,
) - > str :
""" Get the first system message. """
for message in messages :
if message [ " role " ] == " system " :
return message [ " content " ] or " "
return " "
def _map_roles (
messages : List [ llama_types . ChatCompletionRequestMessage ] ,
role_map : Dict [ str , str ] ,
) - > List [ Tuple [ str , Optional [ str ] ] ] :
""" Map the message roles. """
output : List [ Tuple [ str , Optional [ str ] ] ] = [ ]
for message in messages :
role = message [ " role " ]
if role in role_map :
content : str | None = (
message [ " content " ] if isinstance ( message [ " content " ] , str ) else None
)
output . append ( ( role_map [ role ] , content ) )
return output
def _format_llama2 (
system_message : str , messages : List [ Tuple [ str , Optional [ str ] ] ] , sep : str , sep2 : str
) - > str :
""" Format the prompt with the llama2 style. """
seps = [ sep , sep2 ]
ret = system_message + sep
for i , ( role , message ) in enumerate ( messages ) :
if system_message and i == 0 :
m = message or " "
ret + = m + seps [ i % 2 ]
elif message :
ret + = role + message + " " + seps [ i % 2 ]
else :
ret + = role + " "
return ret
def _format_add_colon_single (
system_message : str , messages : List [ Tuple [ str , Optional [ str ] ] ] , sep : str
) - > str :
""" Format the prompt with the add-colon-single style. """
ret = system_message + sep
for role , message in messages :
if message :
ret + = role + " : " + message + sep
else :
ret + = role + " : "
return ret
def _format_add_colon_two (
system_message : str , messages : List [ Tuple [ str , Optional [ str ] ] ] , sep : str , sep2 : str
) - > str :
""" Format the prompt with the add-colon-two style. """
seps = [ sep , sep2 ]
ret = system_message + seps [ 0 ]
for i , ( role , message ) in enumerate ( messages ) :
if message :
ret + = role + " : " + message + seps [ i % 2 ]
else :
ret + = role + " : "
return ret
def _format_no_colon_single (
system_message : str , messages : List [ Tuple [ str , Optional [ str ] ] ] , sep : str
) - > str :
""" Format the prompt with the no-colon-single style. """
ret = system_message
for role , message in messages :
if message :
ret + = role + message + sep
else :
ret + = role
return ret
def _format_add_colon_space_single (
system_message : str , messages : List [ Tuple [ str , Optional [ str ] ] ] , sep : str
) - > str :
""" Format the prompt with the add-colon-space-single style. """
ret = system_message + sep
for role , message in messages :
if message :
ret + = role + " : " + message + sep
else :
ret + = role + " : " # must be end with a space
return ret
def _format_chatml (
system_message : str , messages : List [ Tuple [ str , Optional [ str ] ] ] , sep : str
) - > str :
""" Format the prompt with the chatml style. """
ret = " " if system_message == " " else system_message + sep + " \n "
for role , message in messages :
if message :
ret + = role + " \n " + message + sep + " \n "
else :
ret + = role + " \n "
return ret
def _format_chatglm3 (
system_message : str , messages : List [ Tuple [ str , Optional [ str ] ] ] , sep : str
) - > str :
""" Format the prompt with the chatglm3 style. """
ret = " "
if system_message :
ret + = system_message
for role , message in messages :
if message :
ret + = role + " \n " + " " + message
else :
ret + = role
return ret
### Chat Formats ###
def register_chat_format ( name : str ) :
def decorator ( f : ChatFormatter ) :
chat_completion_handler = chat_formatter_to_chat_completion_handler ( f )
LlamaChatCompletionHandlerRegistry ( ) . register_chat_completion_handler (
name , chat_completion_handler
)
return f
return decorator
2023-11-05 22:00:13 +00:00
# see https://github.com/huggingface/transformers/blob/main/src/transformers/models/llama/tokenization_llama.py
# system prompt is "embedded" in the first message
2023-09-29 23:52:04 +00:00
@register_chat_format ( " llama-2 " )
def format_llama2 (
messages : List [ llama_types . ChatCompletionRequestMessage ] ,
* * kwargs : Any ,
) - > ChatFormatterResponse :
2023-11-05 22:00:13 +00:00
_system_template = " <s>[INST] <<SYS>> \n {system_message} \n <</SYS>> "
_roles = dict ( user = " <s>[INST] " , assistant = " [/INST] " )
2023-09-29 23:52:04 +00:00
_messages = _map_roles ( messages , _roles )
2023-11-05 22:00:13 +00:00
system_message = _get_system_message ( messages )
if system_message :
system_message = _system_template . format ( system_message = system_message )
_prompt = _format_llama2 ( system_message , _messages , " " , " </s> " ) + " [/INST] "
2023-09-29 23:52:04 +00:00
return ChatFormatterResponse ( prompt = _prompt )
@register_chat_format ( " alpaca " )
def format_alpaca (
messages : List [ llama_types . ChatCompletionRequestMessage ] ,
* * kwargs : Any ,
) - > ChatFormatterResponse :
_roles = dict ( user = " ### Instruction " , assistant = " ### Response " )
_sep = " \n \n "
_sep2 = " </s> "
system_message = _get_system_message ( messages )
_messages = _map_roles ( messages , _roles )
_prompt = _format_add_colon_two ( system_message , _messages , _sep , _sep2 )
return ChatFormatterResponse ( prompt = _prompt )
2024-01-19 02:21:37 +00:00
2023-12-14 02:43:43 +00:00
@register_chat_format ( " qwen " )
def format_qwen (
messages : List [ llama_types . ChatCompletionRequestMessage ] ,
* * kwargs : Any ,
) - > ChatFormatterResponse :
_roles = dict ( user = " <|im_start|>user " , assistant = " <|im_start|>assistant " )
2024-01-19 02:21:37 +00:00
system_message = " You are a helpful assistant. "
system_template = " <|im_start|>system \n {system_message} "
system_message = system_template . format ( system_message = system_message )
2023-12-14 02:43:43 +00:00
_messages = _map_roles ( messages , _roles )
_messages . append ( ( _roles [ " assistant " ] , None ) )
_sep = " <|im_end|> "
_prompt = _format_chatml ( system_message , _messages , _sep )
_sep2 = " <|endoftext|> "
2024-01-19 02:21:37 +00:00
return ChatFormatterResponse ( prompt = _prompt , stop = _sep2 )
2023-09-29 23:52:04 +00:00
@register_chat_format ( " vicuna " )
def format (
messages : List [ llama_types . ChatCompletionRequestMessage ] ,
* * kwargs : Any ,
) - > ChatFormatterResponse :
_system_message = " A chat between a curious user and an artificial intelligence assistant. The assistant gives helpful, detailed, and polite answers to the user ' s questions. "
_roles = dict ( user = " USER " , assistant = " ASSISTANT " )
_sep = " "
_sep2 = " </s> "
system_message = _system_message
_messages = _map_roles ( messages , _roles )
_messages . append ( ( _roles [ " assistant " ] , None ) )
_prompt = _format_add_colon_two ( system_message , _messages , _sep , _sep2 )
return ChatFormatterResponse ( prompt = _prompt )
@register_chat_format ( " oasst_llama " )
def format_oasst_llama (
messages : List [ llama_types . ChatCompletionRequestMessage ] ,
* * kwargs : Any ,
) - > ChatFormatterResponse :
_system_template = " [INST] <<SYS>> \n {system_message} \n <</SYS>> \n \n "
_roles = dict ( user = " <|prompter|> " , assistant = " <|assistant|> " )
_sep = " </s> "
system_message = _get_system_message ( messages )
system_message = _system_template . format ( system_message = system_message )
_messages = _map_roles ( messages , _roles )
_messages . append ( ( _roles [ " assistant " ] , None ) )
_prompt = _format_no_colon_single ( system_message , _messages , _sep )
return ChatFormatterResponse ( prompt = _prompt )
2023-11-22 11:08:06 +00:00
@register_chat_format ( " baichuan-2 " )
def format_baichuan2 (
messages : List [ llama_types . ChatCompletionRequestMessage ] ,
* * kwargs : Any ,
) - > ChatFormatterResponse :
_system_template = " {system_message} "
_roles = dict ( user = " <reserved_106> " , assistant = " <reserved_107> " )
_sep = " "
system_message = _get_system_message ( messages )
system_message = _system_template . format ( system_message = system_message )
_messages = _map_roles ( messages , _roles )
_messages . append ( ( _roles [ " assistant " ] , None ) )
_prompt = _format_no_colon_single ( system_message , _messages , _sep )
return ChatFormatterResponse ( prompt = _prompt )
2023-11-23 06:19:50 +00:00
@register_chat_format ( " baichuan " )
def format_baichuan (
messages : List [ llama_types . ChatCompletionRequestMessage ] ,
* * kwargs : Any ,
) - > ChatFormatterResponse :
_system_template = " {system_message} "
_roles = dict ( user = " <reserved_102> " , assistant = " <reserved_103> " )
_sep = " "
system_message = _get_system_message ( messages )
system_message = _system_template . format ( system_message = system_message )
_messages = _map_roles ( messages , _roles )
_messages . append ( ( _roles [ " assistant " ] , None ) )
_prompt = _format_no_colon_single ( system_message , _messages , _sep )
return ChatFormatterResponse ( prompt = _prompt )
2023-09-29 23:52:04 +00:00
@register_chat_format ( " openbuddy " )
def format_openbuddy (
messages : List [ llama_types . ChatCompletionRequestMessage ] ,
* * kwargs : Any ,
) - > ChatFormatterResponse :
_system_message = """ Consider a conversation between User (a human) and Assistant (named Buddy).
Buddy is an INTP - T , a friendly , intelligent and multilingual AI assistant , by OpenBuddy team . GitHub : https : / / github . com / OpenBuddy / OpenBuddy
Buddy cannot access the Internet .
Buddy can fluently speak the user ' s language (e.g. English, Chinese).
Buddy can generate poems , stories , code , essays , songs , parodies , and more .
Buddy possesses vast knowledge about the world , history , and culture .
Buddy ' s responses are always safe, creative, high-quality, human-like, and interesting.
Buddy strictly refuses to discuss political , NSFW , or other unsafe topics .
User : Hi .
Assistant : Hi , I ' m Buddy, your AI assistant. How can I help you today? " " "
_roles = dict ( user = " User " , assistant = " Assistant " )
_sep = " \n "
system_message = _system_message
_messages = _map_roles ( messages , _roles )
_messages . append ( ( _roles [ " assistant " ] , None ) )
_prompt = _format_add_colon_single ( system_message , _messages , _sep )
return ChatFormatterResponse ( prompt = _prompt )
@register_chat_format ( " redpajama-incite " )
def format_redpajama_incite (
messages : List [ llama_types . ChatCompletionRequestMessage ] ,
* * kwargs : Any ,
) - > ChatFormatterResponse :
_system_message = _get_system_message ( messages )
_roles = dict ( user = " <human> " , assistant = " <bot> " )
_sep = " \n "
_stop = " <human> "
system_message = _system_message
_messages = _map_roles ( messages , _roles )
_messages . append ( ( _roles [ " assistant " ] , None ) )
_prompt = _format_add_colon_single ( system_message , _messages , _sep )
return ChatFormatterResponse ( prompt = _prompt , stop = _stop )
@register_chat_format ( " snoozy " )
def format_snoozy (
messages : List [ llama_types . ChatCompletionRequestMessage ] ,
* * kwargs : Any ,
) - > ChatFormatterResponse :
system_template = " ### Instruction: \n {system_message} "
default_system_message = " The prompt below is a question to answer, a task to complete, or a conversation to respond to; decide which and write an appropriate response. "
_system_message = _get_system_message ( messages )
_system_message = (
_system_message if _system_message != " " else default_system_message
)
system_message = system_template . format ( system_message = _system_message )
_roles = dict ( user = " ### Prompt " , assistant = " ### Response " )
_sep = " \n "
_stop = " ### "
system_message = _system_message
_messages = _map_roles ( messages , _roles )
_messages . append ( ( _roles [ " assistant " ] , None ) )
_prompt = _format_add_colon_single ( system_message , _messages , _sep )
return ChatFormatterResponse ( prompt = _prompt , stop = _stop )
@register_chat_format ( " phind " )
def format_phind (
messages : List [ llama_types . ChatCompletionRequestMessage ] ,
* * kwargs : Any ,
) - > ChatFormatterResponse :
_roles = dict ( user = " ### User Message " , assistant = " ### Assistant " )
_sep = " \n \n "
_system_message = " ### System Prompt \n You are an intelligent programming assistant. "
_messages = _map_roles ( messages , _roles )
_messages . append ( ( _roles [ " assistant " ] , None ) )
_prompt = _format_add_colon_single ( _system_message , _messages , _sep )
return ChatFormatterResponse ( prompt = _prompt )
2023-11-21 09:02:20 +00:00
2023-11-21 05:19:25 +00:00
@register_chat_format ( " intel " )
def format_intel (
messages : List [ llama_types . ChatCompletionRequestMessage ] ,
* * kwargs : Any ,
) - > ChatFormatterResponse :
_roles = dict ( user = " ### User: " , assistant = " ### Assistant: " )
_sep = " \n "
_system_message = " ### System: \n {system_message} "
_messages = _map_roles ( messages , _roles )
_messages . append ( ( _roles [ " assistant " ] , None ) )
_prompt = _format_add_colon_single ( _system_message , _messages , _sep )
return ChatFormatterResponse ( prompt = _prompt )
2023-09-29 23:52:04 +00:00
@register_chat_format ( " open-orca " )
def format_open_orca (
messages : List [ llama_types . ChatCompletionRequestMessage ] ,
* * kwargs : Any ,
) - > ChatFormatterResponse :
system_template = " {system_message} "
system_message = (
" You are a helpful assistant. Please answer truthfully and write out your "
2023-11-26 20:39:18 +00:00
" thinking step by step to be sure you get the right answer. If you make a mistake or encounter "
" an error in your thinking, say so out loud and attempt to correct it. If you don ' t know or "
" aren ' t sure about something, say so clearly. You will act as a professional logician, mathematician, "
" and physicist. You will also act as the most appropriate type of expert to answer any particular "
" question or solve the relevant problem; state which expert type your are, if so. Also think of "
" any particular named expert that would be ideal to answer the relevant question or solve the "
" relevant problem; name and act as them, if appropriate. "
2023-09-29 23:52:04 +00:00
)
roles = ( " User " , " Assistant " )
sep = " <|end_of_turn|> \n "
# stop_token_ids=[32000, 32001], # "<|end_of_turn|>"
stop_str = " User "
system_message = system_template . format ( system_message = system_message )
_messages = _map_roles ( messages , dict ( zip ( roles , roles ) ) )
_messages . append ( ( roles [ 1 ] , None ) )
_prompt = _format_add_colon_space_single ( system_message , _messages , sep )
return ChatFormatterResponse ( prompt = _prompt , stop = stop_str )
2023-10-01 01:01:34 +00:00
2023-11-21 05:19:25 +00:00
@register_chat_format ( " mistrallite " )
def format_mistrallite (
messages : List [ llama_types . ChatCompletionRequestMessage ] ,
* * kwargs : Any ,
) - > ChatFormatterResponse :
_roles = dict ( user = " <|prompter|> " , assistant = " </s> \n <|assistant|> " )
_sep = " "
system_template = """ <|system|> {system_message} </s> """
system_message = _get_system_message ( messages )
system_message = system_template . format ( system_message = system_message )
_messages = _map_roles ( messages , _roles )
_messages . append ( ( _roles [ " assistant " ] , None ) )
_prompt = _format_no_colon_single ( system_message , _messages , _sep )
return ChatFormatterResponse ( prompt = _prompt )
2024-01-19 02:21:37 +00:00
2023-11-23 06:20:08 +00:00
@register_chat_format ( " zephyr " )
def format_zephyr (
messages : List [ llama_types . ChatCompletionRequestMessage ] ,
* * kwargs : Any ,
) - > ChatFormatterResponse :
system_template = """ <|system|>
{ system_message } """
system_message = _get_system_message ( messages )
system_message = system_template . format ( system_message = system_message )
_roles = dict ( user = " <|user|> \n " , assistant = " <|assistant|> \n " )
_sep = " </s> "
_messages = _map_roles ( messages , _roles )
_messages . append ( ( _roles [ " assistant " ] , None ) )
_prompt = _format_chatml ( system_message , _messages , _sep )
return ChatFormatterResponse ( prompt = _prompt , stop = _sep )
2023-11-21 09:02:20 +00:00
2023-12-12 01:44:04 +00:00
@register_chat_format ( " pygmalion " )
def format_pygmalion (
messages : List [ llama_types . ChatCompletionRequestMessage ] ,
* * kwargs : Any ,
) - > ChatFormatterResponse :
system_template = """ <|system|> {system_message} """
system_message = _get_system_message ( messages )
system_message = system_template . format ( system_message = system_message )
_roles = dict ( user = " <|user|> " , assistant = " <|model|> " )
_sep = " \n "
_messages = _map_roles ( messages , _roles )
_messages . append ( ( _roles [ " assistant " ] , None ) )
_prompt = _format_chatml ( system_message , _messages , _sep )
return ChatFormatterResponse ( prompt = _prompt , stop = _sep )
2023-10-01 01:01:34 +00:00
@register_chat_format ( " chatml " )
def format_chatml (
messages : List [ llama_types . ChatCompletionRequestMessage ] ,
* * kwargs : Any ,
) - > ChatFormatterResponse :
system_template = """ <|im_start|>system
{ system_message } """
system_message = _get_system_message ( messages )
system_message = system_template . format ( system_message = system_message )
_roles = dict ( user = " <|im_start|>user " , assistant = " <|im_start|>assistant " )
_sep = " <|im_end|> "
_messages = _map_roles ( messages , _roles )
_messages . append ( ( _roles [ " assistant " ] , None ) )
_prompt = _format_chatml ( system_message , _messages , _sep )
2023-11-10 09:24:48 +00:00
return ChatFormatterResponse ( prompt = _prompt , stop = _sep )
2023-11-03 06:12:14 +00:00
2024-01-19 02:21:37 +00:00
2024-01-04 23:12:02 +00:00
@register_chat_format ( " chatglm3 " )
def format_chatglm3 (
messages : List [ llama_types . ChatCompletionRequestMessage ] ,
* * kwargs : Any ,
) - > ChatFormatterResponse :
system_template = """ <|system|>
{ system_message } """
system_message = _get_system_message ( messages )
system_message = system_template . format ( system_message = system_message )
_roles = dict ( user = " <|user|> " , assistant = " <|assistant|> " )
_sep = " </s> "
_messages = _map_roles ( messages , _roles )
_messages . append ( ( _roles [ " assistant " ] , None ) )
_prompt = _format_chatglm3 ( system_message , _messages , _sep )
return ChatFormatterResponse ( prompt = _prompt , stop = _sep )
2023-11-21 09:02:20 +00:00
2023-11-21 05:19:25 +00:00
@register_chat_format ( " openchat " )
def format_openchat (
messages : List [ llama_types . ChatCompletionRequestMessage ] ,
* * kwargs : Any ,
) - > ChatFormatterResponse :
system_template = " {system_message} <|end_of_turn|> "
system_message = _get_system_message ( messages )
system_message = system_template . format ( system_message = system_message )
2023-11-21 09:02:20 +00:00
_roles = dict (
user = " GPT4 Correct User: " , assistant = " <|end_of_turn|>GPT4 Correct Assistant: "
)
2023-11-21 05:19:25 +00:00
_sep = " <|end_of_turn|> "
_messages = _map_roles ( messages , _roles )
_messages . append ( ( _roles [ " assistant " ] , None ) )
_prompt = _format_chatml ( system_message , _messages , _sep )
return ChatFormatterResponse ( prompt = _prompt , stop = _sep )
2023-11-03 06:12:14 +00:00
2024-01-04 23:12:58 +00:00
# Chat format for Saiga models, see more details and available models:
# https://huggingface.co/collections/IlyaGusev/saiga2-saigamistral-6505d4ccc3d1e53166b636cd
@register_chat_format ( " saiga " )
def format_saiga (
messages : list [ llama_types . ChatCompletionRequestMessage ] ,
2024-01-19 02:21:37 +00:00
* * kwargs : Any ,
2024-01-04 23:12:58 +00:00
) - > ChatFormatterResponse :
_message_template = " <s> {role} \n {content} </s> "
_roles = dict ( user = " user " , bot = " bot " , system = " system " )
_messages = _map_roles ( messages , _roles )
_prompt = " "
for role , content in _messages :
if content :
_prompt + = _message_template . format ( role = role , content = content )
else :
_prompt + = f " <s> { role } \n "
# Response template
_prompt + = " <s>bot "
return ChatFormatterResponse ( prompt = _prompt . strip ( ) )
2023-11-03 06:12:14 +00:00
@register_chat_completion_handler ( " functionary " )
def functionary_chat_handler (
llama : llama . Llama ,
messages : List [ llama_types . ChatCompletionRequestMessage ] ,
functions : Optional [ List [ llama_types . ChatCompletionFunction ] ] = None ,
2023-11-08 03:48:51 +00:00
function_call : Optional [ llama_types . ChatCompletionRequestFunctionCall ] = None ,
2023-11-10 07:51:58 +00:00
tools : Optional [ List [ llama_types . ChatCompletionTool ] ] = None ,
tool_choice : Optional [ llama_types . ChatCompletionToolChoiceOption ] = None ,
2023-11-03 06:12:14 +00:00
temperature : float = 0.2 ,
top_p : float = 0.95 ,
top_k : int = 40 ,
2023-11-21 04:21:33 +00:00
min_p : float = 0.05 ,
typical_p : float = 1.0 ,
2023-11-03 06:12:14 +00:00
stream : bool = False ,
stop : Optional [ Union [ str , List [ str ] ] ] = [ ] ,
2023-11-09 05:55:23 +00:00
response_format : Optional [ llama_types . ChatCompletionRequestResponseFormat ] = None ,
2023-11-10 07:51:58 +00:00
max_tokens : Optional [ int ] = None ,
2023-11-03 06:12:14 +00:00
presence_penalty : float = 0.0 ,
frequency_penalty : float = 0.0 ,
repeat_penalty : float = 1.1 ,
tfs_z : float = 1.0 ,
mirostat_mode : int = 0 ,
mirostat_tau : float = 5.0 ,
mirostat_eta : float = 0.1 ,
model : Optional [ str ] = None ,
logits_processor : Optional [ llama . LogitsProcessorList ] = None ,
grammar : Optional [ llama . LlamaGrammar ] = None ,
2023-11-08 03:48:51 +00:00
* * kwargs , # type: ignore
2023-11-03 06:12:14 +00:00
) - > Union [ llama_types . ChatCompletion , Iterator [ llama_types . ChatCompletionChunk ] ] :
SYSTEM_MESSAGE = """ A chat between a curious user and an artificial intelligence assistant. The assistant gives helpful, detailed, and polite answers to the user ' s questions. The assistant calls functions with appropriate input when necessary """
2023-11-21 09:02:20 +00:00
def generate_type_definition (
param : Dict [ str , llama_types . JsonType ] , indent_level : int , shared_defs
) - > str :
indent = " " * indent_level
if " $ref " in param :
2023-11-10 07:51:58 +00:00
# Reference to a shared definition
2023-11-21 09:02:20 +00:00
ref_name = param [ " $ref " ] . split ( " / " ) [
- 1
] # Extract the type name from the reference
2023-11-10 07:51:58 +00:00
return ref_name
2023-11-21 09:02:20 +00:00
elif param . get ( " type " ) == " array " :
items = param . get ( " items " , { } )
2023-11-10 07:51:58 +00:00
item_type = generate_type_definition ( items , indent_level + 1 , shared_defs )
return f " Array< { item_type } > "
2023-11-21 09:02:20 +00:00
elif param . get ( " type " ) == " object " :
properties = param . get ( " properties " , { } )
2023-11-10 07:51:58 +00:00
nested_schema = " { \n "
for nested_param_name , nested_param in properties . items ( ) :
2023-11-21 09:02:20 +00:00
nested_param_type = generate_type_definition (
nested_param , indent_level + 1 , shared_defs
)
nested_schema + = (
f " { indent } { nested_param_name } : { nested_param_type } , \n "
)
2023-11-10 07:51:58 +00:00
nested_schema + = indent + " } "
return nested_schema
2023-11-21 09:02:20 +00:00
elif " enum " in param :
2023-11-10 07:51:58 +00:00
# Enum type
2023-11-21 09:02:20 +00:00
return " | " . join ( [ f ' " { enum_value } " ' for enum_value in param [ " enum " ] ] )
2023-11-10 07:51:58 +00:00
else :
# Simple type
2023-11-21 09:02:20 +00:00
return param . get ( " type " , " any " )
2023-11-10 07:51:58 +00:00
def generate_shared_definitions ( shared_defs , indent_level : int ) - > str :
2023-11-21 09:02:20 +00:00
indent = " " * indent_level
2023-11-10 07:51:58 +00:00
shared_definitions = " "
for def_name , def_properties in shared_defs . items ( ) :
shared_definitions + = f " { indent } type { def_name } = "
2023-11-21 09:02:20 +00:00
if def_properties . get ( " type " ) == " object " :
shared_definitions + = generate_type_definition (
def_properties , indent_level , shared_defs
)
elif " enum " in def_properties :
2023-11-10 07:51:58 +00:00
# Enum type
2023-11-21 09:02:20 +00:00
shared_definitions + = " | " . join (
[ f ' " { enum_value } " ' for enum_value in def_properties [ " enum " ] ]
)
2023-11-10 07:51:58 +00:00
shared_definitions + = " ; \n "
return shared_definitions
def generate_schema_from_functions ( functions , namespace = " functions " ) - > str :
2023-11-21 09:02:20 +00:00
schema = (
" // Supported function definitions that should be called when necessary. \n "
)
2023-11-03 06:12:14 +00:00
schema + = f " namespace { namespace } {{ \n \n "
2023-11-10 07:51:58 +00:00
# Generate shared definitions
shared_definitions = { }
for function in functions :
parameters = function . get ( " parameters " , { } )
shared_definitions . update ( parameters . get ( " $defs " , { } ) )
schema + = generate_shared_definitions ( shared_definitions , 1 )
2023-11-03 06:12:14 +00:00
for function in functions :
function_name = function [ " name " ]
description = function . get ( " description " , " " )
2023-11-10 07:51:58 +00:00
parameters = function . get ( " parameters " , { } )
2023-11-03 06:12:14 +00:00
required_params = parameters . get ( " required " , [ ] )
2023-11-21 09:02:20 +00:00
2023-11-10 07:51:58 +00:00
schema + = f " // { description } \n "
schema + = f " type { function_name } = (_: {{ \n "
2023-11-21 09:02:20 +00:00
2023-11-03 06:12:14 +00:00
for param_name , param in parameters . get ( " properties " , { } ) . items ( ) :
2023-11-10 07:51:58 +00:00
param_description = param . get ( " description " , " " )
param_type = generate_type_definition ( param , 2 , shared_definitions )
optional_indicator = " " if param_name in required_params else " ? "
schema + = f " // { param_description } \n "
schema + = f " { param_name } { optional_indicator } : { param_type } , \n "
schema + = " }) => any; \n \n "
schema + = " }} // namespace {} \n " . format ( namespace )
2023-11-03 06:12:14 +00:00
return schema
def prepare_messages_for_inference (
messages : List [ llama_types . ChatCompletionRequestMessage ] ,
functions : Optional [ List [ llama_types . ChatCompletionFunctions ] ] = None ,
2023-11-10 07:51:58 +00:00
tools : Optional [ List [ llama_types . ChatCompletionTool ] ] = None ,
2023-11-03 06:12:14 +00:00
) :
all_messages : List [ llama_types . ChatCompletionRequestMessage ] = [ ]
if functions is not None :
all_messages . append (
2023-11-08 03:48:51 +00:00
llama_types . ChatCompletionRequestSystemMessage (
2023-11-03 06:12:14 +00:00
role = " system " , content = generate_schema_from_functions ( functions )
)
)
2023-11-21 09:02:20 +00:00
2023-11-10 07:51:58 +00:00
if tools is not None :
all_messages . append (
llama_types . ChatCompletionRequestSystemMessage (
2023-11-21 09:02:20 +00:00
role = " system " ,
content = generate_schema_from_functions (
[
tool [ " function " ]
for tool in tools
if tool [ " type " ] == " function "
]
) ,
2023-11-10 07:51:58 +00:00
)
)
2023-11-03 06:12:14 +00:00
all_messages . append (
2023-11-08 03:48:51 +00:00
llama_types . ChatCompletionRequestSystemMessage (
2023-11-03 06:12:14 +00:00
role = " system " , content = SYSTEM_MESSAGE
)
)
for message in messages :
# Function call responses
if message [ " role " ] == " function " and " name " in message :
message [ " name " ] = f " functions. { message [ ' name ' ] } "
# Function call requests by assistant
if " function_call " in message :
message [ " function_call " ] [
" name "
] = f " functions. { message [ ' function_call ' ] [ ' name ' ] } "
all_messages . append ( message )
all_messages . append (
2023-11-08 03:48:51 +00:00
llama_types . ChatCompletionRequestAssistantMessage (
role = " assistant " , content = None
)
2023-11-03 06:12:14 +00:00
)
def message_to_str ( msg : llama_types . ChatCompletionRequestMessage ) :
if msg [ " role " ] == " system " :
return f " system: \n { msg [ ' content ' ] } \n "
elif msg [ " role " ] == " function " and " name " in msg :
return f " function name= { msg [ ' name ' ] } : \n { msg [ ' content ' ] } \n "
elif msg [ " role " ] == " function " and " function_call " in msg :
return f " function name= { msg [ ' function_call ' ] [ ' name ' ] } : \n { msg [ ' function_call ' ] [ ' arguments ' ] } \n "
2023-11-10 07:51:58 +00:00
elif msg [ " role " ] == " tool " :
if msg [ " content " ] is not None :
return f " function name= { msg [ ' tool_call_id ' ] } : \n { msg [ ' content ' ] } \n "
else :
return f " function name= { msg [ ' tool_call_id ' ] } \n "
2023-11-03 06:12:14 +00:00
elif msg [ " role " ] == " user " :
if msg [ " content " ] is None :
2023-11-10 07:51:58 +00:00
return " user: \n </s></s> \n "
2023-11-03 06:12:14 +00:00
else :
2023-11-10 07:51:58 +00:00
return f " user: \n </s> { msg [ ' content ' ] } </s> \n "
2023-11-03 06:12:14 +00:00
elif msg [ " role " ] == " assistant " :
if msg [ " content " ] is not None and " function_call " in msg :
2023-11-10 07:51:58 +00:00
return f " assistant: \n { msg [ ' content ' ] } \n assistant to= { msg [ ' function_call ' ] [ ' name ' ] } : \n { msg [ ' function_call ' ] [ ' arguments ' ] } </s> \n "
2023-11-03 06:12:14 +00:00
elif " function_call " in msg :
2023-11-10 07:51:58 +00:00
return f " assistant to= { msg [ ' function_call ' ] [ ' name ' ] } : \n { msg [ ' function_call ' ] [ ' arguments ' ] } </s> \n "
elif " tool_calls " in msg and len ( msg [ " tool_calls " ] ) > 0 :
2023-11-21 09:02:20 +00:00
for tool_call in msg [
" tool_calls "
] : # NOTE: probably doesn't work with the functionary model
2023-11-10 07:51:58 +00:00
return f " assistant to= { tool_call [ ' id ' ] } : \n { tool_call [ ' function ' ] [ ' arguments ' ] } </s> \n "
2023-11-03 06:12:14 +00:00
elif msg [ " content " ] is None :
return " assistant "
else :
return f " assistant: \n { msg [ ' content ' ] } \n "
else :
raise ValueError ( f " Unsupported role: { msg [ ' role ' ] } " )
return " " . join ( [ message_to_str ( msg ) for msg in all_messages ] )
2023-11-21 09:02:20 +00:00
2023-11-10 07:51:58 +00:00
if tools is not None :
functions = [ tool [ " function " ] for tool in tools if tool [ " type " ] == " function " ]
2023-11-21 09:02:20 +00:00
2023-11-10 07:51:58 +00:00
if tool_choice is not None :
2023-11-21 09:02:20 +00:00
function_call = (
tool_choice if isinstance ( tool_choice , str ) else tool_choice [ " function " ]
)
2023-11-03 06:12:14 +00:00
2023-11-10 07:51:58 +00:00
prompt = prepare_messages_for_inference ( messages , functions , tools )
2023-11-03 06:12:14 +00:00
if function_call is None and ( functions is None or len ( functions ) == 0 ) :
completion_or_completion_chunks = llama . create_completion (
prompt = prompt + " : \n " ,
temperature = temperature ,
top_p = top_p ,
top_k = top_k ,
2023-11-21 04:21:33 +00:00
min_p = min_p ,
typical_p = typical_p ,
2023-11-03 06:12:14 +00:00
stream = stream ,
stop = [ " user: " , " </s> " ] ,
max_tokens = max_tokens ,
presence_penalty = presence_penalty ,
frequency_penalty = frequency_penalty ,
repeat_penalty = repeat_penalty ,
tfs_z = tfs_z ,
mirostat_mode = mirostat_mode ,
mirostat_tau = mirostat_tau ,
mirostat_eta = mirostat_eta ,
model = model ,
logits_processor = logits_processor ,
grammar = grammar ,
)
return _convert_completion_to_chat ( completion_or_completion_chunks , stream = stream ) # type: ignore
if function_call is None or (
isinstance ( function_call , str ) and function_call == " auto "
) :
stop = " \n "
completion : llama_types . Completion = llama . create_completion (
prompt = prompt , stop = stop , stream = False
) # type: ignore
completion_text = completion [ " choices " ] [ 0 ] [ " text " ]
# strip " to=functions." and ending ":"
2023-11-10 07:51:58 +00:00
function_call = completion_text . split ( " . " ) [ - 1 ] [ : - 1 ]
2023-11-03 06:12:14 +00:00
new_prompt = prompt + completion_text + stop
elif isinstance ( function_call , str ) and function_call != " none " :
2023-11-10 07:51:58 +00:00
new_prompt = prompt + f " : \n "
2023-11-03 06:12:14 +00:00
elif isinstance ( function_call , dict ) :
2023-11-10 07:51:58 +00:00
new_prompt = prompt + f " to=functions. { function_call [ ' name ' ] } : \n "
2023-11-03 06:12:14 +00:00
function_call = function_call [ " name " ]
else :
2023-11-10 07:51:58 +00:00
new_prompt = prompt + f " : \n "
function_body = None
for function in functions or [ ] :
if function [ " name " ] == function_call :
function_body = function [ " parameters " ]
break
for tool in tools or [ ] :
if tool [ " type " ] == " function " and tool [ " function " ] [ " name " ] == function_call :
function_body = tool [ " function " ] [ " parameters " ]
break
2023-11-21 09:02:20 +00:00
2023-11-10 07:51:58 +00:00
if function_body is not None :
try :
with suppress_stdout_stderr ( disable = llama . verbose ) :
2023-11-21 09:02:20 +00:00
grammar_text = llama_grammar . json_schema_to_gbnf (
json . dumps ( function_body )
)
grammar = llama_grammar . LlamaGrammar . from_string (
llama_grammar . json_schema_to_gbnf ( json . dumps ( function_body ) )
)
2023-11-10 07:51:58 +00:00
print ( grammar_text )
except Exception as e :
if llama . verbose :
2023-11-21 09:02:20 +00:00
print (
" Failed to parse function body as JSON schema, falling back to default grammar "
)
2023-11-10 07:51:58 +00:00
print ( e )
with suppress_stdout_stderr ( disable = llama . verbose ) :
2023-11-21 09:02:20 +00:00
grammar = llama_grammar . LlamaGrammar . from_string (
llama_grammar . JSON_GBNF
)
2023-11-10 07:51:58 +00:00
else :
with suppress_stdout_stderr ( disable = llama . verbose ) :
grammar = llama_grammar . LlamaGrammar . from_string ( llama_grammar . JSON_GBNF )
2023-11-03 06:12:14 +00:00
completion : llama_types . Completion = llama . create_completion (
2023-11-10 07:51:58 +00:00
prompt = new_prompt ,
stop = [ " user: " , " </s> " ] ,
stream = False ,
grammar = grammar ,
max_tokens = max_tokens ,
temperature = temperature ,
top_p = top_p ,
top_k = top_k ,
2023-11-21 04:21:33 +00:00
min_p = min_p ,
typical_p = typical_p ,
2023-11-10 07:51:58 +00:00
presence_penalty = presence_penalty ,
frequency_penalty = frequency_penalty ,
repeat_penalty = repeat_penalty ,
tfs_z = tfs_z ,
mirostat_mode = mirostat_mode ,
mirostat_tau = mirostat_tau ,
mirostat_eta = mirostat_eta ,
model = model ,
logits_processor = logits_processor ,
2023-11-03 06:12:14 +00:00
) # type: ignore
2023-11-08 03:48:51 +00:00
assert " usage " in completion
assert isinstance ( function_call , str )
2023-11-08 05:07:16 +00:00
assert stream is False # TODO: support stream mode
2023-11-08 03:48:51 +00:00
2023-11-24 01:14:23 +00:00
if llama . verbose :
print ( new_prompt )
print ( completion [ " choices " ] [ 0 ] [ " text " ] )
2023-11-09 05:55:23 +00:00
2023-11-24 01:14:23 +00:00
# TODO: support stream mode
2023-11-03 06:12:14 +00:00
return llama_types . CreateChatCompletionResponse (
id = " chat " + completion [ " id " ] ,
object = " chat.completion " ,
created = completion [ " created " ] ,
model = completion [ " model " ] ,
choices = [
{
" index " : 0 ,
" message " : {
2023-11-10 07:51:58 +00:00
" role " : " assistant " ,
2023-11-03 06:12:14 +00:00
" content " : None ,
" function_call " : {
" name " : function_call ,
" arguments " : completion [ " choices " ] [ 0 ] [ " text " ] ,
} ,
2023-11-10 07:51:58 +00:00
" tool_calls " : [
{
" id " : function_call ,
" type " : " function " ,
" function " : {
" name " : function_call ,
" arguments " : completion [ " choices " ] [ 0 ] [ " text " ] ,
2023-11-21 09:02:20 +00:00
} ,
2023-11-10 07:51:58 +00:00
}
2023-11-21 09:02:20 +00:00
] ,
2023-11-03 06:12:14 +00:00
} ,
2023-11-10 07:51:58 +00:00
" finish_reason " : " tool_calls " ,
2023-11-03 06:12:14 +00:00
}
] ,
usage = completion [ " usage " ] ,
)
2023-11-08 03:48:51 +00:00
class Llava15ChatHandler :
2023-11-08 16:05:45 +00:00
_clip_free = None
def __init__ ( self , clip_model_path : str , verbose : bool = False ) :
2023-11-08 03:48:51 +00:00
import llama_cpp . llava_cpp as llava_cpp
self . _llava_cpp = llava_cpp
self . clip_model_path = clip_model_path
2023-11-08 16:05:45 +00:00
self . verbose = verbose
2023-11-09 05:55:23 +00:00
self . _clip_free = self . _llava_cpp . _libllava . clip_free # type: ignore
2023-11-08 03:48:51 +00:00
2023-11-08 16:05:45 +00:00
with suppress_stdout_stderr ( disable = self . verbose ) :
self . clip_ctx = self . _llava_cpp . clip_model_load (
2023-11-09 05:55:23 +00:00
self . clip_model_path . encode ( ) , 0
2023-11-08 16:05:45 +00:00
)
2023-11-08 03:48:51 +00:00
def __del__ ( self ) :
2023-11-08 16:05:45 +00:00
with suppress_stdout_stderr ( disable = self . verbose ) :
if self . clip_ctx is not None and self . _clip_free is not None :
self . _clip_free ( self . clip_ctx )
self . clip_ctx = None
2023-11-08 03:48:51 +00:00
def load_image ( self , image_url : str ) - > bytes :
if image_url . startswith ( " data: " ) :
import base64
image_bytes = base64 . b64decode ( image_url . split ( " , " ) [ 1 ] )
return image_bytes
else :
import urllib . request
with urllib . request . urlopen ( image_url ) as f :
image_bytes = f . read ( )
return image_bytes
def __call__ (
self ,
* ,
llama : llama . Llama ,
messages : List [ llama_types . ChatCompletionRequestMessage ] ,
functions : Optional [ List [ llama_types . ChatCompletionFunction ] ] = None ,
function_call : Optional [ llama_types . ChatCompletionRequestFunctionCall ] = None ,
tools : Optional [ List [ llama_types . ChatCompletionTool ] ] = None ,
tool_choice : Optional [ llama_types . ChatCompletionToolChoiceOption ] = None ,
temperature : float = 0.2 ,
top_p : float = 0.95 ,
top_k : int = 40 ,
2023-11-21 04:21:33 +00:00
min_p : float = 0.05 ,
typical_p : float = 1.0 ,
2023-11-08 03:48:51 +00:00
stream : bool = False ,
stop : Optional [ Union [ str , List [ str ] ] ] = [ ] ,
2023-11-09 05:55:23 +00:00
response_format : Optional [
llama_types . ChatCompletionRequestResponseFormat
] = None ,
2023-11-10 07:51:58 +00:00
max_tokens : Optional [ int ] = None ,
2023-11-08 03:48:51 +00:00
presence_penalty : float = 0.0 ,
frequency_penalty : float = 0.0 ,
repeat_penalty : float = 1.1 ,
tfs_z : float = 1.0 ,
mirostat_mode : int = 0 ,
mirostat_tau : float = 5.0 ,
mirostat_eta : float = 0.1 ,
model : Optional [ str ] = None ,
logits_processor : Optional [ llama . LogitsProcessorList ] = None ,
grammar : Optional [ llama . LlamaGrammar ] = None ,
* * kwargs , # type: ignore
2023-11-08 05:07:16 +00:00
) - > Union [
llama_types . CreateChatCompletionResponse ,
Iterator [ llama_types . CreateChatCompletionStreamResponse ] ,
] :
assert (
llama . context_params . logits_all is True
) # BUG: logits_all=True is required for llava
2023-11-08 03:48:51 +00:00
assert self . clip_ctx is not None
system_prompt = _get_system_message ( messages )
2023-11-08 05:07:16 +00:00
system_prompt = (
system_prompt
if system_prompt != " "
else " A chat between a curious human and an artificial intelligence assistant. The assistant gives helpful, detailed, and polite answers to the human ' s questions. "
)
2023-11-08 03:48:51 +00:00
user_role = " \n USER: "
assistant_role = " \n ASSISTANT: "
llama . reset ( )
llama . eval ( llama . tokenize ( system_prompt . encode ( " utf8 " ) , add_bos = True ) )
for message in messages :
if message [ " role " ] == " user " and message [ " content " ] is not None :
if isinstance ( message [ " content " ] , str ) :
2023-11-08 05:07:16 +00:00
llama . eval (
llama . tokenize (
f " { user_role } { message [ ' content ' ] } " . encode ( " utf8 " ) ,
add_bos = False ,
)
)
2023-11-08 03:48:51 +00:00
else :
assert isinstance ( message [ " content " ] , list )
2023-11-08 05:07:16 +00:00
llama . eval (
llama . tokenize ( f " { user_role } " . encode ( " utf8 " ) , add_bos = False )
)
2023-11-08 03:48:51 +00:00
for content in message [ " content " ] :
if content [ " type " ] == " text " :
2023-11-08 05:07:16 +00:00
llama . eval (
llama . tokenize (
f " { content [ ' text ' ] } " . encode ( " utf8 " ) , add_bos = False
)
)
2023-11-08 03:48:51 +00:00
if content [ " type " ] == " image_url " :
2023-11-08 05:07:16 +00:00
image_bytes = (
self . load_image ( content [ " image_url " ] [ " url " ] )
if isinstance ( content [ " image_url " ] , dict )
else self . load_image ( content [ " image_url " ] )
)
2023-11-08 03:48:51 +00:00
import array
2023-11-08 05:07:16 +00:00
data_array = array . array ( " B " , image_bytes )
c_ubyte_ptr = (
ctypes . c_ubyte * len ( data_array )
) . from_buffer ( data_array )
2023-11-08 16:05:45 +00:00
with suppress_stdout_stderr ( disable = self . verbose ) :
2023-11-09 05:55:23 +00:00
embed = (
self . _llava_cpp . llava_image_embed_make_with_bytes (
ctx_clip = self . clip_ctx ,
n_threads = llama . context_params . n_threads ,
image_bytes = c_ubyte_ptr ,
image_bytes_length = len ( image_bytes ) ,
)
2023-11-08 16:05:45 +00:00
)
2023-11-08 03:48:51 +00:00
try :
n_past = ctypes . c_int ( llama . n_tokens )
n_past_p = ctypes . pointer ( n_past )
2023-11-08 16:05:45 +00:00
with suppress_stdout_stderr ( disable = self . verbose ) :
self . _llava_cpp . llava_eval_image_embed (
ctx_llama = llama . ctx ,
embed = embed ,
n_batch = llama . n_batch ,
n_past = n_past_p ,
)
2023-11-08 03:48:51 +00:00
assert llama . n_ctx ( ) > = n_past . value
llama . n_tokens = n_past . value
finally :
2023-11-08 16:05:45 +00:00
with suppress_stdout_stderr ( disable = self . verbose ) :
self . _llava_cpp . llava_image_embed_free ( embed )
2023-11-08 03:48:51 +00:00
if message [ " role " ] == " assistant " and message [ " content " ] is not None :
2023-11-08 05:07:16 +00:00
llama . eval (
llama . tokenize (
f " ASSISTANT: { message [ ' content ' ] } " . encode ( " utf8 " ) , add_bos = False
)
)
2023-11-09 05:55:23 +00:00
assert llama . n_ctx ( ) > = llama . n_tokens
2023-11-08 03:48:51 +00:00
llama . eval ( llama . tokenize ( f " { assistant_role } " . encode ( " utf8 " ) , add_bos = False ) )
2023-11-09 05:55:23 +00:00
assert llama . n_ctx ( ) > = llama . n_tokens
2023-11-08 03:48:51 +00:00
2023-11-09 05:55:23 +00:00
prompt = llama . input_ids [ : llama . n_tokens ] . tolist ( )
if response_format is not None and response_format [ " type " ] == " json_object " :
with suppress_stdout_stderr ( disable = self . verbose ) :
grammar = llama_grammar . LlamaGrammar . from_string (
llama_grammar . JSON_GBNF
)
2023-11-08 03:48:51 +00:00
2023-11-08 05:07:16 +00:00
return _convert_completion_to_chat (
llama . create_completion (
prompt = prompt ,
temperature = temperature ,
top_p = top_p ,
top_k = top_k ,
2023-11-21 04:21:33 +00:00
min_p = min_p ,
typical_p = typical_p ,
2023-11-08 05:07:16 +00:00
stream = stream ,
stop = stop ,
max_tokens = max_tokens ,
presence_penalty = presence_penalty ,
frequency_penalty = frequency_penalty ,
repeat_penalty = repeat_penalty ,
tfs_z = tfs_z ,
mirostat_mode = mirostat_mode ,
mirostat_tau = mirostat_tau ,
mirostat_eta = mirostat_eta ,
model = model ,
logits_processor = logits_processor ,
grammar = grammar ,
) ,
2023-11-08 03:48:51 +00:00
stream = stream ,
2023-11-08 05:07:16 +00:00
)