fix: Fix and optimize functionary chat handler (#1282)

* fix functionary chat logic

* further fixes

---------

Co-authored-by: Andrei <abetlen@gmail.com>
This commit is contained in:
Jeffrey Fong 2024-03-18 22:40:57 +08:00 committed by GitHub
parent 8d298b4750
commit 8a60c7bc8c
No known key found for this signature in database
GPG key ID: B5690EEEBB952194

View file

@ -1596,13 +1596,15 @@ def functionary_v1_v2_chat_handler(
function_call = (
tool_choice if isinstance(tool_choice, str) else tool_choice["function"]
)
else:
function_call = "auto"
prompt = prepare_messages_for_inference(
messages, tokenizer, version, functions, tools
)
# If no tools/functions are provided
if function_call is None and (functions is None or len(functions) == 0):
if function_call == "none" or functions is None or len(functions) == 0:
if version == "v1":
stop = END_ASSISTANT_TOKEN
else:
@ -1630,6 +1632,7 @@ def functionary_v1_v2_chat_handler(
logits_processor=logits_processor,
grammar=grammar,
)
completion_or_completion_chunks["choices"][0]["text"] = completion_or_completion_chunks["choices"][0]["text"].lstrip()
return _convert_completion_to_chat(completion_or_completion_chunks, stream=stream) # type: ignore
assert stream is False # TODO: support stream mode
@ -1692,13 +1695,12 @@ def functionary_v1_v2_chat_handler(
return completion
content = ""
function_calls, function_bodies = [], []
if version == "v1":
# If no or "auto" tool_choice/function_call
if function_call is None or (
isinstance(function_call, str) and function_call == "auto"
):
if isinstance(function_call, str) and function_call == "auto":
stops = ["\n", END_ASSISTANT_TOKEN]
# If tool_choice/function_call is "none"
elif isinstance(function_call, str) and function_call == "none":
@ -1747,70 +1749,67 @@ def functionary_v1_v2_chat_handler(
else:
function_bodies.append(completion_text.strip())
else:
# Loop until all parallel function calls are generated
while True:
# If no or "auto" tool_choice/function_call
if function_call is None or (
isinstance(function_call, str) and function_call == "auto"
):
grammar = None
stops = CONTENT_TOKEN
# If tool_choice/function_call is "none"
elif isinstance(function_call, str) and function_call == "none":
if isinstance(function_call, str) and function_call == "none":
prompt = (
prepare_messages_for_inference(messages, tokenizer, version, [], [])
+ "all\n<|content|>"
)
stops = STOP_TOKEN
stops = [STOP_TOKEN, FROM_TOKEN]
completion = create_completion(stop=stops)
completion["choices"][0]["text"] = completion["choices"][0]["text"].strip()
return _convert_completion_to_chat(completion, stream=stream) # type: ignore
# If tool_choice/function_call is provided
elif isinstance(function_call, dict):
prompt += f"{function_call['name']}\n{CONTENT_TOKEN}"
stops = STOP_TOKEN
function_call = function_call["name"]
function_calls.append(function_call)
grammar = get_grammar(function_call)
else:
prompt = prompt
stops = STOP_TOKEN
stops = [STOP_TOKEN, FROM_TOKEN]
completion = create_completion(stop=stops)
completion_text = completion["choices"][0]["text"]
# If the generation does not involve a function call
if prompt.endswith("all\n<|content|>") and not completion_text.startswith(
"all"
):
return _convert_completion_to_chat(completion, stream=stream) # type: ignore
# Generate model response if the model decides not to call any function
elif prompt.endswith(RECIPIENT_TOKEN) and completion_text.startswith("all"):
prompt += completion_text + CONTENT_TOKEN
completion = create_completion(stop=STOP_TOKEN)
return _convert_completion_to_chat(completion, stream=stream) # type: ignore
# Generate parameters if model decides to call a function
elif prompt.endswith(RECIPIENT_TOKEN):
function_calls.append(completion_text[:-1])
grammar = get_grammar(function_calls[-1])
completion = create_completion(stop=[STOP_TOKEN, "\n"])
function_bodies.append(completion["choices"][0]["text"].strip())
prompt += f"{function_calls[-1]}\n{CONTENT_TOKEN}{function_bodies[-1]}"
function_bodies.append(completion_text.strip())
# If "auto" or no tool_choice/function_call
elif isinstance(function_call, str) and function_call == "auto":
while True:
# Generate function name first
grammar = None
# Try to generate the beginning of next turn
# If empty completion, break from loop
next_turn_completion_text = create_completion(
stop=[STOP_TOKEN, RECIPIENT_TOKEN]
)["choices"][0]["text"]
if len(next_turn_completion_text) > 0:
prompt += f"\n{FROM_TOKEN}assistant\n{RECIPIENT_TOKEN}"
stops = CONTENT_TOKEN
completion = create_completion(stop=stops)
completion_text = completion["choices"][0]["text"]
function_name = completion_text.strip()
if function_name == "all":
prompt += "all\n<|content|>"
else:
function_call = completion_text.strip()
prompt += f"{function_call}\n<|content|>"
function_calls.append(function_call)
grammar = get_grammar(function_call)
# Generate content
stops = [RECIPIENT_TOKEN, STOP_TOKEN]
completion = create_completion(stop=stops)
completion_text = completion["choices"][0]["text"]
if function_name == "all":
content += completion_text.removesuffix("\n<|from|>assistant\n").removesuffix("\n<|from|> assistant\n")
content = content.lstrip()
# Check whether the model wants to generate another turn
if "<|from|> assistant" in completion_text or "<|from|>assistant" in completion_text:
cleaned_completion_text = completion_text.removesuffix("\n<|from|>assistant\n").removesuffix("\n<|from|> assistant\n").strip()
prompt += f"{cleaned_completion_text}\n<|from|>assistant\n<|recipient|>"
else:
break
# Break from loop if tool_choice/function_call is provided as a dict
else:
function_bodies.append(completion_text.strip())
# Check whether the model wants to generate another turn
prompt += completion_text.strip()
grammar = None
completion = create_completion(stop=stops)
if "<|from|> assistant" in completion["choices"][0]["text"] or "<|from|>assistant" in completion["choices"][0]["text"]:
prompt += "\n<|from|>assistant\n<|recipient|>"
else:
break
assert "usage" in completion
assert len(function_calls) > 0
assert len(function_calls) == len(function_bodies)
tool_calls = []
@ -1843,14 +1842,14 @@ def functionary_v1_v2_chat_handler(
"index": 0,
"message": {
"role": "assistant",
"content": None,
"content": None if content == "" else content,
"function_call": {
"name": tool_calls[0]["function"]["name"],
"arguments": tool_calls[0]["function"]["arguments"],
} if len(tool_calls) > 0 else None,
"tool_calls": tool_calls if len(tool_calls) > 0 else None,
},
"tool_calls": tool_calls,
},
"finish_reason": "tool_calls",
"finish_reason": "tool_calls" if len(tool_calls) > 0 else "stop",
}
],
usage=completion["usage"],