From 422ebc89ce45af63c5f919da74ea68188ef70c1c Mon Sep 17 00:00:00 2001 From: Andrei Betlen Date: Tue, 21 Nov 2023 04:01:36 -0500 Subject: [PATCH] Fix: Add logit_bias to all completion api methods --- llama_cpp/llama.py | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/llama_cpp/llama.py b/llama_cpp/llama.py index f49991c..7d3dc76 100644 --- a/llama_cpp/llama.py +++ b/llama_cpp/llama.py @@ -1327,7 +1327,7 @@ class Llama: stopping_criteria: Optional[StoppingCriteriaList] = None, logits_processor: Optional[LogitsProcessorList] = None, grammar: Optional[LlamaGrammar] = None, - logit_bias: Optional[Dict[int, float]] = None, + logit_bias: Optional[Dict[str, float]] = None, ) -> Union[ Iterator[CreateCompletionResponse], Iterator[CreateCompletionStreamResponse] ]: @@ -1828,6 +1828,7 @@ class Llama: stopping_criteria: Optional[StoppingCriteriaList] = None, logits_processor: Optional[LogitsProcessorList] = None, grammar: Optional[LlamaGrammar] = None, + logit_bias: Optional[Dict[str, float]] = None, ) -> Union[CreateCompletionResponse, Iterator[CreateCompletionStreamResponse]]: """Generate text from a prompt. @@ -1876,6 +1877,7 @@ class Llama: stopping_criteria=stopping_criteria, logits_processor=logits_processor, grammar=grammar, + logit_bias=logit_bias, ) if stream: chunks: Iterator[CreateCompletionStreamResponse] = completion_or_chunks @@ -1909,6 +1911,7 @@ class Llama: stopping_criteria: Optional[StoppingCriteriaList] = None, logits_processor: Optional[LogitsProcessorList] = None, grammar: Optional[LlamaGrammar] = None, + logit_bias: Optional[Dict[str, float]] = None, ) -> Union[CreateCompletionResponse, Iterator[CreateCompletionStreamResponse]]: """Generate text from a prompt. @@ -1957,6 +1960,7 @@ class Llama: stopping_criteria=stopping_criteria, logits_processor=logits_processor, grammar=grammar, + logit_bias=logit_bias, ) def create_chat_completion(