feat: Implement streaming for Functionary v2 + Bug fixes (#1419)
* set up streaming for v2 * assert v2 streaming, fix tool_call vs function_call * fix streaming with tool_choice/function_call * make functions return 1 function call only when 'auto' * fix --------- Co-authored-by: Andrei <abetlen@gmail.com>
This commit is contained in:
parent
f9b7221c8f
commit
1f56c648c3
1 changed files with 443 additions and 133 deletions
|
@ -1894,6 +1894,8 @@ def functionary_v1_v2_chat_handler(
|
||||||
function_call = (
|
function_call = (
|
||||||
tool_choice if isinstance(tool_choice, str) else tool_choice["function"]
|
tool_choice if isinstance(tool_choice, str) else tool_choice["function"]
|
||||||
)
|
)
|
||||||
|
elif function_call is not None:
|
||||||
|
pass
|
||||||
else:
|
else:
|
||||||
function_call = "auto"
|
function_call = "auto"
|
||||||
|
|
||||||
|
@ -1930,11 +1932,10 @@ def functionary_v1_v2_chat_handler(
|
||||||
logits_processor=logits_processor,
|
logits_processor=logits_processor,
|
||||||
grammar=grammar,
|
grammar=grammar,
|
||||||
)
|
)
|
||||||
completion_or_completion_chunks["choices"][0]["text"] = completion_or_completion_chunks["choices"][0]["text"].lstrip()
|
if stream is False:
|
||||||
|
completion_or_completion_chunks["choices"][0]["text"] = completion_or_completion_chunks["choices"][0]["text"].lstrip()
|
||||||
return _convert_completion_to_chat(completion_or_completion_chunks, stream=stream) # type: ignore
|
return _convert_completion_to_chat(completion_or_completion_chunks, stream=stream) # type: ignore
|
||||||
|
|
||||||
assert stream is False # TODO: support stream mode
|
|
||||||
|
|
||||||
def get_grammar(function_call):
|
def get_grammar(function_call):
|
||||||
function_body = None
|
function_body = None
|
||||||
for function in functions or []:
|
for function in functions or []:
|
||||||
|
@ -1968,7 +1969,7 @@ def functionary_v1_v2_chat_handler(
|
||||||
|
|
||||||
return grammar
|
return grammar
|
||||||
|
|
||||||
def create_completion(stop):
|
def create_completion(prompt, stop, grammar):
|
||||||
completion = cast(llama_types.Completion, llama.create_completion(
|
completion = cast(llama_types.Completion, llama.create_completion(
|
||||||
prompt=prompt,
|
prompt=prompt,
|
||||||
temperature=temperature,
|
temperature=temperature,
|
||||||
|
@ -1976,7 +1977,7 @@ def functionary_v1_v2_chat_handler(
|
||||||
top_k=top_k,
|
top_k=top_k,
|
||||||
min_p=min_p,
|
min_p=min_p,
|
||||||
typical_p=typical_p,
|
typical_p=typical_p,
|
||||||
stream=False,
|
stream=stream,
|
||||||
stop=stop,
|
stop=stop,
|
||||||
max_tokens=max_tokens,
|
max_tokens=max_tokens,
|
||||||
presence_penalty=presence_penalty,
|
presence_penalty=presence_penalty,
|
||||||
|
@ -1996,172 +1997,481 @@ def functionary_v1_v2_chat_handler(
|
||||||
content = ""
|
content = ""
|
||||||
function_calls, function_bodies = [], []
|
function_calls, function_bodies = [], []
|
||||||
completion_tokens = 0
|
completion_tokens = 0
|
||||||
|
|
||||||
if version == "v1":
|
def generate_streaming(tools, functions, function_call, prompt):
|
||||||
# If no or "auto" tool_choice/function_call
|
assert version == "v2", "Streaming for v1 is not supported"
|
||||||
if isinstance(function_call, str) and function_call == "auto":
|
|
||||||
stops = ["\n", END_ASSISTANT_TOKEN]
|
chunk_id, chunk_created = None, None
|
||||||
# If tool_choice/function_call is provided
|
|
||||||
elif isinstance(function_call, dict):
|
|
||||||
prompt += f"{START_FUNCTION_CALL_TOKEN}{function_call['name']}:\n"
|
|
||||||
stops = END_FUNCTION_CALL_TOKEN
|
|
||||||
function_call = function_call["name"]
|
|
||||||
function_calls.append(function_call)
|
|
||||||
grammar = get_grammar(function_call)
|
|
||||||
else:
|
|
||||||
prompt = prompt
|
|
||||||
stops = ["\n", END_ASSISTANT_TOKEN]
|
|
||||||
|
|
||||||
completion = create_completion(stop=stops)
|
|
||||||
completion_text = completion["choices"][0]["text"]
|
|
||||||
completion_tokens += completion["usage"]["completion_tokens"]
|
|
||||||
|
|
||||||
|
|
||||||
# If the generation does not involve a function call
|
|
||||||
if (
|
|
||||||
START_FUNCTION_CALL_TOKEN not in prompt
|
|
||||||
and START_FUNCTION_CALL_TOKEN not in completion_text
|
|
||||||
):
|
|
||||||
completion["usage"]["completion_tokens"] = completion_tokens
|
|
||||||
return _convert_completion_to_chat(completion, stream=stream) # type: ignore
|
|
||||||
# If the generation involves a function call in completion, generate the parameters
|
|
||||||
elif (
|
|
||||||
START_FUNCTION_CALL_TOKEN not in prompt
|
|
||||||
and START_FUNCTION_CALL_TOKEN in completion_text
|
|
||||||
):
|
|
||||||
prompt += (
|
|
||||||
completion_text.replace(
|
|
||||||
f"{START_FUNCTION_CALL_TOKEN} ", START_FUNCTION_CALL_TOKEN
|
|
||||||
)
|
|
||||||
+ "\n"
|
|
||||||
)
|
|
||||||
function_calls.append(
|
|
||||||
completion_text.split(START_FUNCTION_CALL_TOKEN)[-1][:-1].strip()
|
|
||||||
)
|
|
||||||
grammar = get_grammar(function_calls[-1])
|
|
||||||
completion = create_completion(stop=END_FUNCTION_CALL_TOKEN)
|
|
||||||
completion_tokens += completion["usage"]["completion_tokens"]
|
|
||||||
function_bodies.append(completion["choices"][0]["text"].strip())
|
|
||||||
# If the prompt involves a function call, just append generated parameters to function_bodies
|
|
||||||
else:
|
|
||||||
function_bodies.append(completion_text.strip())
|
|
||||||
else:
|
|
||||||
# If tool_choice/function_call is provided
|
# If tool_choice/function_call is provided
|
||||||
if isinstance(function_call, dict):
|
if isinstance(function_call, dict):
|
||||||
prompt += f"{function_call['name']}\n{CONTENT_TOKEN}"
|
prompt += f"{function_call['name']}\n{CONTENT_TOKEN}"
|
||||||
function_call = function_call["name"]
|
grammar = get_grammar(function_call["name"])
|
||||||
function_calls.append(function_call)
|
|
||||||
grammar = get_grammar(function_call)
|
|
||||||
stops = [STOP_TOKEN, FROM_TOKEN]
|
stops = [STOP_TOKEN, FROM_TOKEN]
|
||||||
completion = create_completion(stop=stops)
|
tool_id = "".join([random.choice(string.ascii_letters + string.digits) for _ in range(24)])
|
||||||
completion_text = completion["choices"][0]["text"]
|
completion = create_completion(prompt=prompt, stop=stops, grammar=grammar)
|
||||||
completion_tokens += completion["usage"]["completion_tokens"]
|
completion_text = ""
|
||||||
function_bodies.append(completion_text.strip())
|
first = True
|
||||||
|
for chunk in completion:
|
||||||
|
# Yield the tool/function name first
|
||||||
|
if first:
|
||||||
|
if tools is not None:
|
||||||
|
func_call_dict = {
|
||||||
|
"tool_calls": [
|
||||||
|
{
|
||||||
|
"index": 0,
|
||||||
|
"id": "call_" + tool_id,
|
||||||
|
"type": "function",
|
||||||
|
"function": {"name": function_call["name"], "arguments": ""},
|
||||||
|
}
|
||||||
|
]
|
||||||
|
}
|
||||||
|
else:
|
||||||
|
func_call_dict = {"function_call": {"name": function_call["name"], "arguments": ""}}
|
||||||
|
yield llama_types.CreateChatCompletionStreamResponse(
|
||||||
|
id="chat" + chunk["id"],
|
||||||
|
object="chat.completion.chunk",
|
||||||
|
created=chunk["created"],
|
||||||
|
model=chunk["model"],
|
||||||
|
choices=[
|
||||||
|
{"index": 0, "logprobs": None, "delta": {"role": None, "content": None, **func_call_dict}}
|
||||||
|
],
|
||||||
|
)
|
||||||
|
first = False
|
||||||
|
if tools is not None:
|
||||||
|
func_call_dict = {
|
||||||
|
"tool_calls": [
|
||||||
|
{
|
||||||
|
"index": 0,
|
||||||
|
"id": "call_" + tool_id,
|
||||||
|
"type": "function",
|
||||||
|
"function": {
|
||||||
|
"name": None,
|
||||||
|
"arguments": chunk["choices"][0]["text"].rstrip(),
|
||||||
|
},
|
||||||
|
}
|
||||||
|
]
|
||||||
|
}
|
||||||
|
else:
|
||||||
|
func_call_dict = {"function_call": {"name": None, "arguments": chunk["choices"][0]["text"].rstrip()}}
|
||||||
|
if len(chunk["choices"][0]["text"].rstrip()) > 0:
|
||||||
|
yield llama_types.CreateChatCompletionStreamResponse(
|
||||||
|
id="chat" + chunk["id"],
|
||||||
|
object="chat.completion.chunk",
|
||||||
|
created=chunk["created"],
|
||||||
|
model=chunk["model"],
|
||||||
|
choices=[
|
||||||
|
{
|
||||||
|
"index": 0,
|
||||||
|
"logprobs": chunk["choices"][0]["logprobs"],
|
||||||
|
"delta": {
|
||||||
|
"role": None,
|
||||||
|
"content": None,
|
||||||
|
**func_call_dict,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
],
|
||||||
|
)
|
||||||
|
# Yield tool_call/function_call stop message
|
||||||
|
yield llama_types.CreateChatCompletionStreamResponse(
|
||||||
|
id="chat" + chunk["id"],
|
||||||
|
object="chat.completion.chunk",
|
||||||
|
created=chunk["created"],
|
||||||
|
model=chunk["model"],
|
||||||
|
choices=[
|
||||||
|
{
|
||||||
|
"index": 0,
|
||||||
|
"finish_reason": "tool_calls" if tools is not None else "function_call",
|
||||||
|
"logprobs": None,
|
||||||
|
"delta": {
|
||||||
|
"role": None, "content": None, "function_call": None, "tool_calls": None
|
||||||
|
},
|
||||||
|
}
|
||||||
|
],
|
||||||
|
)
|
||||||
# If "auto" or no tool_choice/function_call
|
# If "auto" or no tool_choice/function_call
|
||||||
elif isinstance(function_call, str) and function_call == "auto":
|
elif isinstance(function_call, str) and function_call == "auto":
|
||||||
|
tool_index = 0
|
||||||
while True:
|
while True:
|
||||||
# Generate function name first
|
# Generate function name first
|
||||||
grammar = None
|
grammar = None
|
||||||
stops = CONTENT_TOKEN
|
stops = CONTENT_TOKEN
|
||||||
completion = create_completion(stop=stops)
|
completion = create_completion(prompt=prompt, stop=stops, grammar=grammar)
|
||||||
completion_text = completion["choices"][0]["text"]
|
completion_text = ""
|
||||||
completion_tokens += completion["usage"]["completion_tokens"]
|
for chunk in completion:
|
||||||
|
completion_text += chunk["choices"][0]["text"]
|
||||||
|
if chunk_id is None:
|
||||||
|
chunk_id = chunk["id"]
|
||||||
|
if chunk_created is None:
|
||||||
|
chunk_created = chunk["created"]
|
||||||
function_name = completion_text.strip()
|
function_name = completion_text.strip()
|
||||||
if function_name == "all":
|
if function_name == "all":
|
||||||
prompt += "all\n<|content|>"
|
prompt += "all\n<|content|>"
|
||||||
|
# Yield the first empty message for content
|
||||||
|
yield llama_types.CreateChatCompletionStreamResponse(
|
||||||
|
id="chat" + chunk_id,
|
||||||
|
model=chunk["model"],
|
||||||
|
created=chunk_created,
|
||||||
|
object="chat.completion.chunk",
|
||||||
|
choices=[
|
||||||
|
{
|
||||||
|
"index": 0,
|
||||||
|
"delta": {"role": "assistant", "content": ""},
|
||||||
|
"logprobs": None,
|
||||||
|
"finish_reason": None,
|
||||||
|
}
|
||||||
|
],
|
||||||
|
)
|
||||||
else:
|
else:
|
||||||
function_call = completion_text.strip()
|
prompt += f"{function_name}\n<|content|>"
|
||||||
prompt += f"{function_call}\n<|content|>"
|
grammar = get_grammar(function_name)
|
||||||
function_calls.append(function_call)
|
tool_id = "".join([random.choice(string.ascii_letters + string.digits) for _ in range(24)])
|
||||||
grammar = get_grammar(function_call)
|
if tools is not None:
|
||||||
|
func_call_dict = {
|
||||||
|
"tool_calls": [
|
||||||
|
{
|
||||||
|
"index": tool_index,
|
||||||
|
"id": "call_" + tool_id,
|
||||||
|
"type": "function",
|
||||||
|
"function": {"name": function_name, "arguments": ""},
|
||||||
|
}
|
||||||
|
]
|
||||||
|
}
|
||||||
|
else:
|
||||||
|
func_call_dict = {"function_call": {"name": function_name, "arguments": ""}}
|
||||||
|
# Stream function name
|
||||||
|
yield llama_types.CreateChatCompletionStreamResponse(
|
||||||
|
id="chat" + chunk_id,
|
||||||
|
object="chat.completion.chunk",
|
||||||
|
created=chunk_created,
|
||||||
|
model=chunk["model"],
|
||||||
|
choices=[
|
||||||
|
{
|
||||||
|
"index": 0,
|
||||||
|
"logprobs": chunk["choices"][0]["logprobs"],
|
||||||
|
"delta": {
|
||||||
|
"role": "assistant",
|
||||||
|
"content": None,
|
||||||
|
**func_call_dict,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
],
|
||||||
|
)
|
||||||
# Generate content
|
# Generate content
|
||||||
stops = [RECIPIENT_TOKEN, STOP_TOKEN]
|
stops = [RECIPIENT_TOKEN, STOP_TOKEN]
|
||||||
completion = create_completion(stop=stops)
|
completion = create_completion(prompt=prompt, stop=stops, grammar=grammar)
|
||||||
completion_text = completion["choices"][0]["text"]
|
|
||||||
completion_tokens += completion["usage"]["completion_tokens"]
|
|
||||||
if function_name == "all":
|
if function_name == "all":
|
||||||
if completion_text.endswith("\n<|from|>assistant\n"):
|
completion_text = ""
|
||||||
content += completion_text[:-len("\n<|from|>assistant\n")]
|
stop_sequence, buffer, is_end = "\n<|from|>assistant\n<|recipient|>", [], False
|
||||||
if completion_text.endswith("\n<|from|> assistant\n"):
|
for i, chunk in enumerate(completion):
|
||||||
content += completion_text[-len("\n<|from|> assistant\n")]
|
completion_text += chunk["choices"][0]["text"]
|
||||||
else:
|
if is_end:
|
||||||
content += completion_text
|
buffer.append(chunk["choices"][0]["text"].strip(" "))
|
||||||
content = content.lstrip()
|
if stop_sequence.startswith("".join(buffer)):
|
||||||
|
continue
|
||||||
|
else:
|
||||||
|
buffer.pop()
|
||||||
|
while len(buffer) > 0:
|
||||||
|
yield llama_types.CreateChatCompletionStreamResponse(
|
||||||
|
id="chat" + chunk_id,
|
||||||
|
object="chat.completion.chunk",
|
||||||
|
created=chunk_created,
|
||||||
|
model=chunk["model"],
|
||||||
|
choices=[
|
||||||
|
{
|
||||||
|
"index": 0,
|
||||||
|
"logprobs": chunk["choices"][0]["logprobs"],
|
||||||
|
"delta": {
|
||||||
|
"role": "assistant", "content": buffer.pop(0)
|
||||||
|
},
|
||||||
|
}
|
||||||
|
],
|
||||||
|
)
|
||||||
|
is_end = False
|
||||||
|
elif chunk["choices"][0]["text"] == "\n":
|
||||||
|
is_end = True
|
||||||
|
buffer.append(chunk["choices"][0]["text"].strip(" "))
|
||||||
|
continue
|
||||||
|
|
||||||
|
if len(buffer) == 0 and len(chunk["choices"][0]["text"]) > 0:
|
||||||
|
yield llama_types.CreateChatCompletionStreamResponse(
|
||||||
|
id="chat" + chunk_id,
|
||||||
|
object="chat.completion.chunk",
|
||||||
|
created=chunk_created,
|
||||||
|
model=chunk["model"],
|
||||||
|
choices=[
|
||||||
|
{
|
||||||
|
"index": 0,
|
||||||
|
"logprobs": chunk["choices"][0]["logprobs"],
|
||||||
|
"delta": {
|
||||||
|
"role": "assistant",
|
||||||
|
"content": chunk["choices"][0]["text"] if i > 0 else chunk["choices"][0]["text"].lstrip()
|
||||||
|
},
|
||||||
|
}
|
||||||
|
],
|
||||||
|
)
|
||||||
# Check whether the model wants to generate another turn
|
# Check whether the model wants to generate another turn
|
||||||
if "<|from|> assistant" in completion_text or "<|from|>assistant" in completion_text:
|
if "<|from|> assistant" in completion_text or "<|from|>assistant" in completion_text:
|
||||||
if completion_text.endswith("\n<|from|>assistant\n"):
|
if completion_text.endswith("\n<|from|>assistant\n"):
|
||||||
cleaned_completion_text = completion_text[:-len("\n<|from|>assistant\n")].strip()
|
cleaned_completion_text = completion_text[:-len("\n<|from|>assistant\n")].strip()
|
||||||
elif completion_text.endswith("\n<|from|> assistant\n"):
|
elif completion_text.endswith("\n<|from|> assistant\n"):
|
||||||
cleaned_completion_text = completion_text[-len("\n<|from|> assistant\n")].strip()
|
cleaned_completion_text = completion_text[:-len("\n<|from|> assistant\n")].strip()
|
||||||
else:
|
else:
|
||||||
cleaned_completion_text = completion_text.strip()
|
cleaned_completion_text = completion_text.strip()
|
||||||
prompt += f"{cleaned_completion_text}\n<|from|>assistant\n<|recipient|>"
|
prompt += f"{cleaned_completion_text}\n<|from|>assistant\n<|recipient|>"
|
||||||
else:
|
else:
|
||||||
|
# Yield stop message
|
||||||
|
yield llama_types.CreateChatCompletionStreamResponse(
|
||||||
|
id="chat" + chunk_id,
|
||||||
|
model=chunk["model"],
|
||||||
|
created=chunk_created,
|
||||||
|
object="chat.completion.chunk",
|
||||||
|
choices=[
|
||||||
|
{
|
||||||
|
"index": 0,
|
||||||
|
"delta": {},
|
||||||
|
"logprobs": None,
|
||||||
|
"finish_reason": "stop",
|
||||||
|
}
|
||||||
|
],
|
||||||
|
)
|
||||||
break
|
break
|
||||||
else:
|
else:
|
||||||
function_bodies.append(completion_text.strip())
|
|
||||||
# Check whether the model wants to generate another turn
|
# Check whether the model wants to generate another turn
|
||||||
|
completion_text = ""
|
||||||
|
for chunk in completion:
|
||||||
|
completion_text += chunk["choices"][0]["text"]
|
||||||
|
if len(chunk["choices"][0]["text"].rstrip()) > 0:
|
||||||
|
if tools is not None:
|
||||||
|
func_call_dict = {
|
||||||
|
"tool_calls": [
|
||||||
|
{
|
||||||
|
"index": tool_index,
|
||||||
|
"id": "call_" + tool_id,
|
||||||
|
"type": "function",
|
||||||
|
"function": {
|
||||||
|
"name": None,
|
||||||
|
"arguments": chunk["choices"][0]["text"].rstrip(),
|
||||||
|
},
|
||||||
|
}
|
||||||
|
]
|
||||||
|
}
|
||||||
|
else:
|
||||||
|
func_call_dict = {"function_call": {"name": None, "arguments": chunk["choices"][0]["text"].rstrip()}}
|
||||||
|
yield llama_types.CreateChatCompletionStreamResponse(
|
||||||
|
id="chat" + chunk_id,
|
||||||
|
object="chat.completion.chunk",
|
||||||
|
created=chunk_created,
|
||||||
|
model=chunk["model"],
|
||||||
|
choices=[
|
||||||
|
{
|
||||||
|
"index": 0,
|
||||||
|
"logprobs": chunk["choices"][0]["logprobs"],
|
||||||
|
"delta": {
|
||||||
|
"role": None,
|
||||||
|
"content": None,
|
||||||
|
**func_call_dict,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
],
|
||||||
|
)
|
||||||
prompt += completion_text.strip()
|
prompt += completion_text.strip()
|
||||||
grammar = None
|
grammar = None
|
||||||
completion = create_completion(stop=stops)
|
completion = create_completion(prompt=prompt, stop=stops, grammar=grammar)
|
||||||
completion_tokens += completion["usage"]["completion_tokens"]
|
completion_text += "".join([chunk["choices"][0]["text"] for chunk in completion])
|
||||||
if "<|from|> assistant" in completion["choices"][0]["text"] or "<|from|>assistant" in completion["choices"][0]["text"]:
|
if ("<|from|> assistant" in completion_text or "<|from|>assistant" in completion_text) and tools is not None:
|
||||||
prompt += "\n<|from|>assistant\n<|recipient|>"
|
prompt += "\n<|from|>assistant\n<|recipient|>"
|
||||||
|
tool_index += 1
|
||||||
else:
|
else:
|
||||||
|
# Yield tool_call/function_call stop message
|
||||||
|
yield llama_types.CreateChatCompletionStreamResponse(
|
||||||
|
id="chat" + chunk_id,
|
||||||
|
object="chat.completion.chunk",
|
||||||
|
created=chunk_created,
|
||||||
|
model=chunk["model"],
|
||||||
|
choices=[
|
||||||
|
{
|
||||||
|
"index": 0,
|
||||||
|
"finish_reason": "tool_calls" if tools is not None else "function_call",
|
||||||
|
"logprobs": None,
|
||||||
|
"delta": {
|
||||||
|
"role": None, "content": None, "function_call": None, "tool_calls": None
|
||||||
|
},
|
||||||
|
}
|
||||||
|
],
|
||||||
|
)
|
||||||
break
|
break
|
||||||
|
|
||||||
assert "usage" in completion
|
if stream is not False:
|
||||||
assert len(function_calls) == len(function_bodies)
|
return generate_streaming(
|
||||||
|
tools=tools, functions=functions, function_call=function_call, prompt=prompt
|
||||||
tool_calls: List[llama_types.ChatCompletionMessageToolCall] = []
|
|
||||||
for function_call, function_body in zip(function_calls, function_bodies):
|
|
||||||
tool_calls.append(
|
|
||||||
{
|
|
||||||
"id": "call_"
|
|
||||||
+ "".join(
|
|
||||||
[
|
|
||||||
random.choice(string.ascii_letters + string.digits)
|
|
||||||
for _ in range(24)
|
|
||||||
]
|
|
||||||
),
|
|
||||||
"type": "function",
|
|
||||||
"function": {
|
|
||||||
"name": function_call,
|
|
||||||
"arguments": function_body,
|
|
||||||
},
|
|
||||||
}
|
|
||||||
)
|
)
|
||||||
|
else:
|
||||||
|
if version == "v1":
|
||||||
|
# If no or "auto" tool_choice/function_call
|
||||||
|
if isinstance(function_call, str) and function_call == "auto":
|
||||||
|
stops = ["\n", END_ASSISTANT_TOKEN]
|
||||||
|
# If tool_choice/function_call is provided
|
||||||
|
elif isinstance(function_call, dict):
|
||||||
|
prompt += f"{START_FUNCTION_CALL_TOKEN}{function_call['name']}:\n"
|
||||||
|
stops = END_FUNCTION_CALL_TOKEN
|
||||||
|
function_call = function_call["name"]
|
||||||
|
function_calls.append(function_call)
|
||||||
|
grammar = get_grammar(function_call)
|
||||||
|
else:
|
||||||
|
prompt = prompt
|
||||||
|
stops = ["\n", END_ASSISTANT_TOKEN]
|
||||||
|
|
||||||
# TODO: support stream mode
|
completion = create_completion(stop=stops)
|
||||||
function_call_dict: Union[Dict[str, str], Dict[Literal["function_call"], llama_types.ChatCompletionRequestAssistantMessageFunctionCall]] = {}
|
completion_text = completion["choices"][0]["text"]
|
||||||
if len(tool_calls) > 0:
|
completion_tokens += completion["usage"]["completion_tokens"]
|
||||||
if tools is not None:
|
|
||||||
function_call_dict["tool_calls"] = tool_calls
|
|
||||||
|
# If the generation does not involve a function call
|
||||||
|
if (
|
||||||
|
START_FUNCTION_CALL_TOKEN not in prompt
|
||||||
|
and START_FUNCTION_CALL_TOKEN not in completion_text
|
||||||
|
):
|
||||||
|
completion["usage"]["completion_tokens"] = completion_tokens
|
||||||
|
return _convert_completion_to_chat(completion, stream=stream) # type: ignore
|
||||||
|
# If the generation involves a function call in completion, generate the parameters
|
||||||
|
elif (
|
||||||
|
START_FUNCTION_CALL_TOKEN not in prompt
|
||||||
|
and START_FUNCTION_CALL_TOKEN in completion_text
|
||||||
|
):
|
||||||
|
prompt += (
|
||||||
|
completion_text.replace(
|
||||||
|
f"{START_FUNCTION_CALL_TOKEN} ", START_FUNCTION_CALL_TOKEN
|
||||||
|
)
|
||||||
|
+ "\n"
|
||||||
|
)
|
||||||
|
function_calls.append(
|
||||||
|
completion_text.split(START_FUNCTION_CALL_TOKEN)[-1][:-1].strip()
|
||||||
|
)
|
||||||
|
grammar = get_grammar(function_calls[-1])
|
||||||
|
completion = create_completion(stop=END_FUNCTION_CALL_TOKEN)
|
||||||
|
completion_tokens += completion["usage"]["completion_tokens"]
|
||||||
|
function_bodies.append(completion["choices"][0]["text"].strip())
|
||||||
|
# If the prompt involves a function call, just append generated parameters to function_bodies
|
||||||
|
else:
|
||||||
|
function_bodies.append(completion_text.strip())
|
||||||
else:
|
else:
|
||||||
function_call_dict["function_call"] = {
|
# If tool_choice/function_call is provided
|
||||||
"name": tool_calls[0]["function"]["name"],
|
if isinstance(function_call, dict):
|
||||||
"arguments": tool_calls[0]["function"]["arguments"],
|
prompt += f"{function_call['name']}\n{CONTENT_TOKEN}"
|
||||||
}
|
function_call = function_call["name"]
|
||||||
completion["usage"]["completion_tokens"] = completion_tokens
|
function_calls.append(function_call)
|
||||||
return llama_types.CreateChatCompletionResponse(
|
grammar = get_grammar(function_call)
|
||||||
id="chat" + completion["id"],
|
stops = [STOP_TOKEN, FROM_TOKEN]
|
||||||
object="chat.completion",
|
completion = create_completion(stop=stops)
|
||||||
created=completion["created"],
|
completion_text = completion["choices"][0]["text"]
|
||||||
model=completion["model"],
|
completion_tokens += completion["usage"]["completion_tokens"]
|
||||||
choices=[
|
function_bodies.append(completion_text.strip())
|
||||||
{
|
# If "auto" or no tool_choice/function_call
|
||||||
"index": 0,
|
elif isinstance(function_call, str) and function_call == "auto":
|
||||||
"logprobs": completion["choices"][0]["logprobs"],
|
while True:
|
||||||
"message": {
|
# Generate function name first
|
||||||
"role": "assistant",
|
grammar = None
|
||||||
"content": None if content == "" else content,
|
stops = CONTENT_TOKEN
|
||||||
**function_call_dict,
|
completion = create_completion(stop=stops)
|
||||||
},
|
completion_text = completion["choices"][0]["text"]
|
||||||
"finish_reason": "tool_calls" if len(tool_calls) > 0 else "stop",
|
completion_tokens += completion["usage"]["completion_tokens"]
|
||||||
}
|
function_name = completion_text.strip()
|
||||||
],
|
if function_name == "all":
|
||||||
usage=completion["usage"],
|
prompt += "all\n<|content|>"
|
||||||
)
|
else:
|
||||||
|
function_call = completion_text.strip()
|
||||||
|
prompt += f"{function_call}\n<|content|>"
|
||||||
|
function_calls.append(function_call)
|
||||||
|
grammar = get_grammar(function_call)
|
||||||
|
# Generate content
|
||||||
|
stops = [RECIPIENT_TOKEN, STOP_TOKEN]
|
||||||
|
completion = create_completion(stop=stops)
|
||||||
|
completion_text = completion["choices"][0]["text"]
|
||||||
|
completion_tokens += completion["usage"]["completion_tokens"]
|
||||||
|
if function_name == "all":
|
||||||
|
if completion_text.endswith("\n<|from|>assistant\n"):
|
||||||
|
content += completion_text[:-len("\n<|from|>assistant\n")]
|
||||||
|
if completion_text.endswith("\n<|from|> assistant\n"):
|
||||||
|
content += completion_text[-len("\n<|from|> assistant\n")]
|
||||||
|
else:
|
||||||
|
content += completion_text
|
||||||
|
content = content.lstrip()
|
||||||
|
# Check whether the model wants to generate another turn
|
||||||
|
if "<|from|> assistant" in completion_text or "<|from|>assistant" in completion_text:
|
||||||
|
if completion_text.endswith("\n<|from|>assistant\n"):
|
||||||
|
cleaned_completion_text = completion_text[:-len("\n<|from|>assistant\n")].strip()
|
||||||
|
elif completion_text.endswith("\n<|from|> assistant\n"):
|
||||||
|
cleaned_completion_text = completion_text[-len("\n<|from|> assistant\n")].strip()
|
||||||
|
else:
|
||||||
|
cleaned_completion_text = completion_text.strip()
|
||||||
|
prompt += f"{cleaned_completion_text}\n<|from|>assistant\n<|recipient|>"
|
||||||
|
else:
|
||||||
|
break
|
||||||
|
else:
|
||||||
|
function_bodies.append(completion_text.strip())
|
||||||
|
# Check whether the model wants to generate another turn
|
||||||
|
prompt += completion_text.strip()
|
||||||
|
grammar = None
|
||||||
|
completion = create_completion(stop=stops)
|
||||||
|
completion_tokens += completion["usage"]["completion_tokens"]
|
||||||
|
if "<|from|> assistant" in completion["choices"][0]["text"] or "<|from|>assistant" in completion["choices"][0]["text"]:
|
||||||
|
prompt += "\n<|from|>assistant\n<|recipient|>"
|
||||||
|
else:
|
||||||
|
break
|
||||||
|
|
||||||
|
assert "usage" in completion
|
||||||
|
assert len(function_calls) == len(function_bodies)
|
||||||
|
|
||||||
|
tool_calls: List[llama_types.ChatCompletionMessageToolCall] = []
|
||||||
|
for function_call, function_body in zip(function_calls, function_bodies):
|
||||||
|
tool_calls.append(
|
||||||
|
{
|
||||||
|
"id": "call_"
|
||||||
|
+ "".join(
|
||||||
|
[
|
||||||
|
random.choice(string.ascii_letters + string.digits)
|
||||||
|
for _ in range(24)
|
||||||
|
]
|
||||||
|
),
|
||||||
|
"type": "function",
|
||||||
|
"function": {
|
||||||
|
"name": function_call,
|
||||||
|
"arguments": function_body,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
# TODO: support stream mode
|
||||||
|
function_call_dict: Union[Dict[str, str], Dict[Literal["function_call"], llama_types.ChatCompletionRequestAssistantMessageFunctionCall]] = {}
|
||||||
|
if len(tool_calls) > 0:
|
||||||
|
if tools is not None:
|
||||||
|
function_call_dict["tool_calls"] = tool_calls
|
||||||
|
else:
|
||||||
|
function_call_dict["function_call"] = {
|
||||||
|
"name": tool_calls[0]["function"]["name"],
|
||||||
|
"arguments": tool_calls[0]["function"]["arguments"],
|
||||||
|
}
|
||||||
|
completion["usage"]["completion_tokens"] = completion_tokens
|
||||||
|
return llama_types.CreateChatCompletionResponse(
|
||||||
|
id="chat" + completion["id"],
|
||||||
|
object="chat.completion",
|
||||||
|
created=completion["created"],
|
||||||
|
model=completion["model"],
|
||||||
|
choices=[
|
||||||
|
{
|
||||||
|
"index": 0,
|
||||||
|
"logprobs": completion["choices"][0]["logprobs"],
|
||||||
|
"message": {
|
||||||
|
"role": "assistant",
|
||||||
|
"content": None if content == "" else content,
|
||||||
|
**function_call_dict,
|
||||||
|
},
|
||||||
|
"finish_reason": "tool_calls" if len(tool_calls) > 0 else "stop",
|
||||||
|
}
|
||||||
|
],
|
||||||
|
usage=completion["usage"],
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
class Llava15ChatHandler:
|
class Llava15ChatHandler:
|
||||||
|
|
Loading…
Reference in a new issue