Add support for logit_bias outside of server api. Closes #827
This commit is contained in:
parent
c21edb6908
commit
07e47f55ba
3 changed files with 44 additions and 38 deletions
|
@ -1327,6 +1327,7 @@ class Llama:
|
||||||
stopping_criteria: Optional[StoppingCriteriaList] = None,
|
stopping_criteria: Optional[StoppingCriteriaList] = None,
|
||||||
logits_processor: Optional[LogitsProcessorList] = None,
|
logits_processor: Optional[LogitsProcessorList] = None,
|
||||||
grammar: Optional[LlamaGrammar] = None,
|
grammar: Optional[LlamaGrammar] = None,
|
||||||
|
logit_bias: Optional[Dict[int, float]] = None,
|
||||||
) -> Union[
|
) -> Union[
|
||||||
Iterator[CreateCompletionResponse], Iterator[CreateCompletionStreamResponse]
|
Iterator[CreateCompletionResponse], Iterator[CreateCompletionStreamResponse]
|
||||||
]:
|
]:
|
||||||
|
@ -1355,6 +1356,28 @@ class Llama:
|
||||||
)
|
)
|
||||||
model_name: str = model if model is not None else self.model_path
|
model_name: str = model if model is not None else self.model_path
|
||||||
|
|
||||||
|
# NOTE: This likely doesn't work correctly for the first token in the prompt
|
||||||
|
# because of the extra space added to the start of the prompt_tokens
|
||||||
|
if logit_bias is not None:
|
||||||
|
logit_bias_map = {int(k): float(v) for k, v in logit_bias.items()}
|
||||||
|
|
||||||
|
def logit_bias_processor(
|
||||||
|
input_ids: npt.NDArray[np.intc],
|
||||||
|
scores: npt.NDArray[np.single],
|
||||||
|
) -> npt.NDArray[np.single]:
|
||||||
|
new_scores = np.copy(
|
||||||
|
scores
|
||||||
|
) # Does it make sense to copy the whole array or can we just overwrite the original one?
|
||||||
|
for input_id, score in logit_bias_map.items():
|
||||||
|
new_scores[input_id] = score + scores[input_id]
|
||||||
|
return new_scores
|
||||||
|
|
||||||
|
_logit_bias_processor = LogitsProcessorList([logit_bias_processor])
|
||||||
|
if logits_processor is None:
|
||||||
|
logits_processor = _logit_bias_processor
|
||||||
|
else:
|
||||||
|
logits_processor = logits_processor.extend(_logit_bias_processor)
|
||||||
|
|
||||||
if self.verbose:
|
if self.verbose:
|
||||||
self._ctx.reset_timings()
|
self._ctx.reset_timings()
|
||||||
|
|
||||||
|
@ -1963,6 +1986,7 @@ class Llama:
|
||||||
model: Optional[str] = None,
|
model: Optional[str] = None,
|
||||||
logits_processor: Optional[LogitsProcessorList] = None,
|
logits_processor: Optional[LogitsProcessorList] = None,
|
||||||
grammar: Optional[LlamaGrammar] = None,
|
grammar: Optional[LlamaGrammar] = None,
|
||||||
|
logit_bias: Optional[Dict[str, float]] = None,
|
||||||
) -> Union[
|
) -> Union[
|
||||||
CreateChatCompletionResponse, Iterator[CreateChatCompletionStreamResponse]
|
CreateChatCompletionResponse, Iterator[CreateChatCompletionStreamResponse]
|
||||||
]:
|
]:
|
||||||
|
@ -2011,6 +2035,7 @@ class Llama:
|
||||||
model=model,
|
model=model,
|
||||||
logits_processor=logits_processor,
|
logits_processor=logits_processor,
|
||||||
grammar=grammar,
|
grammar=grammar,
|
||||||
|
logit_bias=logit_bias,
|
||||||
)
|
)
|
||||||
|
|
||||||
def __getstate__(self):
|
def __getstate__(self):
|
||||||
|
|
|
@ -45,6 +45,7 @@ class LlamaChatCompletionHandler(Protocol):
|
||||||
model: Optional[str] = None,
|
model: Optional[str] = None,
|
||||||
logits_processor: Optional[llama.LogitsProcessorList] = None,
|
logits_processor: Optional[llama.LogitsProcessorList] = None,
|
||||||
grammar: Optional[llama.LlamaGrammar] = None,
|
grammar: Optional[llama.LlamaGrammar] = None,
|
||||||
|
logit_bias: Optional[Dict[str, float]] = None,
|
||||||
**kwargs, # type: ignore
|
**kwargs, # type: ignore
|
||||||
) -> Union[
|
) -> Union[
|
||||||
llama_types.CreateChatCompletionResponse,
|
llama_types.CreateChatCompletionResponse,
|
||||||
|
@ -308,6 +309,7 @@ def register_chat_format(name: str):
|
||||||
model: Optional[str] = None,
|
model: Optional[str] = None,
|
||||||
logits_processor: Optional[llama.LogitsProcessorList] = None,
|
logits_processor: Optional[llama.LogitsProcessorList] = None,
|
||||||
grammar: Optional[llama.LlamaGrammar] = None,
|
grammar: Optional[llama.LlamaGrammar] = None,
|
||||||
|
logit_bias: Optional[Dict[str, float]] = None,
|
||||||
**kwargs, # type: ignore
|
**kwargs, # type: ignore
|
||||||
) -> Union[
|
) -> Union[
|
||||||
llama_types.CreateChatCompletionResponse,
|
llama_types.CreateChatCompletionResponse,
|
||||||
|
@ -350,6 +352,7 @@ def register_chat_format(name: str):
|
||||||
model=model,
|
model=model,
|
||||||
logits_processor=logits_processor,
|
logits_processor=logits_processor,
|
||||||
grammar=grammar,
|
grammar=grammar,
|
||||||
|
logit_bias=logit_bias,
|
||||||
)
|
)
|
||||||
return _convert_completion_to_chat(completion_or_chunks, stream=stream)
|
return _convert_completion_to_chat(completion_or_chunks, stream=stream)
|
||||||
|
|
||||||
|
|
|
@ -646,36 +646,16 @@ class CreateCompletionRequest(BaseModel):
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
def make_logit_bias_processor(
|
def _logit_bias_tokens_to_input_ids(
|
||||||
llama: llama_cpp.Llama,
|
llama: llama_cpp.Llama,
|
||||||
logit_bias: Dict[str, float],
|
logit_bias: Dict[str, float],
|
||||||
logit_bias_type: Optional[Literal["input_ids", "tokens"]],
|
) -> Dict[str, float]:
|
||||||
):
|
to_bias: Dict[str, float] = {}
|
||||||
if logit_bias_type is None:
|
for token, score in logit_bias.items():
|
||||||
logit_bias_type = "input_ids"
|
token = token.encode("utf-8")
|
||||||
|
for input_id in llama.tokenize(token, add_bos=False, special=True):
|
||||||
to_bias: Dict[int, float] = {}
|
to_bias[str(input_id)] = score
|
||||||
if logit_bias_type == "input_ids":
|
return to_bias
|
||||||
for input_id, score in logit_bias.items():
|
|
||||||
input_id = int(input_id)
|
|
||||||
to_bias[input_id] = score
|
|
||||||
|
|
||||||
elif logit_bias_type == "tokens":
|
|
||||||
for token, score in logit_bias.items():
|
|
||||||
token = token.encode("utf-8")
|
|
||||||
for input_id in llama.tokenize(token, add_bos=False, special=True):
|
|
||||||
to_bias[input_id] = score
|
|
||||||
|
|
||||||
def logit_bias_processor(
|
|
||||||
input_ids: npt.NDArray[np.intc],
|
|
||||||
scores: npt.NDArray[np.single],
|
|
||||||
) -> npt.NDArray[np.single]:
|
|
||||||
new_scores = np.copy(scores) # Does it make sense to copy the whole array or can we just overwrite the original one?
|
|
||||||
for input_id, score in to_bias.items():
|
|
||||||
new_scores[input_id] = score + scores[input_id]
|
|
||||||
return new_scores
|
|
||||||
|
|
||||||
return logit_bias_processor
|
|
||||||
|
|
||||||
|
|
||||||
@router.post(
|
@router.post(
|
||||||
|
@ -694,17 +674,16 @@ async def create_completion(
|
||||||
exclude = {
|
exclude = {
|
||||||
"n",
|
"n",
|
||||||
"best_of",
|
"best_of",
|
||||||
"logit_bias",
|
|
||||||
"logit_bias_type",
|
"logit_bias_type",
|
||||||
"user",
|
"user",
|
||||||
}
|
}
|
||||||
kwargs = body.model_dump(exclude=exclude)
|
kwargs = body.model_dump(exclude=exclude)
|
||||||
|
|
||||||
if body.logit_bias is not None:
|
if body.logit_bias is not None:
|
||||||
kwargs["logits_processor"] = llama_cpp.LogitsProcessorList(
|
kwargs["logit_bias"] = (
|
||||||
[
|
_logit_bias_tokens_to_input_ids(llama, body.logit_bias)
|
||||||
make_logit_bias_processor(llama, body.logit_bias, body.logit_bias_type),
|
if body.logit_bias_type == "tokens"
|
||||||
]
|
else body.logit_bias
|
||||||
)
|
)
|
||||||
|
|
||||||
if body.grammar is not None:
|
if body.grammar is not None:
|
||||||
|
@ -851,17 +830,16 @@ async def create_chat_completion(
|
||||||
) -> llama_cpp.ChatCompletion:
|
) -> llama_cpp.ChatCompletion:
|
||||||
exclude = {
|
exclude = {
|
||||||
"n",
|
"n",
|
||||||
"logit_bias",
|
|
||||||
"logit_bias_type",
|
"logit_bias_type",
|
||||||
"user",
|
"user",
|
||||||
}
|
}
|
||||||
kwargs = body.model_dump(exclude=exclude)
|
kwargs = body.model_dump(exclude=exclude)
|
||||||
|
|
||||||
if body.logit_bias is not None:
|
if body.logit_bias is not None:
|
||||||
kwargs["logits_processor"] = llama_cpp.LogitsProcessorList(
|
kwargs["logit_bias"] = (
|
||||||
[
|
_logit_bias_tokens_to_input_ids(llama, body.logit_bias)
|
||||||
make_logit_bias_processor(llama, body.logit_bias, body.logit_bias_type),
|
if body.logit_bias_type == "tokens"
|
||||||
]
|
else body.logit_bias
|
||||||
)
|
)
|
||||||
|
|
||||||
if body.grammar is not None:
|
if body.grammar is not None:
|
||||||
|
|
Loading…
Reference in a new issue