fix: missing logprobs in response, incorrect response type for functionary, minor type issues. Closes #1328 #1314
This commit is contained in:
parent
9111b6e03a
commit
49bc66bfa2
1 changed files with 29 additions and 19 deletions
|
@ -6,7 +6,7 @@ import ctypes
|
|||
import dataclasses
|
||||
import random
|
||||
import string
|
||||
from typing import Any, Dict, Iterator, List, Literal, Optional, Tuple, Union, Protocol
|
||||
from typing import Any, Dict, Iterator, List, Literal, Optional, Tuple, Union, Protocol, cast
|
||||
|
||||
import jinja2
|
||||
|
||||
|
@ -338,6 +338,7 @@ def _convert_completion_to_chat_function(
|
|||
}
|
||||
],
|
||||
},
|
||||
"logprobs": None,
|
||||
"finish_reason": "tool_calls",
|
||||
}
|
||||
],
|
||||
|
@ -1191,7 +1192,6 @@ def format_mistral_instruct(
|
|||
elif (
|
||||
message["role"] == "assistant"
|
||||
and message["content"] is not None
|
||||
and isinstance(message["content"], str)
|
||||
):
|
||||
prompt += " [/INST]" + message["content"] + eos
|
||||
prompt += " [/INST]"
|
||||
|
@ -1263,7 +1263,7 @@ def format_gemma(
|
|||
**kwargs: Any,
|
||||
) -> ChatFormatterResponse:
|
||||
system_message = _get_system_message(messages)
|
||||
if system_message is not None and system_message != "":
|
||||
if system_message != "":
|
||||
logger.debug(
|
||||
"`role='system'` messages are not allowed on Google's Gemma models."
|
||||
)
|
||||
|
@ -1628,6 +1628,7 @@ def functionary_chat_handler(
|
|||
}
|
||||
],
|
||||
},
|
||||
"logprobs": None,
|
||||
"finish_reason": "tool_calls",
|
||||
}
|
||||
],
|
||||
|
@ -1909,14 +1910,14 @@ def functionary_v1_v2_chat_handler(
|
|||
return grammar
|
||||
|
||||
def create_completion(stop):
|
||||
completion: llama_types.Completion = llama.create_completion(
|
||||
completion = cast(llama_types.Completion, llama.create_completion(
|
||||
prompt=prompt,
|
||||
temperature=temperature,
|
||||
top_p=top_p,
|
||||
top_k=top_k,
|
||||
min_p=min_p,
|
||||
typical_p=typical_p,
|
||||
stream=stream,
|
||||
stream=False,
|
||||
stop=stop,
|
||||
max_tokens=max_tokens,
|
||||
presence_penalty=presence_penalty,
|
||||
|
@ -1929,7 +1930,7 @@ def functionary_v1_v2_chat_handler(
|
|||
model=model,
|
||||
logits_processor=logits_processor,
|
||||
grammar=grammar,
|
||||
)
|
||||
))
|
||||
|
||||
return completion
|
||||
|
||||
|
@ -2050,7 +2051,7 @@ def functionary_v1_v2_chat_handler(
|
|||
assert "usage" in completion
|
||||
assert len(function_calls) == len(function_bodies)
|
||||
|
||||
tool_calls = []
|
||||
tool_calls: List[llama_types.ChatCompletionMessageToolCall] = []
|
||||
for function_call, function_body in zip(function_calls, function_bodies):
|
||||
tool_calls.append(
|
||||
{
|
||||
|
@ -2070,6 +2071,12 @@ def functionary_v1_v2_chat_handler(
|
|||
)
|
||||
|
||||
# TODO: support stream mode
|
||||
function_call_dict: Union[Dict[str, str], Dict[Literal["function_call"], llama_types.ChatCompletionRequestAssistantMessageFunctionCall]] = {
|
||||
"function_call": {
|
||||
"name": tool_calls[0]["function"]["name"],
|
||||
"arguments": tool_calls[0]["function"]["arguments"],
|
||||
}
|
||||
} if len(tool_calls) == 1 else {}
|
||||
return llama_types.CreateChatCompletionResponse(
|
||||
id="chat" + completion["id"],
|
||||
object="chat.completion",
|
||||
|
@ -2078,14 +2085,12 @@ def functionary_v1_v2_chat_handler(
|
|||
choices=[
|
||||
{
|
||||
"index": 0,
|
||||
"logprobs": None,
|
||||
"message": {
|
||||
"role": "assistant",
|
||||
"content": None if content == "" else content,
|
||||
"function_call": {
|
||||
"name": tool_calls[0]["function"]["name"],
|
||||
"arguments": tool_calls[0]["function"]["arguments"],
|
||||
} if len(tool_calls) > 0 else None,
|
||||
"tool_calls": tool_calls if len(tool_calls) > 0 else None,
|
||||
"tool_calls": tool_calls,
|
||||
**function_call_dict,
|
||||
},
|
||||
"finish_reason": "tool_calls" if len(tool_calls) > 0 else "stop",
|
||||
}
|
||||
|
@ -2565,8 +2570,8 @@ def chatml_function_calling(
|
|||
tool_name = text[len("functions.") :]
|
||||
tool = next((tool for tool in tools if tool["function"]["name"] == tool_name), None)
|
||||
if not stream:
|
||||
completions = []
|
||||
completions_tool_name = []
|
||||
completions: List[llama_types.CreateCompletionResponse] = []
|
||||
completions_tool_name: List[str] = []
|
||||
while tool is not None:
|
||||
prompt += f"functions.{tool_name}:\n"
|
||||
try:
|
||||
|
@ -2603,6 +2608,7 @@ def chatml_function_calling(
|
|||
logits_processor=logits_processor,
|
||||
grammar=grammar,
|
||||
)
|
||||
completion_or_chunks = cast(llama_types.CreateCompletionResponse, completion_or_chunks)
|
||||
completions.append(completion_or_chunks)
|
||||
completions_tool_name.append(tool_name)
|
||||
prompt += completion_or_chunks["choices"][0]["text"]
|
||||
|
@ -2631,6 +2637,7 @@ def chatml_function_calling(
|
|||
follow_up_gbnf_tool_grammar, verbose=llama.verbose
|
||||
),
|
||||
)
|
||||
response = cast(llama_types.CreateCompletionResponse, response)
|
||||
|
||||
tool_name = response["choices"][0]["text"][len("functions.") :]
|
||||
tool = next(
|
||||
|
@ -2638,7 +2645,7 @@ def chatml_function_calling(
|
|||
)
|
||||
|
||||
# Merge completions
|
||||
function_call = {
|
||||
function_call_dict: Union[Dict[str, str], Dict[Literal["function_call"], llama_types.ChatCompletionRequestAssistantMessageFunctionCall]] = {
|
||||
"function_call": {
|
||||
"name": tool_name,
|
||||
"arguments": completions[0]["choices"][0]["text"],
|
||||
|
@ -2653,6 +2660,7 @@ def chatml_function_calling(
|
|||
{
|
||||
"finish_reason": "tool_calls",
|
||||
"index": 0,
|
||||
"logprobs": None,
|
||||
"message": {
|
||||
"role": "assistant",
|
||||
"content": None,
|
||||
|
@ -2673,20 +2681,22 @@ def chatml_function_calling(
|
|||
zip(completions_tool_name, completions)
|
||||
)
|
||||
],
|
||||
**function_call
|
||||
**function_call_dict
|
||||
},
|
||||
}
|
||||
],
|
||||
"usage": {
|
||||
"completion_tokens": sum(
|
||||
completion["usage"]["completion_tokens"]
|
||||
completion["usage"]["completion_tokens"] if "usage" in completion else 0
|
||||
for completion in completions
|
||||
),
|
||||
"prompt_tokens": sum(
|
||||
completion["usage"]["prompt_tokens"] for completion in completions
|
||||
completion["usage"]["prompt_tokens"] if "usage" in completion else 0
|
||||
for completion in completions
|
||||
),
|
||||
"total_tokens": sum(
|
||||
completion["usage"]["total_tokens"] for completion in completions
|
||||
completion["usage"]["total_tokens"] if "usage" in completion else 0
|
||||
for completion in completions
|
||||
),
|
||||
},
|
||||
}
|
||||
|
|
Loading…
Add table
Reference in a new issue