Add ability to pass in penalize_nl param (#1068)

This commit is contained in:
Stephen Hankinson 2024-01-10 03:46:27 -04:00 committed by GitHub
parent 2ddce7294e
commit df3be58d6c
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23

View file

@ -1201,6 +1201,7 @@ class Llama:
mirostat_mode: int = 0, mirostat_mode: int = 0,
mirostat_tau: float = 5.0, mirostat_tau: float = 5.0,
mirostat_eta: float = 0.1, mirostat_eta: float = 0.1,
penalize_nl: bool = True,
logits_processor: Optional[LogitsProcessorList] = None, logits_processor: Optional[LogitsProcessorList] = None,
stopping_criteria: Optional[StoppingCriteriaList] = None, stopping_criteria: Optional[StoppingCriteriaList] = None,
grammar: Optional[LlamaGrammar] = None, grammar: Optional[LlamaGrammar] = None,
@ -1261,6 +1262,7 @@ class Llama:
mirostat_eta=mirostat_eta, mirostat_eta=mirostat_eta,
logits_processor=logits_processor, logits_processor=logits_processor,
grammar=grammar, grammar=grammar,
penalize_nl=penalize_nl,
) )
if stopping_criteria is not None and stopping_criteria( if stopping_criteria is not None and stopping_criteria(
self._input_ids, self._scores[-1, :] self._input_ids, self._scores[-1, :]