Implement penalize_nl

This commit is contained in:
Andrei Betlen 2023-05-17 01:53:26 -04:00
parent f11e2a781c
commit d28b753ed2

View file

@ -291,6 +291,7 @@ class Llama:
mirostat_mode: llama_cpp.c_int,
mirostat_tau: llama_cpp.c_float,
mirostat_eta: llama_cpp.c_float,
penalize_nl: bool = True,
):
assert self.ctx is not None
assert len(self.eval_logits) > 0
@ -299,6 +300,7 @@ class Llama:
top_k = llama_cpp.c_int(n_vocab) if top_k.value <= 0 else top_k
last_n_tokens_size = llama_cpp.c_int(n_ctx) if last_n_tokens_size.value < 0 else last_n_tokens_size
logits = self.eval_logits[-1]
nl_logit = logits[llama_cpp.llama_token_nl().value]
data = (llama_cpp.llama_token_data * n_vocab)(
*[
llama_cpp.llama_token_data(
@ -331,6 +333,8 @@ class Llama:
alpha_frequency=frequency_penalty,
alpha_presence=presence_penalty,
)
if not penalize_nl:
candidates.data[llama_cpp.llama_token_nl().value].logit = nl_logit
if temp.value == 0.0:
return llama_cpp.llama_sample_token_greedy(
ctx=self.ctx,
@ -413,6 +417,7 @@ class Llama:
mirostat_mode: int = 0,
mirostat_eta: float = 0.1,
mirostat_tau: float = 5.0,
penalize_nl: bool = True,
):
"""Sample a token from the model.
@ -444,6 +449,7 @@ class Llama:
mirostat_mode=llama_cpp.c_int(mirostat_mode),
mirostat_tau=llama_cpp.c_float(mirostat_tau),
mirostat_eta=llama_cpp.c_float(mirostat_eta),
penalize_nl=penalize_nl,
)
def generate(
@ -1170,6 +1176,11 @@ class Llama:
"""Return the beginning-of-sequence token."""
return llama_cpp.llama_token_bos()
@staticmethod
def token_nl() -> llama_cpp.llama_token:
"""Return the newline token."""
return llama_cpp.llama_token_nl()
@staticmethod
def logits_to_logprobs(logits: List[float]) -> List[float]:
exps = [math.exp(float(x)) for x in logits]