feat: Implement streaming for Functionary v2 + Bug fixes (#1419)

* set up streaming for v2

* assert v2 streaming, fix tool_call vs function_call

* fix streaming with tool_choice/function_call

* make functions return 1 function call only when 'auto'

* fix

---------

Co-authored-by: Andrei <abetlen@gmail.com>
This commit is contained in:
Jeffrey Fong 2024-05-04 22:11:20 +08:00 committed by GitHub
parent f9b7221c8f
commit 1f56c648c3
No known key found for this signature in database
GPG key ID: B5690EEEBB952194

View file

@ -1894,6 +1894,8 @@ def functionary_v1_v2_chat_handler(
function_call = ( function_call = (
tool_choice if isinstance(tool_choice, str) else tool_choice["function"] tool_choice if isinstance(tool_choice, str) else tool_choice["function"]
) )
elif function_call is not None:
pass
else: else:
function_call = "auto" function_call = "auto"
@ -1930,11 +1932,10 @@ def functionary_v1_v2_chat_handler(
logits_processor=logits_processor, logits_processor=logits_processor,
grammar=grammar, grammar=grammar,
) )
completion_or_completion_chunks["choices"][0]["text"] = completion_or_completion_chunks["choices"][0]["text"].lstrip() if stream is False:
completion_or_completion_chunks["choices"][0]["text"] = completion_or_completion_chunks["choices"][0]["text"].lstrip()
return _convert_completion_to_chat(completion_or_completion_chunks, stream=stream) # type: ignore return _convert_completion_to_chat(completion_or_completion_chunks, stream=stream) # type: ignore
assert stream is False # TODO: support stream mode
def get_grammar(function_call): def get_grammar(function_call):
function_body = None function_body = None
for function in functions or []: for function in functions or []:
@ -1968,7 +1969,7 @@ def functionary_v1_v2_chat_handler(
return grammar return grammar
def create_completion(stop): def create_completion(prompt, stop, grammar):
completion = cast(llama_types.Completion, llama.create_completion( completion = cast(llama_types.Completion, llama.create_completion(
prompt=prompt, prompt=prompt,
temperature=temperature, temperature=temperature,
@ -1976,7 +1977,7 @@ def functionary_v1_v2_chat_handler(
top_k=top_k, top_k=top_k,
min_p=min_p, min_p=min_p,
typical_p=typical_p, typical_p=typical_p,
stream=False, stream=stream,
stop=stop, stop=stop,
max_tokens=max_tokens, max_tokens=max_tokens,
presence_penalty=presence_penalty, presence_penalty=presence_penalty,
@ -1997,171 +1998,480 @@ def functionary_v1_v2_chat_handler(
function_calls, function_bodies = [], [] function_calls, function_bodies = [], []
completion_tokens = 0 completion_tokens = 0
if version == "v1": def generate_streaming(tools, functions, function_call, prompt):
# If no or "auto" tool_choice/function_call assert version == "v2", "Streaming for v1 is not supported"
if isinstance(function_call, str) and function_call == "auto":
stops = ["\n", END_ASSISTANT_TOKEN]
# If tool_choice/function_call is provided
elif isinstance(function_call, dict):
prompt += f"{START_FUNCTION_CALL_TOKEN}{function_call['name']}:\n"
stops = END_FUNCTION_CALL_TOKEN
function_call = function_call["name"]
function_calls.append(function_call)
grammar = get_grammar(function_call)
else:
prompt = prompt
stops = ["\n", END_ASSISTANT_TOKEN]
completion = create_completion(stop=stops) chunk_id, chunk_created = None, None
completion_text = completion["choices"][0]["text"]
completion_tokens += completion["usage"]["completion_tokens"]
# If the generation does not involve a function call
if (
START_FUNCTION_CALL_TOKEN not in prompt
and START_FUNCTION_CALL_TOKEN not in completion_text
):
completion["usage"]["completion_tokens"] = completion_tokens
return _convert_completion_to_chat(completion, stream=stream) # type: ignore
# If the generation involves a function call in completion, generate the parameters
elif (
START_FUNCTION_CALL_TOKEN not in prompt
and START_FUNCTION_CALL_TOKEN in completion_text
):
prompt += (
completion_text.replace(
f"{START_FUNCTION_CALL_TOKEN} ", START_FUNCTION_CALL_TOKEN
)
+ "\n"
)
function_calls.append(
completion_text.split(START_FUNCTION_CALL_TOKEN)[-1][:-1].strip()
)
grammar = get_grammar(function_calls[-1])
completion = create_completion(stop=END_FUNCTION_CALL_TOKEN)
completion_tokens += completion["usage"]["completion_tokens"]
function_bodies.append(completion["choices"][0]["text"].strip())
# If the prompt involves a function call, just append generated parameters to function_bodies
else:
function_bodies.append(completion_text.strip())
else:
# If tool_choice/function_call is provided # If tool_choice/function_call is provided
if isinstance(function_call, dict): if isinstance(function_call, dict):
prompt += f"{function_call['name']}\n{CONTENT_TOKEN}" prompt += f"{function_call['name']}\n{CONTENT_TOKEN}"
function_call = function_call["name"] grammar = get_grammar(function_call["name"])
function_calls.append(function_call)
grammar = get_grammar(function_call)
stops = [STOP_TOKEN, FROM_TOKEN] stops = [STOP_TOKEN, FROM_TOKEN]
completion = create_completion(stop=stops) tool_id = "".join([random.choice(string.ascii_letters + string.digits) for _ in range(24)])
completion_text = completion["choices"][0]["text"] completion = create_completion(prompt=prompt, stop=stops, grammar=grammar)
completion_tokens += completion["usage"]["completion_tokens"] completion_text = ""
function_bodies.append(completion_text.strip()) first = True
for chunk in completion:
# Yield the tool/function name first
if first:
if tools is not None:
func_call_dict = {
"tool_calls": [
{
"index": 0,
"id": "call_" + tool_id,
"type": "function",
"function": {"name": function_call["name"], "arguments": ""},
}
]
}
else:
func_call_dict = {"function_call": {"name": function_call["name"], "arguments": ""}}
yield llama_types.CreateChatCompletionStreamResponse(
id="chat" + chunk["id"],
object="chat.completion.chunk",
created=chunk["created"],
model=chunk["model"],
choices=[
{"index": 0, "logprobs": None, "delta": {"role": None, "content": None, **func_call_dict}}
],
)
first = False
if tools is not None:
func_call_dict = {
"tool_calls": [
{
"index": 0,
"id": "call_" + tool_id,
"type": "function",
"function": {
"name": None,
"arguments": chunk["choices"][0]["text"].rstrip(),
},
}
]
}
else:
func_call_dict = {"function_call": {"name": None, "arguments": chunk["choices"][0]["text"].rstrip()}}
if len(chunk["choices"][0]["text"].rstrip()) > 0:
yield llama_types.CreateChatCompletionStreamResponse(
id="chat" + chunk["id"],
object="chat.completion.chunk",
created=chunk["created"],
model=chunk["model"],
choices=[
{
"index": 0,
"logprobs": chunk["choices"][0]["logprobs"],
"delta": {
"role": None,
"content": None,
**func_call_dict,
},
}
],
)
# Yield tool_call/function_call stop message
yield llama_types.CreateChatCompletionStreamResponse(
id="chat" + chunk["id"],
object="chat.completion.chunk",
created=chunk["created"],
model=chunk["model"],
choices=[
{
"index": 0,
"finish_reason": "tool_calls" if tools is not None else "function_call",
"logprobs": None,
"delta": {
"role": None, "content": None, "function_call": None, "tool_calls": None
},
}
],
)
# If "auto" or no tool_choice/function_call # If "auto" or no tool_choice/function_call
elif isinstance(function_call, str) and function_call == "auto": elif isinstance(function_call, str) and function_call == "auto":
tool_index = 0
while True: while True:
# Generate function name first # Generate function name first
grammar = None grammar = None
stops = CONTENT_TOKEN stops = CONTENT_TOKEN
completion = create_completion(stop=stops) completion = create_completion(prompt=prompt, stop=stops, grammar=grammar)
completion_text = completion["choices"][0]["text"] completion_text = ""
completion_tokens += completion["usage"]["completion_tokens"] for chunk in completion:
completion_text += chunk["choices"][0]["text"]
if chunk_id is None:
chunk_id = chunk["id"]
if chunk_created is None:
chunk_created = chunk["created"]
function_name = completion_text.strip() function_name = completion_text.strip()
if function_name == "all": if function_name == "all":
prompt += "all\n<|content|>" prompt += "all\n<|content|>"
# Yield the first empty message for content
yield llama_types.CreateChatCompletionStreamResponse(
id="chat" + chunk_id,
model=chunk["model"],
created=chunk_created,
object="chat.completion.chunk",
choices=[
{
"index": 0,
"delta": {"role": "assistant", "content": ""},
"logprobs": None,
"finish_reason": None,
}
],
)
else: else:
function_call = completion_text.strip() prompt += f"{function_name}\n<|content|>"
prompt += f"{function_call}\n<|content|>" grammar = get_grammar(function_name)
function_calls.append(function_call) tool_id = "".join([random.choice(string.ascii_letters + string.digits) for _ in range(24)])
grammar = get_grammar(function_call) if tools is not None:
func_call_dict = {
"tool_calls": [
{
"index": tool_index,
"id": "call_" + tool_id,
"type": "function",
"function": {"name": function_name, "arguments": ""},
}
]
}
else:
func_call_dict = {"function_call": {"name": function_name, "arguments": ""}}
# Stream function name
yield llama_types.CreateChatCompletionStreamResponse(
id="chat" + chunk_id,
object="chat.completion.chunk",
created=chunk_created,
model=chunk["model"],
choices=[
{
"index": 0,
"logprobs": chunk["choices"][0]["logprobs"],
"delta": {
"role": "assistant",
"content": None,
**func_call_dict,
},
}
],
)
# Generate content # Generate content
stops = [RECIPIENT_TOKEN, STOP_TOKEN] stops = [RECIPIENT_TOKEN, STOP_TOKEN]
completion = create_completion(stop=stops) completion = create_completion(prompt=prompt, stop=stops, grammar=grammar)
completion_text = completion["choices"][0]["text"]
completion_tokens += completion["usage"]["completion_tokens"]
if function_name == "all": if function_name == "all":
if completion_text.endswith("\n<|from|>assistant\n"): completion_text = ""
content += completion_text[:-len("\n<|from|>assistant\n")] stop_sequence, buffer, is_end = "\n<|from|>assistant\n<|recipient|>", [], False
if completion_text.endswith("\n<|from|> assistant\n"): for i, chunk in enumerate(completion):
content += completion_text[-len("\n<|from|> assistant\n")] completion_text += chunk["choices"][0]["text"]
else: if is_end:
content += completion_text buffer.append(chunk["choices"][0]["text"].strip(" "))
content = content.lstrip() if stop_sequence.startswith("".join(buffer)):
continue
else:
buffer.pop()
while len(buffer) > 0:
yield llama_types.CreateChatCompletionStreamResponse(
id="chat" + chunk_id,
object="chat.completion.chunk",
created=chunk_created,
model=chunk["model"],
choices=[
{
"index": 0,
"logprobs": chunk["choices"][0]["logprobs"],
"delta": {
"role": "assistant", "content": buffer.pop(0)
},
}
],
)
is_end = False
elif chunk["choices"][0]["text"] == "\n":
is_end = True
buffer.append(chunk["choices"][0]["text"].strip(" "))
continue
if len(buffer) == 0 and len(chunk["choices"][0]["text"]) > 0:
yield llama_types.CreateChatCompletionStreamResponse(
id="chat" + chunk_id,
object="chat.completion.chunk",
created=chunk_created,
model=chunk["model"],
choices=[
{
"index": 0,
"logprobs": chunk["choices"][0]["logprobs"],
"delta": {
"role": "assistant",
"content": chunk["choices"][0]["text"] if i > 0 else chunk["choices"][0]["text"].lstrip()
},
}
],
)
# Check whether the model wants to generate another turn # Check whether the model wants to generate another turn
if "<|from|> assistant" in completion_text or "<|from|>assistant" in completion_text: if "<|from|> assistant" in completion_text or "<|from|>assistant" in completion_text:
if completion_text.endswith("\n<|from|>assistant\n"): if completion_text.endswith("\n<|from|>assistant\n"):
cleaned_completion_text = completion_text[:-len("\n<|from|>assistant\n")].strip() cleaned_completion_text = completion_text[:-len("\n<|from|>assistant\n")].strip()
elif completion_text.endswith("\n<|from|> assistant\n"): elif completion_text.endswith("\n<|from|> assistant\n"):
cleaned_completion_text = completion_text[-len("\n<|from|> assistant\n")].strip() cleaned_completion_text = completion_text[:-len("\n<|from|> assistant\n")].strip()
else: else:
cleaned_completion_text = completion_text.strip() cleaned_completion_text = completion_text.strip()
prompt += f"{cleaned_completion_text}\n<|from|>assistant\n<|recipient|>" prompt += f"{cleaned_completion_text}\n<|from|>assistant\n<|recipient|>"
else: else:
# Yield stop message
yield llama_types.CreateChatCompletionStreamResponse(
id="chat" + chunk_id,
model=chunk["model"],
created=chunk_created,
object="chat.completion.chunk",
choices=[
{
"index": 0,
"delta": {},
"logprobs": None,
"finish_reason": "stop",
}
],
)
break break
else: else:
function_bodies.append(completion_text.strip())
# Check whether the model wants to generate another turn # Check whether the model wants to generate another turn
completion_text = ""
for chunk in completion:
completion_text += chunk["choices"][0]["text"]
if len(chunk["choices"][0]["text"].rstrip()) > 0:
if tools is not None:
func_call_dict = {
"tool_calls": [
{
"index": tool_index,
"id": "call_" + tool_id,
"type": "function",
"function": {
"name": None,
"arguments": chunk["choices"][0]["text"].rstrip(),
},
}
]
}
else:
func_call_dict = {"function_call": {"name": None, "arguments": chunk["choices"][0]["text"].rstrip()}}
yield llama_types.CreateChatCompletionStreamResponse(
id="chat" + chunk_id,
object="chat.completion.chunk",
created=chunk_created,
model=chunk["model"],
choices=[
{
"index": 0,
"logprobs": chunk["choices"][0]["logprobs"],
"delta": {
"role": None,
"content": None,
**func_call_dict,
},
}
],
)
prompt += completion_text.strip() prompt += completion_text.strip()
grammar = None grammar = None
completion = create_completion(stop=stops) completion = create_completion(prompt=prompt, stop=stops, grammar=grammar)
completion_tokens += completion["usage"]["completion_tokens"] completion_text += "".join([chunk["choices"][0]["text"] for chunk in completion])
if "<|from|> assistant" in completion["choices"][0]["text"] or "<|from|>assistant" in completion["choices"][0]["text"]: if ("<|from|> assistant" in completion_text or "<|from|>assistant" in completion_text) and tools is not None:
prompt += "\n<|from|>assistant\n<|recipient|>" prompt += "\n<|from|>assistant\n<|recipient|>"
tool_index += 1
else: else:
# Yield tool_call/function_call stop message
yield llama_types.CreateChatCompletionStreamResponse(
id="chat" + chunk_id,
object="chat.completion.chunk",
created=chunk_created,
model=chunk["model"],
choices=[
{
"index": 0,
"finish_reason": "tool_calls" if tools is not None else "function_call",
"logprobs": None,
"delta": {
"role": None, "content": None, "function_call": None, "tool_calls": None
},
}
],
)
break break
assert "usage" in completion if stream is not False:
assert len(function_calls) == len(function_bodies) return generate_streaming(
tools=tools, functions=functions, function_call=function_call, prompt=prompt
tool_calls: List[llama_types.ChatCompletionMessageToolCall] = []
for function_call, function_body in zip(function_calls, function_bodies):
tool_calls.append(
{
"id": "call_"
+ "".join(
[
random.choice(string.ascii_letters + string.digits)
for _ in range(24)
]
),
"type": "function",
"function": {
"name": function_call,
"arguments": function_body,
},
}
) )
else:
if version == "v1":
# If no or "auto" tool_choice/function_call
if isinstance(function_call, str) and function_call == "auto":
stops = ["\n", END_ASSISTANT_TOKEN]
# If tool_choice/function_call is provided
elif isinstance(function_call, dict):
prompt += f"{START_FUNCTION_CALL_TOKEN}{function_call['name']}:\n"
stops = END_FUNCTION_CALL_TOKEN
function_call = function_call["name"]
function_calls.append(function_call)
grammar = get_grammar(function_call)
else:
prompt = prompt
stops = ["\n", END_ASSISTANT_TOKEN]
# TODO: support stream mode completion = create_completion(stop=stops)
function_call_dict: Union[Dict[str, str], Dict[Literal["function_call"], llama_types.ChatCompletionRequestAssistantMessageFunctionCall]] = {} completion_text = completion["choices"][0]["text"]
if len(tool_calls) > 0: completion_tokens += completion["usage"]["completion_tokens"]
if tools is not None:
function_call_dict["tool_calls"] = tool_calls
# If the generation does not involve a function call
if (
START_FUNCTION_CALL_TOKEN not in prompt
and START_FUNCTION_CALL_TOKEN not in completion_text
):
completion["usage"]["completion_tokens"] = completion_tokens
return _convert_completion_to_chat(completion, stream=stream) # type: ignore
# If the generation involves a function call in completion, generate the parameters
elif (
START_FUNCTION_CALL_TOKEN not in prompt
and START_FUNCTION_CALL_TOKEN in completion_text
):
prompt += (
completion_text.replace(
f"{START_FUNCTION_CALL_TOKEN} ", START_FUNCTION_CALL_TOKEN
)
+ "\n"
)
function_calls.append(
completion_text.split(START_FUNCTION_CALL_TOKEN)[-1][:-1].strip()
)
grammar = get_grammar(function_calls[-1])
completion = create_completion(stop=END_FUNCTION_CALL_TOKEN)
completion_tokens += completion["usage"]["completion_tokens"]
function_bodies.append(completion["choices"][0]["text"].strip())
# If the prompt involves a function call, just append generated parameters to function_bodies
else:
function_bodies.append(completion_text.strip())
else: else:
function_call_dict["function_call"] = { # If tool_choice/function_call is provided
"name": tool_calls[0]["function"]["name"], if isinstance(function_call, dict):
"arguments": tool_calls[0]["function"]["arguments"], prompt += f"{function_call['name']}\n{CONTENT_TOKEN}"
} function_call = function_call["name"]
completion["usage"]["completion_tokens"] = completion_tokens function_calls.append(function_call)
return llama_types.CreateChatCompletionResponse( grammar = get_grammar(function_call)
id="chat" + completion["id"], stops = [STOP_TOKEN, FROM_TOKEN]
object="chat.completion", completion = create_completion(stop=stops)
created=completion["created"], completion_text = completion["choices"][0]["text"]
model=completion["model"], completion_tokens += completion["usage"]["completion_tokens"]
choices=[ function_bodies.append(completion_text.strip())
{ # If "auto" or no tool_choice/function_call
"index": 0, elif isinstance(function_call, str) and function_call == "auto":
"logprobs": completion["choices"][0]["logprobs"], while True:
"message": { # Generate function name first
"role": "assistant", grammar = None
"content": None if content == "" else content, stops = CONTENT_TOKEN
**function_call_dict, completion = create_completion(stop=stops)
}, completion_text = completion["choices"][0]["text"]
"finish_reason": "tool_calls" if len(tool_calls) > 0 else "stop", completion_tokens += completion["usage"]["completion_tokens"]
} function_name = completion_text.strip()
], if function_name == "all":
usage=completion["usage"], prompt += "all\n<|content|>"
) else:
function_call = completion_text.strip()
prompt += f"{function_call}\n<|content|>"
function_calls.append(function_call)
grammar = get_grammar(function_call)
# Generate content
stops = [RECIPIENT_TOKEN, STOP_TOKEN]
completion = create_completion(stop=stops)
completion_text = completion["choices"][0]["text"]
completion_tokens += completion["usage"]["completion_tokens"]
if function_name == "all":
if completion_text.endswith("\n<|from|>assistant\n"):
content += completion_text[:-len("\n<|from|>assistant\n")]
if completion_text.endswith("\n<|from|> assistant\n"):
content += completion_text[-len("\n<|from|> assistant\n")]
else:
content += completion_text
content = content.lstrip()
# Check whether the model wants to generate another turn
if "<|from|> assistant" in completion_text or "<|from|>assistant" in completion_text:
if completion_text.endswith("\n<|from|>assistant\n"):
cleaned_completion_text = completion_text[:-len("\n<|from|>assistant\n")].strip()
elif completion_text.endswith("\n<|from|> assistant\n"):
cleaned_completion_text = completion_text[-len("\n<|from|> assistant\n")].strip()
else:
cleaned_completion_text = completion_text.strip()
prompt += f"{cleaned_completion_text}\n<|from|>assistant\n<|recipient|>"
else:
break
else:
function_bodies.append(completion_text.strip())
# Check whether the model wants to generate another turn
prompt += completion_text.strip()
grammar = None
completion = create_completion(stop=stops)
completion_tokens += completion["usage"]["completion_tokens"]
if "<|from|> assistant" in completion["choices"][0]["text"] or "<|from|>assistant" in completion["choices"][0]["text"]:
prompt += "\n<|from|>assistant\n<|recipient|>"
else:
break
assert "usage" in completion
assert len(function_calls) == len(function_bodies)
tool_calls: List[llama_types.ChatCompletionMessageToolCall] = []
for function_call, function_body in zip(function_calls, function_bodies):
tool_calls.append(
{
"id": "call_"
+ "".join(
[
random.choice(string.ascii_letters + string.digits)
for _ in range(24)
]
),
"type": "function",
"function": {
"name": function_call,
"arguments": function_body,
},
}
)
# TODO: support stream mode
function_call_dict: Union[Dict[str, str], Dict[Literal["function_call"], llama_types.ChatCompletionRequestAssistantMessageFunctionCall]] = {}
if len(tool_calls) > 0:
if tools is not None:
function_call_dict["tool_calls"] = tool_calls
else:
function_call_dict["function_call"] = {
"name": tool_calls[0]["function"]["name"],
"arguments": tool_calls[0]["function"]["arguments"],
}
completion["usage"]["completion_tokens"] = completion_tokens
return llama_types.CreateChatCompletionResponse(
id="chat" + completion["id"],
object="chat.completion",
created=completion["created"],
model=completion["model"],
choices=[
{
"index": 0,
"logprobs": completion["choices"][0]["logprobs"],
"message": {
"role": "assistant",
"content": None if content == "" else content,
**function_call_dict,
},
"finish_reason": "tool_calls" if len(tool_calls) > 0 else "stop",
}
],
usage=completion["usage"],
)
class Llava15ChatHandler: class Llava15ChatHandler: