Implemented logit processors and stop criteria's

This commit is contained in:
Maximilian-Winter 2023-05-24 21:55:44 +02:00
parent e5d596e0e9
commit 5bb780d455

View file

@ -316,6 +316,7 @@ class Llama:
mirostat_tau: llama_cpp.c_float, mirostat_tau: llama_cpp.c_float,
mirostat_eta: llama_cpp.c_float, mirostat_eta: llama_cpp.c_float,
penalize_nl: bool = True, penalize_nl: bool = True,
logits_processors=None
): ):
assert self.ctx is not None assert self.ctx is not None
assert len(self.eval_logits) > 0 assert len(self.eval_logits) > 0
@ -328,6 +329,10 @@ class Llama:
else last_n_tokens_size else last_n_tokens_size
) )
logits = self.eval_logits[-1] logits = self.eval_logits[-1]
for processor in logits_processors:
logits = processor(list(self.eval_tokens), logits)
self.eval_logits[-1] = logits
nl_logit = logits[self._token_nl] nl_logit = logits[self._token_nl]
candidates = self._candidates candidates = self._candidates
for i, logit in enumerate(logits): for i, logit in enumerate(logits):
@ -436,6 +441,8 @@ class Llama:
mirostat_eta: float = 0.1, mirostat_eta: float = 0.1,
mirostat_tau: float = 5.0, mirostat_tau: float = 5.0,
penalize_nl: bool = True, penalize_nl: bool = True,
logits_processors=None
): ):
"""Sample a token from the model. """Sample a token from the model.
@ -468,6 +475,8 @@ class Llama:
mirostat_tau=llama_cpp.c_float(mirostat_tau), mirostat_tau=llama_cpp.c_float(mirostat_tau),
mirostat_eta=llama_cpp.c_float(mirostat_eta), mirostat_eta=llama_cpp.c_float(mirostat_eta),
penalize_nl=penalize_nl, penalize_nl=penalize_nl,
logits_processors=logits_processors
) )
def generate( def generate(
@ -484,6 +493,7 @@ class Llama:
mirostat_mode: int = 0, mirostat_mode: int = 0,
mirostat_tau: float = 5.0, mirostat_tau: float = 5.0,
mirostat_eta: float = 0.1, mirostat_eta: float = 0.1,
logits_processors=None
) -> Generator[int, Optional[Sequence[int]], None]: ) -> Generator[int, Optional[Sequence[int]], None]:
"""Create a generator of tokens from a prompt. """Create a generator of tokens from a prompt.
@ -541,6 +551,7 @@ class Llama:
mirostat_mode=mirostat_mode, mirostat_mode=mirostat_mode,
mirostat_tau=mirostat_tau, mirostat_tau=mirostat_tau,
mirostat_eta=mirostat_eta, mirostat_eta=mirostat_eta,
logits_processors=logits_processors
) )
tokens_or_none = yield token tokens_or_none = yield token
tokens = [token] tokens = [token]
@ -637,6 +648,8 @@ class Llama:
mirostat_tau: float = 5.0, mirostat_tau: float = 5.0,
mirostat_eta: float = 0.1, mirostat_eta: float = 0.1,
model: Optional[str] = None, model: Optional[str] = None,
logits_processors=None,
stopping_criterias=None
) -> Union[Iterator[Completion], Iterator[CompletionChunk]]: ) -> Union[Iterator[Completion], Iterator[CompletionChunk]]:
assert self.ctx is not None assert self.ctx is not None
completion_id: str = f"cmpl-{str(uuid.uuid4())}" completion_id: str = f"cmpl-{str(uuid.uuid4())}"
@ -700,6 +713,7 @@ class Llama:
frequency_penalty=frequency_penalty, frequency_penalty=frequency_penalty,
presence_penalty=presence_penalty, presence_penalty=presence_penalty,
repeat_penalty=repeat_penalty, repeat_penalty=repeat_penalty,
logits_processors=logits_processors
): ):
if token == self._token_eos: if token == self._token_eos:
text = self.detokenize(completion_tokens) text = self.detokenize(completion_tokens)
@ -707,6 +721,14 @@ class Llama:
break break
completion_tokens.append(token) completion_tokens.append(token)
for stopping_crit in stopping_criterias:
if stopping_crit(completion_tokens, None):
text = self.detokenize(completion_tokens)
finish_reason = "stop"
break
if finish_reason == "stop":
break
all_text = self.detokenize(completion_tokens) all_text = self.detokenize(completion_tokens)
@ -1006,6 +1028,8 @@ class Llama:
mirostat_tau: float = 5.0, mirostat_tau: float = 5.0,
mirostat_eta: float = 0.1, mirostat_eta: float = 0.1,
model: Optional[str] = None, model: Optional[str] = None,
logits_processors=None,
stopping_criterias=None
) -> Union[Completion, Iterator[CompletionChunk]]: ) -> Union[Completion, Iterator[CompletionChunk]]:
"""Generate text from a prompt. """Generate text from a prompt.
@ -1048,6 +1072,9 @@ class Llama:
mirostat_tau=mirostat_tau, mirostat_tau=mirostat_tau,
mirostat_eta=mirostat_eta, mirostat_eta=mirostat_eta,
model=model, model=model,
logits_processors=logits_processors,
stopping_criterias=stopping_criterias
) )
if stream: if stream:
chunks: Iterator[CompletionChunk] = completion_or_chunks chunks: Iterator[CompletionChunk] = completion_or_chunks