From aa9f1ae011fbc22893750209af500fee3167f21c Mon Sep 17 00:00:00 2001 From: windspirit95 <32880949+windspirit95@users.noreply.github.com> Date: Mon, 1 Apr 2024 02:30:13 +0900 Subject: [PATCH] feat: Add logprobs support to chat completions (#1311) * Add logprobs return in ChatCompletionResponse * Fix duplicate field * Set default to false * Simplify check * Add server example --------- Co-authored-by: Andrei Betlen --- llama_cpp/llama.py | 1 + llama_cpp/llama_chat_format.py | 5 +++++ llama_cpp/llama_types.py | 1 + llama_cpp/server/app.py | 12 ++++++++++++ llama_cpp/server/types.py | 10 +++++++++- 5 files changed, 28 insertions(+), 1 deletion(-) diff --git a/llama_cpp/llama.py b/llama_cpp/llama.py index fed84d5..66caaa9 100644 --- a/llama_cpp/llama.py +++ b/llama_cpp/llama.py @@ -1653,6 +1653,7 @@ class Llama: top_k=top_k, min_p=min_p, typical_p=typical_p, + logprobs=top_logprobs if logprobs else None, stream=stream, stop=stop, seed=seed, diff --git a/llama_cpp/llama_chat_format.py b/llama_cpp/llama_chat_format.py index ccf4fd0..06cf9ce 100644 --- a/llama_cpp/llama_chat_format.py +++ b/llama_cpp/llama_chat_format.py @@ -231,6 +231,7 @@ def _convert_text_completion_to_chat( "role": "assistant", "content": completion["choices"][0]["text"], }, + "logprobs": completion["choices"][0]["logprobs"], "finish_reason": completion["choices"][0]["finish_reason"], } ], @@ -254,6 +255,7 @@ def _convert_text_completion_chunks_to_chat( "delta": { "role": "assistant", }, + "logprobs": None, "finish_reason": None, } ], @@ -273,6 +275,7 @@ def _convert_text_completion_chunks_to_chat( if chunk["choices"][0]["finish_reason"] is None else {} ), + "logprobs": chunk["choices"][0]["logprobs"], "finish_reason": chunk["choices"][0]["finish_reason"], } ], @@ -487,6 +490,7 @@ def chat_formatter_to_chat_completion_handler( temperature: float = 0.2, top_p: float = 0.95, top_k: int = 40, + logprobs: int = 0, min_p: float = 0.05, typical_p: float = 1.0, stream: bool = False, @@ -576,6 +580,7 @@ def chat_formatter_to_chat_completion_handler( top_k=top_k, min_p=min_p, typical_p=typical_p, + logprobs=logprobs, stream=stream, stop=stop, seed=seed, diff --git a/llama_cpp/llama_types.py b/llama_cpp/llama_types.py index 1b1befe..87e000f 100644 --- a/llama_cpp/llama_types.py +++ b/llama_cpp/llama_types.py @@ -84,6 +84,7 @@ class ChatCompletionFunction(TypedDict): class ChatCompletionResponseChoice(TypedDict): index: int message: "ChatCompletionResponseMessage" + logprobs: Optional[CompletionLogprobs] finish_reason: Optional[str] diff --git a/llama_cpp/server/app.py b/llama_cpp/server/app.py index d5e2f7e..815ed3c 100644 --- a/llama_cpp/server/app.py +++ b/llama_cpp/server/app.py @@ -405,6 +405,18 @@ async def create_chat_completion( } }, }, + "logprobs": { + "summary": "Logprobs", + "value": { + "model": "gpt-3.5-turbo", + "messages": [ + {"role": "system", "content": "You are a helpful assistant."}, + {"role": "user", "content": "What is the capital of France?"}, + ], + "logprobs": True, + "top_logprobs": 10 + }, + }, } ), llama_proxy: LlamaProxy = Depends(get_llama_proxy), diff --git a/llama_cpp/server/types.py b/llama_cpp/server/types.py index 378f8d7..ce9c87a 100644 --- a/llama_cpp/server/types.py +++ b/llama_cpp/server/types.py @@ -130,7 +130,6 @@ class CreateCompletionRequest(BaseModel): presence_penalty: Optional[float] = presence_penalty_field frequency_penalty: Optional[float] = frequency_penalty_field logit_bias: Optional[Dict[str, float]] = Field(None) - logprobs: Optional[int] = Field(None) seed: Optional[int] = Field(None) # ignored or currently unsupported @@ -209,6 +208,15 @@ class CreateChatCompletionRequest(BaseModel): default=None, description="The maximum number of tokens to generate. Defaults to inf", ) + logprobs: Optional[bool] = Field( + default=False, + description="Whether to output the logprobs or not. Default is True" + ) + top_logprobs: Optional[int] = Field( + default=None, + ge=0, + description="The number of logprobs to generate. If None, no logprobs are generated. logprobs need to set to True.", + ) temperature: float = temperature_field top_p: float = top_p_field min_p: float = min_p_field