Add extra logits_processor and stopping_criteria

This commit is contained in:
Andrei Betlen 2023-05-26 03:13:24 -04:00
parent 30bf8ec557
commit 433a2e3e8a

View file

@ -677,6 +677,8 @@ class Llama:
mirostat_tau: float = 5.0, mirostat_tau: float = 5.0,
mirostat_eta: float = 0.1, mirostat_eta: float = 0.1,
model: Optional[str] = None, model: Optional[str] = None,
stopping_criteria: Optional[StoppingCriteriaList] = None,
logits_processor: Optional[LogitsProcessorList] = None,
) -> Union[Iterator[Completion], Iterator[CompletionChunk]]: ) -> Union[Iterator[Completion], Iterator[CompletionChunk]]:
assert self.ctx is not None assert self.ctx is not None
@ -739,6 +741,8 @@ class Llama:
frequency_penalty=frequency_penalty, frequency_penalty=frequency_penalty,
presence_penalty=presence_penalty, presence_penalty=presence_penalty,
repeat_penalty=repeat_penalty, repeat_penalty=repeat_penalty,
stopping_criteria=stopping_criteria,
logits_processor=logits_processor,
): ):
if token == self._token_eos: if token == self._token_eos:
text = self.detokenize(completion_tokens) text = self.detokenize(completion_tokens)
@ -848,6 +852,11 @@ class Llama:
finish_reason = "length" finish_reason = "length"
break break
if stopping_criteria is not None and stopping_criteria(
list(self.eval_tokens), self.eval_logits[-1]
):
finish_reason = "stop"
if self.verbose: if self.verbose:
llama_cpp.llama_print_timings(self.ctx) llama_cpp.llama_print_timings(self.ctx)
@ -1049,6 +1058,8 @@ class Llama:
mirostat_tau: float = 5.0, mirostat_tau: float = 5.0,
mirostat_eta: float = 0.1, mirostat_eta: float = 0.1,
model: Optional[str] = None, model: Optional[str] = None,
stopping_criteria: Optional[StoppingCriteriaList] = None,
logits_processor: Optional[LogitsProcessorList] = None,
) -> Union[Completion, Iterator[CompletionChunk]]: ) -> Union[Completion, Iterator[CompletionChunk]]:
"""Generate text from a prompt. """Generate text from a prompt.
@ -1091,6 +1102,8 @@ class Llama:
mirostat_tau=mirostat_tau, mirostat_tau=mirostat_tau,
mirostat_eta=mirostat_eta, mirostat_eta=mirostat_eta,
model=model, model=model,
stopping_criteria=stopping_criteria,
logits_processor=logits_processor,
) )
if stream: if stream:
chunks: Iterator[CompletionChunk] = completion_or_chunks chunks: Iterator[CompletionChunk] = completion_or_chunks
@ -1118,6 +1131,8 @@ class Llama:
mirostat_tau: float = 5.0, mirostat_tau: float = 5.0,
mirostat_eta: float = 0.1, mirostat_eta: float = 0.1,
model: Optional[str] = None, model: Optional[str] = None,
stopping_criteria: Optional[StoppingCriteriaList] = None,
logits_processor: Optional[LogitsProcessorList] = None,
) -> Union[Completion, Iterator[CompletionChunk]]: ) -> Union[Completion, Iterator[CompletionChunk]]:
"""Generate text from a prompt. """Generate text from a prompt.
@ -1160,6 +1175,8 @@ class Llama:
mirostat_tau=mirostat_tau, mirostat_tau=mirostat_tau,
mirostat_eta=mirostat_eta, mirostat_eta=mirostat_eta,
model=model, model=model,
stopping_criteria=stopping_criteria,
logits_processor=logits_processor,
) )
def _convert_text_completion_to_chat( def _convert_text_completion_to_chat(