feat: Add logprobs support to chat completions (#1311)

* Add logprobs return in ChatCompletionResponse

* Fix duplicate field

* Set default to false

* Simplify check

* Add server example

---------

Co-authored-by: Andrei Betlen <abetlen@gmail.com>
This commit is contained in:
windspirit95 2024-04-01 02:30:13 +09:00 committed by GitHub
parent 1e60dba082
commit aa9f1ae011
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
5 changed files with 28 additions and 1 deletions

View file

@ -1653,6 +1653,7 @@ class Llama:
top_k=top_k, top_k=top_k,
min_p=min_p, min_p=min_p,
typical_p=typical_p, typical_p=typical_p,
logprobs=top_logprobs if logprobs else None,
stream=stream, stream=stream,
stop=stop, stop=stop,
seed=seed, seed=seed,

View file

@ -231,6 +231,7 @@ def _convert_text_completion_to_chat(
"role": "assistant", "role": "assistant",
"content": completion["choices"][0]["text"], "content": completion["choices"][0]["text"],
}, },
"logprobs": completion["choices"][0]["logprobs"],
"finish_reason": completion["choices"][0]["finish_reason"], "finish_reason": completion["choices"][0]["finish_reason"],
} }
], ],
@ -254,6 +255,7 @@ def _convert_text_completion_chunks_to_chat(
"delta": { "delta": {
"role": "assistant", "role": "assistant",
}, },
"logprobs": None,
"finish_reason": None, "finish_reason": None,
} }
], ],
@ -273,6 +275,7 @@ def _convert_text_completion_chunks_to_chat(
if chunk["choices"][0]["finish_reason"] is None if chunk["choices"][0]["finish_reason"] is None
else {} else {}
), ),
"logprobs": chunk["choices"][0]["logprobs"],
"finish_reason": chunk["choices"][0]["finish_reason"], "finish_reason": chunk["choices"][0]["finish_reason"],
} }
], ],
@ -487,6 +490,7 @@ def chat_formatter_to_chat_completion_handler(
temperature: float = 0.2, temperature: float = 0.2,
top_p: float = 0.95, top_p: float = 0.95,
top_k: int = 40, top_k: int = 40,
logprobs: int = 0,
min_p: float = 0.05, min_p: float = 0.05,
typical_p: float = 1.0, typical_p: float = 1.0,
stream: bool = False, stream: bool = False,
@ -576,6 +580,7 @@ def chat_formatter_to_chat_completion_handler(
top_k=top_k, top_k=top_k,
min_p=min_p, min_p=min_p,
typical_p=typical_p, typical_p=typical_p,
logprobs=logprobs,
stream=stream, stream=stream,
stop=stop, stop=stop,
seed=seed, seed=seed,

View file

@ -84,6 +84,7 @@ class ChatCompletionFunction(TypedDict):
class ChatCompletionResponseChoice(TypedDict): class ChatCompletionResponseChoice(TypedDict):
index: int index: int
message: "ChatCompletionResponseMessage" message: "ChatCompletionResponseMessage"
logprobs: Optional[CompletionLogprobs]
finish_reason: Optional[str] finish_reason: Optional[str]

View file

@ -405,6 +405,18 @@ async def create_chat_completion(
} }
}, },
}, },
"logprobs": {
"summary": "Logprobs",
"value": {
"model": "gpt-3.5-turbo",
"messages": [
{"role": "system", "content": "You are a helpful assistant."},
{"role": "user", "content": "What is the capital of France?"},
],
"logprobs": True,
"top_logprobs": 10
},
},
} }
), ),
llama_proxy: LlamaProxy = Depends(get_llama_proxy), llama_proxy: LlamaProxy = Depends(get_llama_proxy),

View file

@ -130,7 +130,6 @@ class CreateCompletionRequest(BaseModel):
presence_penalty: Optional[float] = presence_penalty_field presence_penalty: Optional[float] = presence_penalty_field
frequency_penalty: Optional[float] = frequency_penalty_field frequency_penalty: Optional[float] = frequency_penalty_field
logit_bias: Optional[Dict[str, float]] = Field(None) logit_bias: Optional[Dict[str, float]] = Field(None)
logprobs: Optional[int] = Field(None)
seed: Optional[int] = Field(None) seed: Optional[int] = Field(None)
# ignored or currently unsupported # ignored or currently unsupported
@ -209,6 +208,15 @@ class CreateChatCompletionRequest(BaseModel):
default=None, default=None,
description="The maximum number of tokens to generate. Defaults to inf", description="The maximum number of tokens to generate. Defaults to inf",
) )
logprobs: Optional[bool] = Field(
default=False,
description="Whether to output the logprobs or not. Default is True"
)
top_logprobs: Optional[int] = Field(
default=None,
ge=0,
description="The number of logprobs to generate. If None, no logprobs are generated. logprobs need to set to True.",
)
temperature: float = temperature_field temperature: float = temperature_field
top_p: float = top_p_field top_p: float = top_p_field
min_p: float = min_p_field min_p: float = min_p_field