Add support for logit_bias outside of server api. Closes #827

This commit is contained in:
Andrei Betlen 2023-11-21 03:59:46 -05:00
parent c21edb6908
commit 07e47f55ba
3 changed files with 44 additions and 38 deletions

View file

@ -1327,6 +1327,7 @@ class Llama:
stopping_criteria: Optional[StoppingCriteriaList] = None, stopping_criteria: Optional[StoppingCriteriaList] = None,
logits_processor: Optional[LogitsProcessorList] = None, logits_processor: Optional[LogitsProcessorList] = None,
grammar: Optional[LlamaGrammar] = None, grammar: Optional[LlamaGrammar] = None,
logit_bias: Optional[Dict[int, float]] = None,
) -> Union[ ) -> Union[
Iterator[CreateCompletionResponse], Iterator[CreateCompletionStreamResponse] Iterator[CreateCompletionResponse], Iterator[CreateCompletionStreamResponse]
]: ]:
@ -1355,6 +1356,28 @@ class Llama:
) )
model_name: str = model if model is not None else self.model_path model_name: str = model if model is not None else self.model_path
# NOTE: This likely doesn't work correctly for the first token in the prompt
# because of the extra space added to the start of the prompt_tokens
if logit_bias is not None:
logit_bias_map = {int(k): float(v) for k, v in logit_bias.items()}
def logit_bias_processor(
input_ids: npt.NDArray[np.intc],
scores: npt.NDArray[np.single],
) -> npt.NDArray[np.single]:
new_scores = np.copy(
scores
) # Does it make sense to copy the whole array or can we just overwrite the original one?
for input_id, score in logit_bias_map.items():
new_scores[input_id] = score + scores[input_id]
return new_scores
_logit_bias_processor = LogitsProcessorList([logit_bias_processor])
if logits_processor is None:
logits_processor = _logit_bias_processor
else:
logits_processor = logits_processor.extend(_logit_bias_processor)
if self.verbose: if self.verbose:
self._ctx.reset_timings() self._ctx.reset_timings()
@ -1963,6 +1986,7 @@ class Llama:
model: Optional[str] = None, model: Optional[str] = None,
logits_processor: Optional[LogitsProcessorList] = None, logits_processor: Optional[LogitsProcessorList] = None,
grammar: Optional[LlamaGrammar] = None, grammar: Optional[LlamaGrammar] = None,
logit_bias: Optional[Dict[str, float]] = None,
) -> Union[ ) -> Union[
CreateChatCompletionResponse, Iterator[CreateChatCompletionStreamResponse] CreateChatCompletionResponse, Iterator[CreateChatCompletionStreamResponse]
]: ]:
@ -2011,6 +2035,7 @@ class Llama:
model=model, model=model,
logits_processor=logits_processor, logits_processor=logits_processor,
grammar=grammar, grammar=grammar,
logit_bias=logit_bias,
) )
def __getstate__(self): def __getstate__(self):

View file

@ -45,6 +45,7 @@ class LlamaChatCompletionHandler(Protocol):
model: Optional[str] = None, model: Optional[str] = None,
logits_processor: Optional[llama.LogitsProcessorList] = None, logits_processor: Optional[llama.LogitsProcessorList] = None,
grammar: Optional[llama.LlamaGrammar] = None, grammar: Optional[llama.LlamaGrammar] = None,
logit_bias: Optional[Dict[str, float]] = None,
**kwargs, # type: ignore **kwargs, # type: ignore
) -> Union[ ) -> Union[
llama_types.CreateChatCompletionResponse, llama_types.CreateChatCompletionResponse,
@ -308,6 +309,7 @@ def register_chat_format(name: str):
model: Optional[str] = None, model: Optional[str] = None,
logits_processor: Optional[llama.LogitsProcessorList] = None, logits_processor: Optional[llama.LogitsProcessorList] = None,
grammar: Optional[llama.LlamaGrammar] = None, grammar: Optional[llama.LlamaGrammar] = None,
logit_bias: Optional[Dict[str, float]] = None,
**kwargs, # type: ignore **kwargs, # type: ignore
) -> Union[ ) -> Union[
llama_types.CreateChatCompletionResponse, llama_types.CreateChatCompletionResponse,
@ -350,6 +352,7 @@ def register_chat_format(name: str):
model=model, model=model,
logits_processor=logits_processor, logits_processor=logits_processor,
grammar=grammar, grammar=grammar,
logit_bias=logit_bias,
) )
return _convert_completion_to_chat(completion_or_chunks, stream=stream) return _convert_completion_to_chat(completion_or_chunks, stream=stream)

View file

@ -646,36 +646,16 @@ class CreateCompletionRequest(BaseModel):
} }
def make_logit_bias_processor( def _logit_bias_tokens_to_input_ids(
llama: llama_cpp.Llama, llama: llama_cpp.Llama,
logit_bias: Dict[str, float], logit_bias: Dict[str, float],
logit_bias_type: Optional[Literal["input_ids", "tokens"]], ) -> Dict[str, float]:
): to_bias: Dict[str, float] = {}
if logit_bias_type is None: for token, score in logit_bias.items():
logit_bias_type = "input_ids" token = token.encode("utf-8")
for input_id in llama.tokenize(token, add_bos=False, special=True):
to_bias: Dict[int, float] = {} to_bias[str(input_id)] = score
if logit_bias_type == "input_ids": return to_bias
for input_id, score in logit_bias.items():
input_id = int(input_id)
to_bias[input_id] = score
elif logit_bias_type == "tokens":
for token, score in logit_bias.items():
token = token.encode("utf-8")
for input_id in llama.tokenize(token, add_bos=False, special=True):
to_bias[input_id] = score
def logit_bias_processor(
input_ids: npt.NDArray[np.intc],
scores: npt.NDArray[np.single],
) -> npt.NDArray[np.single]:
new_scores = np.copy(scores) # Does it make sense to copy the whole array or can we just overwrite the original one?
for input_id, score in to_bias.items():
new_scores[input_id] = score + scores[input_id]
return new_scores
return logit_bias_processor
@router.post( @router.post(
@ -694,17 +674,16 @@ async def create_completion(
exclude = { exclude = {
"n", "n",
"best_of", "best_of",
"logit_bias",
"logit_bias_type", "logit_bias_type",
"user", "user",
} }
kwargs = body.model_dump(exclude=exclude) kwargs = body.model_dump(exclude=exclude)
if body.logit_bias is not None: if body.logit_bias is not None:
kwargs["logits_processor"] = llama_cpp.LogitsProcessorList( kwargs["logit_bias"] = (
[ _logit_bias_tokens_to_input_ids(llama, body.logit_bias)
make_logit_bias_processor(llama, body.logit_bias, body.logit_bias_type), if body.logit_bias_type == "tokens"
] else body.logit_bias
) )
if body.grammar is not None: if body.grammar is not None:
@ -851,17 +830,16 @@ async def create_chat_completion(
) -> llama_cpp.ChatCompletion: ) -> llama_cpp.ChatCompletion:
exclude = { exclude = {
"n", "n",
"logit_bias",
"logit_bias_type", "logit_bias_type",
"user", "user",
} }
kwargs = body.model_dump(exclude=exclude) kwargs = body.model_dump(exclude=exclude)
if body.logit_bias is not None: if body.logit_bias is not None:
kwargs["logits_processor"] = llama_cpp.LogitsProcessorList( kwargs["logit_bias"] = (
[ _logit_bias_tokens_to_input_ids(llama, body.logit_bias)
make_logit_bias_processor(llama, body.logit_bias, body.logit_bias_type), if body.logit_bias_type == "tokens"
] else body.logit_bias
) )
if body.grammar is not None: if body.grammar is not None: