Add StoppingCriteria and LogitsProcessor to generate to match huggingface API

This commit is contained in:
Andrei Betlen 2023-05-25 14:04:54 -04:00
parent c6a9659972
commit 1d247e0f35

View file

@ -4,7 +4,17 @@ import uuid
import time
import math
import multiprocessing
from typing import List, Optional, Union, Generator, Sequence, Iterator, Deque, Tuple, Callable
from typing import (
List,
Optional,
Union,
Generator,
Sequence,
Iterator,
Deque,
Tuple,
Callable,
)
from collections import deque, OrderedDict
from . import llama_cpp
@ -72,6 +82,24 @@ class LlamaState:
self.llama_state_size = llama_state_size
LogitsProcessor = Callable[[List[int], List[float]], List[float]]
class LogitsProcessorList(List[LogitsProcessor]):
def __call__(self, input_ids: List[int], scores: List[float]) -> List[float]:
for processor in self:
scores = processor(input_ids, scores)
return scores
StoppingCriteria = Callable[[List[int], List[float]], bool]
class StoppingCriteriaList(List[StoppingCriteria]):
def __call__(self, input_ids: List[int], logits: List[float]) -> bool:
return any([stopping_criteria(input_ids, logits) for stopping_criteria in self])
class Llama:
"""High-level Python wrapper for a llama.cpp model."""
@ -316,12 +344,10 @@ class Llama:
mirostat_tau: llama_cpp.c_float,
mirostat_eta: llama_cpp.c_float,
penalize_nl: bool = True,
logits_processors: List[Callable[[List[int], List[float]], List[float]]] = None
logits_processor: Optional[LogitsProcessorList] = None,
):
assert self.ctx is not None
assert len(self.eval_logits) > 0
if logits_processors is None:
logits_processors = []
n_vocab = self.n_vocab()
n_ctx = self.n_ctx()
@ -332,10 +358,10 @@ class Llama:
else last_n_tokens_size
)
logits = self.eval_logits[-1]
for processor in logits_processors:
logits = processor(list(self.eval_tokens), logits)
self.eval_logits[-1] = logits
if logits_processor is not None:
logits = logits_processor(list(self.eval_tokens), logits)
nl_logit = logits[self._token_nl]
candidates = self._candidates
for i, logit in enumerate(logits):
@ -444,8 +470,7 @@ class Llama:
mirostat_eta: float = 0.1,
mirostat_tau: float = 5.0,
penalize_nl: bool = True,
logits_processors: List[Callable[[List[int], List[float]], List[float]]] = None
logits_processor: Optional[LogitsProcessorList] = None,
):
"""Sample a token from the model.
@ -478,8 +503,7 @@ class Llama:
mirostat_tau=llama_cpp.c_float(mirostat_tau),
mirostat_eta=llama_cpp.c_float(mirostat_eta),
penalize_nl=penalize_nl,
logits_processors=logits_processors
logits_processor=logits_processor,
)
def generate(
@ -496,7 +520,8 @@ class Llama:
mirostat_mode: int = 0,
mirostat_tau: float = 5.0,
mirostat_eta: float = 0.1,
logits_processors: List[Callable[[List[int], List[float]], List[float]]] = None
logits_processor: Optional[LogitsProcessorList] = None,
stopping_criteria: Optional[StoppingCriteriaList] = None,
) -> Generator[int, Optional[Sequence[int]], None]:
"""Create a generator of tokens from a prompt.
@ -554,8 +579,12 @@ class Llama:
mirostat_mode=mirostat_mode,
mirostat_tau=mirostat_tau,
mirostat_eta=mirostat_eta,
logits_processors=logits_processors
logits_processor=logits_processor,
)
if stopping_criteria is not None and stopping_criteria(
list(self.eval_tokens), self.eval_logits[-1]
):
return
tokens_or_none = yield token
tokens = [token]
if tokens_or_none is not None:
@ -651,14 +680,9 @@ class Llama:
mirostat_tau: float = 5.0,
mirostat_eta: float = 0.1,
model: Optional[str] = None,
logits_processors: List[Callable[[List[int], List[float]], List[float]]] = None,
stopping_criterias: List[Callable[[List[int], List[float]], bool]] = None,
) -> Union[Iterator[Completion], Iterator[CompletionChunk]]:
assert self.ctx is not None
if stopping_criterias is None:
stopping_criterias = []
completion_id: str = f"cmpl-{str(uuid.uuid4())}"
created: int = int(time.time())
completion_tokens: List[int] = []
@ -720,7 +744,6 @@ class Llama:
frequency_penalty=frequency_penalty,
presence_penalty=presence_penalty,
repeat_penalty=repeat_penalty,
logits_processors=logits_processors
):
if token == self._token_eos:
text = self.detokenize(completion_tokens)
@ -728,14 +751,6 @@ class Llama:
break
completion_tokens.append(token)
for stopping_crit in stopping_criterias:
if stopping_crit(completion_tokens, None):
text = self.detokenize(completion_tokens)
finish_reason = "stop"
break
if finish_reason == "stop":
break
all_text = self.detokenize(completion_tokens)
@ -1035,8 +1050,6 @@ class Llama:
mirostat_tau: float = 5.0,
mirostat_eta: float = 0.1,
model: Optional[str] = None,
logits_processors: List[Callable[[List[int], List[float]], List[float]]] = None,
stopping_criterias: List[Callable[[List[int], List[float]], bool]] = None
) -> Union[Completion, Iterator[CompletionChunk]]:
"""Generate text from a prompt.
@ -1079,9 +1092,6 @@ class Llama:
mirostat_tau=mirostat_tau,
mirostat_eta=mirostat_eta,
model=model,
logits_processors=logits_processors,
stopping_criterias=stopping_criterias
)
if stream:
chunks: Iterator[CompletionChunk] = completion_or_chunks