From 1f56c648c3a4aed2b4fe3edb4dc745fae3e7d8ae Mon Sep 17 00:00:00 2001 From: Jeffrey Fong Date: Sat, 4 May 2024 22:11:20 +0800 Subject: [PATCH] feat: Implement streaming for Functionary v2 + Bug fixes (#1419) * set up streaming for v2 * assert v2 streaming, fix tool_call vs function_call * fix streaming with tool_choice/function_call * make functions return 1 function call only when 'auto' * fix --------- Co-authored-by: Andrei --- llama_cpp/llama_chat_format.py | 576 +++++++++++++++++++++++++-------- 1 file changed, 443 insertions(+), 133 deletions(-) diff --git a/llama_cpp/llama_chat_format.py b/llama_cpp/llama_chat_format.py index a17b86b..3ab94e0 100644 --- a/llama_cpp/llama_chat_format.py +++ b/llama_cpp/llama_chat_format.py @@ -1894,6 +1894,8 @@ def functionary_v1_v2_chat_handler( function_call = ( tool_choice if isinstance(tool_choice, str) else tool_choice["function"] ) + elif function_call is not None: + pass else: function_call = "auto" @@ -1930,11 +1932,10 @@ def functionary_v1_v2_chat_handler( logits_processor=logits_processor, grammar=grammar, ) - completion_or_completion_chunks["choices"][0]["text"] = completion_or_completion_chunks["choices"][0]["text"].lstrip() + if stream is False: + completion_or_completion_chunks["choices"][0]["text"] = completion_or_completion_chunks["choices"][0]["text"].lstrip() return _convert_completion_to_chat(completion_or_completion_chunks, stream=stream) # type: ignore - assert stream is False # TODO: support stream mode - def get_grammar(function_call): function_body = None for function in functions or []: @@ -1968,7 +1969,7 @@ def functionary_v1_v2_chat_handler( return grammar - def create_completion(stop): + def create_completion(prompt, stop, grammar): completion = cast(llama_types.Completion, llama.create_completion( prompt=prompt, temperature=temperature, @@ -1976,7 +1977,7 @@ def functionary_v1_v2_chat_handler( top_k=top_k, min_p=min_p, typical_p=typical_p, - stream=False, + stream=stream, stop=stop, max_tokens=max_tokens, presence_penalty=presence_penalty, @@ -1996,172 +1997,481 @@ def functionary_v1_v2_chat_handler( content = "" function_calls, function_bodies = [], [] completion_tokens = 0 - - if version == "v1": - # If no or "auto" tool_choice/function_call - if isinstance(function_call, str) and function_call == "auto": - stops = ["\n", END_ASSISTANT_TOKEN] - # If tool_choice/function_call is provided - elif isinstance(function_call, dict): - prompt += f"{START_FUNCTION_CALL_TOKEN}{function_call['name']}:\n" - stops = END_FUNCTION_CALL_TOKEN - function_call = function_call["name"] - function_calls.append(function_call) - grammar = get_grammar(function_call) - else: - prompt = prompt - stops = ["\n", END_ASSISTANT_TOKEN] - - completion = create_completion(stop=stops) - completion_text = completion["choices"][0]["text"] - completion_tokens += completion["usage"]["completion_tokens"] + + def generate_streaming(tools, functions, function_call, prompt): + assert version == "v2", "Streaming for v1 is not supported" + + chunk_id, chunk_created = None, None - - # If the generation does not involve a function call - if ( - START_FUNCTION_CALL_TOKEN not in prompt - and START_FUNCTION_CALL_TOKEN not in completion_text - ): - completion["usage"]["completion_tokens"] = completion_tokens - return _convert_completion_to_chat(completion, stream=stream) # type: ignore - # If the generation involves a function call in completion, generate the parameters - elif ( - START_FUNCTION_CALL_TOKEN not in prompt - and START_FUNCTION_CALL_TOKEN in completion_text - ): - prompt += ( - completion_text.replace( - f"{START_FUNCTION_CALL_TOKEN} ", START_FUNCTION_CALL_TOKEN - ) - + "\n" - ) - function_calls.append( - completion_text.split(START_FUNCTION_CALL_TOKEN)[-1][:-1].strip() - ) - grammar = get_grammar(function_calls[-1]) - completion = create_completion(stop=END_FUNCTION_CALL_TOKEN) - completion_tokens += completion["usage"]["completion_tokens"] - function_bodies.append(completion["choices"][0]["text"].strip()) - # If the prompt involves a function call, just append generated parameters to function_bodies - else: - function_bodies.append(completion_text.strip()) - else: # If tool_choice/function_call is provided if isinstance(function_call, dict): prompt += f"{function_call['name']}\n{CONTENT_TOKEN}" - function_call = function_call["name"] - function_calls.append(function_call) - grammar = get_grammar(function_call) + grammar = get_grammar(function_call["name"]) stops = [STOP_TOKEN, FROM_TOKEN] - completion = create_completion(stop=stops) - completion_text = completion["choices"][0]["text"] - completion_tokens += completion["usage"]["completion_tokens"] - function_bodies.append(completion_text.strip()) + tool_id = "".join([random.choice(string.ascii_letters + string.digits) for _ in range(24)]) + completion = create_completion(prompt=prompt, stop=stops, grammar=grammar) + completion_text = "" + first = True + for chunk in completion: + # Yield the tool/function name first + if first: + if tools is not None: + func_call_dict = { + "tool_calls": [ + { + "index": 0, + "id": "call_" + tool_id, + "type": "function", + "function": {"name": function_call["name"], "arguments": ""}, + } + ] + } + else: + func_call_dict = {"function_call": {"name": function_call["name"], "arguments": ""}} + yield llama_types.CreateChatCompletionStreamResponse( + id="chat" + chunk["id"], + object="chat.completion.chunk", + created=chunk["created"], + model=chunk["model"], + choices=[ + {"index": 0, "logprobs": None, "delta": {"role": None, "content": None, **func_call_dict}} + ], + ) + first = False + if tools is not None: + func_call_dict = { + "tool_calls": [ + { + "index": 0, + "id": "call_" + tool_id, + "type": "function", + "function": { + "name": None, + "arguments": chunk["choices"][0]["text"].rstrip(), + }, + } + ] + } + else: + func_call_dict = {"function_call": {"name": None, "arguments": chunk["choices"][0]["text"].rstrip()}} + if len(chunk["choices"][0]["text"].rstrip()) > 0: + yield llama_types.CreateChatCompletionStreamResponse( + id="chat" + chunk["id"], + object="chat.completion.chunk", + created=chunk["created"], + model=chunk["model"], + choices=[ + { + "index": 0, + "logprobs": chunk["choices"][0]["logprobs"], + "delta": { + "role": None, + "content": None, + **func_call_dict, + }, + } + ], + ) + # Yield tool_call/function_call stop message + yield llama_types.CreateChatCompletionStreamResponse( + id="chat" + chunk["id"], + object="chat.completion.chunk", + created=chunk["created"], + model=chunk["model"], + choices=[ + { + "index": 0, + "finish_reason": "tool_calls" if tools is not None else "function_call", + "logprobs": None, + "delta": { + "role": None, "content": None, "function_call": None, "tool_calls": None + }, + } + ], + ) # If "auto" or no tool_choice/function_call elif isinstance(function_call, str) and function_call == "auto": + tool_index = 0 while True: # Generate function name first grammar = None stops = CONTENT_TOKEN - completion = create_completion(stop=stops) - completion_text = completion["choices"][0]["text"] - completion_tokens += completion["usage"]["completion_tokens"] + completion = create_completion(prompt=prompt, stop=stops, grammar=grammar) + completion_text = "" + for chunk in completion: + completion_text += chunk["choices"][0]["text"] + if chunk_id is None: + chunk_id = chunk["id"] + if chunk_created is None: + chunk_created = chunk["created"] function_name = completion_text.strip() if function_name == "all": prompt += "all\n<|content|>" + # Yield the first empty message for content + yield llama_types.CreateChatCompletionStreamResponse( + id="chat" + chunk_id, + model=chunk["model"], + created=chunk_created, + object="chat.completion.chunk", + choices=[ + { + "index": 0, + "delta": {"role": "assistant", "content": ""}, + "logprobs": None, + "finish_reason": None, + } + ], + ) else: - function_call = completion_text.strip() - prompt += f"{function_call}\n<|content|>" - function_calls.append(function_call) - grammar = get_grammar(function_call) + prompt += f"{function_name}\n<|content|>" + grammar = get_grammar(function_name) + tool_id = "".join([random.choice(string.ascii_letters + string.digits) for _ in range(24)]) + if tools is not None: + func_call_dict = { + "tool_calls": [ + { + "index": tool_index, + "id": "call_" + tool_id, + "type": "function", + "function": {"name": function_name, "arguments": ""}, + } + ] + } + else: + func_call_dict = {"function_call": {"name": function_name, "arguments": ""}} + # Stream function name + yield llama_types.CreateChatCompletionStreamResponse( + id="chat" + chunk_id, + object="chat.completion.chunk", + created=chunk_created, + model=chunk["model"], + choices=[ + { + "index": 0, + "logprobs": chunk["choices"][0]["logprobs"], + "delta": { + "role": "assistant", + "content": None, + **func_call_dict, + }, + } + ], + ) # Generate content stops = [RECIPIENT_TOKEN, STOP_TOKEN] - completion = create_completion(stop=stops) - completion_text = completion["choices"][0]["text"] - completion_tokens += completion["usage"]["completion_tokens"] + completion = create_completion(prompt=prompt, stop=stops, grammar=grammar) if function_name == "all": - if completion_text.endswith("\n<|from|>assistant\n"): - content += completion_text[:-len("\n<|from|>assistant\n")] - if completion_text.endswith("\n<|from|> assistant\n"): - content += completion_text[-len("\n<|from|> assistant\n")] - else: - content += completion_text - content = content.lstrip() + completion_text = "" + stop_sequence, buffer, is_end = "\n<|from|>assistant\n<|recipient|>", [], False + for i, chunk in enumerate(completion): + completion_text += chunk["choices"][0]["text"] + if is_end: + buffer.append(chunk["choices"][0]["text"].strip(" ")) + if stop_sequence.startswith("".join(buffer)): + continue + else: + buffer.pop() + while len(buffer) > 0: + yield llama_types.CreateChatCompletionStreamResponse( + id="chat" + chunk_id, + object="chat.completion.chunk", + created=chunk_created, + model=chunk["model"], + choices=[ + { + "index": 0, + "logprobs": chunk["choices"][0]["logprobs"], + "delta": { + "role": "assistant", "content": buffer.pop(0) + }, + } + ], + ) + is_end = False + elif chunk["choices"][0]["text"] == "\n": + is_end = True + buffer.append(chunk["choices"][0]["text"].strip(" ")) + continue + + if len(buffer) == 0 and len(chunk["choices"][0]["text"]) > 0: + yield llama_types.CreateChatCompletionStreamResponse( + id="chat" + chunk_id, + object="chat.completion.chunk", + created=chunk_created, + model=chunk["model"], + choices=[ + { + "index": 0, + "logprobs": chunk["choices"][0]["logprobs"], + "delta": { + "role": "assistant", + "content": chunk["choices"][0]["text"] if i > 0 else chunk["choices"][0]["text"].lstrip() + }, + } + ], + ) # Check whether the model wants to generate another turn if "<|from|> assistant" in completion_text or "<|from|>assistant" in completion_text: if completion_text.endswith("\n<|from|>assistant\n"): cleaned_completion_text = completion_text[:-len("\n<|from|>assistant\n")].strip() elif completion_text.endswith("\n<|from|> assistant\n"): - cleaned_completion_text = completion_text[-len("\n<|from|> assistant\n")].strip() + cleaned_completion_text = completion_text[:-len("\n<|from|> assistant\n")].strip() else: cleaned_completion_text = completion_text.strip() prompt += f"{cleaned_completion_text}\n<|from|>assistant\n<|recipient|>" else: + # Yield stop message + yield llama_types.CreateChatCompletionStreamResponse( + id="chat" + chunk_id, + model=chunk["model"], + created=chunk_created, + object="chat.completion.chunk", + choices=[ + { + "index": 0, + "delta": {}, + "logprobs": None, + "finish_reason": "stop", + } + ], + ) break else: - function_bodies.append(completion_text.strip()) # Check whether the model wants to generate another turn + completion_text = "" + for chunk in completion: + completion_text += chunk["choices"][0]["text"] + if len(chunk["choices"][0]["text"].rstrip()) > 0: + if tools is not None: + func_call_dict = { + "tool_calls": [ + { + "index": tool_index, + "id": "call_" + tool_id, + "type": "function", + "function": { + "name": None, + "arguments": chunk["choices"][0]["text"].rstrip(), + }, + } + ] + } + else: + func_call_dict = {"function_call": {"name": None, "arguments": chunk["choices"][0]["text"].rstrip()}} + yield llama_types.CreateChatCompletionStreamResponse( + id="chat" + chunk_id, + object="chat.completion.chunk", + created=chunk_created, + model=chunk["model"], + choices=[ + { + "index": 0, + "logprobs": chunk["choices"][0]["logprobs"], + "delta": { + "role": None, + "content": None, + **func_call_dict, + }, + } + ], + ) prompt += completion_text.strip() grammar = None - completion = create_completion(stop=stops) - completion_tokens += completion["usage"]["completion_tokens"] - if "<|from|> assistant" in completion["choices"][0]["text"] or "<|from|>assistant" in completion["choices"][0]["text"]: + completion = create_completion(prompt=prompt, stop=stops, grammar=grammar) + completion_text += "".join([chunk["choices"][0]["text"] for chunk in completion]) + if ("<|from|> assistant" in completion_text or "<|from|>assistant" in completion_text) and tools is not None: prompt += "\n<|from|>assistant\n<|recipient|>" + tool_index += 1 else: + # Yield tool_call/function_call stop message + yield llama_types.CreateChatCompletionStreamResponse( + id="chat" + chunk_id, + object="chat.completion.chunk", + created=chunk_created, + model=chunk["model"], + choices=[ + { + "index": 0, + "finish_reason": "tool_calls" if tools is not None else "function_call", + "logprobs": None, + "delta": { + "role": None, "content": None, "function_call": None, "tool_calls": None + }, + } + ], + ) break - - assert "usage" in completion - assert len(function_calls) == len(function_bodies) - - tool_calls: List[llama_types.ChatCompletionMessageToolCall] = [] - for function_call, function_body in zip(function_calls, function_bodies): - tool_calls.append( - { - "id": "call_" - + "".join( - [ - random.choice(string.ascii_letters + string.digits) - for _ in range(24) - ] - ), - "type": "function", - "function": { - "name": function_call, - "arguments": function_body, - }, - } + + if stream is not False: + return generate_streaming( + tools=tools, functions=functions, function_call=function_call, prompt=prompt ) + else: + if version == "v1": + # If no or "auto" tool_choice/function_call + if isinstance(function_call, str) and function_call == "auto": + stops = ["\n", END_ASSISTANT_TOKEN] + # If tool_choice/function_call is provided + elif isinstance(function_call, dict): + prompt += f"{START_FUNCTION_CALL_TOKEN}{function_call['name']}:\n" + stops = END_FUNCTION_CALL_TOKEN + function_call = function_call["name"] + function_calls.append(function_call) + grammar = get_grammar(function_call) + else: + prompt = prompt + stops = ["\n", END_ASSISTANT_TOKEN] - # TODO: support stream mode - function_call_dict: Union[Dict[str, str], Dict[Literal["function_call"], llama_types.ChatCompletionRequestAssistantMessageFunctionCall]] = {} - if len(tool_calls) > 0: - if tools is not None: - function_call_dict["tool_calls"] = tool_calls + completion = create_completion(stop=stops) + completion_text = completion["choices"][0]["text"] + completion_tokens += completion["usage"]["completion_tokens"] + + + # If the generation does not involve a function call + if ( + START_FUNCTION_CALL_TOKEN not in prompt + and START_FUNCTION_CALL_TOKEN not in completion_text + ): + completion["usage"]["completion_tokens"] = completion_tokens + return _convert_completion_to_chat(completion, stream=stream) # type: ignore + # If the generation involves a function call in completion, generate the parameters + elif ( + START_FUNCTION_CALL_TOKEN not in prompt + and START_FUNCTION_CALL_TOKEN in completion_text + ): + prompt += ( + completion_text.replace( + f"{START_FUNCTION_CALL_TOKEN} ", START_FUNCTION_CALL_TOKEN + ) + + "\n" + ) + function_calls.append( + completion_text.split(START_FUNCTION_CALL_TOKEN)[-1][:-1].strip() + ) + grammar = get_grammar(function_calls[-1]) + completion = create_completion(stop=END_FUNCTION_CALL_TOKEN) + completion_tokens += completion["usage"]["completion_tokens"] + function_bodies.append(completion["choices"][0]["text"].strip()) + # If the prompt involves a function call, just append generated parameters to function_bodies + else: + function_bodies.append(completion_text.strip()) else: - function_call_dict["function_call"] = { - "name": tool_calls[0]["function"]["name"], - "arguments": tool_calls[0]["function"]["arguments"], - } - completion["usage"]["completion_tokens"] = completion_tokens - return llama_types.CreateChatCompletionResponse( - id="chat" + completion["id"], - object="chat.completion", - created=completion["created"], - model=completion["model"], - choices=[ - { - "index": 0, - "logprobs": completion["choices"][0]["logprobs"], - "message": { - "role": "assistant", - "content": None if content == "" else content, - **function_call_dict, - }, - "finish_reason": "tool_calls" if len(tool_calls) > 0 else "stop", - } - ], - usage=completion["usage"], - ) + # If tool_choice/function_call is provided + if isinstance(function_call, dict): + prompt += f"{function_call['name']}\n{CONTENT_TOKEN}" + function_call = function_call["name"] + function_calls.append(function_call) + grammar = get_grammar(function_call) + stops = [STOP_TOKEN, FROM_TOKEN] + completion = create_completion(stop=stops) + completion_text = completion["choices"][0]["text"] + completion_tokens += completion["usage"]["completion_tokens"] + function_bodies.append(completion_text.strip()) + # If "auto" or no tool_choice/function_call + elif isinstance(function_call, str) and function_call == "auto": + while True: + # Generate function name first + grammar = None + stops = CONTENT_TOKEN + completion = create_completion(stop=stops) + completion_text = completion["choices"][0]["text"] + completion_tokens += completion["usage"]["completion_tokens"] + function_name = completion_text.strip() + if function_name == "all": + prompt += "all\n<|content|>" + else: + function_call = completion_text.strip() + prompt += f"{function_call}\n<|content|>" + function_calls.append(function_call) + grammar = get_grammar(function_call) + # Generate content + stops = [RECIPIENT_TOKEN, STOP_TOKEN] + completion = create_completion(stop=stops) + completion_text = completion["choices"][0]["text"] + completion_tokens += completion["usage"]["completion_tokens"] + if function_name == "all": + if completion_text.endswith("\n<|from|>assistant\n"): + content += completion_text[:-len("\n<|from|>assistant\n")] + if completion_text.endswith("\n<|from|> assistant\n"): + content += completion_text[-len("\n<|from|> assistant\n")] + else: + content += completion_text + content = content.lstrip() + # Check whether the model wants to generate another turn + if "<|from|> assistant" in completion_text or "<|from|>assistant" in completion_text: + if completion_text.endswith("\n<|from|>assistant\n"): + cleaned_completion_text = completion_text[:-len("\n<|from|>assistant\n")].strip() + elif completion_text.endswith("\n<|from|> assistant\n"): + cleaned_completion_text = completion_text[-len("\n<|from|> assistant\n")].strip() + else: + cleaned_completion_text = completion_text.strip() + prompt += f"{cleaned_completion_text}\n<|from|>assistant\n<|recipient|>" + else: + break + else: + function_bodies.append(completion_text.strip()) + # Check whether the model wants to generate another turn + prompt += completion_text.strip() + grammar = None + completion = create_completion(stop=stops) + completion_tokens += completion["usage"]["completion_tokens"] + if "<|from|> assistant" in completion["choices"][0]["text"] or "<|from|>assistant" in completion["choices"][0]["text"]: + prompt += "\n<|from|>assistant\n<|recipient|>" + else: + break + + assert "usage" in completion + assert len(function_calls) == len(function_bodies) + + tool_calls: List[llama_types.ChatCompletionMessageToolCall] = [] + for function_call, function_body in zip(function_calls, function_bodies): + tool_calls.append( + { + "id": "call_" + + "".join( + [ + random.choice(string.ascii_letters + string.digits) + for _ in range(24) + ] + ), + "type": "function", + "function": { + "name": function_call, + "arguments": function_body, + }, + } + ) + + # TODO: support stream mode + function_call_dict: Union[Dict[str, str], Dict[Literal["function_call"], llama_types.ChatCompletionRequestAssistantMessageFunctionCall]] = {} + if len(tool_calls) > 0: + if tools is not None: + function_call_dict["tool_calls"] = tool_calls + else: + function_call_dict["function_call"] = { + "name": tool_calls[0]["function"]["name"], + "arguments": tool_calls[0]["function"]["arguments"], + } + completion["usage"]["completion_tokens"] = completion_tokens + return llama_types.CreateChatCompletionResponse( + id="chat" + completion["id"], + object="chat.completion", + created=completion["created"], + model=completion["model"], + choices=[ + { + "index": 0, + "logprobs": completion["choices"][0]["logprobs"], + "message": { + "role": "assistant", + "content": None if content == "" else content, + **function_call_dict, + }, + "finish_reason": "tool_calls" if len(tool_calls) > 0 else "stop", + } + ], + usage=completion["usage"], + ) class Llava15ChatHandler: