From 5bb780d455d4158870c231a7fde1fa16863361f1 Mon Sep 17 00:00:00 2001 From: Maximilian-Winter Date: Wed, 24 May 2023 21:55:44 +0200 Subject: [PATCH] Implemented logit processors and stop criteria's --- llama_cpp/llama.py | 27 +++++++++++++++++++++++++++ 1 file changed, 27 insertions(+) diff --git a/llama_cpp/llama.py b/llama_cpp/llama.py index 916fe07..cf1e719 100644 --- a/llama_cpp/llama.py +++ b/llama_cpp/llama.py @@ -316,6 +316,7 @@ class Llama: mirostat_tau: llama_cpp.c_float, mirostat_eta: llama_cpp.c_float, penalize_nl: bool = True, + logits_processors=None ): assert self.ctx is not None assert len(self.eval_logits) > 0 @@ -328,6 +329,10 @@ class Llama: else last_n_tokens_size ) logits = self.eval_logits[-1] + for processor in logits_processors: + logits = processor(list(self.eval_tokens), logits) + + self.eval_logits[-1] = logits nl_logit = logits[self._token_nl] candidates = self._candidates for i, logit in enumerate(logits): @@ -436,6 +441,8 @@ class Llama: mirostat_eta: float = 0.1, mirostat_tau: float = 5.0, penalize_nl: bool = True, + logits_processors=None + ): """Sample a token from the model. @@ -468,6 +475,8 @@ class Llama: mirostat_tau=llama_cpp.c_float(mirostat_tau), mirostat_eta=llama_cpp.c_float(mirostat_eta), penalize_nl=penalize_nl, + logits_processors=logits_processors + ) def generate( @@ -484,6 +493,7 @@ class Llama: mirostat_mode: int = 0, mirostat_tau: float = 5.0, mirostat_eta: float = 0.1, + logits_processors=None ) -> Generator[int, Optional[Sequence[int]], None]: """Create a generator of tokens from a prompt. @@ -541,6 +551,7 @@ class Llama: mirostat_mode=mirostat_mode, mirostat_tau=mirostat_tau, mirostat_eta=mirostat_eta, + logits_processors=logits_processors ) tokens_or_none = yield token tokens = [token] @@ -637,6 +648,8 @@ class Llama: mirostat_tau: float = 5.0, mirostat_eta: float = 0.1, model: Optional[str] = None, + logits_processors=None, + stopping_criterias=None ) -> Union[Iterator[Completion], Iterator[CompletionChunk]]: assert self.ctx is not None completion_id: str = f"cmpl-{str(uuid.uuid4())}" @@ -700,6 +713,7 @@ class Llama: frequency_penalty=frequency_penalty, presence_penalty=presence_penalty, repeat_penalty=repeat_penalty, + logits_processors=logits_processors ): if token == self._token_eos: text = self.detokenize(completion_tokens) @@ -707,6 +721,14 @@ class Llama: break completion_tokens.append(token) + for stopping_crit in stopping_criterias: + if stopping_crit(completion_tokens, None): + text = self.detokenize(completion_tokens) + finish_reason = "stop" + break + + if finish_reason == "stop": + break all_text = self.detokenize(completion_tokens) @@ -1006,6 +1028,8 @@ class Llama: mirostat_tau: float = 5.0, mirostat_eta: float = 0.1, model: Optional[str] = None, + logits_processors=None, + stopping_criterias=None ) -> Union[Completion, Iterator[CompletionChunk]]: """Generate text from a prompt. @@ -1048,6 +1072,9 @@ class Llama: mirostat_tau=mirostat_tau, mirostat_eta=mirostat_eta, model=model, + logits_processors=logits_processors, + stopping_criterias=stopping_criterias + ) if stream: chunks: Iterator[CompletionChunk] = completion_or_chunks