fix: missing logprobs in response, incorrect response type for functionary, minor type issues. Closes #1328 #1314

This commit is contained in:
Andrei Betlen 2024-04-05 10:50:49 -04:00
parent 9111b6e03a
commit 49bc66bfa2

View file

@ -6,7 +6,7 @@ import ctypes
import dataclasses import dataclasses
import random import random
import string import string
from typing import Any, Dict, Iterator, List, Literal, Optional, Tuple, Union, Protocol from typing import Any, Dict, Iterator, List, Literal, Optional, Tuple, Union, Protocol, cast
import jinja2 import jinja2
@ -338,6 +338,7 @@ def _convert_completion_to_chat_function(
} }
], ],
}, },
"logprobs": None,
"finish_reason": "tool_calls", "finish_reason": "tool_calls",
} }
], ],
@ -1191,7 +1192,6 @@ def format_mistral_instruct(
elif ( elif (
message["role"] == "assistant" message["role"] == "assistant"
and message["content"] is not None and message["content"] is not None
and isinstance(message["content"], str)
): ):
prompt += " [/INST]" + message["content"] + eos prompt += " [/INST]" + message["content"] + eos
prompt += " [/INST]" prompt += " [/INST]"
@ -1263,7 +1263,7 @@ def format_gemma(
**kwargs: Any, **kwargs: Any,
) -> ChatFormatterResponse: ) -> ChatFormatterResponse:
system_message = _get_system_message(messages) system_message = _get_system_message(messages)
if system_message is not None and system_message != "": if system_message != "":
logger.debug( logger.debug(
"`role='system'` messages are not allowed on Google's Gemma models." "`role='system'` messages are not allowed on Google's Gemma models."
) )
@ -1628,6 +1628,7 @@ def functionary_chat_handler(
} }
], ],
}, },
"logprobs": None,
"finish_reason": "tool_calls", "finish_reason": "tool_calls",
} }
], ],
@ -1909,14 +1910,14 @@ def functionary_v1_v2_chat_handler(
return grammar return grammar
def create_completion(stop): def create_completion(stop):
completion: llama_types.Completion = llama.create_completion( completion = cast(llama_types.Completion, llama.create_completion(
prompt=prompt, prompt=prompt,
temperature=temperature, temperature=temperature,
top_p=top_p, top_p=top_p,
top_k=top_k, top_k=top_k,
min_p=min_p, min_p=min_p,
typical_p=typical_p, typical_p=typical_p,
stream=stream, stream=False,
stop=stop, stop=stop,
max_tokens=max_tokens, max_tokens=max_tokens,
presence_penalty=presence_penalty, presence_penalty=presence_penalty,
@ -1929,7 +1930,7 @@ def functionary_v1_v2_chat_handler(
model=model, model=model,
logits_processor=logits_processor, logits_processor=logits_processor,
grammar=grammar, grammar=grammar,
) ))
return completion return completion
@ -2050,7 +2051,7 @@ def functionary_v1_v2_chat_handler(
assert "usage" in completion assert "usage" in completion
assert len(function_calls) == len(function_bodies) assert len(function_calls) == len(function_bodies)
tool_calls = [] tool_calls: List[llama_types.ChatCompletionMessageToolCall] = []
for function_call, function_body in zip(function_calls, function_bodies): for function_call, function_body in zip(function_calls, function_bodies):
tool_calls.append( tool_calls.append(
{ {
@ -2070,6 +2071,12 @@ def functionary_v1_v2_chat_handler(
) )
# TODO: support stream mode # TODO: support stream mode
function_call_dict: Union[Dict[str, str], Dict[Literal["function_call"], llama_types.ChatCompletionRequestAssistantMessageFunctionCall]] = {
"function_call": {
"name": tool_calls[0]["function"]["name"],
"arguments": tool_calls[0]["function"]["arguments"],
}
} if len(tool_calls) == 1 else {}
return llama_types.CreateChatCompletionResponse( return llama_types.CreateChatCompletionResponse(
id="chat" + completion["id"], id="chat" + completion["id"],
object="chat.completion", object="chat.completion",
@ -2078,14 +2085,12 @@ def functionary_v1_v2_chat_handler(
choices=[ choices=[
{ {
"index": 0, "index": 0,
"logprobs": None,
"message": { "message": {
"role": "assistant", "role": "assistant",
"content": None if content == "" else content, "content": None if content == "" else content,
"function_call": { "tool_calls": tool_calls,
"name": tool_calls[0]["function"]["name"], **function_call_dict,
"arguments": tool_calls[0]["function"]["arguments"],
} if len(tool_calls) > 0 else None,
"tool_calls": tool_calls if len(tool_calls) > 0 else None,
}, },
"finish_reason": "tool_calls" if len(tool_calls) > 0 else "stop", "finish_reason": "tool_calls" if len(tool_calls) > 0 else "stop",
} }
@ -2565,8 +2570,8 @@ def chatml_function_calling(
tool_name = text[len("functions.") :] tool_name = text[len("functions.") :]
tool = next((tool for tool in tools if tool["function"]["name"] == tool_name), None) tool = next((tool for tool in tools if tool["function"]["name"] == tool_name), None)
if not stream: if not stream:
completions = [] completions: List[llama_types.CreateCompletionResponse] = []
completions_tool_name = [] completions_tool_name: List[str] = []
while tool is not None: while tool is not None:
prompt += f"functions.{tool_name}:\n" prompt += f"functions.{tool_name}:\n"
try: try:
@ -2603,6 +2608,7 @@ def chatml_function_calling(
logits_processor=logits_processor, logits_processor=logits_processor,
grammar=grammar, grammar=grammar,
) )
completion_or_chunks = cast(llama_types.CreateCompletionResponse, completion_or_chunks)
completions.append(completion_or_chunks) completions.append(completion_or_chunks)
completions_tool_name.append(tool_name) completions_tool_name.append(tool_name)
prompt += completion_or_chunks["choices"][0]["text"] prompt += completion_or_chunks["choices"][0]["text"]
@ -2631,6 +2637,7 @@ def chatml_function_calling(
follow_up_gbnf_tool_grammar, verbose=llama.verbose follow_up_gbnf_tool_grammar, verbose=llama.verbose
), ),
) )
response = cast(llama_types.CreateCompletionResponse, response)
tool_name = response["choices"][0]["text"][len("functions.") :] tool_name = response["choices"][0]["text"][len("functions.") :]
tool = next( tool = next(
@ -2638,7 +2645,7 @@ def chatml_function_calling(
) )
# Merge completions # Merge completions
function_call = { function_call_dict: Union[Dict[str, str], Dict[Literal["function_call"], llama_types.ChatCompletionRequestAssistantMessageFunctionCall]] = {
"function_call": { "function_call": {
"name": tool_name, "name": tool_name,
"arguments": completions[0]["choices"][0]["text"], "arguments": completions[0]["choices"][0]["text"],
@ -2653,6 +2660,7 @@ def chatml_function_calling(
{ {
"finish_reason": "tool_calls", "finish_reason": "tool_calls",
"index": 0, "index": 0,
"logprobs": None,
"message": { "message": {
"role": "assistant", "role": "assistant",
"content": None, "content": None,
@ -2673,20 +2681,22 @@ def chatml_function_calling(
zip(completions_tool_name, completions) zip(completions_tool_name, completions)
) )
], ],
**function_call **function_call_dict
}, },
} }
], ],
"usage": { "usage": {
"completion_tokens": sum( "completion_tokens": sum(
completion["usage"]["completion_tokens"] completion["usage"]["completion_tokens"] if "usage" in completion else 0
for completion in completions for completion in completions
), ),
"prompt_tokens": sum( "prompt_tokens": sum(
completion["usage"]["prompt_tokens"] for completion in completions completion["usage"]["prompt_tokens"] if "usage" in completion else 0
for completion in completions
), ),
"total_tokens": sum( "total_tokens": sum(
completion["usage"]["total_tokens"] for completion in completions completion["usage"]["total_tokens"] if "usage" in completion else 0
for completion in completions
), ),
}, },
} }