Add support for logit_bias outside of server api. Closes #827
This commit is contained in:
parent
c21edb6908
commit
07e47f55ba
3 changed files with 44 additions and 38 deletions
|
@ -1327,6 +1327,7 @@ class Llama:
|
|||
stopping_criteria: Optional[StoppingCriteriaList] = None,
|
||||
logits_processor: Optional[LogitsProcessorList] = None,
|
||||
grammar: Optional[LlamaGrammar] = None,
|
||||
logit_bias: Optional[Dict[int, float]] = None,
|
||||
) -> Union[
|
||||
Iterator[CreateCompletionResponse], Iterator[CreateCompletionStreamResponse]
|
||||
]:
|
||||
|
@ -1355,6 +1356,28 @@ class Llama:
|
|||
)
|
||||
model_name: str = model if model is not None else self.model_path
|
||||
|
||||
# NOTE: This likely doesn't work correctly for the first token in the prompt
|
||||
# because of the extra space added to the start of the prompt_tokens
|
||||
if logit_bias is not None:
|
||||
logit_bias_map = {int(k): float(v) for k, v in logit_bias.items()}
|
||||
|
||||
def logit_bias_processor(
|
||||
input_ids: npt.NDArray[np.intc],
|
||||
scores: npt.NDArray[np.single],
|
||||
) -> npt.NDArray[np.single]:
|
||||
new_scores = np.copy(
|
||||
scores
|
||||
) # Does it make sense to copy the whole array or can we just overwrite the original one?
|
||||
for input_id, score in logit_bias_map.items():
|
||||
new_scores[input_id] = score + scores[input_id]
|
||||
return new_scores
|
||||
|
||||
_logit_bias_processor = LogitsProcessorList([logit_bias_processor])
|
||||
if logits_processor is None:
|
||||
logits_processor = _logit_bias_processor
|
||||
else:
|
||||
logits_processor = logits_processor.extend(_logit_bias_processor)
|
||||
|
||||
if self.verbose:
|
||||
self._ctx.reset_timings()
|
||||
|
||||
|
@ -1963,6 +1986,7 @@ class Llama:
|
|||
model: Optional[str] = None,
|
||||
logits_processor: Optional[LogitsProcessorList] = None,
|
||||
grammar: Optional[LlamaGrammar] = None,
|
||||
logit_bias: Optional[Dict[str, float]] = None,
|
||||
) -> Union[
|
||||
CreateChatCompletionResponse, Iterator[CreateChatCompletionStreamResponse]
|
||||
]:
|
||||
|
@ -2011,6 +2035,7 @@ class Llama:
|
|||
model=model,
|
||||
logits_processor=logits_processor,
|
||||
grammar=grammar,
|
||||
logit_bias=logit_bias,
|
||||
)
|
||||
|
||||
def __getstate__(self):
|
||||
|
|
|
@ -45,6 +45,7 @@ class LlamaChatCompletionHandler(Protocol):
|
|||
model: Optional[str] = None,
|
||||
logits_processor: Optional[llama.LogitsProcessorList] = None,
|
||||
grammar: Optional[llama.LlamaGrammar] = None,
|
||||
logit_bias: Optional[Dict[str, float]] = None,
|
||||
**kwargs, # type: ignore
|
||||
) -> Union[
|
||||
llama_types.CreateChatCompletionResponse,
|
||||
|
@ -308,6 +309,7 @@ def register_chat_format(name: str):
|
|||
model: Optional[str] = None,
|
||||
logits_processor: Optional[llama.LogitsProcessorList] = None,
|
||||
grammar: Optional[llama.LlamaGrammar] = None,
|
||||
logit_bias: Optional[Dict[str, float]] = None,
|
||||
**kwargs, # type: ignore
|
||||
) -> Union[
|
||||
llama_types.CreateChatCompletionResponse,
|
||||
|
@ -350,6 +352,7 @@ def register_chat_format(name: str):
|
|||
model=model,
|
||||
logits_processor=logits_processor,
|
||||
grammar=grammar,
|
||||
logit_bias=logit_bias,
|
||||
)
|
||||
return _convert_completion_to_chat(completion_or_chunks, stream=stream)
|
||||
|
||||
|
|
|
@ -646,36 +646,16 @@ class CreateCompletionRequest(BaseModel):
|
|||
}
|
||||
|
||||
|
||||
def make_logit_bias_processor(
|
||||
def _logit_bias_tokens_to_input_ids(
|
||||
llama: llama_cpp.Llama,
|
||||
logit_bias: Dict[str, float],
|
||||
logit_bias_type: Optional[Literal["input_ids", "tokens"]],
|
||||
):
|
||||
if logit_bias_type is None:
|
||||
logit_bias_type = "input_ids"
|
||||
|
||||
to_bias: Dict[int, float] = {}
|
||||
if logit_bias_type == "input_ids":
|
||||
for input_id, score in logit_bias.items():
|
||||
input_id = int(input_id)
|
||||
to_bias[input_id] = score
|
||||
|
||||
elif logit_bias_type == "tokens":
|
||||
) -> Dict[str, float]:
|
||||
to_bias: Dict[str, float] = {}
|
||||
for token, score in logit_bias.items():
|
||||
token = token.encode("utf-8")
|
||||
for input_id in llama.tokenize(token, add_bos=False, special=True):
|
||||
to_bias[input_id] = score
|
||||
|
||||
def logit_bias_processor(
|
||||
input_ids: npt.NDArray[np.intc],
|
||||
scores: npt.NDArray[np.single],
|
||||
) -> npt.NDArray[np.single]:
|
||||
new_scores = np.copy(scores) # Does it make sense to copy the whole array or can we just overwrite the original one?
|
||||
for input_id, score in to_bias.items():
|
||||
new_scores[input_id] = score + scores[input_id]
|
||||
return new_scores
|
||||
|
||||
return logit_bias_processor
|
||||
to_bias[str(input_id)] = score
|
||||
return to_bias
|
||||
|
||||
|
||||
@router.post(
|
||||
|
@ -694,17 +674,16 @@ async def create_completion(
|
|||
exclude = {
|
||||
"n",
|
||||
"best_of",
|
||||
"logit_bias",
|
||||
"logit_bias_type",
|
||||
"user",
|
||||
}
|
||||
kwargs = body.model_dump(exclude=exclude)
|
||||
|
||||
if body.logit_bias is not None:
|
||||
kwargs["logits_processor"] = llama_cpp.LogitsProcessorList(
|
||||
[
|
||||
make_logit_bias_processor(llama, body.logit_bias, body.logit_bias_type),
|
||||
]
|
||||
kwargs["logit_bias"] = (
|
||||
_logit_bias_tokens_to_input_ids(llama, body.logit_bias)
|
||||
if body.logit_bias_type == "tokens"
|
||||
else body.logit_bias
|
||||
)
|
||||
|
||||
if body.grammar is not None:
|
||||
|
@ -851,17 +830,16 @@ async def create_chat_completion(
|
|||
) -> llama_cpp.ChatCompletion:
|
||||
exclude = {
|
||||
"n",
|
||||
"logit_bias",
|
||||
"logit_bias_type",
|
||||
"user",
|
||||
}
|
||||
kwargs = body.model_dump(exclude=exclude)
|
||||
|
||||
if body.logit_bias is not None:
|
||||
kwargs["logits_processor"] = llama_cpp.LogitsProcessorList(
|
||||
[
|
||||
make_logit_bias_processor(llama, body.logit_bias, body.logit_bias_type),
|
||||
]
|
||||
kwargs["logit_bias"] = (
|
||||
_logit_bias_tokens_to_input_ids(llama, body.logit_bias)
|
||||
if body.logit_bias_type == "tokens"
|
||||
else body.logit_bias
|
||||
)
|
||||
|
||||
if body.grammar is not None:
|
||||
|
|
Loading…
Reference in a new issue