Fix: Add logit_bias to all completion api methods

This commit is contained in:
Andrei Betlen 2023-11-21 04:01:36 -05:00
parent 79efc85206
commit 422ebc89ce

View file

@ -1327,7 +1327,7 @@ class Llama:
stopping_criteria: Optional[StoppingCriteriaList] = None, stopping_criteria: Optional[StoppingCriteriaList] = None,
logits_processor: Optional[LogitsProcessorList] = None, logits_processor: Optional[LogitsProcessorList] = None,
grammar: Optional[LlamaGrammar] = None, grammar: Optional[LlamaGrammar] = None,
logit_bias: Optional[Dict[int, float]] = None, logit_bias: Optional[Dict[str, float]] = None,
) -> Union[ ) -> Union[
Iterator[CreateCompletionResponse], Iterator[CreateCompletionStreamResponse] Iterator[CreateCompletionResponse], Iterator[CreateCompletionStreamResponse]
]: ]:
@ -1828,6 +1828,7 @@ class Llama:
stopping_criteria: Optional[StoppingCriteriaList] = None, stopping_criteria: Optional[StoppingCriteriaList] = None,
logits_processor: Optional[LogitsProcessorList] = None, logits_processor: Optional[LogitsProcessorList] = None,
grammar: Optional[LlamaGrammar] = None, grammar: Optional[LlamaGrammar] = None,
logit_bias: Optional[Dict[str, float]] = None,
) -> Union[CreateCompletionResponse, Iterator[CreateCompletionStreamResponse]]: ) -> Union[CreateCompletionResponse, Iterator[CreateCompletionStreamResponse]]:
"""Generate text from a prompt. """Generate text from a prompt.
@ -1876,6 +1877,7 @@ class Llama:
stopping_criteria=stopping_criteria, stopping_criteria=stopping_criteria,
logits_processor=logits_processor, logits_processor=logits_processor,
grammar=grammar, grammar=grammar,
logit_bias=logit_bias,
) )
if stream: if stream:
chunks: Iterator[CreateCompletionStreamResponse] = completion_or_chunks chunks: Iterator[CreateCompletionStreamResponse] = completion_or_chunks
@ -1909,6 +1911,7 @@ class Llama:
stopping_criteria: Optional[StoppingCriteriaList] = None, stopping_criteria: Optional[StoppingCriteriaList] = None,
logits_processor: Optional[LogitsProcessorList] = None, logits_processor: Optional[LogitsProcessorList] = None,
grammar: Optional[LlamaGrammar] = None, grammar: Optional[LlamaGrammar] = None,
logit_bias: Optional[Dict[str, float]] = None,
) -> Union[CreateCompletionResponse, Iterator[CreateCompletionStreamResponse]]: ) -> Union[CreateCompletionResponse, Iterator[CreateCompletionStreamResponse]]:
"""Generate text from a prompt. """Generate text from a prompt.
@ -1957,6 +1960,7 @@ class Llama:
stopping_criteria=stopping_criteria, stopping_criteria=stopping_criteria,
logits_processor=logits_processor, logits_processor=logits_processor,
grammar=grammar, grammar=grammar,
logit_bias=logit_bias,
) )
def create_chat_completion( def create_chat_completion(