Implemented logit processors and stop criteria's
This commit is contained in:
parent
e5d596e0e9
commit
5bb780d455
1 changed files with 27 additions and 0 deletions
|
@ -316,6 +316,7 @@ class Llama:
|
|||
mirostat_tau: llama_cpp.c_float,
|
||||
mirostat_eta: llama_cpp.c_float,
|
||||
penalize_nl: bool = True,
|
||||
logits_processors=None
|
||||
):
|
||||
assert self.ctx is not None
|
||||
assert len(self.eval_logits) > 0
|
||||
|
@ -328,6 +329,10 @@ class Llama:
|
|||
else last_n_tokens_size
|
||||
)
|
||||
logits = self.eval_logits[-1]
|
||||
for processor in logits_processors:
|
||||
logits = processor(list(self.eval_tokens), logits)
|
||||
|
||||
self.eval_logits[-1] = logits
|
||||
nl_logit = logits[self._token_nl]
|
||||
candidates = self._candidates
|
||||
for i, logit in enumerate(logits):
|
||||
|
@ -436,6 +441,8 @@ class Llama:
|
|||
mirostat_eta: float = 0.1,
|
||||
mirostat_tau: float = 5.0,
|
||||
penalize_nl: bool = True,
|
||||
logits_processors=None
|
||||
|
||||
):
|
||||
"""Sample a token from the model.
|
||||
|
||||
|
@ -468,6 +475,8 @@ class Llama:
|
|||
mirostat_tau=llama_cpp.c_float(mirostat_tau),
|
||||
mirostat_eta=llama_cpp.c_float(mirostat_eta),
|
||||
penalize_nl=penalize_nl,
|
||||
logits_processors=logits_processors
|
||||
|
||||
)
|
||||
|
||||
def generate(
|
||||
|
@ -484,6 +493,7 @@ class Llama:
|
|||
mirostat_mode: int = 0,
|
||||
mirostat_tau: float = 5.0,
|
||||
mirostat_eta: float = 0.1,
|
||||
logits_processors=None
|
||||
) -> Generator[int, Optional[Sequence[int]], None]:
|
||||
"""Create a generator of tokens from a prompt.
|
||||
|
||||
|
@ -541,6 +551,7 @@ class Llama:
|
|||
mirostat_mode=mirostat_mode,
|
||||
mirostat_tau=mirostat_tau,
|
||||
mirostat_eta=mirostat_eta,
|
||||
logits_processors=logits_processors
|
||||
)
|
||||
tokens_or_none = yield token
|
||||
tokens = [token]
|
||||
|
@ -637,6 +648,8 @@ class Llama:
|
|||
mirostat_tau: float = 5.0,
|
||||
mirostat_eta: float = 0.1,
|
||||
model: Optional[str] = None,
|
||||
logits_processors=None,
|
||||
stopping_criterias=None
|
||||
) -> Union[Iterator[Completion], Iterator[CompletionChunk]]:
|
||||
assert self.ctx is not None
|
||||
completion_id: str = f"cmpl-{str(uuid.uuid4())}"
|
||||
|
@ -700,6 +713,7 @@ class Llama:
|
|||
frequency_penalty=frequency_penalty,
|
||||
presence_penalty=presence_penalty,
|
||||
repeat_penalty=repeat_penalty,
|
||||
logits_processors=logits_processors
|
||||
):
|
||||
if token == self._token_eos:
|
||||
text = self.detokenize(completion_tokens)
|
||||
|
@ -707,6 +721,14 @@ class Llama:
|
|||
break
|
||||
|
||||
completion_tokens.append(token)
|
||||
for stopping_crit in stopping_criterias:
|
||||
if stopping_crit(completion_tokens, None):
|
||||
text = self.detokenize(completion_tokens)
|
||||
finish_reason = "stop"
|
||||
break
|
||||
|
||||
if finish_reason == "stop":
|
||||
break
|
||||
|
||||
all_text = self.detokenize(completion_tokens)
|
||||
|
||||
|
@ -1006,6 +1028,8 @@ class Llama:
|
|||
mirostat_tau: float = 5.0,
|
||||
mirostat_eta: float = 0.1,
|
||||
model: Optional[str] = None,
|
||||
logits_processors=None,
|
||||
stopping_criterias=None
|
||||
) -> Union[Completion, Iterator[CompletionChunk]]:
|
||||
"""Generate text from a prompt.
|
||||
|
||||
|
@ -1048,6 +1072,9 @@ class Llama:
|
|||
mirostat_tau=mirostat_tau,
|
||||
mirostat_eta=mirostat_eta,
|
||||
model=model,
|
||||
logits_processors=logits_processors,
|
||||
stopping_criterias=stopping_criterias
|
||||
|
||||
)
|
||||
if stream:
|
||||
chunks: Iterator[CompletionChunk] = completion_or_chunks
|
||||
|
|
Loading…
Reference in a new issue