fix: Fix and optimize functionary chat handler (#1282)

* fix functionary chat logic

* further fixes

---------

Co-authored-by: Andrei <abetlen@gmail.com>
This commit is contained in:
Jeffrey Fong 2024-03-18 22:40:57 +08:00 committed by GitHub
parent 8d298b4750
commit 8a60c7bc8c
No known key found for this signature in database
GPG key ID: B5690EEEBB952194

View file

@ -1596,13 +1596,15 @@ def functionary_v1_v2_chat_handler(
function_call = ( function_call = (
tool_choice if isinstance(tool_choice, str) else tool_choice["function"] tool_choice if isinstance(tool_choice, str) else tool_choice["function"]
) )
else:
function_call = "auto"
prompt = prepare_messages_for_inference( prompt = prepare_messages_for_inference(
messages, tokenizer, version, functions, tools messages, tokenizer, version, functions, tools
) )
# If no tools/functions are provided # If no tools/functions are provided
if function_call is None and (functions is None or len(functions) == 0): if function_call == "none" or functions is None or len(functions) == 0:
if version == "v1": if version == "v1":
stop = END_ASSISTANT_TOKEN stop = END_ASSISTANT_TOKEN
else: else:
@ -1630,6 +1632,7 @@ def functionary_v1_v2_chat_handler(
logits_processor=logits_processor, logits_processor=logits_processor,
grammar=grammar, grammar=grammar,
) )
completion_or_completion_chunks["choices"][0]["text"] = completion_or_completion_chunks["choices"][0]["text"].lstrip()
return _convert_completion_to_chat(completion_or_completion_chunks, stream=stream) # type: ignore return _convert_completion_to_chat(completion_or_completion_chunks, stream=stream) # type: ignore
assert stream is False # TODO: support stream mode assert stream is False # TODO: support stream mode
@ -1692,13 +1695,12 @@ def functionary_v1_v2_chat_handler(
return completion return completion
content = ""
function_calls, function_bodies = [], [] function_calls, function_bodies = [], []
if version == "v1": if version == "v1":
# If no or "auto" tool_choice/function_call # If no or "auto" tool_choice/function_call
if function_call is None or ( if isinstance(function_call, str) and function_call == "auto":
isinstance(function_call, str) and function_call == "auto"
):
stops = ["\n", END_ASSISTANT_TOKEN] stops = ["\n", END_ASSISTANT_TOKEN]
# If tool_choice/function_call is "none" # If tool_choice/function_call is "none"
elif isinstance(function_call, str) and function_call == "none": elif isinstance(function_call, str) and function_call == "none":
@ -1747,70 +1749,67 @@ def functionary_v1_v2_chat_handler(
else: else:
function_bodies.append(completion_text.strip()) function_bodies.append(completion_text.strip())
else: else:
# Loop until all parallel function calls are generated # If tool_choice/function_call is "none"
while True: if isinstance(function_call, str) and function_call == "none":
# If no or "auto" tool_choice/function_call prompt = (
if function_call is None or ( prepare_messages_for_inference(messages, tokenizer, version, [], [])
isinstance(function_call, str) and function_call == "auto" + "all\n<|content|>"
): )
grammar = None stops = [STOP_TOKEN, FROM_TOKEN]
stops = CONTENT_TOKEN completion = create_completion(stop=stops)
# If tool_choice/function_call is "none" completion["choices"][0]["text"] = completion["choices"][0]["text"].strip()
elif isinstance(function_call, str) and function_call == "none": return _convert_completion_to_chat(completion, stream=stream) # type: ignore
prompt = ( # If tool_choice/function_call is provided
prepare_messages_for_inference(messages, tokenizer, version, [], []) elif isinstance(function_call, dict):
+ "all\n<|content|>" prompt += f"{function_call['name']}\n{CONTENT_TOKEN}"
) function_call = function_call["name"]
stops = STOP_TOKEN function_calls.append(function_call)
# If tool_choice/function_call is provided grammar = get_grammar(function_call)
elif isinstance(function_call, dict): stops = [STOP_TOKEN, FROM_TOKEN]
prompt += f"{function_call['name']}\n{CONTENT_TOKEN}"
stops = STOP_TOKEN
function_call = function_call["name"]
function_calls.append(function_call)
grammar = get_grammar(function_call)
else:
prompt = prompt
stops = STOP_TOKEN
completion = create_completion(stop=stops) completion = create_completion(stop=stops)
completion_text = completion["choices"][0]["text"] completion_text = completion["choices"][0]["text"]
function_bodies.append(completion_text.strip())
# If the generation does not involve a function call # If "auto" or no tool_choice/function_call
if prompt.endswith("all\n<|content|>") and not completion_text.startswith( elif isinstance(function_call, str) and function_call == "auto":
"all" while True:
): # Generate function name first
return _convert_completion_to_chat(completion, stream=stream) # type: ignore
# Generate model response if the model decides not to call any function
elif prompt.endswith(RECIPIENT_TOKEN) and completion_text.startswith("all"):
prompt += completion_text + CONTENT_TOKEN
completion = create_completion(stop=STOP_TOKEN)
return _convert_completion_to_chat(completion, stream=stream) # type: ignore
# Generate parameters if model decides to call a function
elif prompt.endswith(RECIPIENT_TOKEN):
function_calls.append(completion_text[:-1])
grammar = get_grammar(function_calls[-1])
completion = create_completion(stop=[STOP_TOKEN, "\n"])
function_bodies.append(completion["choices"][0]["text"].strip())
prompt += f"{function_calls[-1]}\n{CONTENT_TOKEN}{function_bodies[-1]}"
grammar = None grammar = None
stops = CONTENT_TOKEN
# Try to generate the beginning of next turn completion = create_completion(stop=stops)
# If empty completion, break from loop completion_text = completion["choices"][0]["text"]
next_turn_completion_text = create_completion( function_name = completion_text.strip()
stop=[STOP_TOKEN, RECIPIENT_TOKEN] if function_name == "all":
)["choices"][0]["text"] prompt += "all\n<|content|>"
if len(next_turn_completion_text) > 0:
prompt += f"\n{FROM_TOKEN}assistant\n{RECIPIENT_TOKEN}"
else: else:
break function_call = completion_text.strip()
# Break from loop if tool_choice/function_call is provided as a dict prompt += f"{function_call}\n<|content|>"
else: function_calls.append(function_call)
function_bodies.append(completion_text.strip()) grammar = get_grammar(function_call)
break # Generate content
stops = [RECIPIENT_TOKEN, STOP_TOKEN]
completion = create_completion(stop=stops)
completion_text = completion["choices"][0]["text"]
if function_name == "all":
content += completion_text.removesuffix("\n<|from|>assistant\n").removesuffix("\n<|from|> assistant\n")
content = content.lstrip()
# Check whether the model wants to generate another turn
if "<|from|> assistant" in completion_text or "<|from|>assistant" in completion_text:
cleaned_completion_text = completion_text.removesuffix("\n<|from|>assistant\n").removesuffix("\n<|from|> assistant\n").strip()
prompt += f"{cleaned_completion_text}\n<|from|>assistant\n<|recipient|>"
else:
break
else:
function_bodies.append(completion_text.strip())
# Check whether the model wants to generate another turn
prompt += completion_text.strip()
grammar = None
completion = create_completion(stop=stops)
if "<|from|> assistant" in completion["choices"][0]["text"] or "<|from|>assistant" in completion["choices"][0]["text"]:
prompt += "\n<|from|>assistant\n<|recipient|>"
else:
break
assert "usage" in completion assert "usage" in completion
assert len(function_calls) > 0
assert len(function_calls) == len(function_bodies) assert len(function_calls) == len(function_bodies)
tool_calls = [] tool_calls = []
@ -1843,14 +1842,14 @@ def functionary_v1_v2_chat_handler(
"index": 0, "index": 0,
"message": { "message": {
"role": "assistant", "role": "assistant",
"content": None, "content": None if content == "" else content,
"function_call": { "function_call": {
"name": tool_calls[0]["function"]["name"], "name": tool_calls[0]["function"]["name"],
"arguments": tool_calls[0]["function"]["arguments"], "arguments": tool_calls[0]["function"]["arguments"],
}, } if len(tool_calls) > 0 else None,
"tool_calls": tool_calls, "tool_calls": tool_calls if len(tool_calls) > 0 else None,
}, },
"finish_reason": "tool_calls", "finish_reason": "tool_calls" if len(tool_calls) > 0 else "stop",
} }
], ],
usage=completion["usage"], usage=completion["usage"],