diff --git a/llama_cpp/llama.py b/llama_cpp/llama.py index f4b2d49..f61b077 100644 --- a/llama_cpp/llama.py +++ b/llama_cpp/llama.py @@ -677,6 +677,8 @@ class Llama: mirostat_tau: float = 5.0, mirostat_eta: float = 0.1, model: Optional[str] = None, + stopping_criteria: Optional[StoppingCriteriaList] = None, + logits_processor: Optional[LogitsProcessorList] = None, ) -> Union[Iterator[Completion], Iterator[CompletionChunk]]: assert self.ctx is not None @@ -739,6 +741,8 @@ class Llama: frequency_penalty=frequency_penalty, presence_penalty=presence_penalty, repeat_penalty=repeat_penalty, + stopping_criteria=stopping_criteria, + logits_processor=logits_processor, ): if token == self._token_eos: text = self.detokenize(completion_tokens) @@ -848,6 +852,11 @@ class Llama: finish_reason = "length" break + if stopping_criteria is not None and stopping_criteria( + list(self.eval_tokens), self.eval_logits[-1] + ): + finish_reason = "stop" + if self.verbose: llama_cpp.llama_print_timings(self.ctx) @@ -1049,6 +1058,8 @@ class Llama: mirostat_tau: float = 5.0, mirostat_eta: float = 0.1, model: Optional[str] = None, + stopping_criteria: Optional[StoppingCriteriaList] = None, + logits_processor: Optional[LogitsProcessorList] = None, ) -> Union[Completion, Iterator[CompletionChunk]]: """Generate text from a prompt. @@ -1091,6 +1102,8 @@ class Llama: mirostat_tau=mirostat_tau, mirostat_eta=mirostat_eta, model=model, + stopping_criteria=stopping_criteria, + logits_processor=logits_processor, ) if stream: chunks: Iterator[CompletionChunk] = completion_or_chunks @@ -1118,6 +1131,8 @@ class Llama: mirostat_tau: float = 5.0, mirostat_eta: float = 0.1, model: Optional[str] = None, + stopping_criteria: Optional[StoppingCriteriaList] = None, + logits_processor: Optional[LogitsProcessorList] = None, ) -> Union[Completion, Iterator[CompletionChunk]]: """Generate text from a prompt. @@ -1160,6 +1175,8 @@ class Llama: mirostat_tau=mirostat_tau, mirostat_eta=mirostat_eta, model=model, + stopping_criteria=stopping_criteria, + logits_processor=logits_processor, ) def _convert_text_completion_to_chat(