fix: pass correct type to chat handlers for chat completion logprobs
This commit is contained in:
parent
060bfa64d5
commit
bb65b4d764
2 changed files with 18 additions and 9 deletions
|
@ -1664,7 +1664,8 @@ class Llama:
|
|||
top_k=top_k,
|
||||
min_p=min_p,
|
||||
typical_p=typical_p,
|
||||
logprobs=top_logprobs if logprobs else None,
|
||||
logprobs=logprobs,
|
||||
top_logprobs=top_logprobs,
|
||||
stream=stream,
|
||||
stop=stop,
|
||||
seed=seed,
|
||||
|
|
|
@ -77,6 +77,8 @@ class LlamaChatCompletionHandler(Protocol):
|
|||
mirostat_eta: float = 0.1,
|
||||
logits_processor: Optional[llama.LogitsProcessorList] = None,
|
||||
grammar: Optional[llama.LlamaGrammar] = None,
|
||||
logprobs: Optional[bool] = None,
|
||||
top_logprobs: Optional[int] = None,
|
||||
**kwargs, # type: ignore
|
||||
) -> Union[
|
||||
llama_types.CreateChatCompletionResponse,
|
||||
|
@ -338,7 +340,7 @@ def _convert_completion_to_chat_function(
|
|||
}
|
||||
],
|
||||
},
|
||||
"logprobs": None,
|
||||
"logprobs": completion["choices"][0]["logprobs"],
|
||||
"finish_reason": "tool_calls",
|
||||
}
|
||||
],
|
||||
|
@ -391,7 +393,7 @@ def _convert_completion_to_chat_function(
|
|||
{
|
||||
"index": 0,
|
||||
"finish_reason": None,
|
||||
"logprobs": None,
|
||||
"logprobs": chunk["choices"][0]["logprobs"],
|
||||
"delta": {
|
||||
"role": None,
|
||||
"content": None,
|
||||
|
@ -426,7 +428,7 @@ def _convert_completion_to_chat_function(
|
|||
{
|
||||
"index": 0,
|
||||
"finish_reason": None,
|
||||
"logprobs": None,
|
||||
"logprobs": chunk["choices"][0]["logprobs"],
|
||||
"delta": {
|
||||
"role": None,
|
||||
"content": None,
|
||||
|
@ -491,7 +493,6 @@ def chat_formatter_to_chat_completion_handler(
|
|||
temperature: float = 0.2,
|
||||
top_p: float = 0.95,
|
||||
top_k: int = 40,
|
||||
logprobs: int = 0,
|
||||
min_p: float = 0.05,
|
||||
typical_p: float = 1.0,
|
||||
stream: bool = False,
|
||||
|
@ -512,6 +513,8 @@ def chat_formatter_to_chat_completion_handler(
|
|||
logits_processor: Optional[llama.LogitsProcessorList] = None,
|
||||
grammar: Optional[llama.LlamaGrammar] = None,
|
||||
logit_bias: Optional[Dict[str, float]] = None,
|
||||
logprobs: Optional[bool] = None,
|
||||
top_logprobs: Optional[int] = None,
|
||||
**kwargs, # type: ignore
|
||||
) -> Union[
|
||||
llama_types.CreateChatCompletionResponse,
|
||||
|
@ -581,7 +584,7 @@ def chat_formatter_to_chat_completion_handler(
|
|||
top_k=top_k,
|
||||
min_p=min_p,
|
||||
typical_p=typical_p,
|
||||
logprobs=logprobs,
|
||||
logprobs=top_logprobs if logprobs else None,
|
||||
stream=stream,
|
||||
stop=stop,
|
||||
seed=seed,
|
||||
|
@ -1628,7 +1631,7 @@ def functionary_chat_handler(
|
|||
}
|
||||
],
|
||||
},
|
||||
"logprobs": None,
|
||||
"logprobs": completion["choices"][0]["logprobs"],
|
||||
"finish_reason": "tool_calls",
|
||||
}
|
||||
],
|
||||
|
@ -2085,7 +2088,7 @@ def functionary_v1_v2_chat_handler(
|
|||
choices=[
|
||||
{
|
||||
"index": 0,
|
||||
"logprobs": None,
|
||||
"logprobs": completion["choices"][0]["logprobs"],
|
||||
"message": {
|
||||
"role": "assistant",
|
||||
"content": None if content == "" else content,
|
||||
|
@ -2311,11 +2314,14 @@ def chatml_function_calling(
|
|||
model: Optional[str] = None,
|
||||
logits_processor: Optional[llama.LogitsProcessorList] = None,
|
||||
grammar: Optional[llama.LlamaGrammar] = None,
|
||||
logprobs: Optional[bool] = None,
|
||||
top_logprobs: Optional[int] = None,
|
||||
**kwargs, # type: ignore
|
||||
) -> Union[
|
||||
llama_types.CreateChatCompletionResponse,
|
||||
Iterator[llama_types.CreateChatCompletionStreamResponse],
|
||||
]:
|
||||
print(logprobs)
|
||||
function_calling_template = (
|
||||
"{% for message in messages %}"
|
||||
"<|im_start|>{{ message.role }}\n"
|
||||
|
@ -2437,6 +2443,7 @@ def chatml_function_calling(
|
|||
model=model,
|
||||
logits_processor=logits_processor,
|
||||
grammar=grammar,
|
||||
logprobs=top_logprobs if logprobs else None,
|
||||
),
|
||||
stream=stream,
|
||||
)
|
||||
|
@ -2549,6 +2556,7 @@ def chatml_function_calling(
|
|||
typical_p=typical_p,
|
||||
stream=stream,
|
||||
stop=["<|im_end|>"],
|
||||
logprobs=top_logprobs if logprobs else None,
|
||||
max_tokens=None,
|
||||
presence_penalty=presence_penalty,
|
||||
frequency_penalty=frequency_penalty,
|
||||
|
@ -2660,7 +2668,7 @@ def chatml_function_calling(
|
|||
{
|
||||
"finish_reason": "tool_calls",
|
||||
"index": 0,
|
||||
"logprobs": None,
|
||||
"logprobs": completion["choices"][0]["logprobs"],
|
||||
"message": {
|
||||
"role": "assistant",
|
||||
"content": None,
|
||||
|
|
Loading…
Reference in a new issue