From 1d247e0f350948667553f3c880f8df40f0b5c787 Mon Sep 17 00:00:00 2001 From: Andrei Betlen Date: Thu, 25 May 2023 14:04:54 -0400 Subject: [PATCH] Add StoppingCriteria and LogitsProcessor to generate to match huggingface API --- llama_cpp/llama.py | 74 ++++++++++++++++++++++++++-------------------- 1 file changed, 42 insertions(+), 32 deletions(-) diff --git a/llama_cpp/llama.py b/llama_cpp/llama.py index 144671b..b7a8d79 100644 --- a/llama_cpp/llama.py +++ b/llama_cpp/llama.py @@ -4,7 +4,17 @@ import uuid import time import math import multiprocessing -from typing import List, Optional, Union, Generator, Sequence, Iterator, Deque, Tuple, Callable +from typing import ( + List, + Optional, + Union, + Generator, + Sequence, + Iterator, + Deque, + Tuple, + Callable, +) from collections import deque, OrderedDict from . import llama_cpp @@ -72,6 +82,24 @@ class LlamaState: self.llama_state_size = llama_state_size +LogitsProcessor = Callable[[List[int], List[float]], List[float]] + + +class LogitsProcessorList(List[LogitsProcessor]): + def __call__(self, input_ids: List[int], scores: List[float]) -> List[float]: + for processor in self: + scores = processor(input_ids, scores) + return scores + + +StoppingCriteria = Callable[[List[int], List[float]], bool] + + +class StoppingCriteriaList(List[StoppingCriteria]): + def __call__(self, input_ids: List[int], logits: List[float]) -> bool: + return any([stopping_criteria(input_ids, logits) for stopping_criteria in self]) + + class Llama: """High-level Python wrapper for a llama.cpp model.""" @@ -316,12 +344,10 @@ class Llama: mirostat_tau: llama_cpp.c_float, mirostat_eta: llama_cpp.c_float, penalize_nl: bool = True, - logits_processors: List[Callable[[List[int], List[float]], List[float]]] = None + logits_processor: Optional[LogitsProcessorList] = None, ): assert self.ctx is not None assert len(self.eval_logits) > 0 - if logits_processors is None: - logits_processors = [] n_vocab = self.n_vocab() n_ctx = self.n_ctx() @@ -332,10 +358,10 @@ class Llama: else last_n_tokens_size ) logits = self.eval_logits[-1] - for processor in logits_processors: - logits = processor(list(self.eval_tokens), logits) - self.eval_logits[-1] = logits + if logits_processor is not None: + logits = logits_processor(list(self.eval_tokens), logits) + nl_logit = logits[self._token_nl] candidates = self._candidates for i, logit in enumerate(logits): @@ -444,8 +470,7 @@ class Llama: mirostat_eta: float = 0.1, mirostat_tau: float = 5.0, penalize_nl: bool = True, - logits_processors: List[Callable[[List[int], List[float]], List[float]]] = None - + logits_processor: Optional[LogitsProcessorList] = None, ): """Sample a token from the model. @@ -478,8 +503,7 @@ class Llama: mirostat_tau=llama_cpp.c_float(mirostat_tau), mirostat_eta=llama_cpp.c_float(mirostat_eta), penalize_nl=penalize_nl, - logits_processors=logits_processors - + logits_processor=logits_processor, ) def generate( @@ -496,7 +520,8 @@ class Llama: mirostat_mode: int = 0, mirostat_tau: float = 5.0, mirostat_eta: float = 0.1, - logits_processors: List[Callable[[List[int], List[float]], List[float]]] = None + logits_processor: Optional[LogitsProcessorList] = None, + stopping_criteria: Optional[StoppingCriteriaList] = None, ) -> Generator[int, Optional[Sequence[int]], None]: """Create a generator of tokens from a prompt. @@ -554,8 +579,12 @@ class Llama: mirostat_mode=mirostat_mode, mirostat_tau=mirostat_tau, mirostat_eta=mirostat_eta, - logits_processors=logits_processors + logits_processor=logits_processor, ) + if stopping_criteria is not None and stopping_criteria( + list(self.eval_tokens), self.eval_logits[-1] + ): + return tokens_or_none = yield token tokens = [token] if tokens_or_none is not None: @@ -651,14 +680,9 @@ class Llama: mirostat_tau: float = 5.0, mirostat_eta: float = 0.1, model: Optional[str] = None, - logits_processors: List[Callable[[List[int], List[float]], List[float]]] = None, - stopping_criterias: List[Callable[[List[int], List[float]], bool]] = None, ) -> Union[Iterator[Completion], Iterator[CompletionChunk]]: assert self.ctx is not None - if stopping_criterias is None: - stopping_criterias = [] - completion_id: str = f"cmpl-{str(uuid.uuid4())}" created: int = int(time.time()) completion_tokens: List[int] = [] @@ -720,7 +744,6 @@ class Llama: frequency_penalty=frequency_penalty, presence_penalty=presence_penalty, repeat_penalty=repeat_penalty, - logits_processors=logits_processors ): if token == self._token_eos: text = self.detokenize(completion_tokens) @@ -728,14 +751,6 @@ class Llama: break completion_tokens.append(token) - for stopping_crit in stopping_criterias: - if stopping_crit(completion_tokens, None): - text = self.detokenize(completion_tokens) - finish_reason = "stop" - break - - if finish_reason == "stop": - break all_text = self.detokenize(completion_tokens) @@ -1035,8 +1050,6 @@ class Llama: mirostat_tau: float = 5.0, mirostat_eta: float = 0.1, model: Optional[str] = None, - logits_processors: List[Callable[[List[int], List[float]], List[float]]] = None, - stopping_criterias: List[Callable[[List[int], List[float]], bool]] = None ) -> Union[Completion, Iterator[CompletionChunk]]: """Generate text from a prompt. @@ -1079,9 +1092,6 @@ class Llama: mirostat_tau=mirostat_tau, mirostat_eta=mirostat_eta, model=model, - logits_processors=logits_processors, - stopping_criterias=stopping_criterias - ) if stream: chunks: Iterator[CompletionChunk] = completion_or_chunks