From 07e47f55ba3a72e6022ebd12fb036373a7a7c4dd Mon Sep 17 00:00:00 2001 From: Andrei Betlen Date: Tue, 21 Nov 2023 03:59:46 -0500 Subject: [PATCH] Add support for logit_bias outside of server api. Closes #827 --- llama_cpp/llama.py | 25 ++++++++++++++++ llama_cpp/llama_chat_format.py | 3 ++ llama_cpp/server/app.py | 54 ++++++++++------------------------ 3 files changed, 44 insertions(+), 38 deletions(-) diff --git a/llama_cpp/llama.py b/llama_cpp/llama.py index adb767f..f49991c 100644 --- a/llama_cpp/llama.py +++ b/llama_cpp/llama.py @@ -1327,6 +1327,7 @@ class Llama: stopping_criteria: Optional[StoppingCriteriaList] = None, logits_processor: Optional[LogitsProcessorList] = None, grammar: Optional[LlamaGrammar] = None, + logit_bias: Optional[Dict[int, float]] = None, ) -> Union[ Iterator[CreateCompletionResponse], Iterator[CreateCompletionStreamResponse] ]: @@ -1355,6 +1356,28 @@ class Llama: ) model_name: str = model if model is not None else self.model_path + # NOTE: This likely doesn't work correctly for the first token in the prompt + # because of the extra space added to the start of the prompt_tokens + if logit_bias is not None: + logit_bias_map = {int(k): float(v) for k, v in logit_bias.items()} + + def logit_bias_processor( + input_ids: npt.NDArray[np.intc], + scores: npt.NDArray[np.single], + ) -> npt.NDArray[np.single]: + new_scores = np.copy( + scores + ) # Does it make sense to copy the whole array or can we just overwrite the original one? + for input_id, score in logit_bias_map.items(): + new_scores[input_id] = score + scores[input_id] + return new_scores + + _logit_bias_processor = LogitsProcessorList([logit_bias_processor]) + if logits_processor is None: + logits_processor = _logit_bias_processor + else: + logits_processor = logits_processor.extend(_logit_bias_processor) + if self.verbose: self._ctx.reset_timings() @@ -1963,6 +1986,7 @@ class Llama: model: Optional[str] = None, logits_processor: Optional[LogitsProcessorList] = None, grammar: Optional[LlamaGrammar] = None, + logit_bias: Optional[Dict[str, float]] = None, ) -> Union[ CreateChatCompletionResponse, Iterator[CreateChatCompletionStreamResponse] ]: @@ -2011,6 +2035,7 @@ class Llama: model=model, logits_processor=logits_processor, grammar=grammar, + logit_bias=logit_bias, ) def __getstate__(self): diff --git a/llama_cpp/llama_chat_format.py b/llama_cpp/llama_chat_format.py index a855305..8efbaae 100644 --- a/llama_cpp/llama_chat_format.py +++ b/llama_cpp/llama_chat_format.py @@ -45,6 +45,7 @@ class LlamaChatCompletionHandler(Protocol): model: Optional[str] = None, logits_processor: Optional[llama.LogitsProcessorList] = None, grammar: Optional[llama.LlamaGrammar] = None, + logit_bias: Optional[Dict[str, float]] = None, **kwargs, # type: ignore ) -> Union[ llama_types.CreateChatCompletionResponse, @@ -308,6 +309,7 @@ def register_chat_format(name: str): model: Optional[str] = None, logits_processor: Optional[llama.LogitsProcessorList] = None, grammar: Optional[llama.LlamaGrammar] = None, + logit_bias: Optional[Dict[str, float]] = None, **kwargs, # type: ignore ) -> Union[ llama_types.CreateChatCompletionResponse, @@ -350,6 +352,7 @@ def register_chat_format(name: str): model=model, logits_processor=logits_processor, grammar=grammar, + logit_bias=logit_bias, ) return _convert_completion_to_chat(completion_or_chunks, stream=stream) diff --git a/llama_cpp/server/app.py b/llama_cpp/server/app.py index b39a462..9262b20 100644 --- a/llama_cpp/server/app.py +++ b/llama_cpp/server/app.py @@ -646,36 +646,16 @@ class CreateCompletionRequest(BaseModel): } -def make_logit_bias_processor( +def _logit_bias_tokens_to_input_ids( llama: llama_cpp.Llama, logit_bias: Dict[str, float], - logit_bias_type: Optional[Literal["input_ids", "tokens"]], -): - if logit_bias_type is None: - logit_bias_type = "input_ids" - - to_bias: Dict[int, float] = {} - if logit_bias_type == "input_ids": - for input_id, score in logit_bias.items(): - input_id = int(input_id) - to_bias[input_id] = score - - elif logit_bias_type == "tokens": - for token, score in logit_bias.items(): - token = token.encode("utf-8") - for input_id in llama.tokenize(token, add_bos=False, special=True): - to_bias[input_id] = score - - def logit_bias_processor( - input_ids: npt.NDArray[np.intc], - scores: npt.NDArray[np.single], - ) -> npt.NDArray[np.single]: - new_scores = np.copy(scores) # Does it make sense to copy the whole array or can we just overwrite the original one? - for input_id, score in to_bias.items(): - new_scores[input_id] = score + scores[input_id] - return new_scores - - return logit_bias_processor +) -> Dict[str, float]: + to_bias: Dict[str, float] = {} + for token, score in logit_bias.items(): + token = token.encode("utf-8") + for input_id in llama.tokenize(token, add_bos=False, special=True): + to_bias[str(input_id)] = score + return to_bias @router.post( @@ -694,17 +674,16 @@ async def create_completion( exclude = { "n", "best_of", - "logit_bias", "logit_bias_type", "user", } kwargs = body.model_dump(exclude=exclude) if body.logit_bias is not None: - kwargs["logits_processor"] = llama_cpp.LogitsProcessorList( - [ - make_logit_bias_processor(llama, body.logit_bias, body.logit_bias_type), - ] + kwargs["logit_bias"] = ( + _logit_bias_tokens_to_input_ids(llama, body.logit_bias) + if body.logit_bias_type == "tokens" + else body.logit_bias ) if body.grammar is not None: @@ -851,17 +830,16 @@ async def create_chat_completion( ) -> llama_cpp.ChatCompletion: exclude = { "n", - "logit_bias", "logit_bias_type", "user", } kwargs = body.model_dump(exclude=exclude) if body.logit_bias is not None: - kwargs["logits_processor"] = llama_cpp.LogitsProcessorList( - [ - make_logit_bias_processor(llama, body.logit_bias, body.logit_bias_type), - ] + kwargs["logit_bias"] = ( + _logit_bias_tokens_to_input_ids(llama, body.logit_bias) + if body.logit_bias_type == "tokens" + else body.logit_bias ) if body.grammar is not None: