Merge pull request #351 from player1537-forks/th/add-logits-bias-parameter
Add support for `logit_bias` and `logit_bias_type` parameters
This commit is contained in:
commit
f568baeef1
2 changed files with 53 additions and 2 deletions
|
@ -1378,6 +1378,7 @@ class Llama:
|
|||
mirostat_tau: float = 5.0,
|
||||
mirostat_eta: float = 0.1,
|
||||
model: Optional[str] = None,
|
||||
logits_processor: Optional[LogitsProcessorList] = None,
|
||||
) -> Union[ChatCompletion, Iterator[ChatCompletionChunk]]:
|
||||
"""Generate a chat completion from a list of messages.
|
||||
|
||||
|
@ -1419,6 +1420,7 @@ class Llama:
|
|||
mirostat_tau=mirostat_tau,
|
||||
mirostat_eta=mirostat_eta,
|
||||
model=model,
|
||||
logits_processor=logits_processor,
|
||||
)
|
||||
if stream:
|
||||
chunks: Iterator[CompletionChunk] = completion_or_chunks # type: ignore
|
||||
|
|
|
@ -255,13 +255,14 @@ class CreateCompletionRequest(BaseModel):
|
|||
)
|
||||
presence_penalty: Optional[float] = presence_penalty_field
|
||||
frequency_penalty: Optional[float] = frequency_penalty_field
|
||||
logit_bias: Optional[Dict[str, float]] = Field(None)
|
||||
logit_bias_type: Optional[Literal["input_ids", "tokens"]] = Field(None)
|
||||
|
||||
# ignored or currently unsupported
|
||||
model: Optional[str] = model_field
|
||||
n: Optional[int] = 1
|
||||
logprobs: Optional[int] = Field(None)
|
||||
best_of: Optional[int] = 1
|
||||
logit_bias: Optional[Dict[str, float]] = Field(None)
|
||||
user: Optional[str] = Field(None)
|
||||
|
||||
# llama.cpp specific parameters
|
||||
|
@ -280,6 +281,39 @@ class CreateCompletionRequest(BaseModel):
|
|||
CreateCompletionResponse = create_model_from_typeddict(llama_cpp.Completion)
|
||||
|
||||
|
||||
def make_logit_bias_processor(
|
||||
llama: llama_cpp.Llama,
|
||||
logit_bias: Dict[str, float],
|
||||
logit_bias_type: Optional[Literal["input_ids", "tokens"]],
|
||||
):
|
||||
if logit_bias_type is None:
|
||||
logit_bias_type = "input_ids"
|
||||
|
||||
to_bias: Dict[int, float] = {}
|
||||
if logit_bias_type == "input_ids":
|
||||
for input_id, score in logit_bias.items():
|
||||
input_id = int(input_id)
|
||||
to_bias[input_id] = score
|
||||
|
||||
elif logit_bias_type == "tokens":
|
||||
for token, score in logit_bias.items():
|
||||
token = token.encode('utf-8')
|
||||
for input_id in llama.tokenize(token, add_bos=False):
|
||||
to_bias[input_id] = score
|
||||
|
||||
def logit_bias_processor(
|
||||
input_ids: List[int],
|
||||
scores: List[float],
|
||||
) -> List[float]:
|
||||
new_scores = [None] * len(scores)
|
||||
for input_id, score in enumerate(scores):
|
||||
new_scores[input_id] = score + to_bias.get(input_id, 0.0)
|
||||
|
||||
return new_scores
|
||||
|
||||
return logit_bias_processor
|
||||
|
||||
|
||||
@router.post(
|
||||
"/v1/completions",
|
||||
response_model=CreateCompletionResponse,
|
||||
|
@ -297,9 +331,16 @@ async def create_completion(
|
|||
"n",
|
||||
"best_of",
|
||||
"logit_bias",
|
||||
"logit_bias_type",
|
||||
"user",
|
||||
}
|
||||
kwargs = body.dict(exclude=exclude)
|
||||
|
||||
if body.logit_bias is not None:
|
||||
kwargs['logits_processor'] = llama_cpp.LogitsProcessorList([
|
||||
make_logit_bias_processor(llama, body.logit_bias, body.logit_bias_type),
|
||||
])
|
||||
|
||||
if body.stream:
|
||||
send_chan, recv_chan = anyio.create_memory_object_stream(10)
|
||||
|
||||
|
@ -378,11 +419,12 @@ class CreateChatCompletionRequest(BaseModel):
|
|||
stream: bool = stream_field
|
||||
presence_penalty: Optional[float] = presence_penalty_field
|
||||
frequency_penalty: Optional[float] = frequency_penalty_field
|
||||
logit_bias: Optional[Dict[str, float]] = Field(None)
|
||||
logit_bias_type: Optional[Literal["input_ids", "tokens"]] = Field(None)
|
||||
|
||||
# ignored or currently unsupported
|
||||
model: Optional[str] = model_field
|
||||
n: Optional[int] = 1
|
||||
logit_bias: Optional[Dict[str, float]] = Field(None)
|
||||
user: Optional[str] = Field(None)
|
||||
|
||||
# llama.cpp specific parameters
|
||||
|
@ -419,9 +461,16 @@ async def create_chat_completion(
|
|||
exclude = {
|
||||
"n",
|
||||
"logit_bias",
|
||||
"logit_bias_type",
|
||||
"user",
|
||||
}
|
||||
kwargs = body.dict(exclude=exclude)
|
||||
|
||||
if body.logit_bias is not None:
|
||||
kwargs['logits_processor'] = llama_cpp.LogitsProcessorList([
|
||||
make_logit_bias_processor(llama, body.logit_bias, body.logit_bias_type),
|
||||
])
|
||||
|
||||
if body.stream:
|
||||
send_chan, recv_chan = anyio.create_memory_object_stream(10)
|
||||
|
||||
|
|
Loading…
Reference in a new issue