fix: pass correct type to chat handlers for chat completion logprobs
This commit is contained in:
parent
060bfa64d5
commit
bb65b4d764
2 changed files with 18 additions and 9 deletions
|
@ -1664,7 +1664,8 @@ class Llama:
|
||||||
top_k=top_k,
|
top_k=top_k,
|
||||||
min_p=min_p,
|
min_p=min_p,
|
||||||
typical_p=typical_p,
|
typical_p=typical_p,
|
||||||
logprobs=top_logprobs if logprobs else None,
|
logprobs=logprobs,
|
||||||
|
top_logprobs=top_logprobs,
|
||||||
stream=stream,
|
stream=stream,
|
||||||
stop=stop,
|
stop=stop,
|
||||||
seed=seed,
|
seed=seed,
|
||||||
|
|
|
@ -77,6 +77,8 @@ class LlamaChatCompletionHandler(Protocol):
|
||||||
mirostat_eta: float = 0.1,
|
mirostat_eta: float = 0.1,
|
||||||
logits_processor: Optional[llama.LogitsProcessorList] = None,
|
logits_processor: Optional[llama.LogitsProcessorList] = None,
|
||||||
grammar: Optional[llama.LlamaGrammar] = None,
|
grammar: Optional[llama.LlamaGrammar] = None,
|
||||||
|
logprobs: Optional[bool] = None,
|
||||||
|
top_logprobs: Optional[int] = None,
|
||||||
**kwargs, # type: ignore
|
**kwargs, # type: ignore
|
||||||
) -> Union[
|
) -> Union[
|
||||||
llama_types.CreateChatCompletionResponse,
|
llama_types.CreateChatCompletionResponse,
|
||||||
|
@ -338,7 +340,7 @@ def _convert_completion_to_chat_function(
|
||||||
}
|
}
|
||||||
],
|
],
|
||||||
},
|
},
|
||||||
"logprobs": None,
|
"logprobs": completion["choices"][0]["logprobs"],
|
||||||
"finish_reason": "tool_calls",
|
"finish_reason": "tool_calls",
|
||||||
}
|
}
|
||||||
],
|
],
|
||||||
|
@ -391,7 +393,7 @@ def _convert_completion_to_chat_function(
|
||||||
{
|
{
|
||||||
"index": 0,
|
"index": 0,
|
||||||
"finish_reason": None,
|
"finish_reason": None,
|
||||||
"logprobs": None,
|
"logprobs": chunk["choices"][0]["logprobs"],
|
||||||
"delta": {
|
"delta": {
|
||||||
"role": None,
|
"role": None,
|
||||||
"content": None,
|
"content": None,
|
||||||
|
@ -426,7 +428,7 @@ def _convert_completion_to_chat_function(
|
||||||
{
|
{
|
||||||
"index": 0,
|
"index": 0,
|
||||||
"finish_reason": None,
|
"finish_reason": None,
|
||||||
"logprobs": None,
|
"logprobs": chunk["choices"][0]["logprobs"],
|
||||||
"delta": {
|
"delta": {
|
||||||
"role": None,
|
"role": None,
|
||||||
"content": None,
|
"content": None,
|
||||||
|
@ -491,7 +493,6 @@ def chat_formatter_to_chat_completion_handler(
|
||||||
temperature: float = 0.2,
|
temperature: float = 0.2,
|
||||||
top_p: float = 0.95,
|
top_p: float = 0.95,
|
||||||
top_k: int = 40,
|
top_k: int = 40,
|
||||||
logprobs: int = 0,
|
|
||||||
min_p: float = 0.05,
|
min_p: float = 0.05,
|
||||||
typical_p: float = 1.0,
|
typical_p: float = 1.0,
|
||||||
stream: bool = False,
|
stream: bool = False,
|
||||||
|
@ -512,6 +513,8 @@ def chat_formatter_to_chat_completion_handler(
|
||||||
logits_processor: Optional[llama.LogitsProcessorList] = None,
|
logits_processor: Optional[llama.LogitsProcessorList] = None,
|
||||||
grammar: Optional[llama.LlamaGrammar] = None,
|
grammar: Optional[llama.LlamaGrammar] = None,
|
||||||
logit_bias: Optional[Dict[str, float]] = None,
|
logit_bias: Optional[Dict[str, float]] = None,
|
||||||
|
logprobs: Optional[bool] = None,
|
||||||
|
top_logprobs: Optional[int] = None,
|
||||||
**kwargs, # type: ignore
|
**kwargs, # type: ignore
|
||||||
) -> Union[
|
) -> Union[
|
||||||
llama_types.CreateChatCompletionResponse,
|
llama_types.CreateChatCompletionResponse,
|
||||||
|
@ -581,7 +584,7 @@ def chat_formatter_to_chat_completion_handler(
|
||||||
top_k=top_k,
|
top_k=top_k,
|
||||||
min_p=min_p,
|
min_p=min_p,
|
||||||
typical_p=typical_p,
|
typical_p=typical_p,
|
||||||
logprobs=logprobs,
|
logprobs=top_logprobs if logprobs else None,
|
||||||
stream=stream,
|
stream=stream,
|
||||||
stop=stop,
|
stop=stop,
|
||||||
seed=seed,
|
seed=seed,
|
||||||
|
@ -1628,7 +1631,7 @@ def functionary_chat_handler(
|
||||||
}
|
}
|
||||||
],
|
],
|
||||||
},
|
},
|
||||||
"logprobs": None,
|
"logprobs": completion["choices"][0]["logprobs"],
|
||||||
"finish_reason": "tool_calls",
|
"finish_reason": "tool_calls",
|
||||||
}
|
}
|
||||||
],
|
],
|
||||||
|
@ -2085,7 +2088,7 @@ def functionary_v1_v2_chat_handler(
|
||||||
choices=[
|
choices=[
|
||||||
{
|
{
|
||||||
"index": 0,
|
"index": 0,
|
||||||
"logprobs": None,
|
"logprobs": completion["choices"][0]["logprobs"],
|
||||||
"message": {
|
"message": {
|
||||||
"role": "assistant",
|
"role": "assistant",
|
||||||
"content": None if content == "" else content,
|
"content": None if content == "" else content,
|
||||||
|
@ -2311,11 +2314,14 @@ def chatml_function_calling(
|
||||||
model: Optional[str] = None,
|
model: Optional[str] = None,
|
||||||
logits_processor: Optional[llama.LogitsProcessorList] = None,
|
logits_processor: Optional[llama.LogitsProcessorList] = None,
|
||||||
grammar: Optional[llama.LlamaGrammar] = None,
|
grammar: Optional[llama.LlamaGrammar] = None,
|
||||||
|
logprobs: Optional[bool] = None,
|
||||||
|
top_logprobs: Optional[int] = None,
|
||||||
**kwargs, # type: ignore
|
**kwargs, # type: ignore
|
||||||
) -> Union[
|
) -> Union[
|
||||||
llama_types.CreateChatCompletionResponse,
|
llama_types.CreateChatCompletionResponse,
|
||||||
Iterator[llama_types.CreateChatCompletionStreamResponse],
|
Iterator[llama_types.CreateChatCompletionStreamResponse],
|
||||||
]:
|
]:
|
||||||
|
print(logprobs)
|
||||||
function_calling_template = (
|
function_calling_template = (
|
||||||
"{% for message in messages %}"
|
"{% for message in messages %}"
|
||||||
"<|im_start|>{{ message.role }}\n"
|
"<|im_start|>{{ message.role }}\n"
|
||||||
|
@ -2437,6 +2443,7 @@ def chatml_function_calling(
|
||||||
model=model,
|
model=model,
|
||||||
logits_processor=logits_processor,
|
logits_processor=logits_processor,
|
||||||
grammar=grammar,
|
grammar=grammar,
|
||||||
|
logprobs=top_logprobs if logprobs else None,
|
||||||
),
|
),
|
||||||
stream=stream,
|
stream=stream,
|
||||||
)
|
)
|
||||||
|
@ -2549,6 +2556,7 @@ def chatml_function_calling(
|
||||||
typical_p=typical_p,
|
typical_p=typical_p,
|
||||||
stream=stream,
|
stream=stream,
|
||||||
stop=["<|im_end|>"],
|
stop=["<|im_end|>"],
|
||||||
|
logprobs=top_logprobs if logprobs else None,
|
||||||
max_tokens=None,
|
max_tokens=None,
|
||||||
presence_penalty=presence_penalty,
|
presence_penalty=presence_penalty,
|
||||||
frequency_penalty=frequency_penalty,
|
frequency_penalty=frequency_penalty,
|
||||||
|
@ -2660,7 +2668,7 @@ def chatml_function_calling(
|
||||||
{
|
{
|
||||||
"finish_reason": "tool_calls",
|
"finish_reason": "tool_calls",
|
||||||
"index": 0,
|
"index": 0,
|
||||||
"logprobs": None,
|
"logprobs": completion["choices"][0]["logprobs"],
|
||||||
"message": {
|
"message": {
|
||||||
"role": "assistant",
|
"role": "assistant",
|
||||||
"content": None,
|
"content": None,
|
||||||
|
|
Loading…
Reference in a new issue