Implement penalize_nl

This commit is contained in:
Andrei Betlen 2023-05-17 01:53:26 -04:00
parent f11e2a781c
commit d28b753ed2

View file

@ -291,6 +291,7 @@ class Llama:
mirostat_mode: llama_cpp.c_int, mirostat_mode: llama_cpp.c_int,
mirostat_tau: llama_cpp.c_float, mirostat_tau: llama_cpp.c_float,
mirostat_eta: llama_cpp.c_float, mirostat_eta: llama_cpp.c_float,
penalize_nl: bool = True,
): ):
assert self.ctx is not None assert self.ctx is not None
assert len(self.eval_logits) > 0 assert len(self.eval_logits) > 0
@ -299,6 +300,7 @@ class Llama:
top_k = llama_cpp.c_int(n_vocab) if top_k.value <= 0 else top_k top_k = llama_cpp.c_int(n_vocab) if top_k.value <= 0 else top_k
last_n_tokens_size = llama_cpp.c_int(n_ctx) if last_n_tokens_size.value < 0 else last_n_tokens_size last_n_tokens_size = llama_cpp.c_int(n_ctx) if last_n_tokens_size.value < 0 else last_n_tokens_size
logits = self.eval_logits[-1] logits = self.eval_logits[-1]
nl_logit = logits[llama_cpp.llama_token_nl().value]
data = (llama_cpp.llama_token_data * n_vocab)( data = (llama_cpp.llama_token_data * n_vocab)(
*[ *[
llama_cpp.llama_token_data( llama_cpp.llama_token_data(
@ -331,6 +333,8 @@ class Llama:
alpha_frequency=frequency_penalty, alpha_frequency=frequency_penalty,
alpha_presence=presence_penalty, alpha_presence=presence_penalty,
) )
if not penalize_nl:
candidates.data[llama_cpp.llama_token_nl().value].logit = nl_logit
if temp.value == 0.0: if temp.value == 0.0:
return llama_cpp.llama_sample_token_greedy( return llama_cpp.llama_sample_token_greedy(
ctx=self.ctx, ctx=self.ctx,
@ -413,6 +417,7 @@ class Llama:
mirostat_mode: int = 0, mirostat_mode: int = 0,
mirostat_eta: float = 0.1, mirostat_eta: float = 0.1,
mirostat_tau: float = 5.0, mirostat_tau: float = 5.0,
penalize_nl: bool = True,
): ):
"""Sample a token from the model. """Sample a token from the model.
@ -444,6 +449,7 @@ class Llama:
mirostat_mode=llama_cpp.c_int(mirostat_mode), mirostat_mode=llama_cpp.c_int(mirostat_mode),
mirostat_tau=llama_cpp.c_float(mirostat_tau), mirostat_tau=llama_cpp.c_float(mirostat_tau),
mirostat_eta=llama_cpp.c_float(mirostat_eta), mirostat_eta=llama_cpp.c_float(mirostat_eta),
penalize_nl=penalize_nl,
) )
def generate( def generate(
@ -1170,6 +1176,11 @@ class Llama:
"""Return the beginning-of-sequence token.""" """Return the beginning-of-sequence token."""
return llama_cpp.llama_token_bos() return llama_cpp.llama_token_bos()
@staticmethod
def token_nl() -> llama_cpp.llama_token:
"""Return the newline token."""
return llama_cpp.llama_token_nl()
@staticmethod @staticmethod
def logits_to_logprobs(logits: List[float]) -> List[float]: def logits_to_logprobs(logits: List[float]) -> List[float]:
exps = [math.exp(float(x)) for x in logits] exps = [math.exp(float(x)) for x in logits]