Add StoppingCriteria and LogitsProcessor to generate to match huggingface API

This commit is contained in:
Andrei Betlen 2023-05-25 14:04:54 -04:00
parent c6a9659972
commit 1d247e0f35

View file

@ -4,7 +4,17 @@ import uuid
import time import time
import math import math
import multiprocessing import multiprocessing
from typing import List, Optional, Union, Generator, Sequence, Iterator, Deque, Tuple, Callable from typing import (
List,
Optional,
Union,
Generator,
Sequence,
Iterator,
Deque,
Tuple,
Callable,
)
from collections import deque, OrderedDict from collections import deque, OrderedDict
from . import llama_cpp from . import llama_cpp
@ -72,6 +82,24 @@ class LlamaState:
self.llama_state_size = llama_state_size self.llama_state_size = llama_state_size
LogitsProcessor = Callable[[List[int], List[float]], List[float]]
class LogitsProcessorList(List[LogitsProcessor]):
def __call__(self, input_ids: List[int], scores: List[float]) -> List[float]:
for processor in self:
scores = processor(input_ids, scores)
return scores
StoppingCriteria = Callable[[List[int], List[float]], bool]
class StoppingCriteriaList(List[StoppingCriteria]):
def __call__(self, input_ids: List[int], logits: List[float]) -> bool:
return any([stopping_criteria(input_ids, logits) for stopping_criteria in self])
class Llama: class Llama:
"""High-level Python wrapper for a llama.cpp model.""" """High-level Python wrapper for a llama.cpp model."""
@ -316,12 +344,10 @@ class Llama:
mirostat_tau: llama_cpp.c_float, mirostat_tau: llama_cpp.c_float,
mirostat_eta: llama_cpp.c_float, mirostat_eta: llama_cpp.c_float,
penalize_nl: bool = True, penalize_nl: bool = True,
logits_processors: List[Callable[[List[int], List[float]], List[float]]] = None logits_processor: Optional[LogitsProcessorList] = None,
): ):
assert self.ctx is not None assert self.ctx is not None
assert len(self.eval_logits) > 0 assert len(self.eval_logits) > 0
if logits_processors is None:
logits_processors = []
n_vocab = self.n_vocab() n_vocab = self.n_vocab()
n_ctx = self.n_ctx() n_ctx = self.n_ctx()
@ -332,10 +358,10 @@ class Llama:
else last_n_tokens_size else last_n_tokens_size
) )
logits = self.eval_logits[-1] logits = self.eval_logits[-1]
for processor in logits_processors:
logits = processor(list(self.eval_tokens), logits)
self.eval_logits[-1] = logits if logits_processor is not None:
logits = logits_processor(list(self.eval_tokens), logits)
nl_logit = logits[self._token_nl] nl_logit = logits[self._token_nl]
candidates = self._candidates candidates = self._candidates
for i, logit in enumerate(logits): for i, logit in enumerate(logits):
@ -444,8 +470,7 @@ class Llama:
mirostat_eta: float = 0.1, mirostat_eta: float = 0.1,
mirostat_tau: float = 5.0, mirostat_tau: float = 5.0,
penalize_nl: bool = True, penalize_nl: bool = True,
logits_processors: List[Callable[[List[int], List[float]], List[float]]] = None logits_processor: Optional[LogitsProcessorList] = None,
): ):
"""Sample a token from the model. """Sample a token from the model.
@ -478,8 +503,7 @@ class Llama:
mirostat_tau=llama_cpp.c_float(mirostat_tau), mirostat_tau=llama_cpp.c_float(mirostat_tau),
mirostat_eta=llama_cpp.c_float(mirostat_eta), mirostat_eta=llama_cpp.c_float(mirostat_eta),
penalize_nl=penalize_nl, penalize_nl=penalize_nl,
logits_processors=logits_processors logits_processor=logits_processor,
) )
def generate( def generate(
@ -496,7 +520,8 @@ class Llama:
mirostat_mode: int = 0, mirostat_mode: int = 0,
mirostat_tau: float = 5.0, mirostat_tau: float = 5.0,
mirostat_eta: float = 0.1, mirostat_eta: float = 0.1,
logits_processors: List[Callable[[List[int], List[float]], List[float]]] = None logits_processor: Optional[LogitsProcessorList] = None,
stopping_criteria: Optional[StoppingCriteriaList] = None,
) -> Generator[int, Optional[Sequence[int]], None]: ) -> Generator[int, Optional[Sequence[int]], None]:
"""Create a generator of tokens from a prompt. """Create a generator of tokens from a prompt.
@ -554,8 +579,12 @@ class Llama:
mirostat_mode=mirostat_mode, mirostat_mode=mirostat_mode,
mirostat_tau=mirostat_tau, mirostat_tau=mirostat_tau,
mirostat_eta=mirostat_eta, mirostat_eta=mirostat_eta,
logits_processors=logits_processors logits_processor=logits_processor,
) )
if stopping_criteria is not None and stopping_criteria(
list(self.eval_tokens), self.eval_logits[-1]
):
return
tokens_or_none = yield token tokens_or_none = yield token
tokens = [token] tokens = [token]
if tokens_or_none is not None: if tokens_or_none is not None:
@ -651,14 +680,9 @@ class Llama:
mirostat_tau: float = 5.0, mirostat_tau: float = 5.0,
mirostat_eta: float = 0.1, mirostat_eta: float = 0.1,
model: Optional[str] = None, model: Optional[str] = None,
logits_processors: List[Callable[[List[int], List[float]], List[float]]] = None,
stopping_criterias: List[Callable[[List[int], List[float]], bool]] = None,
) -> Union[Iterator[Completion], Iterator[CompletionChunk]]: ) -> Union[Iterator[Completion], Iterator[CompletionChunk]]:
assert self.ctx is not None assert self.ctx is not None
if stopping_criterias is None:
stopping_criterias = []
completion_id: str = f"cmpl-{str(uuid.uuid4())}" completion_id: str = f"cmpl-{str(uuid.uuid4())}"
created: int = int(time.time()) created: int = int(time.time())
completion_tokens: List[int] = [] completion_tokens: List[int] = []
@ -720,7 +744,6 @@ class Llama:
frequency_penalty=frequency_penalty, frequency_penalty=frequency_penalty,
presence_penalty=presence_penalty, presence_penalty=presence_penalty,
repeat_penalty=repeat_penalty, repeat_penalty=repeat_penalty,
logits_processors=logits_processors
): ):
if token == self._token_eos: if token == self._token_eos:
text = self.detokenize(completion_tokens) text = self.detokenize(completion_tokens)
@ -728,14 +751,6 @@ class Llama:
break break
completion_tokens.append(token) completion_tokens.append(token)
for stopping_crit in stopping_criterias:
if stopping_crit(completion_tokens, None):
text = self.detokenize(completion_tokens)
finish_reason = "stop"
break
if finish_reason == "stop":
break
all_text = self.detokenize(completion_tokens) all_text = self.detokenize(completion_tokens)
@ -1035,8 +1050,6 @@ class Llama:
mirostat_tau: float = 5.0, mirostat_tau: float = 5.0,
mirostat_eta: float = 0.1, mirostat_eta: float = 0.1,
model: Optional[str] = None, model: Optional[str] = None,
logits_processors: List[Callable[[List[int], List[float]], List[float]]] = None,
stopping_criterias: List[Callable[[List[int], List[float]], bool]] = None
) -> Union[Completion, Iterator[CompletionChunk]]: ) -> Union[Completion, Iterator[CompletionChunk]]:
"""Generate text from a prompt. """Generate text from a prompt.
@ -1079,9 +1092,6 @@ class Llama:
mirostat_tau=mirostat_tau, mirostat_tau=mirostat_tau,
mirostat_eta=mirostat_eta, mirostat_eta=mirostat_eta,
model=model, model=model,
logits_processors=logits_processors,
stopping_criterias=stopping_criterias
) )
if stream: if stream:
chunks: Iterator[CompletionChunk] = completion_or_chunks chunks: Iterator[CompletionChunk] = completion_or_chunks