fix: missing logprobs in response, incorrect response type for functionary, minor type issues. Closes #1328 Closes #1314
This commit is contained in:
parent
9111b6e03a
commit
1ae3abbcc3
1 changed files with 29 additions and 19 deletions
|
@ -6,7 +6,7 @@ import ctypes
|
||||||
import dataclasses
|
import dataclasses
|
||||||
import random
|
import random
|
||||||
import string
|
import string
|
||||||
from typing import Any, Dict, Iterator, List, Literal, Optional, Tuple, Union, Protocol
|
from typing import Any, Dict, Iterator, List, Literal, Optional, Tuple, Union, Protocol, cast
|
||||||
|
|
||||||
import jinja2
|
import jinja2
|
||||||
|
|
||||||
|
@ -338,6 +338,7 @@ def _convert_completion_to_chat_function(
|
||||||
}
|
}
|
||||||
],
|
],
|
||||||
},
|
},
|
||||||
|
"logprobs": None,
|
||||||
"finish_reason": "tool_calls",
|
"finish_reason": "tool_calls",
|
||||||
}
|
}
|
||||||
],
|
],
|
||||||
|
@ -1191,7 +1192,6 @@ def format_mistral_instruct(
|
||||||
elif (
|
elif (
|
||||||
message["role"] == "assistant"
|
message["role"] == "assistant"
|
||||||
and message["content"] is not None
|
and message["content"] is not None
|
||||||
and isinstance(message["content"], str)
|
|
||||||
):
|
):
|
||||||
prompt += " [/INST]" + message["content"] + eos
|
prompt += " [/INST]" + message["content"] + eos
|
||||||
prompt += " [/INST]"
|
prompt += " [/INST]"
|
||||||
|
@ -1263,7 +1263,7 @@ def format_gemma(
|
||||||
**kwargs: Any,
|
**kwargs: Any,
|
||||||
) -> ChatFormatterResponse:
|
) -> ChatFormatterResponse:
|
||||||
system_message = _get_system_message(messages)
|
system_message = _get_system_message(messages)
|
||||||
if system_message is not None and system_message != "":
|
if system_message != "":
|
||||||
logger.debug(
|
logger.debug(
|
||||||
"`role='system'` messages are not allowed on Google's Gemma models."
|
"`role='system'` messages are not allowed on Google's Gemma models."
|
||||||
)
|
)
|
||||||
|
@ -1628,6 +1628,7 @@ def functionary_chat_handler(
|
||||||
}
|
}
|
||||||
],
|
],
|
||||||
},
|
},
|
||||||
|
"logprobs": None,
|
||||||
"finish_reason": "tool_calls",
|
"finish_reason": "tool_calls",
|
||||||
}
|
}
|
||||||
],
|
],
|
||||||
|
@ -1909,14 +1910,14 @@ def functionary_v1_v2_chat_handler(
|
||||||
return grammar
|
return grammar
|
||||||
|
|
||||||
def create_completion(stop):
|
def create_completion(stop):
|
||||||
completion: llama_types.Completion = llama.create_completion(
|
completion = cast(llama_types.Completion, llama.create_completion(
|
||||||
prompt=prompt,
|
prompt=prompt,
|
||||||
temperature=temperature,
|
temperature=temperature,
|
||||||
top_p=top_p,
|
top_p=top_p,
|
||||||
top_k=top_k,
|
top_k=top_k,
|
||||||
min_p=min_p,
|
min_p=min_p,
|
||||||
typical_p=typical_p,
|
typical_p=typical_p,
|
||||||
stream=stream,
|
stream=False,
|
||||||
stop=stop,
|
stop=stop,
|
||||||
max_tokens=max_tokens,
|
max_tokens=max_tokens,
|
||||||
presence_penalty=presence_penalty,
|
presence_penalty=presence_penalty,
|
||||||
|
@ -1929,7 +1930,7 @@ def functionary_v1_v2_chat_handler(
|
||||||
model=model,
|
model=model,
|
||||||
logits_processor=logits_processor,
|
logits_processor=logits_processor,
|
||||||
grammar=grammar,
|
grammar=grammar,
|
||||||
)
|
))
|
||||||
|
|
||||||
return completion
|
return completion
|
||||||
|
|
||||||
|
@ -2050,7 +2051,7 @@ def functionary_v1_v2_chat_handler(
|
||||||
assert "usage" in completion
|
assert "usage" in completion
|
||||||
assert len(function_calls) == len(function_bodies)
|
assert len(function_calls) == len(function_bodies)
|
||||||
|
|
||||||
tool_calls = []
|
tool_calls: List[llama_types.ChatCompletionMessageToolCall] = []
|
||||||
for function_call, function_body in zip(function_calls, function_bodies):
|
for function_call, function_body in zip(function_calls, function_bodies):
|
||||||
tool_calls.append(
|
tool_calls.append(
|
||||||
{
|
{
|
||||||
|
@ -2070,6 +2071,12 @@ def functionary_v1_v2_chat_handler(
|
||||||
)
|
)
|
||||||
|
|
||||||
# TODO: support stream mode
|
# TODO: support stream mode
|
||||||
|
function_call_dict: Union[Dict[str, str], Dict[Literal["function_call"], llama_types.ChatCompletionRequestAssistantMessageFunctionCall]] = {
|
||||||
|
"function_call": {
|
||||||
|
"name": tool_calls[0]["function"]["name"],
|
||||||
|
"arguments": tool_calls[0]["function"]["arguments"],
|
||||||
|
}
|
||||||
|
} if len(tool_calls) == 1 else {}
|
||||||
return llama_types.CreateChatCompletionResponse(
|
return llama_types.CreateChatCompletionResponse(
|
||||||
id="chat" + completion["id"],
|
id="chat" + completion["id"],
|
||||||
object="chat.completion",
|
object="chat.completion",
|
||||||
|
@ -2078,14 +2085,12 @@ def functionary_v1_v2_chat_handler(
|
||||||
choices=[
|
choices=[
|
||||||
{
|
{
|
||||||
"index": 0,
|
"index": 0,
|
||||||
|
"logprobs": None,
|
||||||
"message": {
|
"message": {
|
||||||
"role": "assistant",
|
"role": "assistant",
|
||||||
"content": None if content == "" else content,
|
"content": None if content == "" else content,
|
||||||
"function_call": {
|
"tool_calls": tool_calls,
|
||||||
"name": tool_calls[0]["function"]["name"],
|
**function_call_dict,
|
||||||
"arguments": tool_calls[0]["function"]["arguments"],
|
|
||||||
} if len(tool_calls) > 0 else None,
|
|
||||||
"tool_calls": tool_calls if len(tool_calls) > 0 else None,
|
|
||||||
},
|
},
|
||||||
"finish_reason": "tool_calls" if len(tool_calls) > 0 else "stop",
|
"finish_reason": "tool_calls" if len(tool_calls) > 0 else "stop",
|
||||||
}
|
}
|
||||||
|
@ -2565,8 +2570,8 @@ def chatml_function_calling(
|
||||||
tool_name = text[len("functions.") :]
|
tool_name = text[len("functions.") :]
|
||||||
tool = next((tool for tool in tools if tool["function"]["name"] == tool_name), None)
|
tool = next((tool for tool in tools if tool["function"]["name"] == tool_name), None)
|
||||||
if not stream:
|
if not stream:
|
||||||
completions = []
|
completions: List[llama_types.CreateCompletionResponse] = []
|
||||||
completions_tool_name = []
|
completions_tool_name: List[str] = []
|
||||||
while tool is not None:
|
while tool is not None:
|
||||||
prompt += f"functions.{tool_name}:\n"
|
prompt += f"functions.{tool_name}:\n"
|
||||||
try:
|
try:
|
||||||
|
@ -2603,6 +2608,7 @@ def chatml_function_calling(
|
||||||
logits_processor=logits_processor,
|
logits_processor=logits_processor,
|
||||||
grammar=grammar,
|
grammar=grammar,
|
||||||
)
|
)
|
||||||
|
completion_or_chunks = cast(llama_types.CreateCompletionResponse, completion_or_chunks)
|
||||||
completions.append(completion_or_chunks)
|
completions.append(completion_or_chunks)
|
||||||
completions_tool_name.append(tool_name)
|
completions_tool_name.append(tool_name)
|
||||||
prompt += completion_or_chunks["choices"][0]["text"]
|
prompt += completion_or_chunks["choices"][0]["text"]
|
||||||
|
@ -2631,6 +2637,7 @@ def chatml_function_calling(
|
||||||
follow_up_gbnf_tool_grammar, verbose=llama.verbose
|
follow_up_gbnf_tool_grammar, verbose=llama.verbose
|
||||||
),
|
),
|
||||||
)
|
)
|
||||||
|
response = cast(llama_types.CreateCompletionResponse, response)
|
||||||
|
|
||||||
tool_name = response["choices"][0]["text"][len("functions.") :]
|
tool_name = response["choices"][0]["text"][len("functions.") :]
|
||||||
tool = next(
|
tool = next(
|
||||||
|
@ -2638,7 +2645,7 @@ def chatml_function_calling(
|
||||||
)
|
)
|
||||||
|
|
||||||
# Merge completions
|
# Merge completions
|
||||||
function_call = {
|
function_call_dict: Union[Dict[str, str], Dict[Literal["function_call"], llama_types.ChatCompletionRequestAssistantMessageFunctionCall]] = {
|
||||||
"function_call": {
|
"function_call": {
|
||||||
"name": tool_name,
|
"name": tool_name,
|
||||||
"arguments": completions[0]["choices"][0]["text"],
|
"arguments": completions[0]["choices"][0]["text"],
|
||||||
|
@ -2653,6 +2660,7 @@ def chatml_function_calling(
|
||||||
{
|
{
|
||||||
"finish_reason": "tool_calls",
|
"finish_reason": "tool_calls",
|
||||||
"index": 0,
|
"index": 0,
|
||||||
|
"logprobs": None,
|
||||||
"message": {
|
"message": {
|
||||||
"role": "assistant",
|
"role": "assistant",
|
||||||
"content": None,
|
"content": None,
|
||||||
|
@ -2673,20 +2681,22 @@ def chatml_function_calling(
|
||||||
zip(completions_tool_name, completions)
|
zip(completions_tool_name, completions)
|
||||||
)
|
)
|
||||||
],
|
],
|
||||||
**function_call
|
**function_call_dict
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
],
|
],
|
||||||
"usage": {
|
"usage": {
|
||||||
"completion_tokens": sum(
|
"completion_tokens": sum(
|
||||||
completion["usage"]["completion_tokens"]
|
completion["usage"]["completion_tokens"] if "usage" in completion else 0
|
||||||
for completion in completions
|
for completion in completions
|
||||||
),
|
),
|
||||||
"prompt_tokens": sum(
|
"prompt_tokens": sum(
|
||||||
completion["usage"]["prompt_tokens"] for completion in completions
|
completion["usage"]["prompt_tokens"] if "usage" in completion else 0
|
||||||
|
for completion in completions
|
||||||
),
|
),
|
||||||
"total_tokens": sum(
|
"total_tokens": sum(
|
||||||
completion["usage"]["total_tokens"] for completion in completions
|
completion["usage"]["total_tokens"] if "usage" in completion else 0
|
||||||
|
for completion in completions
|
||||||
),
|
),
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
|
|
Loading…
Reference in a new issue