diff --git a/llama_cpp/llama.py b/llama_cpp/llama.py index f49991c..7d3dc76 100644 --- a/llama_cpp/llama.py +++ b/llama_cpp/llama.py @@ -1327,7 +1327,7 @@ class Llama: stopping_criteria: Optional[StoppingCriteriaList] = None, logits_processor: Optional[LogitsProcessorList] = None, grammar: Optional[LlamaGrammar] = None, - logit_bias: Optional[Dict[int, float]] = None, + logit_bias: Optional[Dict[str, float]] = None, ) -> Union[ Iterator[CreateCompletionResponse], Iterator[CreateCompletionStreamResponse] ]: @@ -1828,6 +1828,7 @@ class Llama: stopping_criteria: Optional[StoppingCriteriaList] = None, logits_processor: Optional[LogitsProcessorList] = None, grammar: Optional[LlamaGrammar] = None, + logit_bias: Optional[Dict[str, float]] = None, ) -> Union[CreateCompletionResponse, Iterator[CreateCompletionStreamResponse]]: """Generate text from a prompt. @@ -1876,6 +1877,7 @@ class Llama: stopping_criteria=stopping_criteria, logits_processor=logits_processor, grammar=grammar, + logit_bias=logit_bias, ) if stream: chunks: Iterator[CreateCompletionStreamResponse] = completion_or_chunks @@ -1909,6 +1911,7 @@ class Llama: stopping_criteria: Optional[StoppingCriteriaList] = None, logits_processor: Optional[LogitsProcessorList] = None, grammar: Optional[LlamaGrammar] = None, + logit_bias: Optional[Dict[str, float]] = None, ) -> Union[CreateCompletionResponse, Iterator[CreateCompletionStreamResponse]]: """Generate text from a prompt. @@ -1957,6 +1960,7 @@ class Llama: stopping_criteria=stopping_criteria, logits_processor=logits_processor, grammar=grammar, + logit_bias=logit_bias, ) def create_chat_completion(