feat: Add logprobs support to chat completions (#1311)
* Add logprobs return in ChatCompletionResponse * Fix duplicate field * Set default to false * Simplify check * Add server example --------- Co-authored-by: Andrei Betlen <abetlen@gmail.com>
This commit is contained in:
parent
1e60dba082
commit
aa9f1ae011
5 changed files with 28 additions and 1 deletions
|
@ -1653,6 +1653,7 @@ class Llama:
|
|||
top_k=top_k,
|
||||
min_p=min_p,
|
||||
typical_p=typical_p,
|
||||
logprobs=top_logprobs if logprobs else None,
|
||||
stream=stream,
|
||||
stop=stop,
|
||||
seed=seed,
|
||||
|
|
|
@ -231,6 +231,7 @@ def _convert_text_completion_to_chat(
|
|||
"role": "assistant",
|
||||
"content": completion["choices"][0]["text"],
|
||||
},
|
||||
"logprobs": completion["choices"][0]["logprobs"],
|
||||
"finish_reason": completion["choices"][0]["finish_reason"],
|
||||
}
|
||||
],
|
||||
|
@ -254,6 +255,7 @@ def _convert_text_completion_chunks_to_chat(
|
|||
"delta": {
|
||||
"role": "assistant",
|
||||
},
|
||||
"logprobs": None,
|
||||
"finish_reason": None,
|
||||
}
|
||||
],
|
||||
|
@ -273,6 +275,7 @@ def _convert_text_completion_chunks_to_chat(
|
|||
if chunk["choices"][0]["finish_reason"] is None
|
||||
else {}
|
||||
),
|
||||
"logprobs": chunk["choices"][0]["logprobs"],
|
||||
"finish_reason": chunk["choices"][0]["finish_reason"],
|
||||
}
|
||||
],
|
||||
|
@ -487,6 +490,7 @@ def chat_formatter_to_chat_completion_handler(
|
|||
temperature: float = 0.2,
|
||||
top_p: float = 0.95,
|
||||
top_k: int = 40,
|
||||
logprobs: int = 0,
|
||||
min_p: float = 0.05,
|
||||
typical_p: float = 1.0,
|
||||
stream: bool = False,
|
||||
|
@ -576,6 +580,7 @@ def chat_formatter_to_chat_completion_handler(
|
|||
top_k=top_k,
|
||||
min_p=min_p,
|
||||
typical_p=typical_p,
|
||||
logprobs=logprobs,
|
||||
stream=stream,
|
||||
stop=stop,
|
||||
seed=seed,
|
||||
|
|
|
@ -84,6 +84,7 @@ class ChatCompletionFunction(TypedDict):
|
|||
class ChatCompletionResponseChoice(TypedDict):
|
||||
index: int
|
||||
message: "ChatCompletionResponseMessage"
|
||||
logprobs: Optional[CompletionLogprobs]
|
||||
finish_reason: Optional[str]
|
||||
|
||||
|
||||
|
|
|
@ -405,6 +405,18 @@ async def create_chat_completion(
|
|||
}
|
||||
},
|
||||
},
|
||||
"logprobs": {
|
||||
"summary": "Logprobs",
|
||||
"value": {
|
||||
"model": "gpt-3.5-turbo",
|
||||
"messages": [
|
||||
{"role": "system", "content": "You are a helpful assistant."},
|
||||
{"role": "user", "content": "What is the capital of France?"},
|
||||
],
|
||||
"logprobs": True,
|
||||
"top_logprobs": 10
|
||||
},
|
||||
},
|
||||
}
|
||||
),
|
||||
llama_proxy: LlamaProxy = Depends(get_llama_proxy),
|
||||
|
|
|
@ -130,7 +130,6 @@ class CreateCompletionRequest(BaseModel):
|
|||
presence_penalty: Optional[float] = presence_penalty_field
|
||||
frequency_penalty: Optional[float] = frequency_penalty_field
|
||||
logit_bias: Optional[Dict[str, float]] = Field(None)
|
||||
logprobs: Optional[int] = Field(None)
|
||||
seed: Optional[int] = Field(None)
|
||||
|
||||
# ignored or currently unsupported
|
||||
|
@ -209,6 +208,15 @@ class CreateChatCompletionRequest(BaseModel):
|
|||
default=None,
|
||||
description="The maximum number of tokens to generate. Defaults to inf",
|
||||
)
|
||||
logprobs: Optional[bool] = Field(
|
||||
default=False,
|
||||
description="Whether to output the logprobs or not. Default is True"
|
||||
)
|
||||
top_logprobs: Optional[int] = Field(
|
||||
default=None,
|
||||
ge=0,
|
||||
description="The number of logprobs to generate. If None, no logprobs are generated. logprobs need to set to True.",
|
||||
)
|
||||
temperature: float = temperature_field
|
||||
top_p: float = top_p_field
|
||||
min_p: float = min_p_field
|
||||
|
|
Loading…
Add table
Reference in a new issue