From bb65b4d76411112c6fb0bf759efd746f99ef3c6b Mon Sep 17 00:00:00 2001 From: Andrei Betlen Date: Wed, 10 Apr 2024 03:41:55 -0400 Subject: [PATCH] fix: pass correct type to chat handlers for chat completion logprobs --- llama_cpp/llama.py | 3 ++- llama_cpp/llama_chat_format.py | 24 ++++++++++++++++-------- 2 files changed, 18 insertions(+), 9 deletions(-) diff --git a/llama_cpp/llama.py b/llama_cpp/llama.py index e07d57a..466dc22 100644 --- a/llama_cpp/llama.py +++ b/llama_cpp/llama.py @@ -1664,7 +1664,8 @@ class Llama: top_k=top_k, min_p=min_p, typical_p=typical_p, - logprobs=top_logprobs if logprobs else None, + logprobs=logprobs, + top_logprobs=top_logprobs, stream=stream, stop=stop, seed=seed, diff --git a/llama_cpp/llama_chat_format.py b/llama_cpp/llama_chat_format.py index 705202e..519d2f5 100644 --- a/llama_cpp/llama_chat_format.py +++ b/llama_cpp/llama_chat_format.py @@ -77,6 +77,8 @@ class LlamaChatCompletionHandler(Protocol): mirostat_eta: float = 0.1, logits_processor: Optional[llama.LogitsProcessorList] = None, grammar: Optional[llama.LlamaGrammar] = None, + logprobs: Optional[bool] = None, + top_logprobs: Optional[int] = None, **kwargs, # type: ignore ) -> Union[ llama_types.CreateChatCompletionResponse, @@ -338,7 +340,7 @@ def _convert_completion_to_chat_function( } ], }, - "logprobs": None, + "logprobs": completion["choices"][0]["logprobs"], "finish_reason": "tool_calls", } ], @@ -391,7 +393,7 @@ def _convert_completion_to_chat_function( { "index": 0, "finish_reason": None, - "logprobs": None, + "logprobs": chunk["choices"][0]["logprobs"], "delta": { "role": None, "content": None, @@ -426,7 +428,7 @@ def _convert_completion_to_chat_function( { "index": 0, "finish_reason": None, - "logprobs": None, + "logprobs": chunk["choices"][0]["logprobs"], "delta": { "role": None, "content": None, @@ -491,7 +493,6 @@ def chat_formatter_to_chat_completion_handler( temperature: float = 0.2, top_p: float = 0.95, top_k: int = 40, - logprobs: int = 0, min_p: float = 0.05, typical_p: float = 1.0, stream: bool = False, @@ -512,6 +513,8 @@ def chat_formatter_to_chat_completion_handler( logits_processor: Optional[llama.LogitsProcessorList] = None, grammar: Optional[llama.LlamaGrammar] = None, logit_bias: Optional[Dict[str, float]] = None, + logprobs: Optional[bool] = None, + top_logprobs: Optional[int] = None, **kwargs, # type: ignore ) -> Union[ llama_types.CreateChatCompletionResponse, @@ -581,7 +584,7 @@ def chat_formatter_to_chat_completion_handler( top_k=top_k, min_p=min_p, typical_p=typical_p, - logprobs=logprobs, + logprobs=top_logprobs if logprobs else None, stream=stream, stop=stop, seed=seed, @@ -1628,7 +1631,7 @@ def functionary_chat_handler( } ], }, - "logprobs": None, + "logprobs": completion["choices"][0]["logprobs"], "finish_reason": "tool_calls", } ], @@ -2085,7 +2088,7 @@ def functionary_v1_v2_chat_handler( choices=[ { "index": 0, - "logprobs": None, + "logprobs": completion["choices"][0]["logprobs"], "message": { "role": "assistant", "content": None if content == "" else content, @@ -2311,11 +2314,14 @@ def chatml_function_calling( model: Optional[str] = None, logits_processor: Optional[llama.LogitsProcessorList] = None, grammar: Optional[llama.LlamaGrammar] = None, + logprobs: Optional[bool] = None, + top_logprobs: Optional[int] = None, **kwargs, # type: ignore ) -> Union[ llama_types.CreateChatCompletionResponse, Iterator[llama_types.CreateChatCompletionStreamResponse], ]: + print(logprobs) function_calling_template = ( "{% for message in messages %}" "<|im_start|>{{ message.role }}\n" @@ -2437,6 +2443,7 @@ def chatml_function_calling( model=model, logits_processor=logits_processor, grammar=grammar, + logprobs=top_logprobs if logprobs else None, ), stream=stream, ) @@ -2549,6 +2556,7 @@ def chatml_function_calling( typical_p=typical_p, stream=stream, stop=["<|im_end|>"], + logprobs=top_logprobs if logprobs else None, max_tokens=None, presence_penalty=presence_penalty, frequency_penalty=frequency_penalty, @@ -2660,7 +2668,7 @@ def chatml_function_calling( { "finish_reason": "tool_calls", "index": 0, - "logprobs": None, + "logprobs": completion["choices"][0]["logprobs"], "message": { "role": "assistant", "content": None,