Add support for logit_bias and logit_bias_type parameters

This commit is contained in:
Tanner Hobson 2023-06-09 13:13:08 -04:00
parent 0da655b3be
commit eb7645b3ba
2 changed files with 53 additions and 2 deletions

View file

@ -1380,6 +1380,7 @@ class Llama:
mirostat_tau: float = 5.0, mirostat_tau: float = 5.0,
mirostat_eta: float = 0.1, mirostat_eta: float = 0.1,
model: Optional[str] = None, model: Optional[str] = None,
logits_processor: Optional[LogitsProcessorList] = None,
) -> Union[ChatCompletion, Iterator[ChatCompletionChunk]]: ) -> Union[ChatCompletion, Iterator[ChatCompletionChunk]]:
"""Generate a chat completion from a list of messages. """Generate a chat completion from a list of messages.
@ -1421,6 +1422,7 @@ class Llama:
mirostat_tau=mirostat_tau, mirostat_tau=mirostat_tau,
mirostat_eta=mirostat_eta, mirostat_eta=mirostat_eta,
model=model, model=model,
logits_processor=logits_processor,
) )
if stream: if stream:
chunks: Iterator[CompletionChunk] = completion_or_chunks # type: ignore chunks: Iterator[CompletionChunk] = completion_or_chunks # type: ignore

View file

@ -249,13 +249,14 @@ class CreateCompletionRequest(BaseModel):
) )
presence_penalty: Optional[float] = presence_penalty_field presence_penalty: Optional[float] = presence_penalty_field
frequency_penalty: Optional[float] = frequency_penalty_field frequency_penalty: Optional[float] = frequency_penalty_field
logit_bias: Optional[Dict[str, float]] = Field(None)
logit_bias_type: Optional[Literal["input_ids", "tokens"]] = Field(None)
# ignored or currently unsupported # ignored or currently unsupported
model: Optional[str] = model_field model: Optional[str] = model_field
n: Optional[int] = 1 n: Optional[int] = 1
logprobs: Optional[int] = Field(None) logprobs: Optional[int] = Field(None)
best_of: Optional[int] = 1 best_of: Optional[int] = 1
logit_bias: Optional[Dict[str, float]] = Field(None)
user: Optional[str] = Field(None) user: Optional[str] = Field(None)
# llama.cpp specific parameters # llama.cpp specific parameters
@ -274,6 +275,39 @@ class CreateCompletionRequest(BaseModel):
CreateCompletionResponse = create_model_from_typeddict(llama_cpp.Completion) CreateCompletionResponse = create_model_from_typeddict(llama_cpp.Completion)
def make_logit_bias_processor(
llama: llama_cpp.Llama,
logit_bias: Dict[str, float],
logit_bias_type: Optional[Literal["input_ids", "tokens"]],
):
if logit_bias_type is None:
logit_bias_type = "input_ids"
to_bias: Dict[int, float] = {}
if logit_bias_type == "input_ids":
for input_id, score in logit_bias.items():
input_id = int(input_id)
to_bias[input_id] = score
elif logit_bias_type == "tokens":
for token, score in logit_bias.items():
token = token.encode('utf-8')
for input_id in llama.tokenize(token, add_bos=False):
to_bias[input_id] = score
def logit_bias_processor(
input_ids: List[int],
scores: List[float],
) -> List[float]:
new_scores = [None] * len(scores)
for input_id, score in enumerate(scores):
new_scores[input_id] = score + to_bias.get(input_id, 0.0)
return new_scores
return logit_bias_processor
@router.post( @router.post(
"/v1/completions", "/v1/completions",
response_model=CreateCompletionResponse, response_model=CreateCompletionResponse,
@ -291,9 +325,16 @@ async def create_completion(
"n", "n",
"best_of", "best_of",
"logit_bias", "logit_bias",
"logit_bias_type",
"user", "user",
} }
kwargs = body.dict(exclude=exclude) kwargs = body.dict(exclude=exclude)
if body.logit_bias is not None:
kwargs['logits_processor'] = llama_cpp.LogitsProcessorList([
make_logit_bias_processor(llama, body.logit_bias, body.logit_bias_type),
])
if body.stream: if body.stream:
send_chan, recv_chan = anyio.create_memory_object_stream(10) send_chan, recv_chan = anyio.create_memory_object_stream(10)
@ -372,11 +413,12 @@ class CreateChatCompletionRequest(BaseModel):
stream: bool = stream_field stream: bool = stream_field
presence_penalty: Optional[float] = presence_penalty_field presence_penalty: Optional[float] = presence_penalty_field
frequency_penalty: Optional[float] = frequency_penalty_field frequency_penalty: Optional[float] = frequency_penalty_field
logit_bias: Optional[Dict[str, float]] = Field(None)
logit_bias_type: Optional[Literal["input_ids", "tokens"]] = Field(None)
# ignored or currently unsupported # ignored or currently unsupported
model: Optional[str] = model_field model: Optional[str] = model_field
n: Optional[int] = 1 n: Optional[int] = 1
logit_bias: Optional[Dict[str, float]] = Field(None)
user: Optional[str] = Field(None) user: Optional[str] = Field(None)
# llama.cpp specific parameters # llama.cpp specific parameters
@ -413,9 +455,16 @@ async def create_chat_completion(
exclude = { exclude = {
"n", "n",
"logit_bias", "logit_bias",
"logit_bias_type",
"user", "user",
} }
kwargs = body.dict(exclude=exclude) kwargs = body.dict(exclude=exclude)
if body.logit_bias is not None:
kwargs['logits_processor'] = llama_cpp.LogitsProcessorList([
make_logit_bias_processor(llama, body.logit_bias, body.logit_bias_type),
])
if body.stream: if body.stream:
send_chan, recv_chan = anyio.create_memory_object_stream(10) send_chan, recv_chan = anyio.create_memory_object_stream(10)