Merge pull request #351 from player1537-forks/th/add-logits-bias-parameter

Add support for `logit_bias` and `logit_bias_type` parameters
This commit is contained in:
Andrei 2023-06-14 21:49:56 -04:00 committed by GitHub
commit f568baeef1
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
2 changed files with 53 additions and 2 deletions

View file

@ -1378,6 +1378,7 @@ class Llama:
mirostat_tau: float = 5.0,
mirostat_eta: float = 0.1,
model: Optional[str] = None,
logits_processor: Optional[LogitsProcessorList] = None,
) -> Union[ChatCompletion, Iterator[ChatCompletionChunk]]:
"""Generate a chat completion from a list of messages.
@ -1419,6 +1420,7 @@ class Llama:
mirostat_tau=mirostat_tau,
mirostat_eta=mirostat_eta,
model=model,
logits_processor=logits_processor,
)
if stream:
chunks: Iterator[CompletionChunk] = completion_or_chunks # type: ignore

View file

@ -255,13 +255,14 @@ class CreateCompletionRequest(BaseModel):
)
presence_penalty: Optional[float] = presence_penalty_field
frequency_penalty: Optional[float] = frequency_penalty_field
logit_bias: Optional[Dict[str, float]] = Field(None)
logit_bias_type: Optional[Literal["input_ids", "tokens"]] = Field(None)
# ignored or currently unsupported
model: Optional[str] = model_field
n: Optional[int] = 1
logprobs: Optional[int] = Field(None)
best_of: Optional[int] = 1
logit_bias: Optional[Dict[str, float]] = Field(None)
user: Optional[str] = Field(None)
# llama.cpp specific parameters
@ -280,6 +281,39 @@ class CreateCompletionRequest(BaseModel):
CreateCompletionResponse = create_model_from_typeddict(llama_cpp.Completion)
def make_logit_bias_processor(
llama: llama_cpp.Llama,
logit_bias: Dict[str, float],
logit_bias_type: Optional[Literal["input_ids", "tokens"]],
):
if logit_bias_type is None:
logit_bias_type = "input_ids"
to_bias: Dict[int, float] = {}
if logit_bias_type == "input_ids":
for input_id, score in logit_bias.items():
input_id = int(input_id)
to_bias[input_id] = score
elif logit_bias_type == "tokens":
for token, score in logit_bias.items():
token = token.encode('utf-8')
for input_id in llama.tokenize(token, add_bos=False):
to_bias[input_id] = score
def logit_bias_processor(
input_ids: List[int],
scores: List[float],
) -> List[float]:
new_scores = [None] * len(scores)
for input_id, score in enumerate(scores):
new_scores[input_id] = score + to_bias.get(input_id, 0.0)
return new_scores
return logit_bias_processor
@router.post(
"/v1/completions",
response_model=CreateCompletionResponse,
@ -297,9 +331,16 @@ async def create_completion(
"n",
"best_of",
"logit_bias",
"logit_bias_type",
"user",
}
kwargs = body.dict(exclude=exclude)
if body.logit_bias is not None:
kwargs['logits_processor'] = llama_cpp.LogitsProcessorList([
make_logit_bias_processor(llama, body.logit_bias, body.logit_bias_type),
])
if body.stream:
send_chan, recv_chan = anyio.create_memory_object_stream(10)
@ -378,11 +419,12 @@ class CreateChatCompletionRequest(BaseModel):
stream: bool = stream_field
presence_penalty: Optional[float] = presence_penalty_field
frequency_penalty: Optional[float] = frequency_penalty_field
logit_bias: Optional[Dict[str, float]] = Field(None)
logit_bias_type: Optional[Literal["input_ids", "tokens"]] = Field(None)
# ignored or currently unsupported
model: Optional[str] = model_field
n: Optional[int] = 1
logit_bias: Optional[Dict[str, float]] = Field(None)
user: Optional[str] = Field(None)
# llama.cpp specific parameters
@ -419,9 +461,16 @@ async def create_chat_completion(
exclude = {
"n",
"logit_bias",
"logit_bias_type",
"user",
}
kwargs = body.dict(exclude=exclude)
if body.logit_bias is not None:
kwargs['logits_processor'] = llama_cpp.LogitsProcessorList([
make_logit_bias_processor(llama, body.logit_bias, body.logit_bias_type),
])
if body.stream:
send_chan, recv_chan = anyio.create_memory_object_stream(10)