fix: Fix and optimize functionary chat handler (#1282)
* fix functionary chat logic * further fixes --------- Co-authored-by: Andrei <abetlen@gmail.com>
This commit is contained in:
parent
8d298b4750
commit
8a60c7bc8c
1 changed files with 65 additions and 66 deletions
|
@ -1596,13 +1596,15 @@ def functionary_v1_v2_chat_handler(
|
|||
function_call = (
|
||||
tool_choice if isinstance(tool_choice, str) else tool_choice["function"]
|
||||
)
|
||||
else:
|
||||
function_call = "auto"
|
||||
|
||||
prompt = prepare_messages_for_inference(
|
||||
messages, tokenizer, version, functions, tools
|
||||
)
|
||||
|
||||
# If no tools/functions are provided
|
||||
if function_call is None and (functions is None or len(functions) == 0):
|
||||
if function_call == "none" or functions is None or len(functions) == 0:
|
||||
if version == "v1":
|
||||
stop = END_ASSISTANT_TOKEN
|
||||
else:
|
||||
|
@ -1630,6 +1632,7 @@ def functionary_v1_v2_chat_handler(
|
|||
logits_processor=logits_processor,
|
||||
grammar=grammar,
|
||||
)
|
||||
completion_or_completion_chunks["choices"][0]["text"] = completion_or_completion_chunks["choices"][0]["text"].lstrip()
|
||||
return _convert_completion_to_chat(completion_or_completion_chunks, stream=stream) # type: ignore
|
||||
|
||||
assert stream is False # TODO: support stream mode
|
||||
|
@ -1692,13 +1695,12 @@ def functionary_v1_v2_chat_handler(
|
|||
|
||||
return completion
|
||||
|
||||
content = ""
|
||||
function_calls, function_bodies = [], []
|
||||
|
||||
if version == "v1":
|
||||
# If no or "auto" tool_choice/function_call
|
||||
if function_call is None or (
|
||||
isinstance(function_call, str) and function_call == "auto"
|
||||
):
|
||||
if isinstance(function_call, str) and function_call == "auto":
|
||||
stops = ["\n", END_ASSISTANT_TOKEN]
|
||||
# If tool_choice/function_call is "none"
|
||||
elif isinstance(function_call, str) and function_call == "none":
|
||||
|
@ -1747,70 +1749,67 @@ def functionary_v1_v2_chat_handler(
|
|||
else:
|
||||
function_bodies.append(completion_text.strip())
|
||||
else:
|
||||
# Loop until all parallel function calls are generated
|
||||
while True:
|
||||
# If no or "auto" tool_choice/function_call
|
||||
if function_call is None or (
|
||||
isinstance(function_call, str) and function_call == "auto"
|
||||
):
|
||||
grammar = None
|
||||
stops = CONTENT_TOKEN
|
||||
# If tool_choice/function_call is "none"
|
||||
elif isinstance(function_call, str) and function_call == "none":
|
||||
prompt = (
|
||||
prepare_messages_for_inference(messages, tokenizer, version, [], [])
|
||||
+ "all\n<|content|>"
|
||||
)
|
||||
stops = STOP_TOKEN
|
||||
# If tool_choice/function_call is provided
|
||||
elif isinstance(function_call, dict):
|
||||
prompt += f"{function_call['name']}\n{CONTENT_TOKEN}"
|
||||
stops = STOP_TOKEN
|
||||
function_call = function_call["name"]
|
||||
function_calls.append(function_call)
|
||||
grammar = get_grammar(function_call)
|
||||
else:
|
||||
prompt = prompt
|
||||
stops = STOP_TOKEN
|
||||
|
||||
# If tool_choice/function_call is "none"
|
||||
if isinstance(function_call, str) and function_call == "none":
|
||||
prompt = (
|
||||
prepare_messages_for_inference(messages, tokenizer, version, [], [])
|
||||
+ "all\n<|content|>"
|
||||
)
|
||||
stops = [STOP_TOKEN, FROM_TOKEN]
|
||||
completion = create_completion(stop=stops)
|
||||
completion["choices"][0]["text"] = completion["choices"][0]["text"].strip()
|
||||
return _convert_completion_to_chat(completion, stream=stream) # type: ignore
|
||||
# If tool_choice/function_call is provided
|
||||
elif isinstance(function_call, dict):
|
||||
prompt += f"{function_call['name']}\n{CONTENT_TOKEN}"
|
||||
function_call = function_call["name"]
|
||||
function_calls.append(function_call)
|
||||
grammar = get_grammar(function_call)
|
||||
stops = [STOP_TOKEN, FROM_TOKEN]
|
||||
completion = create_completion(stop=stops)
|
||||
completion_text = completion["choices"][0]["text"]
|
||||
|
||||
# If the generation does not involve a function call
|
||||
if prompt.endswith("all\n<|content|>") and not completion_text.startswith(
|
||||
"all"
|
||||
):
|
||||
return _convert_completion_to_chat(completion, stream=stream) # type: ignore
|
||||
# Generate model response if the model decides not to call any function
|
||||
elif prompt.endswith(RECIPIENT_TOKEN) and completion_text.startswith("all"):
|
||||
prompt += completion_text + CONTENT_TOKEN
|
||||
completion = create_completion(stop=STOP_TOKEN)
|
||||
return _convert_completion_to_chat(completion, stream=stream) # type: ignore
|
||||
# Generate parameters if model decides to call a function
|
||||
elif prompt.endswith(RECIPIENT_TOKEN):
|
||||
function_calls.append(completion_text[:-1])
|
||||
grammar = get_grammar(function_calls[-1])
|
||||
completion = create_completion(stop=[STOP_TOKEN, "\n"])
|
||||
function_bodies.append(completion["choices"][0]["text"].strip())
|
||||
prompt += f"{function_calls[-1]}\n{CONTENT_TOKEN}{function_bodies[-1]}"
|
||||
function_bodies.append(completion_text.strip())
|
||||
# If "auto" or no tool_choice/function_call
|
||||
elif isinstance(function_call, str) and function_call == "auto":
|
||||
while True:
|
||||
# Generate function name first
|
||||
grammar = None
|
||||
|
||||
# Try to generate the beginning of next turn
|
||||
# If empty completion, break from loop
|
||||
next_turn_completion_text = create_completion(
|
||||
stop=[STOP_TOKEN, RECIPIENT_TOKEN]
|
||||
)["choices"][0]["text"]
|
||||
if len(next_turn_completion_text) > 0:
|
||||
prompt += f"\n{FROM_TOKEN}assistant\n{RECIPIENT_TOKEN}"
|
||||
stops = CONTENT_TOKEN
|
||||
completion = create_completion(stop=stops)
|
||||
completion_text = completion["choices"][0]["text"]
|
||||
function_name = completion_text.strip()
|
||||
if function_name == "all":
|
||||
prompt += "all\n<|content|>"
|
||||
else:
|
||||
break
|
||||
# Break from loop if tool_choice/function_call is provided as a dict
|
||||
else:
|
||||
function_bodies.append(completion_text.strip())
|
||||
break
|
||||
function_call = completion_text.strip()
|
||||
prompt += f"{function_call}\n<|content|>"
|
||||
function_calls.append(function_call)
|
||||
grammar = get_grammar(function_call)
|
||||
# Generate content
|
||||
stops = [RECIPIENT_TOKEN, STOP_TOKEN]
|
||||
completion = create_completion(stop=stops)
|
||||
completion_text = completion["choices"][0]["text"]
|
||||
if function_name == "all":
|
||||
content += completion_text.removesuffix("\n<|from|>assistant\n").removesuffix("\n<|from|> assistant\n")
|
||||
content = content.lstrip()
|
||||
# Check whether the model wants to generate another turn
|
||||
if "<|from|> assistant" in completion_text or "<|from|>assistant" in completion_text:
|
||||
cleaned_completion_text = completion_text.removesuffix("\n<|from|>assistant\n").removesuffix("\n<|from|> assistant\n").strip()
|
||||
prompt += f"{cleaned_completion_text}\n<|from|>assistant\n<|recipient|>"
|
||||
else:
|
||||
break
|
||||
else:
|
||||
function_bodies.append(completion_text.strip())
|
||||
# Check whether the model wants to generate another turn
|
||||
prompt += completion_text.strip()
|
||||
grammar = None
|
||||
completion = create_completion(stop=stops)
|
||||
if "<|from|> assistant" in completion["choices"][0]["text"] or "<|from|>assistant" in completion["choices"][0]["text"]:
|
||||
prompt += "\n<|from|>assistant\n<|recipient|>"
|
||||
else:
|
||||
break
|
||||
|
||||
assert "usage" in completion
|
||||
assert len(function_calls) > 0
|
||||
assert len(function_calls) == len(function_bodies)
|
||||
|
||||
tool_calls = []
|
||||
|
@ -1843,14 +1842,14 @@ def functionary_v1_v2_chat_handler(
|
|||
"index": 0,
|
||||
"message": {
|
||||
"role": "assistant",
|
||||
"content": None,
|
||||
"content": None if content == "" else content,
|
||||
"function_call": {
|
||||
"name": tool_calls[0]["function"]["name"],
|
||||
"arguments": tool_calls[0]["function"]["arguments"],
|
||||
},
|
||||
"tool_calls": tool_calls,
|
||||
} if len(tool_calls) > 0 else None,
|
||||
"tool_calls": tool_calls if len(tool_calls) > 0 else None,
|
||||
},
|
||||
"finish_reason": "tool_calls",
|
||||
"finish_reason": "tool_calls" if len(tool_calls) > 0 else "stop",
|
||||
}
|
||||
],
|
||||
usage=completion["usage"],
|
||||
|
|
Loading…
Add table
Reference in a new issue