Add StoppingCriteria and LogitsProcessor to generate to match huggingface API
This commit is contained in:
parent
c6a9659972
commit
1d247e0f35
1 changed files with 42 additions and 32 deletions
|
@ -4,7 +4,17 @@ import uuid
|
||||||
import time
|
import time
|
||||||
import math
|
import math
|
||||||
import multiprocessing
|
import multiprocessing
|
||||||
from typing import List, Optional, Union, Generator, Sequence, Iterator, Deque, Tuple, Callable
|
from typing import (
|
||||||
|
List,
|
||||||
|
Optional,
|
||||||
|
Union,
|
||||||
|
Generator,
|
||||||
|
Sequence,
|
||||||
|
Iterator,
|
||||||
|
Deque,
|
||||||
|
Tuple,
|
||||||
|
Callable,
|
||||||
|
)
|
||||||
from collections import deque, OrderedDict
|
from collections import deque, OrderedDict
|
||||||
|
|
||||||
from . import llama_cpp
|
from . import llama_cpp
|
||||||
|
@ -72,6 +82,24 @@ class LlamaState:
|
||||||
self.llama_state_size = llama_state_size
|
self.llama_state_size = llama_state_size
|
||||||
|
|
||||||
|
|
||||||
|
LogitsProcessor = Callable[[List[int], List[float]], List[float]]
|
||||||
|
|
||||||
|
|
||||||
|
class LogitsProcessorList(List[LogitsProcessor]):
|
||||||
|
def __call__(self, input_ids: List[int], scores: List[float]) -> List[float]:
|
||||||
|
for processor in self:
|
||||||
|
scores = processor(input_ids, scores)
|
||||||
|
return scores
|
||||||
|
|
||||||
|
|
||||||
|
StoppingCriteria = Callable[[List[int], List[float]], bool]
|
||||||
|
|
||||||
|
|
||||||
|
class StoppingCriteriaList(List[StoppingCriteria]):
|
||||||
|
def __call__(self, input_ids: List[int], logits: List[float]) -> bool:
|
||||||
|
return any([stopping_criteria(input_ids, logits) for stopping_criteria in self])
|
||||||
|
|
||||||
|
|
||||||
class Llama:
|
class Llama:
|
||||||
"""High-level Python wrapper for a llama.cpp model."""
|
"""High-level Python wrapper for a llama.cpp model."""
|
||||||
|
|
||||||
|
@ -316,12 +344,10 @@ class Llama:
|
||||||
mirostat_tau: llama_cpp.c_float,
|
mirostat_tau: llama_cpp.c_float,
|
||||||
mirostat_eta: llama_cpp.c_float,
|
mirostat_eta: llama_cpp.c_float,
|
||||||
penalize_nl: bool = True,
|
penalize_nl: bool = True,
|
||||||
logits_processors: List[Callable[[List[int], List[float]], List[float]]] = None
|
logits_processor: Optional[LogitsProcessorList] = None,
|
||||||
):
|
):
|
||||||
assert self.ctx is not None
|
assert self.ctx is not None
|
||||||
assert len(self.eval_logits) > 0
|
assert len(self.eval_logits) > 0
|
||||||
if logits_processors is None:
|
|
||||||
logits_processors = []
|
|
||||||
|
|
||||||
n_vocab = self.n_vocab()
|
n_vocab = self.n_vocab()
|
||||||
n_ctx = self.n_ctx()
|
n_ctx = self.n_ctx()
|
||||||
|
@ -332,10 +358,10 @@ class Llama:
|
||||||
else last_n_tokens_size
|
else last_n_tokens_size
|
||||||
)
|
)
|
||||||
logits = self.eval_logits[-1]
|
logits = self.eval_logits[-1]
|
||||||
for processor in logits_processors:
|
|
||||||
logits = processor(list(self.eval_tokens), logits)
|
|
||||||
|
|
||||||
self.eval_logits[-1] = logits
|
if logits_processor is not None:
|
||||||
|
logits = logits_processor(list(self.eval_tokens), logits)
|
||||||
|
|
||||||
nl_logit = logits[self._token_nl]
|
nl_logit = logits[self._token_nl]
|
||||||
candidates = self._candidates
|
candidates = self._candidates
|
||||||
for i, logit in enumerate(logits):
|
for i, logit in enumerate(logits):
|
||||||
|
@ -444,8 +470,7 @@ class Llama:
|
||||||
mirostat_eta: float = 0.1,
|
mirostat_eta: float = 0.1,
|
||||||
mirostat_tau: float = 5.0,
|
mirostat_tau: float = 5.0,
|
||||||
penalize_nl: bool = True,
|
penalize_nl: bool = True,
|
||||||
logits_processors: List[Callable[[List[int], List[float]], List[float]]] = None
|
logits_processor: Optional[LogitsProcessorList] = None,
|
||||||
|
|
||||||
):
|
):
|
||||||
"""Sample a token from the model.
|
"""Sample a token from the model.
|
||||||
|
|
||||||
|
@ -478,8 +503,7 @@ class Llama:
|
||||||
mirostat_tau=llama_cpp.c_float(mirostat_tau),
|
mirostat_tau=llama_cpp.c_float(mirostat_tau),
|
||||||
mirostat_eta=llama_cpp.c_float(mirostat_eta),
|
mirostat_eta=llama_cpp.c_float(mirostat_eta),
|
||||||
penalize_nl=penalize_nl,
|
penalize_nl=penalize_nl,
|
||||||
logits_processors=logits_processors
|
logits_processor=logits_processor,
|
||||||
|
|
||||||
)
|
)
|
||||||
|
|
||||||
def generate(
|
def generate(
|
||||||
|
@ -496,7 +520,8 @@ class Llama:
|
||||||
mirostat_mode: int = 0,
|
mirostat_mode: int = 0,
|
||||||
mirostat_tau: float = 5.0,
|
mirostat_tau: float = 5.0,
|
||||||
mirostat_eta: float = 0.1,
|
mirostat_eta: float = 0.1,
|
||||||
logits_processors: List[Callable[[List[int], List[float]], List[float]]] = None
|
logits_processor: Optional[LogitsProcessorList] = None,
|
||||||
|
stopping_criteria: Optional[StoppingCriteriaList] = None,
|
||||||
) -> Generator[int, Optional[Sequence[int]], None]:
|
) -> Generator[int, Optional[Sequence[int]], None]:
|
||||||
"""Create a generator of tokens from a prompt.
|
"""Create a generator of tokens from a prompt.
|
||||||
|
|
||||||
|
@ -554,8 +579,12 @@ class Llama:
|
||||||
mirostat_mode=mirostat_mode,
|
mirostat_mode=mirostat_mode,
|
||||||
mirostat_tau=mirostat_tau,
|
mirostat_tau=mirostat_tau,
|
||||||
mirostat_eta=mirostat_eta,
|
mirostat_eta=mirostat_eta,
|
||||||
logits_processors=logits_processors
|
logits_processor=logits_processor,
|
||||||
)
|
)
|
||||||
|
if stopping_criteria is not None and stopping_criteria(
|
||||||
|
list(self.eval_tokens), self.eval_logits[-1]
|
||||||
|
):
|
||||||
|
return
|
||||||
tokens_or_none = yield token
|
tokens_or_none = yield token
|
||||||
tokens = [token]
|
tokens = [token]
|
||||||
if tokens_or_none is not None:
|
if tokens_or_none is not None:
|
||||||
|
@ -651,14 +680,9 @@ class Llama:
|
||||||
mirostat_tau: float = 5.0,
|
mirostat_tau: float = 5.0,
|
||||||
mirostat_eta: float = 0.1,
|
mirostat_eta: float = 0.1,
|
||||||
model: Optional[str] = None,
|
model: Optional[str] = None,
|
||||||
logits_processors: List[Callable[[List[int], List[float]], List[float]]] = None,
|
|
||||||
stopping_criterias: List[Callable[[List[int], List[float]], bool]] = None,
|
|
||||||
) -> Union[Iterator[Completion], Iterator[CompletionChunk]]:
|
) -> Union[Iterator[Completion], Iterator[CompletionChunk]]:
|
||||||
assert self.ctx is not None
|
assert self.ctx is not None
|
||||||
|
|
||||||
if stopping_criterias is None:
|
|
||||||
stopping_criterias = []
|
|
||||||
|
|
||||||
completion_id: str = f"cmpl-{str(uuid.uuid4())}"
|
completion_id: str = f"cmpl-{str(uuid.uuid4())}"
|
||||||
created: int = int(time.time())
|
created: int = int(time.time())
|
||||||
completion_tokens: List[int] = []
|
completion_tokens: List[int] = []
|
||||||
|
@ -720,7 +744,6 @@ class Llama:
|
||||||
frequency_penalty=frequency_penalty,
|
frequency_penalty=frequency_penalty,
|
||||||
presence_penalty=presence_penalty,
|
presence_penalty=presence_penalty,
|
||||||
repeat_penalty=repeat_penalty,
|
repeat_penalty=repeat_penalty,
|
||||||
logits_processors=logits_processors
|
|
||||||
):
|
):
|
||||||
if token == self._token_eos:
|
if token == self._token_eos:
|
||||||
text = self.detokenize(completion_tokens)
|
text = self.detokenize(completion_tokens)
|
||||||
|
@ -728,14 +751,6 @@ class Llama:
|
||||||
break
|
break
|
||||||
|
|
||||||
completion_tokens.append(token)
|
completion_tokens.append(token)
|
||||||
for stopping_crit in stopping_criterias:
|
|
||||||
if stopping_crit(completion_tokens, None):
|
|
||||||
text = self.detokenize(completion_tokens)
|
|
||||||
finish_reason = "stop"
|
|
||||||
break
|
|
||||||
|
|
||||||
if finish_reason == "stop":
|
|
||||||
break
|
|
||||||
|
|
||||||
all_text = self.detokenize(completion_tokens)
|
all_text = self.detokenize(completion_tokens)
|
||||||
|
|
||||||
|
@ -1035,8 +1050,6 @@ class Llama:
|
||||||
mirostat_tau: float = 5.0,
|
mirostat_tau: float = 5.0,
|
||||||
mirostat_eta: float = 0.1,
|
mirostat_eta: float = 0.1,
|
||||||
model: Optional[str] = None,
|
model: Optional[str] = None,
|
||||||
logits_processors: List[Callable[[List[int], List[float]], List[float]]] = None,
|
|
||||||
stopping_criterias: List[Callable[[List[int], List[float]], bool]] = None
|
|
||||||
) -> Union[Completion, Iterator[CompletionChunk]]:
|
) -> Union[Completion, Iterator[CompletionChunk]]:
|
||||||
"""Generate text from a prompt.
|
"""Generate text from a prompt.
|
||||||
|
|
||||||
|
@ -1079,9 +1092,6 @@ class Llama:
|
||||||
mirostat_tau=mirostat_tau,
|
mirostat_tau=mirostat_tau,
|
||||||
mirostat_eta=mirostat_eta,
|
mirostat_eta=mirostat_eta,
|
||||||
model=model,
|
model=model,
|
||||||
logits_processors=logits_processors,
|
|
||||||
stopping_criterias=stopping_criterias
|
|
||||||
|
|
||||||
)
|
)
|
||||||
if stream:
|
if stream:
|
||||||
chunks: Iterator[CompletionChunk] = completion_or_chunks
|
chunks: Iterator[CompletionChunk] = completion_or_chunks
|
||||||
|
|
Loading…
Reference in a new issue