Merge pull request #351 from player1537-forks/th/add-logits-bias-parameter
Add support for `logit_bias` and `logit_bias_type` parameters
This commit is contained in:
commit
f568baeef1
2 changed files with 53 additions and 2 deletions
|
@ -1378,6 +1378,7 @@ class Llama:
|
||||||
mirostat_tau: float = 5.0,
|
mirostat_tau: float = 5.0,
|
||||||
mirostat_eta: float = 0.1,
|
mirostat_eta: float = 0.1,
|
||||||
model: Optional[str] = None,
|
model: Optional[str] = None,
|
||||||
|
logits_processor: Optional[LogitsProcessorList] = None,
|
||||||
) -> Union[ChatCompletion, Iterator[ChatCompletionChunk]]:
|
) -> Union[ChatCompletion, Iterator[ChatCompletionChunk]]:
|
||||||
"""Generate a chat completion from a list of messages.
|
"""Generate a chat completion from a list of messages.
|
||||||
|
|
||||||
|
@ -1419,6 +1420,7 @@ class Llama:
|
||||||
mirostat_tau=mirostat_tau,
|
mirostat_tau=mirostat_tau,
|
||||||
mirostat_eta=mirostat_eta,
|
mirostat_eta=mirostat_eta,
|
||||||
model=model,
|
model=model,
|
||||||
|
logits_processor=logits_processor,
|
||||||
)
|
)
|
||||||
if stream:
|
if stream:
|
||||||
chunks: Iterator[CompletionChunk] = completion_or_chunks # type: ignore
|
chunks: Iterator[CompletionChunk] = completion_or_chunks # type: ignore
|
||||||
|
|
|
@ -255,13 +255,14 @@ class CreateCompletionRequest(BaseModel):
|
||||||
)
|
)
|
||||||
presence_penalty: Optional[float] = presence_penalty_field
|
presence_penalty: Optional[float] = presence_penalty_field
|
||||||
frequency_penalty: Optional[float] = frequency_penalty_field
|
frequency_penalty: Optional[float] = frequency_penalty_field
|
||||||
|
logit_bias: Optional[Dict[str, float]] = Field(None)
|
||||||
|
logit_bias_type: Optional[Literal["input_ids", "tokens"]] = Field(None)
|
||||||
|
|
||||||
# ignored or currently unsupported
|
# ignored or currently unsupported
|
||||||
model: Optional[str] = model_field
|
model: Optional[str] = model_field
|
||||||
n: Optional[int] = 1
|
n: Optional[int] = 1
|
||||||
logprobs: Optional[int] = Field(None)
|
logprobs: Optional[int] = Field(None)
|
||||||
best_of: Optional[int] = 1
|
best_of: Optional[int] = 1
|
||||||
logit_bias: Optional[Dict[str, float]] = Field(None)
|
|
||||||
user: Optional[str] = Field(None)
|
user: Optional[str] = Field(None)
|
||||||
|
|
||||||
# llama.cpp specific parameters
|
# llama.cpp specific parameters
|
||||||
|
@ -280,6 +281,39 @@ class CreateCompletionRequest(BaseModel):
|
||||||
CreateCompletionResponse = create_model_from_typeddict(llama_cpp.Completion)
|
CreateCompletionResponse = create_model_from_typeddict(llama_cpp.Completion)
|
||||||
|
|
||||||
|
|
||||||
|
def make_logit_bias_processor(
|
||||||
|
llama: llama_cpp.Llama,
|
||||||
|
logit_bias: Dict[str, float],
|
||||||
|
logit_bias_type: Optional[Literal["input_ids", "tokens"]],
|
||||||
|
):
|
||||||
|
if logit_bias_type is None:
|
||||||
|
logit_bias_type = "input_ids"
|
||||||
|
|
||||||
|
to_bias: Dict[int, float] = {}
|
||||||
|
if logit_bias_type == "input_ids":
|
||||||
|
for input_id, score in logit_bias.items():
|
||||||
|
input_id = int(input_id)
|
||||||
|
to_bias[input_id] = score
|
||||||
|
|
||||||
|
elif logit_bias_type == "tokens":
|
||||||
|
for token, score in logit_bias.items():
|
||||||
|
token = token.encode('utf-8')
|
||||||
|
for input_id in llama.tokenize(token, add_bos=False):
|
||||||
|
to_bias[input_id] = score
|
||||||
|
|
||||||
|
def logit_bias_processor(
|
||||||
|
input_ids: List[int],
|
||||||
|
scores: List[float],
|
||||||
|
) -> List[float]:
|
||||||
|
new_scores = [None] * len(scores)
|
||||||
|
for input_id, score in enumerate(scores):
|
||||||
|
new_scores[input_id] = score + to_bias.get(input_id, 0.0)
|
||||||
|
|
||||||
|
return new_scores
|
||||||
|
|
||||||
|
return logit_bias_processor
|
||||||
|
|
||||||
|
|
||||||
@router.post(
|
@router.post(
|
||||||
"/v1/completions",
|
"/v1/completions",
|
||||||
response_model=CreateCompletionResponse,
|
response_model=CreateCompletionResponse,
|
||||||
|
@ -297,9 +331,16 @@ async def create_completion(
|
||||||
"n",
|
"n",
|
||||||
"best_of",
|
"best_of",
|
||||||
"logit_bias",
|
"logit_bias",
|
||||||
|
"logit_bias_type",
|
||||||
"user",
|
"user",
|
||||||
}
|
}
|
||||||
kwargs = body.dict(exclude=exclude)
|
kwargs = body.dict(exclude=exclude)
|
||||||
|
|
||||||
|
if body.logit_bias is not None:
|
||||||
|
kwargs['logits_processor'] = llama_cpp.LogitsProcessorList([
|
||||||
|
make_logit_bias_processor(llama, body.logit_bias, body.logit_bias_type),
|
||||||
|
])
|
||||||
|
|
||||||
if body.stream:
|
if body.stream:
|
||||||
send_chan, recv_chan = anyio.create_memory_object_stream(10)
|
send_chan, recv_chan = anyio.create_memory_object_stream(10)
|
||||||
|
|
||||||
|
@ -378,11 +419,12 @@ class CreateChatCompletionRequest(BaseModel):
|
||||||
stream: bool = stream_field
|
stream: bool = stream_field
|
||||||
presence_penalty: Optional[float] = presence_penalty_field
|
presence_penalty: Optional[float] = presence_penalty_field
|
||||||
frequency_penalty: Optional[float] = frequency_penalty_field
|
frequency_penalty: Optional[float] = frequency_penalty_field
|
||||||
|
logit_bias: Optional[Dict[str, float]] = Field(None)
|
||||||
|
logit_bias_type: Optional[Literal["input_ids", "tokens"]] = Field(None)
|
||||||
|
|
||||||
# ignored or currently unsupported
|
# ignored or currently unsupported
|
||||||
model: Optional[str] = model_field
|
model: Optional[str] = model_field
|
||||||
n: Optional[int] = 1
|
n: Optional[int] = 1
|
||||||
logit_bias: Optional[Dict[str, float]] = Field(None)
|
|
||||||
user: Optional[str] = Field(None)
|
user: Optional[str] = Field(None)
|
||||||
|
|
||||||
# llama.cpp specific parameters
|
# llama.cpp specific parameters
|
||||||
|
@ -419,9 +461,16 @@ async def create_chat_completion(
|
||||||
exclude = {
|
exclude = {
|
||||||
"n",
|
"n",
|
||||||
"logit_bias",
|
"logit_bias",
|
||||||
|
"logit_bias_type",
|
||||||
"user",
|
"user",
|
||||||
}
|
}
|
||||||
kwargs = body.dict(exclude=exclude)
|
kwargs = body.dict(exclude=exclude)
|
||||||
|
|
||||||
|
if body.logit_bias is not None:
|
||||||
|
kwargs['logits_processor'] = llama_cpp.LogitsProcessorList([
|
||||||
|
make_logit_bias_processor(llama, body.logit_bias, body.logit_bias_type),
|
||||||
|
])
|
||||||
|
|
||||||
if body.stream:
|
if body.stream:
|
||||||
send_chan, recv_chan = anyio.create_memory_object_stream(10)
|
send_chan, recv_chan = anyio.create_memory_object_stream(10)
|
||||||
|
|
||||||
|
|
Loading…
Add table
Reference in a new issue