2023-11-03 02:12:14 -04:00
from __future__ import annotations
2023-11-06 09:07:27 -05:00
import os
2023-11-10 02:51:58 -05:00
import json
2023-11-08 04:48:51 +01:00
import ctypes
2023-09-29 19:52:04 -04:00
import dataclasses
2024-02-08 09:07:03 +08:00
import random
import string
2024-04-30 01:35:38 -04:00
from contextlib import ExitStack
2024-04-05 10:50:49 -04:00
from typing import Any , Dict , Iterator , List , Literal , Optional , Tuple , Union , Protocol , cast
2023-11-06 09:07:27 -05:00
2024-01-18 21:21:37 -05:00
import jinja2
2024-04-20 00:00:53 -04:00
import numpy as np
import numpy . typing as npt
2023-11-08 04:48:51 +01:00
import llama_cpp . llama as llama
2023-11-08 00:07:16 -05:00
import llama_cpp . llama_types as llama_types
import llama_cpp . llama_grammar as llama_grammar
2023-11-03 02:12:14 -04:00
2024-02-23 18:40:52 +09:00
from . _logger import logger
2024-01-18 21:21:37 -05:00
from . _utils import suppress_stdout_stderr , Singleton
2023-11-08 11:05:45 -05:00
2024-01-29 14:22:23 -05:00
### Common Chat Templates and Special Tokens ###
# Source: https://huggingface.co/teknium/OpenHermes-2.5-Mistral-7B/blob/main/tokenizer_config.json
CHATML_CHAT_TEMPLATE = " { % f or message in messages % } {{ ' <|im_start|> ' + message[ ' role ' ] + ' \n ' + message[ ' content ' ] + ' <|im_end|> ' + ' \n ' }} { % e ndfor % } { % i f add_generation_prompt % } {{ ' <|im_start|>assistant \n ' }} { % e ndif % } "
CHATML_BOS_TOKEN = " <s> "
CHATML_EOS_TOKEN = " <|im_end|> "
# Source: https://huggingface.co/mistralai/Mistral-7B-Instruct-v0.1/blob/main/tokenizer_config.json
MISTRAL_INSTRUCT_CHAT_TEMPLATE = " {{ bos_token }} { % f or message in messages % } { % i f (message[ ' role ' ] == ' user ' ) != (loop.index0 % 2 == 0) % } {{ raise_exception( ' Conversation roles must alternate user/assistant/user/assistant/... ' ) }} { % e ndif % } { % i f message[ ' role ' ] == ' user ' % } {{ ' [INST] ' + message[ ' content ' ] + ' [/INST] ' }} { % e lif message[ ' role ' ] == ' assistant ' % } {{ message[ ' content ' ] + eos_token + ' ' }} { % e lse % } {{ raise_exception( ' Only user and assistant roles are supported! ' ) }} { % e ndif % } { % e ndfor % } "
MISTRAL_INSTRUCT_BOS_TOKEN = " <s> "
MISTRAL_INSTRUCT_EOS_TOKEN = " </s> "
2024-02-23 16:27:38 +00:00
# Source: https://huggingface.co/mistralai/Mixtral-8x7B-Instruct-v0.1/blob/main/tokenizer_config.json
MIXTRAL_INSTRUCT_CHAT_TEMPLATE = " {{ bos_token }} { % f or message in messages % } { % i f (message[ ' role ' ] == ' user ' ) != (loop.index0 % 2 == 0) % } {{ raise_exception( ' Conversation roles must alternate user/assistant/user/assistant/... ' ) }} { % e ndif % } { % i f message[ ' role ' ] == ' user ' % } {{ ' [INST] ' + message[ ' content ' ] + ' [/INST] ' }} { % e lif message[ ' role ' ] == ' assistant ' % } {{ message[ ' content ' ] + eos_token}} { % e lse % } {{ raise_exception( ' Only user and assistant roles are supported! ' ) }} { % e ndif % } { % e ndfor % } "
2024-01-29 14:22:23 -05:00
2024-04-23 06:33:29 +00:00
# Source: https://huggingface.co/meta-llama/Meta-Llama-3-8B-Instruct/blob/main/tokenizer_config.json
LLAMA3_INSTRUCT_CHAT_TEMPLATE = " { % s et loop_messages = messages % } { % f or message in loop_messages % } { % s et content = ' <|start_header_id|> ' + message[ ' role ' ] + ' <|end_header_id|> \n \n ' + message[ ' content ' ] | trim + ' <|eot_id|> ' % } { % i f loop.index0 == 0 % } { % s et content = bos_token + content % } { % e ndif % } {{ content }} { % e ndfor % } { % i f add_generation_prompt % } {{ ' <|start_header_id|>assistant<|end_header_id|> \n \n ' }} { % e ndif % } "
2024-01-29 14:22:23 -05:00
### Chat Completion Handler ###
2023-11-03 02:12:14 -04:00
2024-02-12 15:56:07 -05:00
2023-11-03 02:12:14 -04:00
class LlamaChatCompletionHandler ( Protocol ) :
2024-01-18 21:21:37 -05:00
""" Base Protocol for a llama chat completion handler.
Very generic protocol that can be used to implement any chat format .
The only hard requirement is that it must return a ChatCompletion when
stream = False and an iterator of ChatCompletionChunks when stream = True . """
2023-11-03 02:12:14 -04:00
def __call__ (
self ,
2023-11-08 04:48:51 +01:00
* ,
2024-01-18 21:21:37 -05:00
# llama.cpp instance
2023-11-03 02:12:14 -04:00
llama : llama . Llama ,
2024-01-18 21:21:37 -05:00
# openai api parameters
2023-11-03 02:12:14 -04:00
messages : List [ llama_types . ChatCompletionRequestMessage ] ,
functions : Optional [ List [ llama_types . ChatCompletionFunction ] ] = None ,
2023-11-08 04:48:51 +01:00
function_call : Optional [ llama_types . ChatCompletionRequestFunctionCall ] = None ,
tools : Optional [ List [ llama_types . ChatCompletionTool ] ] = None ,
tool_choice : Optional [ llama_types . ChatCompletionToolChoiceOption ] = None ,
2023-11-03 02:12:14 -04:00
temperature : float = 0.2 ,
top_p : float = 0.95 ,
top_k : int = 40 ,
stream : bool = False ,
stop : Optional [ Union [ str , List [ str ] ] ] = [ ] ,
2023-11-07 23:41:29 -05:00
seed : Optional [ int ] = None ,
2023-11-08 00:07:16 -05:00
response_format : Optional [
llama_types . ChatCompletionRequestResponseFormat
] = None ,
2023-11-10 02:51:58 -05:00
max_tokens : Optional [ int ] = None ,
2023-11-03 02:12:14 -04:00
presence_penalty : float = 0.0 ,
frequency_penalty : float = 0.0 ,
repeat_penalty : float = 1.1 ,
2024-01-18 21:21:37 -05:00
model : Optional [ str ] = None ,
logit_bias : Optional [ Dict [ str , float ] ] = None ,
# llama.cpp parameters
min_p : float = 0.05 ,
typical_p : float = 1.0 ,
2023-11-03 02:12:14 -04:00
tfs_z : float = 1.0 ,
mirostat_mode : int = 0 ,
mirostat_tau : float = 5.0 ,
mirostat_eta : float = 0.1 ,
logits_processor : Optional [ llama . LogitsProcessorList ] = None ,
grammar : Optional [ llama . LlamaGrammar ] = None ,
2024-04-10 03:41:55 -04:00
logprobs : Optional [ bool ] = None ,
top_logprobs : Optional [ int ] = None ,
2023-11-08 04:48:51 +01:00
* * kwargs , # type: ignore
2023-11-08 00:07:16 -05:00
) - > Union [
llama_types . CreateChatCompletionResponse ,
Iterator [ llama_types . CreateChatCompletionStreamResponse ] ,
2024-02-12 15:56:07 -05:00
] : . . .
2023-11-03 02:12:14 -04:00
2024-01-18 21:21:37 -05:00
class LlamaChatCompletionHandlerNotFoundException ( Exception ) :
pass
class LlamaChatCompletionHandlerRegistry ( Singleton ) :
_chat_handlers : Dict [ str , LlamaChatCompletionHandler ] = { }
def register_chat_completion_handler (
self ,
name : str ,
chat_handler : LlamaChatCompletionHandler ,
overwrite : bool = False ,
) :
if not overwrite and name in self . _chat_handlers :
raise ValueError (
f " Formatter with name ' { name } ' is already registered. Use `overwrite=True` to overwrite it. "
)
self . _chat_handlers [ name ] = chat_handler
def unregister_chat_handler ( self , name : str ) :
if name in self . _chat_handlers :
del self . _chat_handlers [ name ]
else :
raise ValueError ( f " No formatter registered under the name ' { name } ' . " )
def get_chat_completion_handler_by_name (
self , name : str
) - > LlamaChatCompletionHandler :
try :
chat_handler = self . _chat_handlers [ name ]
return chat_handler
except KeyError :
raise LlamaChatCompletionHandlerNotFoundException (
f " Invalid chat handler: { name } (valid formats: { list ( self . _chat_handlers . keys ( ) ) } ) "
)
2023-11-03 02:12:14 -04:00
def get_chat_completion_handler ( name : str ) - > LlamaChatCompletionHandler :
2024-01-18 21:21:37 -05:00
return LlamaChatCompletionHandlerRegistry ( ) . get_chat_completion_handler_by_name (
name
)
2023-11-03 02:12:14 -04:00
def register_chat_completion_handler ( name : str ) :
def decorator ( f : LlamaChatCompletionHandler ) :
2024-01-18 21:21:37 -05:00
LlamaChatCompletionHandlerRegistry ( ) . register_chat_completion_handler ( name , f )
2023-11-03 02:12:14 -04:00
return f
return decorator
2023-09-29 19:52:04 -04:00
2024-01-18 21:21:37 -05:00
### Chat Formatter ###
2024-02-12 15:56:07 -05:00
2024-01-18 21:21:37 -05:00
@dataclasses.dataclass
class ChatFormatterResponse :
2024-01-19 15:04:42 -05:00
""" Dataclass that stores completion parameters for a given chat format and
create_chat_completion request .
prompt contains the formatted prompt generated from the chat format and messages .
stop contains the stop token or list of stop tokens to use for the chat format . """
2024-01-18 21:21:37 -05:00
prompt : str
stop : Optional [ Union [ str , List [ str ] ] ] = None
2024-04-20 00:00:53 -04:00
stopping_criteria : Optional [ llama . StoppingCriteriaList ] = None
2024-01-18 21:21:37 -05:00
class ChatFormatter ( Protocol ) :
""" Base Protocol for a chat formatter. A chat formatter is a function that
2024-01-19 15:04:42 -05:00
takes a list of messages and returns a chat format response which can be used
to generate a completion . The response can also include a stop token or list
of stop tokens to use for the completion . """
2024-01-18 21:21:37 -05:00
def __call__ (
self ,
* ,
messages : List [ llama_types . ChatCompletionRequestMessage ] ,
* * kwargs : Any ,
2024-02-12 15:56:07 -05:00
) - > ChatFormatterResponse : . . .
2024-01-18 21:21:37 -05:00
2024-01-19 15:04:42 -05:00
class Jinja2ChatFormatter ( ChatFormatter ) :
def __init__ (
self ,
template : str ,
eos_token : str ,
bos_token : str ,
2024-01-21 18:37:24 -05:00
add_generation_prompt : bool = True ,
2024-04-20 00:00:53 -04:00
stop_token_ids : Optional [ List [ int ] ] = None ,
2024-01-19 15:04:42 -05:00
) :
""" A chat formatter that uses jinja2 templates to format the prompt. """
self . template = template
self . eos_token = eos_token
self . bos_token = bos_token
2024-01-21 18:37:24 -05:00
self . add_generation_prompt = add_generation_prompt
2024-04-20 00:00:53 -04:00
self . stop_token_ids = set ( stop_token_ids ) if stop_token_ids is not None else None
2023-09-29 19:52:04 -04:00
2024-01-19 15:04:42 -05:00
self . _environment = jinja2 . Environment (
loader = jinja2 . BaseLoader ( ) ,
trim_blocks = True ,
lstrip_blocks = True ,
) . from_string ( self . template )
2023-09-29 19:52:04 -04:00
2024-01-19 15:04:42 -05:00
def __call__ (
self ,
* ,
messages : List [ llama_types . ChatCompletionRequestMessage ] ,
2024-03-19 04:55:57 -04:00
functions : Optional [ List [ llama_types . ChatCompletionFunction ] ] = None ,
function_call : Optional [ llama_types . ChatCompletionRequestFunctionCall ] = None ,
tools : Optional [ List [ llama_types . ChatCompletionTool ] ] = None ,
tool_choice : Optional [ llama_types . ChatCompletionToolChoiceOption ] = None ,
2024-01-19 15:04:42 -05:00
* * kwargs : Any ,
) - > ChatFormatterResponse :
2024-01-31 08:42:21 -05:00
def raise_exception ( message : str ) :
raise ValueError ( message )
2024-01-19 15:04:42 -05:00
prompt = self . _environment . render (
2024-01-31 08:42:21 -05:00
messages = messages ,
eos_token = self . eos_token ,
bos_token = self . bos_token ,
raise_exception = raise_exception ,
2024-02-12 15:56:07 -05:00
add_generation_prompt = self . add_generation_prompt ,
2024-03-19 04:55:57 -04:00
functions = functions ,
function_call = function_call ,
tools = tools ,
tool_choice = tool_choice ,
2024-01-19 15:04:42 -05:00
)
2024-01-31 08:42:21 -05:00
2024-04-20 00:00:53 -04:00
stopping_criteria = None
if self . stop_token_ids is not None :
def stop_on_last_token (
tokens : npt . NDArray [ np . intc ] ,
logits : npt . NDArray [ np . single ]
) - > bool :
return tokens [ - 1 ] in self . stop_token_ids
stopping_criteria = llama . StoppingCriteriaList ( [ stop_on_last_token ] )
return ChatFormatterResponse ( prompt = prompt , stop = [ self . eos_token ] , stopping_criteria = stopping_criteria )
2023-09-30 21:01:34 -04:00
2024-01-19 15:04:42 -05:00
def to_chat_handler ( self ) - > LlamaChatCompletionHandler :
return chat_formatter_to_chat_completion_handler ( self )
2024-01-05 00:12:02 +01:00
2023-09-30 21:01:34 -04:00
2023-11-03 02:12:14 -04:00
def _convert_text_completion_to_chat (
completion : llama_types . Completion ,
) - > llama_types . ChatCompletion :
2024-01-18 21:21:37 -05:00
assert " usage " in completion
2023-11-03 02:12:14 -04:00
return {
" id " : " chat " + completion [ " id " ] ,
" object " : " chat.completion " ,
" created " : completion [ " created " ] ,
" model " : completion [ " model " ] ,
" choices " : [
{
" index " : 0 ,
" message " : {
" role " : " assistant " ,
" content " : completion [ " choices " ] [ 0 ] [ " text " ] ,
} ,
2024-04-01 02:30:13 +09:00
" logprobs " : completion [ " choices " ] [ 0 ] [ " logprobs " ] ,
2023-11-03 02:12:14 -04:00
" finish_reason " : completion [ " choices " ] [ 0 ] [ " finish_reason " ] ,
}
] ,
" usage " : completion [ " usage " ] ,
}
def _convert_text_completion_chunks_to_chat (
2023-11-08 04:48:51 +01:00
chunks : Iterator [ llama_types . CreateCompletionStreamResponse ] ,
2023-11-03 02:12:14 -04:00
) - > Iterator [ llama_types . ChatCompletionChunk ] :
for i , chunk in enumerate ( chunks ) :
if i == 0 :
yield {
" id " : " chat " + chunk [ " id " ] ,
" model " : chunk [ " model " ] ,
" created " : chunk [ " created " ] ,
" object " : " chat.completion.chunk " ,
" choices " : [
{
" index " : 0 ,
" delta " : {
" role " : " assistant " ,
} ,
2024-04-01 02:30:13 +09:00
" logprobs " : None ,
2023-11-03 02:12:14 -04:00
" finish_reason " : None ,
}
] ,
}
yield {
" id " : " chat " + chunk [ " id " ] ,
" model " : chunk [ " model " ] ,
" created " : chunk [ " created " ] ,
" object " : " chat.completion.chunk " ,
" choices " : [
{
" index " : 0 ,
2024-02-12 15:56:07 -05:00
" delta " : (
{
" content " : chunk [ " choices " ] [ 0 ] [ " text " ] ,
}
if chunk [ " choices " ] [ 0 ] [ " finish_reason " ] is None
else { }
) ,
2024-04-01 02:30:13 +09:00
" logprobs " : chunk [ " choices " ] [ 0 ] [ " logprobs " ] ,
2023-11-03 02:12:14 -04:00
" finish_reason " : chunk [ " choices " ] [ 0 ] [ " finish_reason " ] ,
}
] ,
}
def _convert_completion_to_chat (
completion_or_chunks : Union [
2023-11-08 04:48:51 +01:00
llama_types . CreateCompletionResponse ,
Iterator [ llama_types . CreateCompletionStreamResponse ] ,
2023-11-03 02:12:14 -04:00
] ,
stream : bool = False ,
2023-11-08 04:48:51 +01:00
) - > Union [
llama_types . CreateChatCompletionResponse , Iterator [ llama_types . ChatCompletionChunk ]
] :
2023-11-03 02:12:14 -04:00
if stream :
2023-11-08 04:48:51 +01:00
chunks : Iterator [ llama_types . CreateCompletionStreamResponse ] = completion_or_chunks # type: ignore
2023-11-03 02:12:14 -04:00
return _convert_text_completion_chunks_to_chat ( chunks )
else :
completion : llama_types . Completion = completion_or_chunks # type: ignore
return _convert_text_completion_to_chat ( completion )
2024-03-19 04:55:57 -04:00
def _convert_completion_to_chat_function (
tool_name : str ,
completion_or_chunks : Union [
llama_types . CreateCompletionResponse ,
Iterator [ llama_types . CreateCompletionStreamResponse ] ,
] ,
stream : bool ,
) :
if not stream :
completion : llama_types . CreateCompletionResponse = completion_or_chunks # type: ignore
assert " usage " in completion
tool_id = " call_ " + " _0_ " + tool_name + " _ " + completion [ " id " ]
# TODO: Fix for legacy function calls
chat_completion : llama_types . CreateChatCompletionResponse = {
" id " : " chat " + completion [ " id " ] ,
" object " : " chat.completion " ,
" created " : completion [ " created " ] ,
" model " : completion [ " model " ] ,
" choices " : [
{
" index " : 0 ,
" message " : {
" role " : " assistant " ,
" content " : None ,
" function_call " : {
" name " : tool_name ,
" arguments " : completion [ " choices " ] [ 0 ] [ " text " ] ,
} ,
" tool_calls " : [
{
" id " : tool_id ,
" type " : " function " ,
" function " : {
" name " : tool_name ,
" arguments " : completion [ " choices " ] [ 0 ] [ " text " ] ,
} ,
}
] ,
} ,
2024-04-10 03:41:55 -04:00
" logprobs " : completion [ " choices " ] [ 0 ] [ " logprobs " ] ,
2024-03-19 04:55:57 -04:00
" finish_reason " : " tool_calls " ,
}
] ,
" usage " : completion [ " usage " ] ,
}
return chat_completion
else :
chunks : Iterator [ llama_types . CreateCompletionStreamResponse ] = completion_or_chunks # type: ignore
def _stream_response_to_function_stream (
chunks : Iterator [ llama_types . CreateCompletionStreamResponse ] ,
) - > Iterator [ llama_types . CreateChatCompletionStreamResponse ] :
# blank first message
first = True
id_ = None
created = None
model = None
tool_id = None
for chunk in chunks :
if first :
id_ = " chat " + chunk [ " id " ]
created = chunk [ " created " ]
model = chunk [ " model " ]
tool_id = " call_ " + " _0_ " + tool_name + " _ " + chunk [ " id " ]
yield {
" id " : id_ ,
" object " : " chat.completion.chunk " ,
" created " : created ,
" model " : model ,
" choices " : [
{
" index " : 0 ,
" finish_reason " : None ,
" logprobs " : None ,
" delta " : {
" role " : " assistant " ,
" content " : None ,
" function_call " : None ,
" tool_calls " : None ,
} ,
}
] ,
}
yield {
" id " : " chat " + chunk [ " id " ] ,
" object " : " chat.completion.chunk " ,
" created " : chunk [ " created " ] ,
" model " : chunk [ " model " ] ,
" choices " : [
{
" index " : 0 ,
" finish_reason " : None ,
2024-04-10 03:41:55 -04:00
" logprobs " : chunk [ " choices " ] [ 0 ] [ " logprobs " ] ,
2024-03-19 04:55:57 -04:00
" delta " : {
" role " : None ,
" content " : None ,
" function_call " : {
" name " : tool_name ,
" arguments " : chunk [ " choices " ] [ 0 ] [ " text " ] ,
} ,
" tool_calls " : [
{
" index " : 0 ,
" id " : tool_id ,
" type " : " function " ,
" function " : {
" name " : tool_name ,
2024-03-22 23:44:04 -04:00
" arguments " : chunk [ " choices " ] [ 0 ] [ " text " ] ,
2024-03-19 04:55:57 -04:00
} ,
}
] ,
} ,
}
] ,
}
first = False
continue
assert tool_id is not None
yield {
" id " : " chat " + chunk [ " id " ] ,
" object " : " chat.completion.chunk " ,
" created " : chunk [ " created " ] ,
" model " : chunk [ " model " ] ,
" choices " : [
{
" index " : 0 ,
" finish_reason " : None ,
2024-04-10 03:41:55 -04:00
" logprobs " : chunk [ " choices " ] [ 0 ] [ " logprobs " ] ,
2024-03-19 04:55:57 -04:00
" delta " : {
" role " : None ,
" content " : None ,
" function_call " : {
" name " : tool_name ,
" arguments " : chunk [ " choices " ] [ 0 ] [ " text " ] ,
} ,
" tool_calls " : [
{
" index " : 0 ,
" id " : tool_id ,
" type " : " function " ,
" function " : {
" name " : tool_name ,
" arguments " : chunk [ " choices " ] [ 0 ] [
" text "
] ,
} ,
}
] ,
} ,
}
] ,
}
if id_ is not None and created is not None and model is not None :
yield {
" id " : id_ ,
" object " : " chat.completion.chunk " ,
" created " : created ,
" model " : model ,
" choices " : [
{
" index " : 0 ,
" finish_reason " : " tool_calls " ,
" logprobs " : None ,
" delta " : {
" role " : None ,
" content " : None ,
" function_call " : None ,
" tool_calls " : None ,
} ,
}
] ,
}
return _stream_response_to_function_stream ( chunks )
2024-01-18 21:21:37 -05:00
def chat_formatter_to_chat_completion_handler (
chat_formatter : ChatFormatter ,
) - > LlamaChatCompletionHandler :
def chat_completion_handler (
* ,
llama : llama . Llama ,
messages : List [ llama_types . ChatCompletionRequestMessage ] ,
functions : Optional [ List [ llama_types . ChatCompletionFunction ] ] = None ,
function_call : Optional [ llama_types . ChatCompletionRequestFunctionCall ] = None ,
tools : Optional [ List [ llama_types . ChatCompletionTool ] ] = None ,
tool_choice : Optional [ llama_types . ChatCompletionToolChoiceOption ] = None ,
temperature : float = 0.2 ,
top_p : float = 0.95 ,
top_k : int = 40 ,
min_p : float = 0.05 ,
typical_p : float = 1.0 ,
stream : bool = False ,
stop : Optional [ Union [ str , List [ str ] ] ] = [ ] ,
seed : Optional [ int ] = None ,
response_format : Optional [
llama_types . ChatCompletionRequestResponseFormat
] = None ,
max_tokens : Optional [ int ] = None ,
presence_penalty : float = 0.0 ,
frequency_penalty : float = 0.0 ,
repeat_penalty : float = 1.1 ,
tfs_z : float = 1.0 ,
mirostat_mode : int = 0 ,
mirostat_tau : float = 5.0 ,
mirostat_eta : float = 0.1 ,
model : Optional [ str ] = None ,
logits_processor : Optional [ llama . LogitsProcessorList ] = None ,
grammar : Optional [ llama . LlamaGrammar ] = None ,
logit_bias : Optional [ Dict [ str , float ] ] = None ,
2024-04-10 03:41:55 -04:00
logprobs : Optional [ bool ] = None ,
top_logprobs : Optional [ int ] = None ,
2024-01-18 21:21:37 -05:00
* * kwargs , # type: ignore
) - > Union [
llama_types . CreateChatCompletionResponse ,
Iterator [ llama_types . CreateChatCompletionStreamResponse ] ,
] :
result = chat_formatter (
messages = messages ,
functions = functions ,
function_call = function_call ,
2024-03-19 04:55:57 -04:00
tools = tools ,
tool_choice = tool_choice ,
2024-01-18 21:21:37 -05:00
)
prompt = result . prompt
if result . stop is not None :
stop = [ ] if stop is None else [ stop ] if isinstance ( stop , str ) else stop
rstop = result . stop if isinstance ( result . stop , list ) else [ result . stop ]
stop = stop + rstop
2024-04-20 00:00:53 -04:00
stopping_criteria = None
if result . stopping_criteria is not None :
stopping_criteria = result . stopping_criteria
2024-01-18 21:21:37 -05:00
if response_format is not None and response_format [ " type " ] == " json_object " :
2024-03-15 12:58:34 -04:00
grammar = _grammar_for_response_format ( response_format , verbose = llama . verbose )
2024-01-18 21:21:37 -05:00
2024-03-19 04:55:57 -04:00
# Convert legacy functions to tools
if functions is not None :
tools = [
{
" type " : " function " ,
" function " : function ,
}
for function in functions
]
# Convert legacy function_call to tool_choice
if function_call is not None :
if isinstance ( function_call , str ) and (
function_call == " none " or function_call == " auto "
) :
tool_choice = function_call
if isinstance ( function_call , dict ) and " name " in function_call :
tool_choice = {
" type " : " function " ,
" function " : {
" name " : function_call [ " name " ] ,
} ,
}
tool = None
if tool_choice is not None and isinstance ( tool_choice , dict ) and tools is not None :
name = tool_choice [ " function " ] [ " name " ]
tool = next ( ( t for t in tools if t [ " function " ] [ " name " ] == name ) , None )
if tool is None :
raise ValueError ( f " Tool choice ' { name } ' not found in tools. " )
schema = tool [ " function " ] [ " parameters " ]
try :
# create grammar from json schema
grammar = llama_grammar . LlamaGrammar . from_json_schema (
json . dumps ( schema ) , verbose = llama . verbose
)
except Exception as e :
grammar = llama_grammar . LlamaGrammar . from_string (
llama_grammar . JSON_GBNF , verbose = llama . verbose
)
2024-01-18 21:21:37 -05:00
completion_or_chunks = llama . create_completion (
prompt = prompt ,
temperature = temperature ,
top_p = top_p ,
top_k = top_k ,
min_p = min_p ,
typical_p = typical_p ,
2024-04-10 03:41:55 -04:00
logprobs = top_logprobs if logprobs else None ,
2024-01-18 21:21:37 -05:00
stream = stream ,
stop = stop ,
seed = seed ,
max_tokens = max_tokens ,
presence_penalty = presence_penalty ,
frequency_penalty = frequency_penalty ,
repeat_penalty = repeat_penalty ,
tfs_z = tfs_z ,
mirostat_mode = mirostat_mode ,
mirostat_tau = mirostat_tau ,
mirostat_eta = mirostat_eta ,
model = model ,
logits_processor = logits_processor ,
2024-04-20 00:00:53 -04:00
stopping_criteria = stopping_criteria ,
2024-01-18 21:21:37 -05:00
grammar = grammar ,
logit_bias = logit_bias ,
)
2024-03-19 04:55:57 -04:00
if tool is not None :
tool_name = tool [ " function " ] [ " name " ]
return _convert_completion_to_chat_function (
tool_name , completion_or_chunks , stream
)
2024-01-18 21:21:37 -05:00
return _convert_completion_to_chat ( completion_or_chunks , stream = stream )
return chat_completion_handler
2023-09-29 19:52:04 -04:00
2023-11-08 04:48:51 +01:00
def hf_autotokenizer_to_chat_formatter (
pretrained_model_name_or_path : Union [ str , os . PathLike [ str ] ]
) - > ChatFormatter :
2023-11-06 09:07:27 -05:00
# https://huggingface.co/docs/transformers/main/chat_templating
# https://huggingface.co/mistralai/Mistral-7B-Instruct-v0.1#instruction-format
# https://huggingface.co/mistralai/Mistral-7B-Instruct-v0.1/blob/main/tokenizer_config.json
2024-01-18 21:21:37 -05:00
from transformers import AutoTokenizer # type: ignore
2023-11-06 09:07:27 -05:00
2024-01-18 21:21:37 -05:00
tokenizer = AutoTokenizer . from_pretrained ( pretrained_model_name_or_path ) # type: ignore
2023-11-06 09:07:27 -05:00
def format_autotokenizer (
messages : List [ llama_types . ChatCompletionRequestMessage ] ,
* * kwargs : Any ,
) - > ChatFormatterResponse :
2024-01-18 21:21:37 -05:00
tokenizer . use_default_system_prompt = False # type: ignore
prompt : str = tokenizer . apply_chat_template ( messages , tokenize = False ) # type: ignore
assert isinstance ( prompt , str )
2023-11-06 09:07:27 -05:00
# Return formatted prompt and eos token by default
2024-01-18 21:21:37 -05:00
return ChatFormatterResponse ( prompt = prompt , stop = tokenizer . eos_token )
2023-11-06 09:07:27 -05:00
return format_autotokenizer
2024-01-18 21:21:37 -05:00
def hf_autotokenizer_to_chat_completion_handler (
pretrained_model_name_or_path : Union [ str , os . PathLike [ str ] ]
) - > LlamaChatCompletionHandler :
chat_formatter = hf_autotokenizer_to_chat_formatter ( pretrained_model_name_or_path )
return chat_formatter_to_chat_completion_handler ( chat_formatter )
2024-01-19 15:04:42 -05:00
def hf_tokenizer_config_to_chat_formatter (
2024-01-22 08:32:48 -05:00
tokenizer_config : Dict [ str , Any ] ,
add_generation_prompt : bool = True ,
2024-01-19 15:04:42 -05:00
) - > ChatFormatter :
2024-01-18 21:21:37 -05:00
assert isinstance ( tokenizer_config , dict )
assert " chat_template " in tokenizer_config
assert isinstance ( tokenizer_config [ " chat_template " ] , str )
chat_template = tokenizer_config [ " chat_template " ]
assert " bos_token " in tokenizer_config
assert isinstance ( tokenizer_config [ " bos_token " ] , str )
bos_token = tokenizer_config [ " bos_token " ]
assert " eos_token " in tokenizer_config
assert isinstance ( tokenizer_config [ " eos_token " ] , str )
eos_token = tokenizer_config [ " eos_token " ]
env = jinja2 . Environment (
loader = jinja2 . BaseLoader ( ) ,
trim_blocks = True ,
lstrip_blocks = True ,
) . from_string ( chat_template )
2024-01-22 08:32:48 -05:00
def format_tokenizer_config (
2024-01-18 21:21:37 -05:00
messages : List [ llama_types . ChatCompletionRequestMessage ] ,
* * kwargs : Any ,
) - > ChatFormatterResponse :
# TODO: veryify this is correct
# Add a blank assistant message to the end of the messages to prompt the model to generate a response
2024-01-22 08:32:48 -05:00
if add_generation_prompt :
messages = [
2024-01-18 21:21:37 -05:00
* messages ,
llama_types . ChatCompletionRequestAssistantMessage (
role = " assistant " , content = " "
) ,
2024-01-22 08:32:48 -05:00
]
prompt = env . render (
messages = messages ,
2024-01-18 21:21:37 -05:00
bos_token = bos_token ,
eos_token = eos_token ,
)
2024-01-22 08:32:48 -05:00
return ChatFormatterResponse ( prompt = prompt , stop = [ eos_token , bos_token ] )
2024-01-19 15:04:42 -05:00
2024-01-22 08:32:48 -05:00
return format_tokenizer_config
2024-01-18 21:21:37 -05:00
def hf_tokenizer_config_to_chat_completion_handler (
tokenizer_config : Dict [ str , Any ] ,
2024-01-22 08:32:48 -05:00
add_generation_prompt : bool = True ,
2024-01-18 21:21:37 -05:00
) - > LlamaChatCompletionHandler :
2024-02-12 15:56:07 -05:00
chat_formatter = hf_tokenizer_config_to_chat_formatter (
tokenizer_config , add_generation_prompt = add_generation_prompt
)
2024-01-18 21:21:37 -05:00
return chat_formatter_to_chat_completion_handler ( chat_formatter )
2024-01-29 14:22:23 -05:00
def guess_chat_format_from_gguf_metadata ( metadata : Dict [ str , str ] ) - > Optional [ str ] :
if " tokenizer.chat_template " not in metadata :
return None
if metadata [ " tokenizer.chat_template " ] == CHATML_CHAT_TEMPLATE :
return " chatml "
2024-02-23 16:27:38 +00:00
if ( metadata [ " tokenizer.chat_template " ] == MISTRAL_INSTRUCT_CHAT_TEMPLATE or
metadata [ " tokenizer.chat_template " ] == MIXTRAL_INSTRUCT_CHAT_TEMPLATE ) :
2024-01-29 14:22:23 -05:00
return " mistral-instruct "
2024-04-23 06:33:29 +00:00
if metadata [ " tokenizer.chat_template " ] == LLAMA3_INSTRUCT_CHAT_TEMPLATE :
return " llama-3 "
2024-01-29 14:22:23 -05:00
return None
2024-02-12 15:56:07 -05:00
2024-01-19 15:04:42 -05:00
### Utility functions for formatting chat prompts ###
2024-01-29 14:22:23 -05:00
# TODO: Replace these with jinja2 templates
2024-01-19 15:04:42 -05:00
def _get_system_message (
messages : List [ llama_types . ChatCompletionRequestMessage ] ,
) - > str :
""" Get the first system message. """
for message in messages :
if message [ " role " ] == " system " :
return message [ " content " ] or " "
return " "
def _map_roles (
messages : List [ llama_types . ChatCompletionRequestMessage ] ,
role_map : Dict [ str , str ] ,
) - > List [ Tuple [ str , Optional [ str ] ] ] :
""" Map the message roles. """
output : List [ Tuple [ str , Optional [ str ] ] ] = [ ]
for message in messages :
role = message [ " role " ]
if role in role_map :
content : str | None = (
message [ " content " ] if isinstance ( message [ " content " ] , str ) else None
)
output . append ( ( role_map [ role ] , content ) )
return output
def _format_llama2 (
system_message : str , messages : List [ Tuple [ str , Optional [ str ] ] ] , sep : str , sep2 : str
) - > str :
""" Format the prompt with the llama2 style. """
seps = [ sep , sep2 ]
ret = system_message + sep
for i , ( role , message ) in enumerate ( messages ) :
if system_message and i == 0 :
m = message or " "
ret + = m + seps [ i % 2 ]
elif message :
ret + = role + message + " " + seps [ i % 2 ]
else :
ret + = role + " "
return ret
def _format_add_colon_single (
system_message : str , messages : List [ Tuple [ str , Optional [ str ] ] ] , sep : str
) - > str :
""" Format the prompt with the add-colon-single style. """
ret = system_message + sep
for role , message in messages :
if message :
ret + = role + " : " + message + sep
else :
ret + = role + " : "
return ret
def _format_add_colon_two (
system_message : str , messages : List [ Tuple [ str , Optional [ str ] ] ] , sep : str , sep2 : str
) - > str :
""" Format the prompt with the add-colon-two style. """
seps = [ sep , sep2 ]
ret = system_message + seps [ 0 ]
for i , ( role , message ) in enumerate ( messages ) :
if message :
ret + = role + " : " + message + seps [ i % 2 ]
else :
ret + = role + " : "
return ret
def _format_no_colon_single (
system_message : str , messages : List [ Tuple [ str , Optional [ str ] ] ] , sep : str
) - > str :
""" Format the prompt with the no-colon-single style. """
ret = system_message
for role , message in messages :
if message :
ret + = role + message + sep
else :
ret + = role
return ret
def _format_add_colon_space_single (
system_message : str , messages : List [ Tuple [ str , Optional [ str ] ] ] , sep : str
) - > str :
""" Format the prompt with the add-colon-space-single style. """
ret = system_message + sep
for role , message in messages :
if message :
ret + = role + " : " + message + sep
else :
ret + = role + " : " # must be end with a space
return ret
def _format_chatml (
system_message : str , messages : List [ Tuple [ str , Optional [ str ] ] ] , sep : str
) - > str :
""" Format the prompt with the chatml style. """
ret = " " if system_message == " " else system_message + sep + " \n "
for role , message in messages :
if message :
ret + = role + " \n " + message + sep + " \n "
else :
ret + = role + " \n "
return ret
def _format_chatglm3 (
system_message : str , messages : List [ Tuple [ str , Optional [ str ] ] ] , sep : str
) - > str :
""" Format the prompt with the chatglm3 style. """
ret = " "
if system_message :
ret + = system_message
for role , message in messages :
if message :
ret + = role + " \n " + " " + message
else :
ret + = role
return ret
2024-03-15 12:58:34 -04:00
def _grammar_for_json ( verbose : bool = False ) :
return llama_grammar . LlamaGrammar . from_string ( llama_grammar . JSON_GBNF , verbose = verbose )
def _grammar_for_json_schema (
schema : str ,
verbose : bool = False ,
fallback_to_json : bool = True
) :
try :
return llama_grammar . LlamaGrammar . from_json_schema ( schema , verbose = verbose )
except Exception as e :
if fallback_to_json :
return _grammar_for_json ( verbose = verbose )
else :
raise e
def _grammar_for_response_format (
response_format : llama_types . ChatCompletionRequestResponseFormat ,
verbose : bool = False
) :
if response_format [ " type " ] != " json_object " :
return None
if " schema " in response_format :
return _grammar_for_json_schema (
json . dumps ( response_format [ " schema " ] ) , verbose = verbose
)
else :
return _grammar_for_json ( verbose = verbose )
2024-01-19 15:04:42 -05:00
### Chat Formats ###
def register_chat_format ( name : str ) :
def decorator ( f : ChatFormatter ) :
chat_completion_handler = chat_formatter_to_chat_completion_handler ( f )
LlamaChatCompletionHandlerRegistry ( ) . register_chat_completion_handler (
name , chat_completion_handler
)
return f
return decorator
2023-11-05 17:00:13 -05:00
# see https://github.com/huggingface/transformers/blob/main/src/transformers/models/llama/tokenization_llama.py
# system prompt is "embedded" in the first message
2023-09-29 19:52:04 -04:00
@register_chat_format ( " llama-2 " )
def format_llama2 (
messages : List [ llama_types . ChatCompletionRequestMessage ] ,
* * kwargs : Any ,
) - > ChatFormatterResponse :
2023-11-05 17:00:13 -05:00
_system_template = " <s>[INST] <<SYS>> \n {system_message} \n <</SYS>> "
_roles = dict ( user = " <s>[INST] " , assistant = " [/INST] " )
2023-09-29 19:52:04 -04:00
_messages = _map_roles ( messages , _roles )
2023-11-05 17:00:13 -05:00
system_message = _get_system_message ( messages )
if system_message :
system_message = _system_template . format ( system_message = system_message )
_prompt = _format_llama2 ( system_message , _messages , " " , " </s> " ) + " [/INST] "
2023-09-29 19:52:04 -04:00
return ChatFormatterResponse ( prompt = _prompt )
2024-04-23 06:33:29 +00:00
# Chat format for Llama-3 models, see more details at:
# https://github.com/meta-llama/llama3/blob/main/llama/tokenizer.py#L202-L229
@register_chat_format ( " llama-3 " )
def format_llama3 (
messages : List [ llama_types . ChatCompletionRequestMessage ] ,
* * kwargs : Any ,
) - > ChatFormatterResponse :
_roles = dict (
system = " <|start_header_id|>system<|end_header_id|> \n \n " ,
user = " <|start_header_id|>user<|end_header_id|> \n \n " ,
assistant = " <|start_header_id|>assistant<|end_header_id|> \n \n " ,
)
_begin_token = " <|begin_of_text|> "
_sep = " <|eot_id|> "
_messages = _map_roles ( messages , _roles )
_messages . append ( ( _roles [ " assistant " ] , None ) )
_prompt = _format_no_colon_single ( _begin_token , _messages , _sep )
return ChatFormatterResponse ( prompt = _prompt , stop = _sep )
2023-09-29 19:52:04 -04:00
@register_chat_format ( " alpaca " )
def format_alpaca (
messages : List [ llama_types . ChatCompletionRequestMessage ] ,
* * kwargs : Any ,
) - > ChatFormatterResponse :
_roles = dict ( user = " ### Instruction " , assistant = " ### Response " )
_sep = " \n \n "
_sep2 = " </s> "
system_message = _get_system_message ( messages )
_messages = _map_roles ( messages , _roles )
_prompt = _format_add_colon_two ( system_message , _messages , _sep , _sep2 )
return ChatFormatterResponse ( prompt = _prompt )
2024-01-18 21:21:37 -05:00
2023-12-14 10:43:43 +08:00
@register_chat_format ( " qwen " )
def format_qwen (
messages : List [ llama_types . ChatCompletionRequestMessage ] ,
* * kwargs : Any ,
) - > ChatFormatterResponse :
_roles = dict ( user = " <|im_start|>user " , assistant = " <|im_start|>assistant " )
2024-01-18 21:21:37 -05:00
system_message = " You are a helpful assistant. "
system_template = " <|im_start|>system \n {system_message} "
system_message = system_template . format ( system_message = system_message )
2023-12-14 10:43:43 +08:00
_messages = _map_roles ( messages , _roles )
_messages . append ( ( _roles [ " assistant " ] , None ) )
_sep = " <|im_end|> "
_prompt = _format_chatml ( system_message , _messages , _sep )
_sep2 = " <|endoftext|> "
2024-01-18 21:21:37 -05:00
return ChatFormatterResponse ( prompt = _prompt , stop = _sep2 )
2023-09-29 19:52:04 -04:00
@register_chat_format ( " vicuna " )
def format (
messages : List [ llama_types . ChatCompletionRequestMessage ] ,
* * kwargs : Any ,
) - > ChatFormatterResponse :
_system_message = " A chat between a curious user and an artificial intelligence assistant. The assistant gives helpful, detailed, and polite answers to the user ' s questions. "
_roles = dict ( user = " USER " , assistant = " ASSISTANT " )
_sep = " "
_sep2 = " </s> "
system_message = _system_message
_messages = _map_roles ( messages , _roles )
_messages . append ( ( _roles [ " assistant " ] , None ) )
_prompt = _format_add_colon_two ( system_message , _messages , _sep , _sep2 )
return ChatFormatterResponse ( prompt = _prompt )
@register_chat_format ( " oasst_llama " )
def format_oasst_llama (
messages : List [ llama_types . ChatCompletionRequestMessage ] ,
* * kwargs : Any ,
) - > ChatFormatterResponse :
_system_template = " [INST] <<SYS>> \n {system_message} \n <</SYS>> \n \n "
_roles = dict ( user = " <|prompter|> " , assistant = " <|assistant|> " )
_sep = " </s> "
system_message = _get_system_message ( messages )
system_message = _system_template . format ( system_message = system_message )
_messages = _map_roles ( messages , _roles )
_messages . append ( ( _roles [ " assistant " ] , None ) )
_prompt = _format_no_colon_single ( system_message , _messages , _sep )
return ChatFormatterResponse ( prompt = _prompt )
2023-11-22 19:08:06 +08:00
@register_chat_format ( " baichuan-2 " )
def format_baichuan2 (
messages : List [ llama_types . ChatCompletionRequestMessage ] ,
* * kwargs : Any ,
) - > ChatFormatterResponse :
_system_template = " {system_message} "
_roles = dict ( user = " <reserved_106> " , assistant = " <reserved_107> " )
_sep = " "
system_message = _get_system_message ( messages )
system_message = _system_template . format ( system_message = system_message )
_messages = _map_roles ( messages , _roles )
_messages . append ( ( _roles [ " assistant " ] , None ) )
_prompt = _format_no_colon_single ( system_message , _messages , _sep )
return ChatFormatterResponse ( prompt = _prompt )
2023-11-23 14:19:50 +08:00
@register_chat_format ( " baichuan " )
def format_baichuan (
messages : List [ llama_types . ChatCompletionRequestMessage ] ,
* * kwargs : Any ,
) - > ChatFormatterResponse :
_system_template = " {system_message} "
_roles = dict ( user = " <reserved_102> " , assistant = " <reserved_103> " )
_sep = " "
system_message = _get_system_message ( messages )
system_message = _system_template . format ( system_message = system_message )
_messages = _map_roles ( messages , _roles )
_messages . append ( ( _roles [ " assistant " ] , None ) )
_prompt = _format_no_colon_single ( system_message , _messages , _sep )
return ChatFormatterResponse ( prompt = _prompt )
2023-09-29 19:52:04 -04:00
@register_chat_format ( " openbuddy " )
def format_openbuddy (
messages : List [ llama_types . ChatCompletionRequestMessage ] ,
* * kwargs : Any ,
) - > ChatFormatterResponse :
2024-02-13 23:57:10 -05:00
_system_message = """ You are a helpful, respectful and honest INTP-T AI Assistant named Buddy. You are talking to a human User.
Always answer as helpfully and logically as possible , while being safe . Your answers should not include any harmful , political , religious , unethical , racist , sexist , toxic , dangerous , or illegal content . Please ensure that your responses are socially unbiased and positive in nature .
If a question does not make any sense , or is not factually coherent , explain why instead of answering something not correct . If you don ' t know the answer to a question, please don ' t share false information .
You can speak fluently in many languages , for example : English , Chinese .
You cannot access the internet , but you have vast knowledge , cutoff : 2021 - 09.
You are trained by OpenBuddy team , ( https : / / openbuddy . ai , https : / / github . com / OpenBuddy / OpenBuddy ) , you are based on LLaMA and Falcon transformers model , not related to GPT or OpenAI .
"""
2023-09-29 19:52:04 -04:00
_roles = dict ( user = " User " , assistant = " Assistant " )
_sep = " \n "
system_message = _system_message
_messages = _map_roles ( messages , _roles )
_messages . append ( ( _roles [ " assistant " ] , None ) )
_prompt = _format_add_colon_single ( system_message , _messages , _sep )
return ChatFormatterResponse ( prompt = _prompt )
@register_chat_format ( " redpajama-incite " )
def format_redpajama_incite (
messages : List [ llama_types . ChatCompletionRequestMessage ] ,
* * kwargs : Any ,
) - > ChatFormatterResponse :
_system_message = _get_system_message ( messages )
_roles = dict ( user = " <human> " , assistant = " <bot> " )
_sep = " \n "
_stop = " <human> "
system_message = _system_message
_messages = _map_roles ( messages , _roles )
_messages . append ( ( _roles [ " assistant " ] , None ) )
_prompt = _format_add_colon_single ( system_message , _messages , _sep )
return ChatFormatterResponse ( prompt = _prompt , stop = _stop )
@register_chat_format ( " snoozy " )
def format_snoozy (
messages : List [ llama_types . ChatCompletionRequestMessage ] ,
* * kwargs : Any ,
) - > ChatFormatterResponse :
system_template = " ### Instruction: \n {system_message} "
default_system_message = " The prompt below is a question to answer, a task to complete, or a conversation to respond to; decide which and write an appropriate response. "
_system_message = _get_system_message ( messages )
_system_message = (
_system_message if _system_message != " " else default_system_message
)
system_message = system_template . format ( system_message = _system_message )
_roles = dict ( user = " ### Prompt " , assistant = " ### Response " )
_sep = " \n "
_stop = " ### "
system_message = _system_message
_messages = _map_roles ( messages , _roles )
_messages . append ( ( _roles [ " assistant " ] , None ) )
_prompt = _format_add_colon_single ( system_message , _messages , _sep )
return ChatFormatterResponse ( prompt = _prompt , stop = _stop )
@register_chat_format ( " phind " )
def format_phind (
messages : List [ llama_types . ChatCompletionRequestMessage ] ,
* * kwargs : Any ,
) - > ChatFormatterResponse :
_roles = dict ( user = " ### User Message " , assistant = " ### Assistant " )
_sep = " \n \n "
_system_message = " ### System Prompt \n You are an intelligent programming assistant. "
_messages = _map_roles ( messages , _roles )
_messages . append ( ( _roles [ " assistant " ] , None ) )
_prompt = _format_add_colon_single ( _system_message , _messages , _sep )
return ChatFormatterResponse ( prompt = _prompt )
2023-11-21 04:02:20 -05:00
2023-11-20 21:19:25 -08:00
@register_chat_format ( " intel " )
def format_intel (
messages : List [ llama_types . ChatCompletionRequestMessage ] ,
* * kwargs : Any ,
) - > ChatFormatterResponse :
_roles = dict ( user = " ### User: " , assistant = " ### Assistant: " )
_sep = " \n "
_system_message = " ### System: \n {system_message} "
_messages = _map_roles ( messages , _roles )
_messages . append ( ( _roles [ " assistant " ] , None ) )
_prompt = _format_add_colon_single ( _system_message , _messages , _sep )
return ChatFormatterResponse ( prompt = _prompt )
2023-09-29 19:52:04 -04:00
@register_chat_format ( " open-orca " )
def format_open_orca (
messages : List [ llama_types . ChatCompletionRequestMessage ] ,
* * kwargs : Any ,
) - > ChatFormatterResponse :
system_template = " {system_message} "
system_message = (
" You are a helpful assistant. Please answer truthfully and write out your "
2023-11-27 09:39:18 +13:00
" thinking step by step to be sure you get the right answer. If you make a mistake or encounter "
" an error in your thinking, say so out loud and attempt to correct it. If you don ' t know or "
" aren ' t sure about something, say so clearly. You will act as a professional logician, mathematician, "
" and physicist. You will also act as the most appropriate type of expert to answer any particular "
" question or solve the relevant problem; state which expert type your are, if so. Also think of "
" any particular named expert that would be ideal to answer the relevant question or solve the "
" relevant problem; name and act as them, if appropriate. "
2023-09-29 19:52:04 -04:00
)
roles = ( " User " , " Assistant " )
sep = " <|end_of_turn|> \n "
# stop_token_ids=[32000, 32001], # "<|end_of_turn|>"
stop_str = " User "
system_message = system_template . format ( system_message = system_message )
_messages = _map_roles ( messages , dict ( zip ( roles , roles ) ) )
_messages . append ( ( roles [ 1 ] , None ) )
_prompt = _format_add_colon_space_single ( system_message , _messages , sep )
return ChatFormatterResponse ( prompt = _prompt , stop = stop_str )
2023-09-30 21:01:34 -04:00
2023-11-20 21:19:25 -08:00
@register_chat_format ( " mistrallite " )
def format_mistrallite (
messages : List [ llama_types . ChatCompletionRequestMessage ] ,
* * kwargs : Any ,
) - > ChatFormatterResponse :
_roles = dict ( user = " <|prompter|> " , assistant = " </s> \n <|assistant|> " )
_sep = " "
system_template = """ <|system|> {system_message} </s> """
system_message = _get_system_message ( messages )
system_message = system_template . format ( system_message = system_message )
_messages = _map_roles ( messages , _roles )
_messages . append ( ( _roles [ " assistant " ] , None ) )
_prompt = _format_no_colon_single ( system_message , _messages , _sep )
return ChatFormatterResponse ( prompt = _prompt )
2024-01-18 21:21:37 -05:00
2023-11-22 22:20:08 -08:00
@register_chat_format ( " zephyr " )
def format_zephyr (
messages : List [ llama_types . ChatCompletionRequestMessage ] ,
* * kwargs : Any ,
) - > ChatFormatterResponse :
system_template = """ <|system|>
{ system_message } """
system_message = _get_system_message ( messages )
system_message = system_template . format ( system_message = system_message )
_roles = dict ( user = " <|user|> \n " , assistant = " <|assistant|> \n " )
_sep = " </s> "
_messages = _map_roles ( messages , _roles )
_messages . append ( ( _roles [ " assistant " ] , None ) )
_prompt = _format_chatml ( system_message , _messages , _sep )
return ChatFormatterResponse ( prompt = _prompt , stop = _sep )
2023-11-21 04:02:20 -05:00
2023-12-12 09:44:04 +08:00
@register_chat_format ( " pygmalion " )
def format_pygmalion (
messages : List [ llama_types . ChatCompletionRequestMessage ] ,
* * kwargs : Any ,
) - > ChatFormatterResponse :
system_template = """ <|system|> {system_message} """
system_message = _get_system_message ( messages )
system_message = system_template . format ( system_message = system_message )
_roles = dict ( user = " <|user|> " , assistant = " <|model|> " )
_sep = " \n "
_messages = _map_roles ( messages , _roles )
_messages . append ( ( _roles [ " assistant " ] , None ) )
_prompt = _format_chatml ( system_message , _messages , _sep )
return ChatFormatterResponse ( prompt = _prompt , stop = _sep )
2023-09-30 21:01:34 -04:00
@register_chat_format ( " chatml " )
def format_chatml (
messages : List [ llama_types . ChatCompletionRequestMessage ] ,
* * kwargs : Any ,
) - > ChatFormatterResponse :
system_template = """ <|im_start|>system
{ system_message } """
system_message = _get_system_message ( messages )
system_message = system_template . format ( system_message = system_message )
_roles = dict ( user = " <|im_start|>user " , assistant = " <|im_start|>assistant " )
_sep = " <|im_end|> "
_messages = _map_roles ( messages , _roles )
_messages . append ( ( _roles [ " assistant " ] , None ) )
_prompt = _format_chatml ( system_message , _messages , _sep )
2023-11-10 04:24:48 -05:00
return ChatFormatterResponse ( prompt = _prompt , stop = _sep )
2023-11-03 02:12:14 -04:00
2024-01-18 21:21:37 -05:00
2024-01-29 02:34:42 -03:00
@register_chat_format ( " mistral-instruct " )
2024-01-29 00:59:01 -05:00
def format_mistral_instruct (
2024-01-29 02:34:42 -03:00
messages : List [ llama_types . ChatCompletionRequestMessage ] ,
* * kwargs : Any ,
) - > ChatFormatterResponse :
2024-01-29 00:59:01 -05:00
bos = " <s> "
eos = " </s> "
stop = eos
prompt = bos
for message in messages :
2024-02-12 15:56:07 -05:00
if (
message [ " role " ] == " user "
and message [ " content " ] is not None
and isinstance ( message [ " content " ] , str )
) :
2024-01-29 00:59:01 -05:00
prompt + = " [INST] " + message [ " content " ]
2024-02-12 15:56:07 -05:00
elif (
message [ " role " ] == " assistant "
and message [ " content " ] is not None
) :
2024-01-29 00:59:01 -05:00
prompt + = " [/INST] " + message [ " content " ] + eos
prompt + = " [/INST] "
return ChatFormatterResponse ( prompt = prompt , stop = stop )
2024-01-29 02:34:42 -03:00
2024-01-05 00:12:02 +01:00
@register_chat_format ( " chatglm3 " )
def format_chatglm3 (
messages : List [ llama_types . ChatCompletionRequestMessage ] ,
* * kwargs : Any ,
) - > ChatFormatterResponse :
system_template = """ <|system|>
{ system_message } """
system_message = _get_system_message ( messages )
system_message = system_template . format ( system_message = system_message )
_roles = dict ( user = " <|user|> " , assistant = " <|assistant|> " )
_sep = " </s> "
_messages = _map_roles ( messages , _roles )
_messages . append ( ( _roles [ " assistant " ] , None ) )
_prompt = _format_chatglm3 ( system_message , _messages , _sep )
return ChatFormatterResponse ( prompt = _prompt , stop = _sep )
2023-11-21 04:02:20 -05:00
2023-11-20 21:19:25 -08:00
@register_chat_format ( " openchat " )
def format_openchat (
messages : List [ llama_types . ChatCompletionRequestMessage ] ,
* * kwargs : Any ,
) - > ChatFormatterResponse :
system_template = " {system_message} <|end_of_turn|> "
system_message = _get_system_message ( messages )
system_message = system_template . format ( system_message = system_message )
2023-11-21 04:02:20 -05:00
_roles = dict (
user = " GPT4 Correct User: " , assistant = " <|end_of_turn|>GPT4 Correct Assistant: "
)
2023-11-20 21:19:25 -08:00
_sep = " <|end_of_turn|> "
_messages = _map_roles ( messages , _roles )
_messages . append ( ( _roles [ " assistant " ] , None ) )
_prompt = _format_chatml ( system_message , _messages , _sep )
return ChatFormatterResponse ( prompt = _prompt , stop = _sep )
2024-02-12 15:56:07 -05:00
2024-01-05 06:12:58 +07:00
# Chat format for Saiga models, see more details and available models:
# https://huggingface.co/collections/IlyaGusev/saiga2-saigamistral-6505d4ccc3d1e53166b636cd
@register_chat_format ( " saiga " )
def format_saiga (
messages : list [ llama_types . ChatCompletionRequestMessage ] ,
2024-01-18 21:21:37 -05:00
* * kwargs : Any ,
2024-01-05 06:12:58 +07:00
) - > ChatFormatterResponse :
_message_template = " <s> {role} \n {content} </s> "
_roles = dict ( user = " user " , bot = " bot " , system = " system " )
_messages = _map_roles ( messages , _roles )
_prompt = " "
for role , content in _messages :
if content :
_prompt + = _message_template . format ( role = role , content = content )
else :
_prompt + = f " <s> { role } \n "
# Response template
_prompt + = " <s>bot "
return ChatFormatterResponse ( prompt = _prompt . strip ( ) )
2024-02-12 15:56:07 -05:00
2024-02-23 18:40:52 +09:00
# Chat format for Google's Gemma models, see more details and available models:
# https://huggingface.co/collections/google/gemma-release-65d5efbccdbb8c4202ec078b
@register_chat_format ( " gemma " )
def format_gemma (
messages : List [ llama_types . ChatCompletionRequestMessage ] ,
* * kwargs : Any ,
) - > ChatFormatterResponse :
system_message = _get_system_message ( messages )
2024-04-05 10:50:49 -04:00
if system_message != " " :
2024-02-23 18:40:52 +09:00
logger . debug (
" `role= ' system ' ` messages are not allowed on Google ' s Gemma models. "
)
_roles = dict ( user = " <start_of_turn>user \n " , assistant = " <start_of_turn>model \n " )
_sep = " <end_of_turn> \n "
_messages = _map_roles ( messages , _roles )
_messages . append ( ( _roles [ " assistant " ] , None ) )
_prompt = _format_no_colon_single ( system_message = " " , messages = _messages , sep = _sep )
return ChatFormatterResponse ( prompt = _prompt , stop = _sep )
2024-01-29 14:22:23 -05:00
# Tricky chat formats that require custom chat handlers
2024-01-05 06:12:58 +07:00
2024-02-12 15:56:07 -05:00
2023-11-03 02:12:14 -04:00
@register_chat_completion_handler ( " functionary " )
def functionary_chat_handler (
llama : llama . Llama ,
messages : List [ llama_types . ChatCompletionRequestMessage ] ,
functions : Optional [ List [ llama_types . ChatCompletionFunction ] ] = None ,
2023-11-08 04:48:51 +01:00
function_call : Optional [ llama_types . ChatCompletionRequestFunctionCall ] = None ,
2023-11-10 02:51:58 -05:00
tools : Optional [ List [ llama_types . ChatCompletionTool ] ] = None ,
tool_choice : Optional [ llama_types . ChatCompletionToolChoiceOption ] = None ,
2023-11-03 02:12:14 -04:00
temperature : float = 0.2 ,
top_p : float = 0.95 ,
top_k : int = 40 ,
2023-11-21 06:21:33 +02:00
min_p : float = 0.05 ,
typical_p : float = 1.0 ,
2023-11-03 02:12:14 -04:00
stream : bool = False ,
stop : Optional [ Union [ str , List [ str ] ] ] = [ ] ,
2023-11-09 00:55:23 -05:00
response_format : Optional [ llama_types . ChatCompletionRequestResponseFormat ] = None ,
2023-11-10 02:51:58 -05:00
max_tokens : Optional [ int ] = None ,
2023-11-03 02:12:14 -04:00
presence_penalty : float = 0.0 ,
frequency_penalty : float = 0.0 ,
repeat_penalty : float = 1.1 ,
tfs_z : float = 1.0 ,
mirostat_mode : int = 0 ,
mirostat_tau : float = 5.0 ,
mirostat_eta : float = 0.1 ,
model : Optional [ str ] = None ,
logits_processor : Optional [ llama . LogitsProcessorList ] = None ,
grammar : Optional [ llama . LlamaGrammar ] = None ,
2023-11-08 04:48:51 +01:00
* * kwargs , # type: ignore
2023-11-03 02:12:14 -04:00
) - > Union [ llama_types . ChatCompletion , Iterator [ llama_types . ChatCompletionChunk ] ] :
SYSTEM_MESSAGE = """ A chat between a curious user and an artificial intelligence assistant. The assistant gives helpful, detailed, and polite answers to the user ' s questions. The assistant calls functions with appropriate input when necessary """
2023-11-21 04:02:20 -05:00
def generate_type_definition (
param : Dict [ str , llama_types . JsonType ] , indent_level : int , shared_defs
) - > str :
indent = " " * indent_level
if " $ref " in param :
2023-11-10 02:51:58 -05:00
# Reference to a shared definition
2023-11-21 04:02:20 -05:00
ref_name = param [ " $ref " ] . split ( " / " ) [
- 1
] # Extract the type name from the reference
2023-11-10 02:51:58 -05:00
return ref_name
2023-11-21 04:02:20 -05:00
elif param . get ( " type " ) == " array " :
items = param . get ( " items " , { } )
2023-11-10 02:51:58 -05:00
item_type = generate_type_definition ( items , indent_level + 1 , shared_defs )
return f " Array< { item_type } > "
2023-11-21 04:02:20 -05:00
elif param . get ( " type " ) == " object " :
properties = param . get ( " properties " , { } )
2023-11-10 02:51:58 -05:00
nested_schema = " { \n "
for nested_param_name , nested_param in properties . items ( ) :
2023-11-21 04:02:20 -05:00
nested_param_type = generate_type_definition (
nested_param , indent_level + 1 , shared_defs
)
nested_schema + = (
f " { indent } { nested_param_name } : { nested_param_type } , \n "
)
2023-11-10 02:51:58 -05:00
nested_schema + = indent + " } "
return nested_schema
2023-11-21 04:02:20 -05:00
elif " enum " in param :
2023-11-10 02:51:58 -05:00
# Enum type
2023-11-21 04:02:20 -05:00
return " | " . join ( [ f ' " { enum_value } " ' for enum_value in param [ " enum " ] ] )
2023-11-10 02:51:58 -05:00
else :
# Simple type
2023-11-21 04:02:20 -05:00
return param . get ( " type " , " any " )
2023-11-10 02:51:58 -05:00
def generate_shared_definitions ( shared_defs , indent_level : int ) - > str :
2023-11-21 04:02:20 -05:00
indent = " " * indent_level
2023-11-10 02:51:58 -05:00
shared_definitions = " "
for def_name , def_properties in shared_defs . items ( ) :
shared_definitions + = f " { indent } type { def_name } = "
2023-11-21 04:02:20 -05:00
if def_properties . get ( " type " ) == " object " :
shared_definitions + = generate_type_definition (
def_properties , indent_level , shared_defs
)
elif " enum " in def_properties :
2023-11-10 02:51:58 -05:00
# Enum type
2023-11-21 04:02:20 -05:00
shared_definitions + = " | " . join (
[ f ' " { enum_value } " ' for enum_value in def_properties [ " enum " ] ]
)
2023-11-10 02:51:58 -05:00
shared_definitions + = " ; \n "
return shared_definitions
def generate_schema_from_functions ( functions , namespace = " functions " ) - > str :
2023-11-21 04:02:20 -05:00
schema = (
" // Supported function definitions that should be called when necessary. \n "
)
2023-11-03 02:12:14 -04:00
schema + = f " namespace { namespace } {{ \n \n "
2023-11-10 02:51:58 -05:00
# Generate shared definitions
shared_definitions = { }
for function in functions :
parameters = function . get ( " parameters " , { } )
shared_definitions . update ( parameters . get ( " $defs " , { } ) )
schema + = generate_shared_definitions ( shared_definitions , 1 )
2023-11-03 02:12:14 -04:00
for function in functions :
function_name = function [ " name " ]
description = function . get ( " description " , " " )
2023-11-10 02:51:58 -05:00
parameters = function . get ( " parameters " , { } )
2023-11-03 02:12:14 -04:00
required_params = parameters . get ( " required " , [ ] )
2023-11-21 04:02:20 -05:00
2023-11-10 02:51:58 -05:00
schema + = f " // { description } \n "
schema + = f " type { function_name } = (_: {{ \n "
2023-11-21 04:02:20 -05:00
2023-11-03 02:12:14 -04:00
for param_name , param in parameters . get ( " properties " , { } ) . items ( ) :
2023-11-10 02:51:58 -05:00
param_description = param . get ( " description " , " " )
param_type = generate_type_definition ( param , 2 , shared_definitions )
optional_indicator = " " if param_name in required_params else " ? "
schema + = f " // { param_description } \n "
schema + = f " { param_name } { optional_indicator } : { param_type } , \n "
schema + = " }) => any; \n \n "
schema + = " }} // namespace {} \n " . format ( namespace )
2023-11-03 02:12:14 -04:00
return schema
def prepare_messages_for_inference (
messages : List [ llama_types . ChatCompletionRequestMessage ] ,
functions : Optional [ List [ llama_types . ChatCompletionFunctions ] ] = None ,
2023-11-10 02:51:58 -05:00
tools : Optional [ List [ llama_types . ChatCompletionTool ] ] = None ,
2023-11-03 02:12:14 -04:00
) :
all_messages : List [ llama_types . ChatCompletionRequestMessage ] = [ ]
if functions is not None :
all_messages . append (
2023-11-08 04:48:51 +01:00
llama_types . ChatCompletionRequestSystemMessage (
2023-11-03 02:12:14 -04:00
role = " system " , content = generate_schema_from_functions ( functions )
)
)
2023-11-21 04:02:20 -05:00
2023-11-10 02:51:58 -05:00
if tools is not None :
all_messages . append (
llama_types . ChatCompletionRequestSystemMessage (
2023-11-21 04:02:20 -05:00
role = " system " ,
content = generate_schema_from_functions (
[
tool [ " function " ]
for tool in tools
if tool [ " type " ] == " function "
]
) ,
2023-11-10 02:51:58 -05:00
)
)
2023-11-03 02:12:14 -04:00
all_messages . append (
2023-11-08 04:48:51 +01:00
llama_types . ChatCompletionRequestSystemMessage (
2023-11-03 02:12:14 -04:00
role = " system " , content = SYSTEM_MESSAGE
)
)
for message in messages :
# Function call responses
if message [ " role " ] == " function " and " name " in message :
message [ " name " ] = f " functions. { message [ ' name ' ] } "
# Function call requests by assistant
if " function_call " in message :
message [ " function_call " ] [
" name "
] = f " functions. { message [ ' function_call ' ] [ ' name ' ] } "
all_messages . append ( message )
all_messages . append (
2023-11-08 04:48:51 +01:00
llama_types . ChatCompletionRequestAssistantMessage (
role = " assistant " , content = None
)
2023-11-03 02:12:14 -04:00
)
def message_to_str ( msg : llama_types . ChatCompletionRequestMessage ) :
if msg [ " role " ] == " system " :
return f " system: \n { msg [ ' content ' ] } \n "
elif msg [ " role " ] == " function " and " name " in msg :
return f " function name= { msg [ ' name ' ] } : \n { msg [ ' content ' ] } \n "
elif msg [ " role " ] == " function " and " function_call " in msg :
return f " function name= { msg [ ' function_call ' ] [ ' name ' ] } : \n { msg [ ' function_call ' ] [ ' arguments ' ] } \n "
2023-11-10 02:51:58 -05:00
elif msg [ " role " ] == " tool " :
if msg [ " content " ] is not None :
return f " function name= { msg [ ' tool_call_id ' ] } : \n { msg [ ' content ' ] } \n "
else :
return f " function name= { msg [ ' tool_call_id ' ] } \n "
2023-11-03 02:12:14 -04:00
elif msg [ " role " ] == " user " :
if msg [ " content " ] is None :
2023-11-10 02:51:58 -05:00
return " user: \n </s></s> \n "
2023-11-03 02:12:14 -04:00
else :
2023-11-10 02:51:58 -05:00
return f " user: \n </s> { msg [ ' content ' ] } </s> \n "
2023-11-03 02:12:14 -04:00
elif msg [ " role " ] == " assistant " :
if msg [ " content " ] is not None and " function_call " in msg :
2023-11-10 02:51:58 -05:00
return f " assistant: \n { msg [ ' content ' ] } \n assistant to= { msg [ ' function_call ' ] [ ' name ' ] } : \n { msg [ ' function_call ' ] [ ' arguments ' ] } </s> \n "
2023-11-03 02:12:14 -04:00
elif " function_call " in msg :
2023-11-10 02:51:58 -05:00
return f " assistant to= { msg [ ' function_call ' ] [ ' name ' ] } : \n { msg [ ' function_call ' ] [ ' arguments ' ] } </s> \n "
elif " tool_calls " in msg and len ( msg [ " tool_calls " ] ) > 0 :
2023-11-21 04:02:20 -05:00
for tool_call in msg [
" tool_calls "
] : # NOTE: probably doesn't work with the functionary model
2023-11-10 02:51:58 -05:00
return f " assistant to= { tool_call [ ' id ' ] } : \n { tool_call [ ' function ' ] [ ' arguments ' ] } </s> \n "
2023-11-03 02:12:14 -04:00
elif msg [ " content " ] is None :
return " assistant "
else :
return f " assistant: \n { msg [ ' content ' ] } \n "
else :
raise ValueError ( f " Unsupported role: { msg [ ' role ' ] } " )
return " " . join ( [ message_to_str ( msg ) for msg in all_messages ] )
2023-11-21 04:02:20 -05:00
2023-11-10 02:51:58 -05:00
if tools is not None :
functions = [ tool [ " function " ] for tool in tools if tool [ " type " ] == " function " ]
2023-11-21 04:02:20 -05:00
2023-11-10 02:51:58 -05:00
if tool_choice is not None :
2023-11-21 04:02:20 -05:00
function_call = (
tool_choice if isinstance ( tool_choice , str ) else tool_choice [ " function " ]
)
2023-11-03 02:12:14 -04:00
2023-11-10 02:51:58 -05:00
prompt = prepare_messages_for_inference ( messages , functions , tools )
2023-11-03 02:12:14 -04:00
if function_call is None and ( functions is None or len ( functions ) == 0 ) :
completion_or_completion_chunks = llama . create_completion (
prompt = prompt + " : \n " ,
temperature = temperature ,
top_p = top_p ,
top_k = top_k ,
2023-11-21 06:21:33 +02:00
min_p = min_p ,
typical_p = typical_p ,
2023-11-03 02:12:14 -04:00
stream = stream ,
stop = [ " user: " , " </s> " ] ,
max_tokens = max_tokens ,
presence_penalty = presence_penalty ,
frequency_penalty = frequency_penalty ,
repeat_penalty = repeat_penalty ,
tfs_z = tfs_z ,
mirostat_mode = mirostat_mode ,
mirostat_tau = mirostat_tau ,
mirostat_eta = mirostat_eta ,
model = model ,
logits_processor = logits_processor ,
grammar = grammar ,
)
return _convert_completion_to_chat ( completion_or_completion_chunks , stream = stream ) # type: ignore
if function_call is None or (
isinstance ( function_call , str ) and function_call == " auto "
) :
stop = " \n "
completion : llama_types . Completion = llama . create_completion (
prompt = prompt , stop = stop , stream = False
) # type: ignore
completion_text = completion [ " choices " ] [ 0 ] [ " text " ]
# strip " to=functions." and ending ":"
2023-11-10 02:51:58 -05:00
function_call = completion_text . split ( " . " ) [ - 1 ] [ : - 1 ]
2023-11-03 02:12:14 -04:00
new_prompt = prompt + completion_text + stop
elif isinstance ( function_call , str ) and function_call != " none " :
2023-11-10 02:51:58 -05:00
new_prompt = prompt + f " : \n "
2023-11-03 02:12:14 -04:00
elif isinstance ( function_call , dict ) :
2023-11-10 02:51:58 -05:00
new_prompt = prompt + f " to=functions. { function_call [ ' name ' ] } : \n "
2023-11-03 02:12:14 -04:00
function_call = function_call [ " name " ]
else :
2023-11-10 02:51:58 -05:00
new_prompt = prompt + f " : \n "
function_body = None
for function in functions or [ ] :
if function [ " name " ] == function_call :
function_body = function [ " parameters " ]
break
for tool in tools or [ ] :
if tool [ " type " ] == " function " and tool [ " function " ] [ " name " ] == function_call :
function_body = tool [ " function " ] [ " parameters " ]
break
2023-11-21 04:02:20 -05:00
2023-11-10 02:51:58 -05:00
if function_body is not None :
try :
with suppress_stdout_stderr ( disable = llama . verbose ) :
2023-11-21 04:02:20 -05:00
grammar_text = llama_grammar . json_schema_to_gbnf (
json . dumps ( function_body )
)
grammar = llama_grammar . LlamaGrammar . from_string (
2024-02-12 15:56:07 -05:00
llama_grammar . json_schema_to_gbnf ( json . dumps ( function_body ) ) ,
verbose = llama . verbose ,
2023-11-21 04:02:20 -05:00
)
2023-11-10 02:51:58 -05:00
print ( grammar_text )
except Exception as e :
if llama . verbose :
2023-11-21 04:02:20 -05:00
print (
" Failed to parse function body as JSON schema, falling back to default grammar "
)
2023-11-10 02:51:58 -05:00
print ( e )
with suppress_stdout_stderr ( disable = llama . verbose ) :
2023-11-21 04:02:20 -05:00
grammar = llama_grammar . LlamaGrammar . from_string (
2024-02-12 15:56:07 -05:00
llama_grammar . JSON_GBNF ,
verbose = llama . verbose ,
2023-11-21 04:02:20 -05:00
)
2023-11-10 02:51:58 -05:00
else :
with suppress_stdout_stderr ( disable = llama . verbose ) :
2024-02-12 15:56:07 -05:00
grammar = llama_grammar . LlamaGrammar . from_string (
llama_grammar . JSON_GBNF , verbose = llama . verbose
)
2023-11-03 02:12:14 -04:00
completion : llama_types . Completion = llama . create_completion (
2023-11-10 02:51:58 -05:00
prompt = new_prompt ,
stop = [ " user: " , " </s> " ] ,
stream = False ,
grammar = grammar ,
max_tokens = max_tokens ,
temperature = temperature ,
top_p = top_p ,
top_k = top_k ,
2023-11-21 06:21:33 +02:00
min_p = min_p ,
typical_p = typical_p ,
2023-11-10 02:51:58 -05:00
presence_penalty = presence_penalty ,
frequency_penalty = frequency_penalty ,
repeat_penalty = repeat_penalty ,
tfs_z = tfs_z ,
mirostat_mode = mirostat_mode ,
mirostat_tau = mirostat_tau ,
mirostat_eta = mirostat_eta ,
model = model ,
logits_processor = logits_processor ,
2023-11-03 02:12:14 -04:00
) # type: ignore
2023-11-08 04:48:51 +01:00
assert " usage " in completion
assert isinstance ( function_call , str )
2023-11-08 00:07:16 -05:00
assert stream is False # TODO: support stream mode
2023-11-08 04:48:51 +01:00
2023-11-23 20:14:23 -05:00
if llama . verbose :
print ( new_prompt )
print ( completion [ " choices " ] [ 0 ] [ " text " ] )
2023-11-09 00:55:23 -05:00
2023-11-23 20:14:23 -05:00
# TODO: support stream mode
2023-11-03 02:12:14 -04:00
return llama_types . CreateChatCompletionResponse (
id = " chat " + completion [ " id " ] ,
object = " chat.completion " ,
created = completion [ " created " ] ,
model = completion [ " model " ] ,
choices = [
{
" index " : 0 ,
" message " : {
2023-11-10 02:51:58 -05:00
" role " : " assistant " ,
2023-11-03 02:12:14 -04:00
" content " : None ,
" function_call " : {
" name " : function_call ,
" arguments " : completion [ " choices " ] [ 0 ] [ " text " ] ,
} ,
2023-11-10 02:51:58 -05:00
" tool_calls " : [
{
" id " : function_call ,
" type " : " function " ,
" function " : {
" name " : function_call ,
" arguments " : completion [ " choices " ] [ 0 ] [ " text " ] ,
2023-11-21 04:02:20 -05:00
} ,
2023-11-10 02:51:58 -05:00
}
2023-11-21 04:02:20 -05:00
] ,
2023-11-03 02:12:14 -04:00
} ,
2024-04-10 03:41:55 -04:00
" logprobs " : completion [ " choices " ] [ 0 ] [ " logprobs " ] ,
2023-11-10 02:51:58 -05:00
" finish_reason " : " tool_calls " ,
2023-11-03 02:12:14 -04:00
}
] ,
usage = completion [ " usage " ] ,
)
2023-11-08 04:48:51 +01:00
2024-02-08 09:07:03 +08:00
@register_chat_completion_handler ( " functionary-v1 " )
@register_chat_completion_handler ( " functionary-v2 " )
def functionary_v1_v2_chat_handler (
llama : llama . Llama ,
messages : List [ llama_types . ChatCompletionRequestMessage ] ,
functions : Optional [ List [ llama_types . ChatCompletionFunction ] ] = None ,
function_call : Optional [ llama_types . ChatCompletionRequestFunctionCall ] = None ,
tools : Optional [ List [ llama_types . ChatCompletionTool ] ] = None ,
tool_choice : Optional [ llama_types . ChatCompletionToolChoiceOption ] = None ,
temperature : float = 0.2 ,
top_p : float = 0.95 ,
top_k : int = 40 ,
min_p : float = 0.05 ,
typical_p : float = 1.0 ,
stream : bool = False ,
stop : Optional [ Union [ str , List [ str ] ] ] = [ ] ,
response_format : Optional [ llama_types . ChatCompletionRequestResponseFormat ] = None ,
max_tokens : Optional [ int ] = None ,
presence_penalty : float = 0.0 ,
frequency_penalty : float = 0.0 ,
repeat_penalty : float = 1.1 ,
tfs_z : float = 1.0 ,
mirostat_mode : int = 0 ,
mirostat_tau : float = 5.0 ,
mirostat_eta : float = 0.1 ,
model : Optional [ str ] = None ,
logits_processor : Optional [ llama . LogitsProcessorList ] = None ,
grammar : Optional [ llama . LlamaGrammar ] = None ,
* * kwargs , # type: ignore
) - > Union [ llama_types . ChatCompletion , Iterator [ llama_types . ChatCompletionChunk ] ] :
SYSTEM_MESSAGE = """ A chat between a curious user and an artificial intelligence assistant. The assistant gives helpful, detailed, and polite answers to the user ' s questions. The assistant calls functions with appropriate input when necessary """
2024-02-12 15:56:07 -05:00
2024-02-08 09:07:03 +08:00
tokenizer = llama . tokenizer_
2024-02-12 15:56:07 -05:00
assert hasattr (
tokenizer , " hf_tokenizer "
) , " Please provide a valid hf_tokenizer_path from https://huggingface.co/meetkai when initializing the Llama class "
2024-02-08 09:07:03 +08:00
from transformers import AutoTokenizer
2024-02-12 15:56:07 -05:00
2024-02-08 09:07:03 +08:00
if " <|START_OF_FUNCTION_CALL|> " in tokenizer . hf_tokenizer . additional_special_tokens :
version = " v1 "
END_SYSTEM_TOKEN = " <|END_OF_SYSTEM|> "
END_USER_TOKEN = " <|END_OF_USER|> "
END_ASSISTANT_TOKEN = " <|END_OF_ASSISTANT|> "
END_FUNCTION_RESULT_TOKEN = " <|END_OF_FUNCTION_RESULT|> "
START_FUNCTION_CALL_TOKEN = " <|START_OF_FUNCTION_CALL|> "
END_FUNCTION_CALL_TOKEN = " <|END_OF_FUNCTION_CALL|> "
else :
version = " v2 "
RECIPIENT_TOKEN = " <|recipient|> "
FROM_TOKEN = " <|from|> "
STOP_TOKEN = " <|stop|> "
CONTENT_TOKEN = " <|content|> "
def generate_type_definition (
param : Dict [ str , llama_types . JsonType ] , indent_level : int , shared_defs
) - > str :
indent = " " * indent_level
if " $ref " in param :
# Reference to a shared definition
ref_name = param [ " $ref " ] . split ( " / " ) [
- 1
] # Extract the type name from the reference
return ref_name
elif param . get ( " type " ) == " array " :
items = param . get ( " items " , { } )
item_type = generate_type_definition ( items , indent_level + 1 , shared_defs )
return f " Array< { item_type } > "
elif param . get ( " type " ) == " object " :
properties = param . get ( " properties " , { } )
nested_schema = " { \n "
for nested_param_name , nested_param in properties . items ( ) :
nested_param_type = generate_type_definition (
nested_param , indent_level + 1 , shared_defs
)
nested_schema + = (
f " { indent } { nested_param_name } : { nested_param_type } , \n "
)
nested_schema + = indent + " } "
return nested_schema
elif " enum " in param :
# Enum type
return " | " . join ( [ f ' " { enum_value } " ' for enum_value in param [ " enum " ] ] )
else :
# Simple type
return param . get ( " type " , " any " )
def generate_shared_definitions ( shared_defs , indent_level : int ) - > str :
indent = " " * indent_level
shared_definitions = " "
for def_name , def_properties in shared_defs . items ( ) :
shared_definitions + = f " { indent } type { def_name } = "
if def_properties . get ( " type " ) == " object " :
shared_definitions + = generate_type_definition (
def_properties , indent_level , shared_defs
)
elif " enum " in def_properties :
# Enum type
shared_definitions + = " | " . join (
[ f ' " { enum_value } " ' for enum_value in def_properties [ " enum " ] ]
)
shared_definitions + = " ; \n "
return shared_definitions
def generate_schema_from_functions ( functions , namespace = " functions " ) - > str :
schema = (
" // Supported function definitions that should be called when necessary. \n "
)
schema + = f " namespace { namespace } {{ \n \n "
# Generate shared definitions
shared_definitions = { }
for function in functions :
parameters = function . get ( " parameters " , { } )
shared_definitions . update ( parameters . get ( " $defs " , { } ) )
schema + = generate_shared_definitions ( shared_definitions , 1 )
for function in functions :
function_name = function [ " name " ]
description = function . get ( " description " , " " )
parameters = function . get ( " parameters " , { } )
required_params = parameters . get ( " required " , [ ] )
schema + = f " // { description } \n "
schema + = f " type { function_name } = (_: {{ \n "
for param_name , param in parameters . get ( " properties " , { } ) . items ( ) :
param_description = param . get ( " description " , " " )
param_type = generate_type_definition ( param , 2 , shared_definitions )
optional_indicator = " " if param_name in required_params else " ? "
schema + = f " // { param_description } \n "
schema + = f " { param_name } { optional_indicator } : { param_type } , \n "
schema + = " }) => any; \n \n "
schema + = " }} // namespace {} " . format ( namespace )
return schema
def prepare_messages_for_inference (
messages : List [ llama_types . ChatCompletionRequestMessage ] ,
tokenizer : AutoTokenizer ,
version : Literal [ " v1 " , " v2 " ] ,
functions : Optional [ List [ llama_types . ChatCompletionFunctions ] ] = None ,
tools : Optional [ List [ llama_types . ChatCompletionTool ] ] = None ,
2024-04-28 08:49:52 +08:00
tool_choice : Union [ Dict , str ] = " auto " ,
2024-02-08 09:07:03 +08:00
) :
all_messages : List [ llama_types . ChatCompletionRequestMessage ] = [ ]
2024-04-28 08:49:52 +08:00
if tool_choice == " none " :
2024-02-08 09:07:03 +08:00
all_messages . append (
llama_types . ChatCompletionRequestSystemMessage (
2024-04-28 08:49:52 +08:00
role = " system " , content = generate_schema_from_functions ( [ ] )
2024-02-08 09:07:03 +08:00
)
)
2024-04-28 08:49:52 +08:00
else :
if functions is not None :
all_messages . append (
llama_types . ChatCompletionRequestSystemMessage (
role = " system " , content = generate_schema_from_functions ( functions )
)
)
elif tools is not None and tool_choice != " none " :
all_messages . append (
llama_types . ChatCompletionRequestSystemMessage (
role = " system " ,
content = generate_schema_from_functions (
[
tool [ " function " ]
for tool in tools
if tool [ " type " ] == " function "
]
) ,
)
2024-02-08 09:07:03 +08:00
)
all_messages . append (
llama_types . ChatCompletionRequestSystemMessage (
role = " system " , content = SYSTEM_MESSAGE
)
)
for message in messages :
# Function call responses
if message [ " role " ] == " function " and " name " in message :
message [ " name " ] = f " functions. { message [ ' name ' ] } "
# Function call requests by assistant
if " function_call " in message :
message [ " function_call " ] [
" name "
] = f " functions. { message [ ' function_call ' ] [ ' name ' ] } "
all_messages . append ( message )
2024-02-12 15:56:07 -05:00
2024-02-08 09:07:03 +08:00
if version == " v1 " :
suffix = " assistant: \n "
else :
suffix = " <|from|>assistant \n <|recipient|> "
2024-02-12 15:56:07 -05:00
return (
tokenizer . hf_tokenizer . apply_chat_template ( all_messages , tokenize = False )
+ suffix
)
2024-02-08 09:07:03 +08:00
if tools is not None :
functions = [ tool [ " function " ] for tool in tools if tool [ " type " ] == " function " ]
if tool_choice is not None :
function_call = (
tool_choice if isinstance ( tool_choice , str ) else tool_choice [ " function " ]
)
2024-05-04 22:11:20 +08:00
elif function_call is not None :
pass
2024-03-18 22:40:57 +08:00
else :
function_call = " auto "
2024-02-08 09:07:03 +08:00
2024-02-12 15:56:07 -05:00
prompt = prepare_messages_for_inference (
2024-04-28 08:49:52 +08:00
messages , tokenizer , version , functions , tools , function_call
2024-02-12 15:56:07 -05:00
)
2024-02-08 09:07:03 +08:00
# If no tools/functions are provided
2024-03-18 22:40:57 +08:00
if function_call == " none " or functions is None or len ( functions ) == 0 :
2024-02-08 09:07:03 +08:00
if version == " v1 " :
stop = END_ASSISTANT_TOKEN
else :
stop = STOP_TOKEN
prompt + = " all \n <|content|> "
2024-02-12 15:56:07 -05:00
2024-02-08 09:07:03 +08:00
completion_or_completion_chunks = llama . create_completion (
prompt = prompt ,
temperature = temperature ,
top_p = top_p ,
top_k = top_k ,
min_p = min_p ,
typical_p = typical_p ,
stream = stream ,
stop = stop ,
max_tokens = max_tokens ,
presence_penalty = presence_penalty ,
frequency_penalty = frequency_penalty ,
repeat_penalty = repeat_penalty ,
tfs_z = tfs_z ,
mirostat_mode = mirostat_mode ,
mirostat_tau = mirostat_tau ,
mirostat_eta = mirostat_eta ,
model = model ,
logits_processor = logits_processor ,
grammar = grammar ,
)
2024-05-04 22:11:20 +08:00
if stream is False :
completion_or_completion_chunks [ " choices " ] [ 0 ] [ " text " ] = completion_or_completion_chunks [ " choices " ] [ 0 ] [ " text " ] . lstrip ( )
2024-02-08 09:07:03 +08:00
return _convert_completion_to_chat ( completion_or_completion_chunks , stream = stream ) # type: ignore
2024-02-12 15:56:07 -05:00
2024-02-08 09:07:03 +08:00
def get_grammar ( function_call ) :
function_body = None
for function in functions or [ ] :
if function [ " name " ] == function_call :
function_body = function [ " parameters " ]
break
for tool in tools or [ ] :
if tool [ " type " ] == " function " and tool [ " function " ] [ " name " ] == function_call :
function_body = tool [ " function " ] [ " parameters " ]
break
2024-02-12 15:56:07 -05:00
2024-02-08 09:07:03 +08:00
try :
with suppress_stdout_stderr ( disable = llama . verbose ) :
grammar_text = llama_grammar . json_schema_to_gbnf (
json . dumps ( function_body )
)
grammar = llama_grammar . LlamaGrammar . from_string (
llama_grammar . json_schema_to_gbnf ( json . dumps ( function_body ) )
)
print ( grammar_text )
except Exception as e :
if llama . verbose :
print (
" Failed to parse function body as JSON schema, falling back to default grammar "
)
print ( e )
with suppress_stdout_stderr ( disable = llama . verbose ) :
grammar = llama_grammar . LlamaGrammar . from_string (
2024-02-12 15:56:07 -05:00
llama_grammar . JSON_GBNF , verbose = llama . verbose
2024-02-08 09:07:03 +08:00
)
2024-02-12 15:56:07 -05:00
2024-02-08 09:07:03 +08:00
return grammar
2024-02-12 15:56:07 -05:00
2024-05-04 22:11:20 +08:00
def create_completion ( prompt , stop , grammar ) :
2024-04-05 10:50:49 -04:00
completion = cast ( llama_types . Completion , llama . create_completion (
2024-02-08 09:07:03 +08:00
prompt = prompt ,
temperature = temperature ,
top_p = top_p ,
top_k = top_k ,
min_p = min_p ,
typical_p = typical_p ,
2024-05-04 22:11:20 +08:00
stream = stream ,
2024-02-08 09:07:03 +08:00
stop = stop ,
max_tokens = max_tokens ,
presence_penalty = presence_penalty ,
frequency_penalty = frequency_penalty ,
repeat_penalty = repeat_penalty ,
tfs_z = tfs_z ,
mirostat_mode = mirostat_mode ,
mirostat_tau = mirostat_tau ,
mirostat_eta = mirostat_eta ,
model = model ,
logits_processor = logits_processor ,
grammar = grammar ,
2024-04-05 10:50:49 -04:00
) )
2024-02-12 15:56:07 -05:00
2024-02-08 09:07:03 +08:00
return completion
2024-02-12 15:56:07 -05:00
2024-03-18 22:40:57 +08:00
content = " "
2024-02-08 09:07:03 +08:00
function_calls , function_bodies = [ ] , [ ]
2024-04-28 08:49:52 +08:00
completion_tokens = 0
2024-05-04 22:11:20 +08:00
def generate_streaming ( tools , functions , function_call , prompt ) :
assert version == " v2 " , " Streaming for v1 is not supported "
chunk_id , chunk_created = None , None
2024-04-28 08:49:52 +08:00
2024-03-18 22:40:57 +08:00
# If tool_choice/function_call is provided
2024-04-28 08:49:52 +08:00
if isinstance ( function_call , dict ) :
2024-03-18 22:40:57 +08:00
prompt + = f " { function_call [ ' name ' ] } \n { CONTENT_TOKEN } "
2024-05-04 22:11:20 +08:00
grammar = get_grammar ( function_call [ " name " ] )
2024-03-18 22:40:57 +08:00
stops = [ STOP_TOKEN , FROM_TOKEN ]
2024-05-04 22:11:20 +08:00
tool_id = " " . join ( [ random . choice ( string . ascii_letters + string . digits ) for _ in range ( 24 ) ] )
completion = create_completion ( prompt = prompt , stop = stops , grammar = grammar )
completion_text = " "
first = True
for chunk in completion :
# Yield the tool/function name first
if first :
if tools is not None :
func_call_dict = {
" tool_calls " : [
{
" index " : 0 ,
" id " : " call_ " + tool_id ,
" type " : " function " ,
" function " : { " name " : function_call [ " name " ] , " arguments " : " " } ,
}
]
}
else :
func_call_dict = { " function_call " : { " name " : function_call [ " name " ] , " arguments " : " " } }
yield llama_types . CreateChatCompletionStreamResponse (
id = " chat " + chunk [ " id " ] ,
object = " chat.completion.chunk " ,
created = chunk [ " created " ] ,
model = chunk [ " model " ] ,
choices = [
{ " index " : 0 , " logprobs " : None , " delta " : { " role " : None , " content " : None , * * func_call_dict } }
] ,
)
first = False
if tools is not None :
func_call_dict = {
" tool_calls " : [
{
" index " : 0 ,
" id " : " call_ " + tool_id ,
" type " : " function " ,
" function " : {
" name " : None ,
" arguments " : chunk [ " choices " ] [ 0 ] [ " text " ] . rstrip ( ) ,
} ,
}
]
}
else :
func_call_dict = { " function_call " : { " name " : None , " arguments " : chunk [ " choices " ] [ 0 ] [ " text " ] . rstrip ( ) } }
if len ( chunk [ " choices " ] [ 0 ] [ " text " ] . rstrip ( ) ) > 0 :
yield llama_types . CreateChatCompletionStreamResponse (
id = " chat " + chunk [ " id " ] ,
object = " chat.completion.chunk " ,
created = chunk [ " created " ] ,
model = chunk [ " model " ] ,
choices = [
{
" index " : 0 ,
" logprobs " : chunk [ " choices " ] [ 0 ] [ " logprobs " ] ,
" delta " : {
" role " : None ,
" content " : None ,
* * func_call_dict ,
} ,
}
] ,
)
# Yield tool_call/function_call stop message
yield llama_types . CreateChatCompletionStreamResponse (
id = " chat " + chunk [ " id " ] ,
object = " chat.completion.chunk " ,
created = chunk [ " created " ] ,
model = chunk [ " model " ] ,
choices = [
{
" index " : 0 ,
" finish_reason " : " tool_calls " if tools is not None else " function_call " ,
" logprobs " : None ,
" delta " : {
" role " : None , " content " : None , " function_call " : None , " tool_calls " : None
} ,
}
] ,
)
2024-03-18 22:40:57 +08:00
# If "auto" or no tool_choice/function_call
elif isinstance ( function_call , str ) and function_call == " auto " :
2024-05-04 22:11:20 +08:00
tool_index = 0
2024-03-18 22:40:57 +08:00
while True :
# Generate function name first
2024-02-08 09:07:03 +08:00
grammar = None
2024-03-18 22:40:57 +08:00
stops = CONTENT_TOKEN
2024-05-04 22:11:20 +08:00
completion = create_completion ( prompt = prompt , stop = stops , grammar = grammar )
completion_text = " "
for chunk in completion :
completion_text + = chunk [ " choices " ] [ 0 ] [ " text " ]
if chunk_id is None :
chunk_id = chunk [ " id " ]
if chunk_created is None :
chunk_created = chunk [ " created " ]
2024-03-18 22:40:57 +08:00
function_name = completion_text . strip ( )
if function_name == " all " :
prompt + = " all \n <|content|> "
2024-05-04 22:11:20 +08:00
# Yield the first empty message for content
yield llama_types . CreateChatCompletionStreamResponse (
id = " chat " + chunk_id ,
model = chunk [ " model " ] ,
created = chunk_created ,
object = " chat.completion.chunk " ,
choices = [
{
" index " : 0 ,
" delta " : { " role " : " assistant " , " content " : " " } ,
" logprobs " : None ,
" finish_reason " : None ,
}
] ,
)
2024-02-08 09:07:03 +08:00
else :
2024-05-04 22:11:20 +08:00
prompt + = f " { function_name } \n <|content|> "
grammar = get_grammar ( function_name )
tool_id = " " . join ( [ random . choice ( string . ascii_letters + string . digits ) for _ in range ( 24 ) ] )
if tools is not None :
func_call_dict = {
" tool_calls " : [
{
" index " : tool_index ,
" id " : " call_ " + tool_id ,
" type " : " function " ,
" function " : { " name " : function_name , " arguments " : " " } ,
}
]
}
else :
func_call_dict = { " function_call " : { " name " : function_name , " arguments " : " " } }
# Stream function name
yield llama_types . CreateChatCompletionStreamResponse (
id = " chat " + chunk_id ,
object = " chat.completion.chunk " ,
created = chunk_created ,
model = chunk [ " model " ] ,
choices = [
{
" index " : 0 ,
" logprobs " : chunk [ " choices " ] [ 0 ] [ " logprobs " ] ,
" delta " : {
" role " : " assistant " ,
" content " : None ,
* * func_call_dict ,
} ,
}
] ,
)
2024-03-18 22:40:57 +08:00
# Generate content
stops = [ RECIPIENT_TOKEN , STOP_TOKEN ]
2024-05-04 22:11:20 +08:00
completion = create_completion ( prompt = prompt , stop = stops , grammar = grammar )
2024-03-18 22:40:57 +08:00
if function_name == " all " :
2024-05-04 22:11:20 +08:00
completion_text = " "
stop_sequence , buffer , is_end = " \n <|from|>assistant \n <|recipient|> " , [ ] , False
for i , chunk in enumerate ( completion ) :
completion_text + = chunk [ " choices " ] [ 0 ] [ " text " ]
if is_end :
buffer . append ( chunk [ " choices " ] [ 0 ] [ " text " ] . strip ( " " ) )
if stop_sequence . startswith ( " " . join ( buffer ) ) :
continue
else :
buffer . pop ( )
while len ( buffer ) > 0 :
yield llama_types . CreateChatCompletionStreamResponse (
id = " chat " + chunk_id ,
object = " chat.completion.chunk " ,
created = chunk_created ,
model = chunk [ " model " ] ,
choices = [
{
" index " : 0 ,
" logprobs " : chunk [ " choices " ] [ 0 ] [ " logprobs " ] ,
" delta " : {
" role " : " assistant " , " content " : buffer . pop ( 0 )
} ,
}
] ,
)
is_end = False
elif chunk [ " choices " ] [ 0 ] [ " text " ] == " \n " :
is_end = True
buffer . append ( chunk [ " choices " ] [ 0 ] [ " text " ] . strip ( " " ) )
continue
if len ( buffer ) == 0 and len ( chunk [ " choices " ] [ 0 ] [ " text " ] ) > 0 :
yield llama_types . CreateChatCompletionStreamResponse (
id = " chat " + chunk_id ,
object = " chat.completion.chunk " ,
created = chunk_created ,
model = chunk [ " model " ] ,
choices = [
{
" index " : 0 ,
" logprobs " : chunk [ " choices " ] [ 0 ] [ " logprobs " ] ,
" delta " : {
" role " : " assistant " ,
" content " : chunk [ " choices " ] [ 0 ] [ " text " ] if i > 0 else chunk [ " choices " ] [ 0 ] [ " text " ] . lstrip ( )
} ,
}
] ,
)
2024-03-18 22:40:57 +08:00
# Check whether the model wants to generate another turn
if " <|from|> assistant " in completion_text or " <|from|>assistant " in completion_text :
2024-04-28 08:49:52 +08:00
if completion_text . endswith ( " \n <|from|>assistant \n " ) :
cleaned_completion_text = completion_text [ : - len ( " \n <|from|>assistant \n " ) ] . strip ( )
elif completion_text . endswith ( " \n <|from|> assistant \n " ) :
2024-05-04 22:11:20 +08:00
cleaned_completion_text = completion_text [ : - len ( " \n <|from|> assistant \n " ) ] . strip ( )
2024-04-28 08:49:52 +08:00
else :
cleaned_completion_text = completion_text . strip ( )
2024-03-18 22:40:57 +08:00
prompt + = f " { cleaned_completion_text } \n <|from|>assistant \n <|recipient|> "
else :
2024-05-04 22:11:20 +08:00
# Yield stop message
yield llama_types . CreateChatCompletionStreamResponse (
id = " chat " + chunk_id ,
model = chunk [ " model " ] ,
created = chunk_created ,
object = " chat.completion.chunk " ,
choices = [
{
" index " : 0 ,
" delta " : { } ,
" logprobs " : None ,
" finish_reason " : " stop " ,
}
] ,
)
2024-03-18 22:40:57 +08:00
break
else :
# Check whether the model wants to generate another turn
2024-05-04 22:11:20 +08:00
completion_text = " "
for chunk in completion :
completion_text + = chunk [ " choices " ] [ 0 ] [ " text " ]
if len ( chunk [ " choices " ] [ 0 ] [ " text " ] . rstrip ( ) ) > 0 :
if tools is not None :
func_call_dict = {
" tool_calls " : [
{
" index " : tool_index ,
" id " : " call_ " + tool_id ,
" type " : " function " ,
" function " : {
" name " : None ,
" arguments " : chunk [ " choices " ] [ 0 ] [ " text " ] . rstrip ( ) ,
} ,
}
]
}
else :
func_call_dict = { " function_call " : { " name " : None , " arguments " : chunk [ " choices " ] [ 0 ] [ " text " ] . rstrip ( ) } }
yield llama_types . CreateChatCompletionStreamResponse (
id = " chat " + chunk_id ,
object = " chat.completion.chunk " ,
created = chunk_created ,
model = chunk [ " model " ] ,
choices = [
{
" index " : 0 ,
" logprobs " : chunk [ " choices " ] [ 0 ] [ " logprobs " ] ,
" delta " : {
" role " : None ,
" content " : None ,
* * func_call_dict ,
} ,
}
] ,
)
2024-03-18 22:40:57 +08:00
prompt + = completion_text . strip ( )
grammar = None
2024-05-04 22:11:20 +08:00
completion = create_completion ( prompt = prompt , stop = stops , grammar = grammar )
completion_text + = " " . join ( [ chunk [ " choices " ] [ 0 ] [ " text " ] for chunk in completion ] )
if ( " <|from|> assistant " in completion_text or " <|from|>assistant " in completion_text ) and tools is not None :
2024-03-18 22:40:57 +08:00
prompt + = " \n <|from|>assistant \n <|recipient|> "
2024-05-04 22:11:20 +08:00
tool_index + = 1
2024-03-18 22:40:57 +08:00
else :
2024-05-04 22:11:20 +08:00
# Yield tool_call/function_call stop message
yield llama_types . CreateChatCompletionStreamResponse (
id = " chat " + chunk_id ,
object = " chat.completion.chunk " ,
created = chunk_created ,
model = chunk [ " model " ] ,
choices = [
{
" index " : 0 ,
" finish_reason " : " tool_calls " if tools is not None else " function_call " ,
" logprobs " : None ,
" delta " : {
" role " : None , " content " : None , " function_call " : None , " tool_calls " : None
} ,
}
] ,
)
2024-03-18 22:40:57 +08:00
break
2024-05-04 22:11:20 +08:00
if stream is not False :
return generate_streaming (
tools = tools , functions = functions , function_call = function_call , prompt = prompt
2024-02-08 09:07:03 +08:00
)
2024-05-04 22:11:20 +08:00
else :
if version == " v1 " :
# If no or "auto" tool_choice/function_call
if isinstance ( function_call , str ) and function_call == " auto " :
stops = [ " \n " , END_ASSISTANT_TOKEN ]
# If tool_choice/function_call is provided
elif isinstance ( function_call , dict ) :
prompt + = f " { START_FUNCTION_CALL_TOKEN } { function_call [ ' name ' ] } : \n "
stops = END_FUNCTION_CALL_TOKEN
function_call = function_call [ " name " ]
function_calls . append ( function_call )
grammar = get_grammar ( function_call )
else :
prompt = prompt
stops = [ " \n " , END_ASSISTANT_TOKEN ]
2024-02-08 09:07:03 +08:00
2024-05-08 07:21:27 +01:00
completion = create_completion ( prompt = prompt , stop = stops , grammar = grammar )
2024-05-04 22:11:20 +08:00
completion_text = completion [ " choices " ] [ 0 ] [ " text " ]
completion_tokens + = completion [ " usage " ] [ " completion_tokens " ]
# If the generation does not involve a function call
if (
START_FUNCTION_CALL_TOKEN not in prompt
and START_FUNCTION_CALL_TOKEN not in completion_text
) :
completion [ " usage " ] [ " completion_tokens " ] = completion_tokens
return _convert_completion_to_chat ( completion , stream = stream ) # type: ignore
# If the generation involves a function call in completion, generate the parameters
elif (
START_FUNCTION_CALL_TOKEN not in prompt
and START_FUNCTION_CALL_TOKEN in completion_text
) :
prompt + = (
completion_text . replace (
f " { START_FUNCTION_CALL_TOKEN } " , START_FUNCTION_CALL_TOKEN
)
+ " \n "
)
function_calls . append (
completion_text . split ( START_FUNCTION_CALL_TOKEN ) [ - 1 ] [ : - 1 ] . strip ( )
)
grammar = get_grammar ( function_calls [ - 1 ] )
2024-05-08 07:21:27 +01:00
completion = create_completion ( prompt = prompt , stop = END_FUNCTION_CALL_TOKEN , grammar = grammar )
2024-05-04 22:11:20 +08:00
completion_tokens + = completion [ " usage " ] [ " completion_tokens " ]
function_bodies . append ( completion [ " choices " ] [ 0 ] [ " text " ] . strip ( ) )
# If the prompt involves a function call, just append generated parameters to function_bodies
else :
function_bodies . append ( completion_text . strip ( ) )
2024-04-28 08:49:52 +08:00
else :
2024-05-04 22:11:20 +08:00
# If tool_choice/function_call is provided
if isinstance ( function_call , dict ) :
prompt + = f " { function_call [ ' name ' ] } \n { CONTENT_TOKEN } "
function_call = function_call [ " name " ]
function_calls . append ( function_call )
grammar = get_grammar ( function_call )
stops = [ STOP_TOKEN , FROM_TOKEN ]
2024-05-08 07:21:27 +01:00
completion = create_completion ( prompt = prompt , stop = stops , grammar = grammar )
2024-05-04 22:11:20 +08:00
completion_text = completion [ " choices " ] [ 0 ] [ " text " ]
completion_tokens + = completion [ " usage " ] [ " completion_tokens " ]
function_bodies . append ( completion_text . strip ( ) )
# If "auto" or no tool_choice/function_call
elif isinstance ( function_call , str ) and function_call == " auto " :
while True :
# Generate function name first
grammar = None
stops = CONTENT_TOKEN
2024-05-08 07:21:27 +01:00
completion = create_completion ( prompt = prompt , stop = stops , grammar = grammar )
2024-05-04 22:11:20 +08:00
completion_text = completion [ " choices " ] [ 0 ] [ " text " ]
completion_tokens + = completion [ " usage " ] [ " completion_tokens " ]
function_name = completion_text . strip ( )
if function_name == " all " :
prompt + = " all \n <|content|> "
else :
function_call = completion_text . strip ( )
prompt + = f " { function_call } \n <|content|> "
function_calls . append ( function_call )
grammar = get_grammar ( function_call )
# Generate content
stops = [ RECIPIENT_TOKEN , STOP_TOKEN ]
2024-05-08 07:21:27 +01:00
completion = create_completion ( prompt = prompt , stop = stops , grammar = grammar )
2024-05-04 22:11:20 +08:00
completion_text = completion [ " choices " ] [ 0 ] [ " text " ]
completion_tokens + = completion [ " usage " ] [ " completion_tokens " ]
if function_name == " all " :
if completion_text . endswith ( " \n <|from|>assistant \n " ) :
content + = completion_text [ : - len ( " \n <|from|>assistant \n " ) ]
if completion_text . endswith ( " \n <|from|> assistant \n " ) :
content + = completion_text [ - len ( " \n <|from|> assistant \n " ) ]
else :
content + = completion_text
content = content . lstrip ( )
# Check whether the model wants to generate another turn
if " <|from|> assistant " in completion_text or " <|from|>assistant " in completion_text :
if completion_text . endswith ( " \n <|from|>assistant \n " ) :
cleaned_completion_text = completion_text [ : - len ( " \n <|from|>assistant \n " ) ] . strip ( )
elif completion_text . endswith ( " \n <|from|> assistant \n " ) :
cleaned_completion_text = completion_text [ - len ( " \n <|from|> assistant \n " ) ] . strip ( )
else :
cleaned_completion_text = completion_text . strip ( )
prompt + = f " { cleaned_completion_text } \n <|from|>assistant \n <|recipient|> "
else :
break
else :
function_bodies . append ( completion_text . strip ( ) )
# Check whether the model wants to generate another turn
prompt + = completion_text . strip ( )
grammar = None
2024-05-08 07:21:27 +01:00
completion = create_completion ( prompt = prompt , stop = stops , grammar = grammar )
2024-05-04 22:11:20 +08:00
completion_tokens + = completion [ " usage " ] [ " completion_tokens " ]
if " <|from|> assistant " in completion [ " choices " ] [ 0 ] [ " text " ] or " <|from|>assistant " in completion [ " choices " ] [ 0 ] [ " text " ] :
prompt + = " \n <|from|>assistant \n <|recipient|> "
else :
break
assert " usage " in completion
assert len ( function_calls ) == len ( function_bodies )
tool_calls : List [ llama_types . ChatCompletionMessageToolCall ] = [ ]
for function_call , function_body in zip ( function_calls , function_bodies ) :
tool_calls . append (
{
" id " : " call_ "
+ " " . join (
[
random . choice ( string . ascii_letters + string . digits )
for _ in range ( 24 )
]
) ,
" type " : " function " ,
" function " : {
" name " : function_call ,
" arguments " : function_body ,
} ,
}
)
# TODO: support stream mode
function_call_dict : Union [ Dict [ str , str ] , Dict [ Literal [ " function_call " ] , llama_types . ChatCompletionRequestAssistantMessageFunctionCall ] ] = { }
if len ( tool_calls ) > 0 :
if tools is not None :
function_call_dict [ " tool_calls " ] = tool_calls
else :
function_call_dict [ " function_call " ] = {
" name " : tool_calls [ 0 ] [ " function " ] [ " name " ] ,
" arguments " : tool_calls [ 0 ] [ " function " ] [ " arguments " ] ,
}
completion [ " usage " ] [ " completion_tokens " ] = completion_tokens
return llama_types . CreateChatCompletionResponse (
id = " chat " + completion [ " id " ] ,
object = " chat.completion " ,
created = completion [ " created " ] ,
model = completion [ " model " ] ,
choices = [
{
" index " : 0 ,
" logprobs " : completion [ " choices " ] [ 0 ] [ " logprobs " ] ,
" message " : {
" role " : " assistant " ,
" content " : None if content == " " else content ,
* * function_call_dict ,
} ,
" finish_reason " : " tool_calls " if len ( tool_calls ) > 0 else " stop " ,
}
] ,
usage = completion [ " usage " ] ,
)
2024-02-08 09:07:03 +08:00
2023-11-08 04:48:51 +01:00
class Llava15ChatHandler :
2024-05-02 11:32:18 -04:00
DEFAULT_SYSTEM_MESSAGE : Optional [ str ] = " A chat between a curious human and an artificial intelligence assistant. The assistant gives helpful, detailed, and polite answers to the human ' s questions. "
2024-04-30 01:35:38 -04:00
CHAT_FORMAT = (
" { % f or message in messages % } "
" { % i f message.role == ' system ' % } "
" {{ message.content }} "
" { % e ndif % } "
" { % i f message.role == ' user ' % } "
" { % i f message.content is string % } "
" \n USER: {{ message.content }} "
2024-04-30 03:08:46 -04:00
" { % e ndif % } "
" { % i f message.content is iterable % } "
2024-04-30 01:35:38 -04:00
" \n USER: "
2024-04-30 03:08:46 -04:00
2024-04-30 01:35:38 -04:00
" { % f or content in message.content % } "
" { % i f content.type == ' image_url ' and content.image_url is string % } "
" {{ content.image_url }} "
" { % e ndif % } "
" { % i f content.type == ' image_url ' and content.image_url is mapping % } "
" {{ content.image_url.url }} "
" { % e ndif % } "
" { % e ndfor % } "
2024-04-30 03:08:46 -04:00
" { % f or content in message.content % } "
" { % i f content.type == ' text ' % } "
" {{ content.text }} "
" { % e ndif % } "
" { % e ndfor % } "
2024-04-30 01:35:38 -04:00
" { % e ndif % } "
" { % e ndif % } "
" { % i f message.role == ' assistant ' and message.content is not none % } "
" \n ASSISTANT: {{ message.content }} "
" { % e ndif % } "
" { % e ndfor % } "
" { % i f add_generation_prompt % } "
" \n ASSISTANT: "
" { % e ndif % } "
)
2023-11-08 11:05:45 -05:00
2024-04-30 15:50:30 -04:00
def __init__ ( self , clip_model_path : str , verbose : bool = True ) :
2023-11-08 04:48:51 +01:00
import llama_cpp . llava_cpp as llava_cpp
self . clip_model_path = clip_model_path
2023-11-08 11:05:45 -05:00
self . verbose = verbose
2024-04-30 01:35:38 -04:00
self . _llava_cpp = llava_cpp # TODO: Fix
self . _exit_stack = ExitStack ( )
self . _last_image_embed : Optional [ llava_cpp . CtypesPointer [ llava_cpp . llava_image_embed ] ] = None
self . _last_image_hash : Optional [ int ] = None
2023-11-08 04:48:51 +01:00
2024-03-08 21:00:10 -05:00
if not os . path . exists ( clip_model_path ) :
raise ValueError ( f " Clip model path does not exist: { clip_model_path } " )
2023-11-08 11:05:45 -05:00
with suppress_stdout_stderr ( disable = self . verbose ) :
2024-04-30 01:35:38 -04:00
clip_ctx = self . _llava_cpp . clip_model_load (
2023-11-09 00:55:23 -05:00
self . clip_model_path . encode ( ) , 0
2023-11-08 11:05:45 -05:00
)
2023-11-08 04:48:51 +01:00
2024-04-30 01:35:38 -04:00
if clip_ctx is None :
raise ValueError ( f " Failed to load clip model: { clip_model_path } " )
self . clip_ctx = clip_ctx
2023-11-08 04:48:51 +01:00
2024-04-30 01:35:38 -04:00
def clip_free ( ) :
with suppress_stdout_stderr ( disable = self . verbose ) :
self . _llava_cpp . clip_free ( self . clip_ctx )
self . _exit_stack . callback ( clip_free )
def last_image_embed_free ( ) :
with suppress_stdout_stderr ( disable = self . verbose ) :
if self . _last_image_embed is not None :
self . _llava_cpp . llava_image_embed_free ( self . _last_image_embed )
self . _last_image_embed = None
2023-11-08 04:48:51 +01:00
2024-04-30 01:35:38 -04:00
self . _exit_stack . callback ( last_image_embed_free )
2023-11-08 04:48:51 +01:00
2024-04-30 01:35:38 -04:00
def load_image ( self , image_url : str ) - > bytes :
return self . _load_image ( image_url )
2023-11-08 04:48:51 +01:00
def __call__ (
self ,
* ,
llama : llama . Llama ,
messages : List [ llama_types . ChatCompletionRequestMessage ] ,
functions : Optional [ List [ llama_types . ChatCompletionFunction ] ] = None ,
function_call : Optional [ llama_types . ChatCompletionRequestFunctionCall ] = None ,
tools : Optional [ List [ llama_types . ChatCompletionTool ] ] = None ,
tool_choice : Optional [ llama_types . ChatCompletionToolChoiceOption ] = None ,
temperature : float = 0.2 ,
top_p : float = 0.95 ,
top_k : int = 40 ,
2023-11-21 06:21:33 +02:00
min_p : float = 0.05 ,
typical_p : float = 1.0 ,
2023-11-08 04:48:51 +01:00
stream : bool = False ,
stop : Optional [ Union [ str , List [ str ] ] ] = [ ] ,
2024-04-30 01:35:38 -04:00
seed : Optional [ int ] = None ,
2023-11-09 00:55:23 -05:00
response_format : Optional [
llama_types . ChatCompletionRequestResponseFormat
] = None ,
2023-11-10 02:51:58 -05:00
max_tokens : Optional [ int ] = None ,
2023-11-08 04:48:51 +01:00
presence_penalty : float = 0.0 ,
frequency_penalty : float = 0.0 ,
repeat_penalty : float = 1.1 ,
tfs_z : float = 1.0 ,
mirostat_mode : int = 0 ,
mirostat_tau : float = 5.0 ,
mirostat_eta : float = 0.1 ,
model : Optional [ str ] = None ,
logits_processor : Optional [ llama . LogitsProcessorList ] = None ,
grammar : Optional [ llama . LlamaGrammar ] = None ,
2024-04-30 01:35:38 -04:00
logit_bias : Optional [ Dict [ str , float ] ] = None ,
logprobs : Optional [ bool ] = None ,
top_logprobs : Optional [ int ] = None ,
2023-11-08 04:48:51 +01:00
* * kwargs , # type: ignore
2023-11-08 00:07:16 -05:00
) - > Union [
llama_types . CreateChatCompletionResponse ,
Iterator [ llama_types . CreateChatCompletionStreamResponse ] ,
] :
2023-11-08 04:48:51 +01:00
assert self . clip_ctx is not None
2024-04-30 01:35:38 -04:00
2023-11-08 04:48:51 +01:00
system_prompt = _get_system_message ( messages )
2024-05-02 11:32:18 -04:00
if system_prompt == " " and self . DEFAULT_SYSTEM_MESSAGE is not None :
2024-04-30 01:35:38 -04:00
messages = [ llama_types . ChatCompletionRequestSystemMessage ( role = " system " , content = self . DEFAULT_SYSTEM_MESSAGE ) ] + messages
image_urls = self . get_image_urls ( messages )
template = jinja2 . Template ( self . CHAT_FORMAT )
text = template . render ( messages = messages , add_generation_prompt = True )
split_text = self . split_text_on_image_urls ( text , image_urls )
def embed_image_bytes ( image_bytes : bytes ) :
if self . _last_image_embed is not None and self . _last_image_hash is not None and hash ( image_bytes ) == self . _last_image_hash :
return self . _last_image_embed
with suppress_stdout_stderr ( disable = self . verbose ) :
embed = (
self . _llava_cpp . llava_image_embed_make_with_bytes (
self . clip_ctx ,
llama . context_params . n_threads_batch ,
( ctypes . c_uint8 * len ( image_bytes ) ) . from_buffer ( bytearray ( image_bytes ) ) ,
len ( image_bytes ) ,
2023-11-08 00:07:16 -05:00
)
)
2024-04-30 01:35:38 -04:00
self . _last_image_embed = embed
self . _last_image_hash = hash ( image_bytes )
return embed
2023-11-08 04:48:51 +01:00
2024-04-30 01:35:38 -04:00
# Evaluate prompt
llama . reset ( )
for i , ( type_ , value ) in enumerate ( split_text ) :
if type_ == " text " :
tokens = llama . tokenize ( value . encode ( " utf8 " ) , add_bos = i == 0 )
if llama . n_tokens + len ( tokens ) > llama . n_ctx ( ) :
raise ValueError ( " Prompt exceeds n_ctx " ) # TODO: Fix
llama . eval ( tokens )
else :
image_bytes = self . load_image ( value )
embed = embed_image_bytes ( image_bytes )
if llama . n_tokens + embed . contents . n_image_pos > llama . n_ctx ( ) :
raise ValueError ( " Prompt exceeds n_ctx " ) # TODO: Fix
n_past = ctypes . c_int ( llama . n_tokens )
n_past_p = ctypes . pointer ( n_past )
with suppress_stdout_stderr ( disable = self . verbose ) :
self . _llava_cpp . llava_eval_image_embed (
llama . ctx ,
embed ,
llama . n_batch ,
n_past_p ,
)
llama . n_tokens = n_past . value
# Get prompt tokens to avoid a cache miss
2023-11-09 00:55:23 -05:00
prompt = llama . input_ids [ : llama . n_tokens ] . tolist ( )
if response_format is not None and response_format [ " type " ] == " json_object " :
2024-03-15 12:58:34 -04:00
grammar = _grammar_for_response_format ( response_format )
2023-11-08 04:48:51 +01:00
2024-04-30 01:35:38 -04:00
# Convert legacy functions to tools
if functions is not None :
tools = [
{
" type " : " function " ,
" function " : function ,
}
for function in functions
]
# Convert legacy function_call to tool_choice
if function_call is not None :
if isinstance ( function_call , str ) and (
function_call == " none " or function_call == " auto "
) :
tool_choice = function_call
if isinstance ( function_call , dict ) and " name " in function_call :
tool_choice = {
" type " : " function " ,
" function " : {
" name " : function_call [ " name " ] ,
} ,
}
tool = None
if tool_choice is not None and isinstance ( tool_choice , dict ) and tools is not None :
name = tool_choice [ " function " ] [ " name " ]
tool = next ( ( t for t in tools if t [ " function " ] [ " name " ] == name ) , None )
if tool is None :
raise ValueError ( f " Tool choice ' { name } ' not found in tools. " )
schema = tool [ " function " ] [ " parameters " ]
try :
# create grammar from json schema
grammar = llama_grammar . LlamaGrammar . from_json_schema (
json . dumps ( schema ) , verbose = llama . verbose
)
except Exception as e :
grammar = llama_grammar . LlamaGrammar . from_string (
llama_grammar . JSON_GBNF , verbose = llama . verbose
)
completion_or_chunks = llama . create_completion (
prompt = prompt ,
temperature = temperature ,
top_p = top_p ,
top_k = top_k ,
min_p = min_p ,
typical_p = typical_p ,
logprobs = top_logprobs if logprobs else None ,
2023-11-08 04:48:51 +01:00
stream = stream ,
2024-04-30 01:35:38 -04:00
stop = stop ,
seed = seed ,
max_tokens = max_tokens ,
presence_penalty = presence_penalty ,
frequency_penalty = frequency_penalty ,
repeat_penalty = repeat_penalty ,
tfs_z = tfs_z ,
mirostat_mode = mirostat_mode ,
mirostat_tau = mirostat_tau ,
mirostat_eta = mirostat_eta ,
model = model ,
logits_processor = logits_processor ,
grammar = grammar ,
logit_bias = logit_bias ,
2023-11-08 00:07:16 -05:00
)
2024-04-30 01:35:38 -04:00
if tool is not None :
tool_name = tool [ " function " ] [ " name " ]
return _convert_completion_to_chat_function (
tool_name , completion_or_chunks , stream
)
return _convert_completion_to_chat ( completion_or_chunks , stream = stream )
@staticmethod
def _load_image ( image_url : str ) - > bytes :
# TODO: Add Pillow support for other image formats beyond (jpg, png)
if image_url . startswith ( " data: " ) :
import base64
image_bytes = base64 . b64decode ( image_url . split ( " , " ) [ 1 ] )
return image_bytes
else :
import urllib . request
with urllib . request . urlopen ( image_url ) as f :
image_bytes = f . read ( )
return image_bytes
@staticmethod
def get_image_urls ( messages : List [ llama_types . ChatCompletionRequestMessage ] ) :
image_urls : List [ str ] = [ ]
for message in messages :
if message [ " role " ] == " user " :
if message [ " content " ] is None :
continue
for content in message [ " content " ] :
if isinstance ( content , dict ) and " type " in content :
if content [ " type " ] == " image_url " :
if (
isinstance ( content [ " image_url " ] , dict )
and " url " in content [ " image_url " ]
) :
image_urls . append ( content [ " image_url " ] [ " url " ] )
else :
image_urls . append ( content [ " image_url " ] )
return image_urls
@staticmethod
def split_text_on_image_urls ( text : str , image_urls : List [ str ] ) :
def find_first ( s : str , substrs : List [ str ] ) :
for i , substr in enumerate ( substrs ) :
pos = s . find ( substr )
if pos != - 1 :
return pos , i
return None , None
split_text : List [ Tuple [ Literal [ " text " , " image_url " ] , str ] ] = [ ]
remaining = text
while remaining :
# Find first image_url
pos , i = find_first ( remaining , image_urls )
if pos is not None and i is not None :
if pos > 0 :
split_text . append ( ( " text " , remaining [ : pos ] ) )
split_text . append ( ( " image_url " , image_urls [ i ] ) )
remaining = remaining [ pos + len ( image_urls [ i ] ) : ]
else :
split_text . append ( ( " text " , remaining ) )
remaining = " "
return split_text
@classmethod
def from_pretrained (
cls ,
repo_id : str ,
filename : Optional [ str ] ,
local_dir : Optional [ Union [ str , os . PathLike [ str ] ] ] = None ,
local_dir_use_symlinks : Union [ bool , Literal [ " auto " ] ] = " auto " ,
cache_dir : Optional [ Union [ str , os . PathLike [ str ] ] ] = None ,
* * kwargs : Any ,
) - > " Llava15ChatHandler " :
import fnmatch
from pathlib import Path
try :
from huggingface_hub import hf_hub_download , HfFileSystem # type: ignore
from huggingface_hub . utils import validate_repo_id # type: ignore
except ImportError :
raise ImportError (
" Llama.from_pretrained requires the huggingface-hub package. "
" You can install it with `pip install huggingface-hub`. "
)
validate_repo_id ( repo_id )
hffs = HfFileSystem ( )
files = [
file [ " name " ] if isinstance ( file , dict ) else file
for file in hffs . ls ( repo_id ) # type: ignore
]
# split each file into repo_id, subfolder, filename
file_list : List [ str ] = [ ]
for file in files :
rel_path = Path ( file ) . relative_to ( repo_id )
file_list . append ( str ( rel_path ) )
matching_files = [ file for file in file_list if fnmatch . fnmatch ( file , filename ) ] # type: ignore
if len ( matching_files ) == 0 :
raise ValueError (
f " No file found in { repo_id } that match { filename } \n \n "
f " Available Files: \n { json . dumps ( file_list ) } "
)
if len ( matching_files ) > 1 :
raise ValueError (
f " Multiple files found in { repo_id } matching { filename } \n \n "
f " Available Files: \n { json . dumps ( files ) } "
)
( matching_file , ) = matching_files
subfolder = str ( Path ( matching_file ) . parent )
filename = Path ( matching_file ) . name
# download the file
hf_hub_download (
repo_id = repo_id ,
filename = filename ,
subfolder = subfolder ,
local_dir = cast ( Union [ str , Path , None ] , local_dir ) ,
local_dir_use_symlinks = local_dir_use_symlinks ,
cache_dir = cast ( Union [ str , Path , None ] , cache_dir ) ,
)
if local_dir is None :
model_path = hf_hub_download (
repo_id = repo_id ,
filename = filename ,
subfolder = subfolder ,
local_dir = local_dir ,
local_dir_use_symlinks = local_dir_use_symlinks ,
cache_dir = cast ( Union [ str , Path , None ] , cache_dir ) ,
local_files_only = True ,
)
else :
model_path = os . path . join ( local_dir , filename )
return cls (
clip_model_path = model_path ,
* * kwargs ,
)
class ObsidianChatHandler ( Llava15ChatHandler ) :
# Prompt Format
# The model followed ChatML format. However, with ### as the seperator
# <|im_start|>user
# What is this sign about?\n<image>
# ###
# <|im_start|>assistant
# The sign is about bullying, and it is placed on a black background with a red background.
# ###
CHAT_FORMAT = (
" { % f or message in messages % } "
# System message
" { % i f message.role == ' system ' % } "
" <|im_start|>system \n "
" {{ message.content }} \n "
" ### \n "
" { % e ndif % } "
# User message
" { % i f message.role == ' user ' % } "
" <|im_start|>user \n "
" { % i f message.content is string % } "
" {{ message.content }} "
" { % e ndif % } "
" { % i f message.content is iterable % } "
2024-04-30 03:08:46 -04:00
" { % f or content in message.content % } "
" { % i f content.type == ' image_url ' and content.image_url is string % } "
" {{ content.image_url }} "
" { % e ndif % } "
" { % i f content.type == ' image_url ' and content.image_url is mapping % } "
" {{ content.image_url.url }} "
" { % e ndif % } "
" { % e ndfor % } "
2024-04-30 01:35:38 -04:00
" { % f or content in message.content % } "
" { % i f content.type == ' text ' % } "
" {{ content.text }} "
" { % e ndif % } "
" { % e ndfor % } "
2024-04-30 03:08:46 -04:00
2024-04-30 01:35:38 -04:00
" { % e ndif % } "
" ### \n "
" { % e ndif % } "
# Assistant message
" { % i f message.role == ' assistant ' % } "
" <|im_start|>assistant \n "
" {{ message.content }} "
" ### \n "
" { % e ndif % } "
" { % e ndfor % } "
# Generation prompt
" { % i f add_generation_prompt % } "
" <|im_start|>assistant \n "
" { % e ndif % } "
)
class MoondreamChatHandler ( Llava15ChatHandler ) :
# Chat Format:
# f"<image>\n\n{chat_history}Question: {question}\n\nAnswer:"
CHAT_FORMAT = (
" { % f or message in messages % } "
" { % i f message.role == ' user ' % } "
" { % i f message.content is iterable % } "
# <image>
2024-04-30 03:08:46 -04:00
" { % f or content in message.content % } "
2024-04-30 01:35:38 -04:00
" { % i f content.type == ' image_url ' % } "
" { % i f content.image_url is string % } "
" {{ content.image_url }} \n \n "
" { % e ndif % } "
" { % i f content.image_url is mapping % } "
" {{ content.image_url.url }} \n \n "
" { % e ndif % } "
" { % e ndif % } "
2024-04-30 03:08:46 -04:00
" { % e ndfor % } "
2024-04-30 01:35:38 -04:00
# Question:
2024-04-30 03:08:46 -04:00
" { % f or content in message.content % } "
2024-04-30 01:35:38 -04:00
" { % i f content.type == ' text ' % } "
" Question: {{ content.text }} \n \n "
" { % e ndif % } "
" { % e ndfor % } "
2024-04-30 03:08:46 -04:00
2024-04-30 01:35:38 -04:00
" { % e ndif % } "
# Question:
" { % i f message.content is string % } "
" Question: {{ message.content }} \n \n "
" { % e ndif % } "
" { % e ndif % } "
# Answer:
" { % i f message.role == ' assistant ' % } "
" Answer: {{ message.content }} \n \n "
" { % e ndif % } "
" { % e ndfor % } "
# Generation prompt
" { % i f add_generation_prompt % } "
" Answer: "
" { % e ndif % } "
)
class Llava16ChatHandler ( Llava15ChatHandler ) :
DEFAULT_SYSTEM_MESSAGE = " A chat between a curious human and an artificial intelligence assistant. The assistant gives helpful, detailed, and polite answers to the human ' s questions. "
# Example prompt
# "A chat between a curious human and an artificial intelligence assistant. The assistant gives helpful, detailed, and polite answers to the human's questions. USER: <image>\nWhat is shown in this image? ASSISTANT:"
CHAT_FORMAT = (
" { % f or message in messages % } "
" { % i f message.role == ' system ' % } "
" {{ message.content }} "
" { % e ndif % } "
" { % i f message.role == ' user ' % } "
" { % i f message.content is iterable % } "
# <image>
2024-04-30 03:08:46 -04:00
" { % f or content in message.content % } "
2024-04-30 01:35:38 -04:00
" { % i f content.type == ' image_url ' % } "
" { % i f content.image_url is string % } "
" {{ content.image_url }} \n "
" { % e ndif % } "
" { % i f content.image_url is mapping % } "
" {{ content.image_url.url }} \n "
" { % e ndif % } "
" { % e ndif % } "
2024-04-30 03:08:46 -04:00
" { % e ndfor % } "
2024-04-30 01:35:38 -04:00
# Question:
2024-04-30 03:08:46 -04:00
" { % f or content in message.content % } "
2024-04-30 01:35:38 -04:00
" { % i f content.type == ' text ' % } "
" {{ content.text }} "
" { % e ndif % } "
" { % e ndfor % } "
2024-04-30 03:08:46 -04:00
2024-04-30 01:35:38 -04:00
" { % e ndif % } "
# Question:
" { % i f message.content is string % } "
" {{ message.content }} "
" { % e ndif % } "
" { % e ndif % } "
# Answer:
" { % i f message.role == ' assistant ' % } "
" {{ message.content }} "
" { % e ndif % } "
" { % e ndfor % } "
# Generation prompt
" { % i f add_generation_prompt % } "
" Answer: "
" { % e ndif % } "
)
class NanoLlavaChatHandler ( Llava15ChatHandler ) :
# Prompt Format
# The model follow the ChatML standard, however, without \n at the end of <|im_end|>:
# <|im_start|>system
# Answer the question<|im_end|><|im_start|>user
# <image>
# What is the picture about?<|im_end|><|im_start|>assistant
CHAT_FORMAT = (
" { % f or message in messages % } "
# System message
" { % i f message.role == ' system ' % } "
" <|im_start|>system \n "
" {{ message.content }} "
" <|im_end|> "
" { % e ndif % } "
# User message
" { % i f message.role == ' user ' % } "
" <|im_start|>user \n "
" { % i f message.content is string % } "
" {{ message.content }} "
" { % e ndif % } "
" { % i f message.content is iterable % } "
2024-04-30 03:08:46 -04:00
" { % f or content in message.content % } "
" { % i f content.type == ' image_url ' and content.image_url is string % } "
" {{ content.image_url }} "
" { % e ndif % } "
" { % i f content.type == ' image_url ' and content.image_url is mapping % } "
" {{ content.image_url.url }} "
" { % e ndif % } "
" { % e ndfor % } "
2024-04-30 01:35:38 -04:00
" { % f or content in message.content % } "
" { % i f content.type == ' text ' % } "
" {{ content.text }} "
" { % e ndif % } "
" { % e ndfor % } "
2024-04-30 03:08:46 -04:00
2024-04-30 01:35:38 -04:00
" { % e ndif % } "
" <|im_end|> "
" { % e ndif % } "
# Assistant message
" { % i f message.role == ' assistant ' % } "
" <|im_start|>assistant \n "
" {{ message.content }} "
" <|im_end|> "
" { % e ndif % } "
" { % e ndfor % } "
# Generation prompt
" { % i f add_generation_prompt % } "
" <|im_start|>assistant \n "
" { % e ndif % } "
)
2024-02-12 15:56:07 -05:00
2024-05-02 11:32:18 -04:00
class Llama3VisionAlpha ( Llava15ChatHandler ) :
# question = "<image>" + q
# prompt = f"<|start_header_id|>user<|end_header_id|>\n\n{question}<|eot_id|><|start_header_id|>assistant<|end_header_id|>\n\n"
DEFAULT_SYSTEM_MESSAGE = None
CHAT_FORMAT = (
" { % f or message in messages % } "
" <|start_header_id|> "
" { % i f message.role == ' user ' % } "
" user<|end_header_id|> \n \n "
" { % i f message.content is iterable % } "
# <image>
" { % f or content in message.content % } "
" { % i f content.type == ' image_url ' % } "
" { % i f content.image_url is string % } "
" {{ content.image_url }} "
" { % e ndif % } "
" { % i f content.image_url is mapping % } "
" {{ content.image_url.url }} "
" { % e ndif % } "
" { % e ndif % } "
" { % e ndfor % } "
# Question:
" { % f or content in message.content % } "
" { % i f content.type == ' text ' % } "
" {{ content.text }} "
" { % e ndif % } "
" { % e ndfor % } "
" { % e ndif % } "
# Question:
" { % i f message.content is string % } "
" {{ message.content }} "
" { % e ndif % } "
" { % e ndif % } "
# Answer:
" { % i f message.role == ' assistant ' % } "
" assistant<|end_header_id|> \n \n "
" {{ message.content }} "
" { % e ndif % } "
" <|eot_id|> "
" { % e ndfor % } "
# Generation prompt
" { % i f add_generation_prompt % } "
" <|start_header_id|>assistant<|end_header_id|> \n \n "
" { % e ndif % } "
)
2024-02-12 15:56:07 -05:00
@register_chat_completion_handler ( " chatml-function-calling " )
def chatml_function_calling (
llama : llama . Llama ,
messages : List [ llama_types . ChatCompletionRequestMessage ] ,
functions : Optional [ List [ llama_types . ChatCompletionFunction ] ] = None ,
function_call : Optional [ llama_types . ChatCompletionRequestFunctionCall ] = None ,
tools : Optional [ List [ llama_types . ChatCompletionTool ] ] = None ,
tool_choice : Optional [ llama_types . ChatCompletionToolChoiceOption ] = None ,
temperature : float = 0.2 ,
top_p : float = 0.95 ,
top_k : int = 40 ,
min_p : float = 0.05 ,
typical_p : float = 1.0 ,
stream : bool = False ,
stop : Optional [ Union [ str , List [ str ] ] ] = [ ] ,
response_format : Optional [ llama_types . ChatCompletionRequestResponseFormat ] = None ,
max_tokens : Optional [ int ] = None ,
presence_penalty : float = 0.0 ,
frequency_penalty : float = 0.0 ,
repeat_penalty : float = 1.1 ,
tfs_z : float = 1.0 ,
mirostat_mode : int = 0 ,
mirostat_tau : float = 5.0 ,
mirostat_eta : float = 0.1 ,
model : Optional [ str ] = None ,
logits_processor : Optional [ llama . LogitsProcessorList ] = None ,
grammar : Optional [ llama . LlamaGrammar ] = None ,
2024-04-10 03:41:55 -04:00
logprobs : Optional [ bool ] = None ,
top_logprobs : Optional [ int ] = None ,
2024-02-12 15:56:07 -05:00
* * kwargs , # type: ignore
) - > Union [
llama_types . CreateChatCompletionResponse ,
Iterator [ llama_types . CreateChatCompletionStreamResponse ] ,
] :
2024-04-10 03:41:55 -04:00
print ( logprobs )
2024-02-12 15:56:07 -05:00
function_calling_template = (
" { % f or message in messages % } "
" <|im_start|> {{ message.role }} \n "
# System message
" { % i f message.role == ' system ' % } "
" {{ message.content }} "
" { % i f tool_calls % } "
" \n \n You have access to the following functions: \n "
" { % f or tool in tools % } "
" \n functions. {{ tool.function.name }}: \n "
" {{ tool.function.parameters | tojson }} "
" \n { % e ndfor % } "
" \n \n You can respond to users messages with either a single message or one or more function calls. "
" \n \n To respond with a message begin the message with ' message: ' , use the following format: "
" \n \n message: "
" \n <message> "
" \n \n To respond with one or more function calls begin the message with ' functions.<function_name>: ' , use the following format: "
" \n \n functions.<function_name>: "
' \n { " arg1 " : " value1 " , " arg2 " : " value2 " } '
" \n functions.<function_name>: "
' \n { " arg1 " : " value1 " , " arg2 " : " value2 " } '
" { % e ndif % } "
2024-02-13 23:02:50 -05:00
" <|im_end|> \n "
2024-02-12 15:56:07 -05:00
" { % e ndif % } "
# User message
" { % i f message.role == ' user ' % } "
" {{ message.content }} "
2024-02-13 23:02:50 -05:00
" <|im_end|> \n "
2024-02-12 15:56:07 -05:00
" { % e ndif % } "
# Assistant message
" { % i f message.role == ' assistant ' % } "
## Reglar message
" { % i f message.content and message.content | length > 0 % } "
2024-02-13 03:11:35 -05:00
" { % i f tool_calls % } "
2024-02-12 15:56:07 -05:00
" message: \n "
2024-02-13 03:11:35 -05:00
" { % e ndif % } "
2024-02-12 15:56:07 -05:00
" {{ message.content }} "
2024-02-13 23:02:50 -05:00
" <|im_end|> \n "
2024-02-12 15:56:07 -05:00
" { % e ndif % } "
## Function calls
2024-02-13 03:11:35 -05:00
" { % i f ' tool_calls ' in message % } "
2024-02-12 15:56:07 -05:00
" { % f or tool_call in message.tool_calls % } "
" functions. {{ tool_call.function.name }}: \n "
" {{ tool_call.function.arguments }} "
" { % e ndfor % } "
2024-02-13 23:02:50 -05:00
" <|im_end|> \n "
2024-02-12 15:56:07 -05:00
" { % e ndif % } "
" { % e ndif % } "
" { % e ndfor % } "
2024-02-13 23:02:50 -05:00
" { % i f add_generation_prompt % }<|im_start|>assistant \n { % e ndif % } "
2024-02-12 15:56:07 -05:00
)
template_renderer = jinja2 . Environment (
loader = jinja2 . BaseLoader ( ) ,
autoescape = jinja2 . select_autoescape ( [ " html " , " xml " ] ) ,
undefined = jinja2 . StrictUndefined ,
) . from_string ( function_calling_template )
# Convert legacy functions to tools
if functions is not None :
tools = [
{
" type " : " function " ,
" function " : function ,
}
for function in functions
]
# Convert legacy function_call to tool_choice
if function_call is not None :
if isinstance ( function_call , str ) and (
function_call == " none " or function_call == " auto "
) :
tool_choice = function_call
if isinstance ( function_call , dict ) and " name " in function_call :
tool_choice = {
" type " : " function " ,
" function " : {
" name " : function_call [ " name " ] ,
} ,
}
2024-02-13 23:02:50 -05:00
stop = [ stop , " <|im_end|> " ] if isinstance ( stop , str ) else stop + [ " <|im_end|> " ] if stop else [ " <|im_end|> " ]
2024-02-12 15:56:07 -05:00
# Case 1: No tool choice by user
if (
tool_choice is None
or ( isinstance ( tool_choice , str ) and tool_choice == " none " )
or tools is None
or len ( tools ) == 0
) :
prompt = template_renderer . render (
messages = messages ,
tools = [ ] ,
tool_calls = None ,
2024-02-13 03:24:41 -05:00
add_generation_prompt = True ,
2024-02-12 15:56:07 -05:00
)
2024-03-15 12:58:34 -04:00
2024-02-12 15:56:07 -05:00
if response_format is not None and response_format [ " type " ] == " json_object " :
2024-03-15 12:58:34 -04:00
grammar = _grammar_for_response_format ( response_format )
2024-02-12 15:56:07 -05:00
return _convert_completion_to_chat (
llama . create_completion (
prompt = prompt ,
temperature = temperature ,
top_p = top_p ,
top_k = top_k ,
min_p = min_p ,
typical_p = typical_p ,
stream = stream ,
stop = stop ,
max_tokens = max_tokens ,
presence_penalty = presence_penalty ,
frequency_penalty = frequency_penalty ,
repeat_penalty = repeat_penalty ,
tfs_z = tfs_z ,
mirostat_mode = mirostat_mode ,
mirostat_tau = mirostat_tau ,
mirostat_eta = mirostat_eta ,
model = model ,
logits_processor = logits_processor ,
grammar = grammar ,
2024-04-10 03:41:55 -04:00
logprobs = top_logprobs if logprobs else None ,
2024-02-12 15:56:07 -05:00
) ,
stream = stream ,
)
# Case 2: Tool choice by user
if isinstance ( tool_choice , dict ) :
tool_name = tool_choice [ " function " ] [ " name " ]
tool = next (
( tool for tool in tools if tool [ " function " ] [ " name " ] == tool_name ) , None
)
if tool is None :
raise ValueError ( f " Tool with name ' { tool_name } ' not found in tools " )
prompt = template_renderer . render (
messages = messages ,
tools = tools ,
tool_calls = True ,
2024-02-13 03:24:41 -05:00
add_generation_prompt = True ,
2024-02-12 15:56:07 -05:00
)
prompt + = f " functions. { tool_name } : \n "
try :
grammar = llama_grammar . LlamaGrammar . from_json_schema (
json . dumps ( tool [ " function " ] [ " parameters " ] ) , verbose = llama . verbose
)
except Exception as e :
grammar = llama_grammar . LlamaGrammar . from_string (
llama_grammar . JSON_GBNF , verbose = llama . verbose
)
if llama . verbose :
print (
" Failed to parse function body as JSON schema, falling back to default grammar "
)
print ( e )
completion_or_chunks = llama . create_completion (
prompt = prompt ,
temperature = temperature ,
top_p = top_p ,
top_k = top_k ,
min_p = min_p ,
typical_p = typical_p ,
stream = stream ,
stop = stop ,
max_tokens = max_tokens ,
presence_penalty = presence_penalty ,
frequency_penalty = frequency_penalty ,
repeat_penalty = repeat_penalty ,
tfs_z = tfs_z ,
mirostat_mode = mirostat_mode ,
mirostat_tau = mirostat_tau ,
mirostat_eta = mirostat_eta ,
model = model ,
logits_processor = logits_processor ,
grammar = grammar ,
)
return _convert_completion_to_chat_function (
tool_name , completion_or_chunks , stream
)
# Case 3: Automatic tool choice
assert isinstance ( tool_choice , str ) and tool_choice == " auto "
function_names = " | " . join (
[ f ''' " functions. { tool [ ' function ' ] [ ' name ' ] } : " ''' for tool in tools ]
)
initial_gbnf_tool_grammar = (
""" root ::= functions | " message: " \n """
f """ functions ::= { function_names } \n """
)
follow_up_gbnf_tool_grammar = (
""" root ::= functions | " <|im_end|> " \n """
f """ functions ::= { function_names } \n """
)
prompt = template_renderer . render (
messages = messages ,
tools = tools ,
tool_calls = True ,
2024-02-13 03:24:41 -05:00
add_generation_prompt = True ,
2024-02-12 15:56:07 -05:00
)
completion_or_chunks = llama . create_completion (
prompt = prompt ,
temperature = 0 ,
top_p = top_p ,
top_k = top_k ,
min_p = min_p ,
typical_p = typical_p ,
stream = False ,
stop = [ " : " ] ,
max_tokens = None ,
presence_penalty = presence_penalty ,
frequency_penalty = frequency_penalty ,
repeat_penalty = repeat_penalty ,
tfs_z = tfs_z ,
mirostat_mode = mirostat_mode ,
mirostat_tau = mirostat_tau ,
mirostat_eta = mirostat_eta ,
model = model ,
logits_processor = logits_processor ,
grammar = llama_grammar . LlamaGrammar . from_string (
initial_gbnf_tool_grammar , verbose = llama . verbose
) ,
)
completion : llama_types . CreateCompletionResponse = completion_or_chunks # type: ignore
text = completion [ " choices " ] [ 0 ] [ " text " ]
if " message " in text :
return _convert_completion_to_chat (
llama . create_completion (
prompt = prompt + " message: \n " ,
temperature = temperature ,
top_p = top_p ,
top_k = top_k ,
min_p = min_p ,
typical_p = typical_p ,
stream = stream ,
stop = [ " <|im_end|> " ] ,
2024-04-10 03:41:55 -04:00
logprobs = top_logprobs if logprobs else None ,
2024-02-12 15:56:07 -05:00
max_tokens = None ,
presence_penalty = presence_penalty ,
frequency_penalty = frequency_penalty ,
repeat_penalty = repeat_penalty ,
tfs_z = tfs_z ,
mirostat_mode = mirostat_mode ,
mirostat_tau = mirostat_tau ,
mirostat_eta = mirostat_eta ,
model = model ,
logits_processor = logits_processor ,
grammar = llama_grammar . LlamaGrammar . from_string (
follow_up_gbnf_tool_grammar , verbose = llama . verbose
) ,
) ,
stream = stream ,
)
# One or more function calls
tool_name = text [ len ( " functions. " ) : ]
tool = next ( ( tool for tool in tools if tool [ " function " ] [ " name " ] == tool_name ) , None )
if not stream :
2024-04-05 10:50:49 -04:00
completions : List [ llama_types . CreateCompletionResponse ] = [ ]
completions_tool_name : List [ str ] = [ ]
2024-02-12 15:56:07 -05:00
while tool is not None :
prompt + = f " functions. { tool_name } : \n "
try :
grammar = llama_grammar . LlamaGrammar . from_json_schema (
json . dumps ( tool [ " function " ] [ " parameters " ] ) , verbose = llama . verbose
)
except Exception as e :
grammar = llama_grammar . LlamaGrammar . from_string (
llama_grammar . JSON_GBNF , verbose = llama . verbose
)
if llama . verbose :
print (
" Failed to parse function body as JSON schema, falling back to default grammar "
)
print ( e )
completion_or_chunks = llama . create_completion (
prompt = prompt ,
temperature = temperature ,
top_p = top_p ,
top_k = top_k ,
min_p = min_p ,
typical_p = typical_p ,
stream = False ,
stop = stop ,
max_tokens = None ,
presence_penalty = presence_penalty ,
frequency_penalty = frequency_penalty ,
repeat_penalty = repeat_penalty ,
tfs_z = tfs_z ,
mirostat_mode = mirostat_mode ,
mirostat_tau = mirostat_tau ,
mirostat_eta = mirostat_eta ,
model = model ,
logits_processor = logits_processor ,
grammar = grammar ,
)
2024-04-05 10:50:49 -04:00
completion_or_chunks = cast ( llama_types . CreateCompletionResponse , completion_or_chunks )
2024-02-12 15:56:07 -05:00
completions . append ( completion_or_chunks )
completions_tool_name . append ( tool_name )
prompt + = completion_or_chunks [ " choices " ] [ 0 ] [ " text " ]
prompt + = " \n "
response = llama . create_completion (
prompt = prompt ,
temperature = temperature ,
top_p = top_p ,
top_k = top_k ,
min_p = min_p ,
typical_p = typical_p ,
stream = False ,
stop = stop ,
max_tokens = None ,
presence_penalty = presence_penalty ,
frequency_penalty = frequency_penalty ,
repeat_penalty = repeat_penalty ,
tfs_z = tfs_z ,
mirostat_mode = mirostat_mode ,
mirostat_tau = mirostat_tau ,
mirostat_eta = mirostat_eta ,
model = model ,
logits_processor = logits_processor ,
grammar = llama_grammar . LlamaGrammar . from_string (
follow_up_gbnf_tool_grammar , verbose = llama . verbose
) ,
)
2024-04-05 10:50:49 -04:00
response = cast ( llama_types . CreateCompletionResponse , response )
2024-02-12 15:56:07 -05:00
tool_name = response [ " choices " ] [ 0 ] [ " text " ] [ len ( " functions. " ) : ]
tool = next (
( tool for tool in tools if tool [ " function " ] [ " name " ] == tool_name ) , None
)
# Merge completions
2024-04-05 10:50:49 -04:00
function_call_dict : Union [ Dict [ str , str ] , Dict [ Literal [ " function_call " ] , llama_types . ChatCompletionRequestAssistantMessageFunctionCall ] ] = {
2024-02-12 15:56:07 -05:00
" function_call " : {
" name " : tool_name ,
" arguments " : completions [ 0 ] [ " choices " ] [ 0 ] [ " text " ] ,
}
} if len ( completions ) == 1 else { }
return {
" id " : " chat " + completion [ " id " ] ,
" object " : " chat.completion " ,
" created " : completion [ " created " ] ,
" model " : completion [ " model " ] ,
" choices " : [
{
" finish_reason " : " tool_calls " ,
" index " : 0 ,
2024-04-10 03:41:55 -04:00
" logprobs " : completion [ " choices " ] [ 0 ] [ " logprobs " ] ,
2024-02-12 15:56:07 -05:00
" message " : {
" role " : " assistant " ,
" content " : None ,
" tool_calls " : [
{
" id " : " call_ "
+ f " _ { i } _ "
+ tool_name
+ " _ "
+ completion [ " id " ] ,
" type " : " function " ,
" function " : {
" name " : tool_name ,
" arguments " : completion [ " choices " ] [ 0 ] [ " text " ] ,
} ,
}
for i , ( tool_name , completion ) in enumerate (
zip ( completions_tool_name , completions )
)
] ,
2024-04-05 10:50:49 -04:00
* * function_call_dict
2024-02-12 15:56:07 -05:00
} ,
}
] ,
" usage " : {
" completion_tokens " : sum (
2024-04-05 10:50:49 -04:00
completion [ " usage " ] [ " completion_tokens " ] if " usage " in completion else 0
2024-02-12 15:56:07 -05:00
for completion in completions
) ,
" prompt_tokens " : sum (
2024-04-05 10:50:49 -04:00
completion [ " usage " ] [ " prompt_tokens " ] if " usage " in completion else 0
for completion in completions
2024-02-12 15:56:07 -05:00
) ,
" total_tokens " : sum (
2024-04-05 10:50:49 -04:00
completion [ " usage " ] [ " total_tokens " ] if " usage " in completion else 0
for completion in completions
2024-02-12 15:56:07 -05:00
) ,
} ,
}
2024-05-08 07:21:27 +01:00
raise ValueError ( " Automatic streaming tool choice is not supported " )