fix: pass correct type to chat handlers for chat completion logprobs

This commit is contained in:
Andrei Betlen 2024-04-10 03:41:55 -04:00
parent 060bfa64d5
commit bb65b4d764
2 changed files with 18 additions and 9 deletions

View file

@ -1664,7 +1664,8 @@ class Llama:
top_k=top_k, top_k=top_k,
min_p=min_p, min_p=min_p,
typical_p=typical_p, typical_p=typical_p,
logprobs=top_logprobs if logprobs else None, logprobs=logprobs,
top_logprobs=top_logprobs,
stream=stream, stream=stream,
stop=stop, stop=stop,
seed=seed, seed=seed,

View file

@ -77,6 +77,8 @@ class LlamaChatCompletionHandler(Protocol):
mirostat_eta: float = 0.1, mirostat_eta: float = 0.1,
logits_processor: Optional[llama.LogitsProcessorList] = None, logits_processor: Optional[llama.LogitsProcessorList] = None,
grammar: Optional[llama.LlamaGrammar] = None, grammar: Optional[llama.LlamaGrammar] = None,
logprobs: Optional[bool] = None,
top_logprobs: Optional[int] = None,
**kwargs, # type: ignore **kwargs, # type: ignore
) -> Union[ ) -> Union[
llama_types.CreateChatCompletionResponse, llama_types.CreateChatCompletionResponse,
@ -338,7 +340,7 @@ def _convert_completion_to_chat_function(
} }
], ],
}, },
"logprobs": None, "logprobs": completion["choices"][0]["logprobs"],
"finish_reason": "tool_calls", "finish_reason": "tool_calls",
} }
], ],
@ -391,7 +393,7 @@ def _convert_completion_to_chat_function(
{ {
"index": 0, "index": 0,
"finish_reason": None, "finish_reason": None,
"logprobs": None, "logprobs": chunk["choices"][0]["logprobs"],
"delta": { "delta": {
"role": None, "role": None,
"content": None, "content": None,
@ -426,7 +428,7 @@ def _convert_completion_to_chat_function(
{ {
"index": 0, "index": 0,
"finish_reason": None, "finish_reason": None,
"logprobs": None, "logprobs": chunk["choices"][0]["logprobs"],
"delta": { "delta": {
"role": None, "role": None,
"content": None, "content": None,
@ -491,7 +493,6 @@ def chat_formatter_to_chat_completion_handler(
temperature: float = 0.2, temperature: float = 0.2,
top_p: float = 0.95, top_p: float = 0.95,
top_k: int = 40, top_k: int = 40,
logprobs: int = 0,
min_p: float = 0.05, min_p: float = 0.05,
typical_p: float = 1.0, typical_p: float = 1.0,
stream: bool = False, stream: bool = False,
@ -512,6 +513,8 @@ def chat_formatter_to_chat_completion_handler(
logits_processor: Optional[llama.LogitsProcessorList] = None, logits_processor: Optional[llama.LogitsProcessorList] = None,
grammar: Optional[llama.LlamaGrammar] = None, grammar: Optional[llama.LlamaGrammar] = None,
logit_bias: Optional[Dict[str, float]] = None, logit_bias: Optional[Dict[str, float]] = None,
logprobs: Optional[bool] = None,
top_logprobs: Optional[int] = None,
**kwargs, # type: ignore **kwargs, # type: ignore
) -> Union[ ) -> Union[
llama_types.CreateChatCompletionResponse, llama_types.CreateChatCompletionResponse,
@ -581,7 +584,7 @@ def chat_formatter_to_chat_completion_handler(
top_k=top_k, top_k=top_k,
min_p=min_p, min_p=min_p,
typical_p=typical_p, typical_p=typical_p,
logprobs=logprobs, logprobs=top_logprobs if logprobs else None,
stream=stream, stream=stream,
stop=stop, stop=stop,
seed=seed, seed=seed,
@ -1628,7 +1631,7 @@ def functionary_chat_handler(
} }
], ],
}, },
"logprobs": None, "logprobs": completion["choices"][0]["logprobs"],
"finish_reason": "tool_calls", "finish_reason": "tool_calls",
} }
], ],
@ -2085,7 +2088,7 @@ def functionary_v1_v2_chat_handler(
choices=[ choices=[
{ {
"index": 0, "index": 0,
"logprobs": None, "logprobs": completion["choices"][0]["logprobs"],
"message": { "message": {
"role": "assistant", "role": "assistant",
"content": None if content == "" else content, "content": None if content == "" else content,
@ -2311,11 +2314,14 @@ def chatml_function_calling(
model: Optional[str] = None, model: Optional[str] = None,
logits_processor: Optional[llama.LogitsProcessorList] = None, logits_processor: Optional[llama.LogitsProcessorList] = None,
grammar: Optional[llama.LlamaGrammar] = None, grammar: Optional[llama.LlamaGrammar] = None,
logprobs: Optional[bool] = None,
top_logprobs: Optional[int] = None,
**kwargs, # type: ignore **kwargs, # type: ignore
) -> Union[ ) -> Union[
llama_types.CreateChatCompletionResponse, llama_types.CreateChatCompletionResponse,
Iterator[llama_types.CreateChatCompletionStreamResponse], Iterator[llama_types.CreateChatCompletionStreamResponse],
]: ]:
print(logprobs)
function_calling_template = ( function_calling_template = (
"{% for message in messages %}" "{% for message in messages %}"
"<|im_start|>{{ message.role }}\n" "<|im_start|>{{ message.role }}\n"
@ -2437,6 +2443,7 @@ def chatml_function_calling(
model=model, model=model,
logits_processor=logits_processor, logits_processor=logits_processor,
grammar=grammar, grammar=grammar,
logprobs=top_logprobs if logprobs else None,
), ),
stream=stream, stream=stream,
) )
@ -2549,6 +2556,7 @@ def chatml_function_calling(
typical_p=typical_p, typical_p=typical_p,
stream=stream, stream=stream,
stop=["<|im_end|>"], stop=["<|im_end|>"],
logprobs=top_logprobs if logprobs else None,
max_tokens=None, max_tokens=None,
presence_penalty=presence_penalty, presence_penalty=presence_penalty,
frequency_penalty=frequency_penalty, frequency_penalty=frequency_penalty,
@ -2660,7 +2668,7 @@ def chatml_function_calling(
{ {
"finish_reason": "tool_calls", "finish_reason": "tool_calls",
"index": 0, "index": 0,
"logprobs": None, "logprobs": completion["choices"][0]["logprobs"],
"message": { "message": {
"role": "assistant", "role": "assistant",
"content": None, "content": None,