Implement penalize_nl
This commit is contained in:
parent
f11e2a781c
commit
d28b753ed2
1 changed files with 11 additions and 0 deletions
|
@ -291,6 +291,7 @@ class Llama:
|
|||
mirostat_mode: llama_cpp.c_int,
|
||||
mirostat_tau: llama_cpp.c_float,
|
||||
mirostat_eta: llama_cpp.c_float,
|
||||
penalize_nl: bool = True,
|
||||
):
|
||||
assert self.ctx is not None
|
||||
assert len(self.eval_logits) > 0
|
||||
|
@ -299,6 +300,7 @@ class Llama:
|
|||
top_k = llama_cpp.c_int(n_vocab) if top_k.value <= 0 else top_k
|
||||
last_n_tokens_size = llama_cpp.c_int(n_ctx) if last_n_tokens_size.value < 0 else last_n_tokens_size
|
||||
logits = self.eval_logits[-1]
|
||||
nl_logit = logits[llama_cpp.llama_token_nl().value]
|
||||
data = (llama_cpp.llama_token_data * n_vocab)(
|
||||
*[
|
||||
llama_cpp.llama_token_data(
|
||||
|
@ -331,6 +333,8 @@ class Llama:
|
|||
alpha_frequency=frequency_penalty,
|
||||
alpha_presence=presence_penalty,
|
||||
)
|
||||
if not penalize_nl:
|
||||
candidates.data[llama_cpp.llama_token_nl().value].logit = nl_logit
|
||||
if temp.value == 0.0:
|
||||
return llama_cpp.llama_sample_token_greedy(
|
||||
ctx=self.ctx,
|
||||
|
@ -413,6 +417,7 @@ class Llama:
|
|||
mirostat_mode: int = 0,
|
||||
mirostat_eta: float = 0.1,
|
||||
mirostat_tau: float = 5.0,
|
||||
penalize_nl: bool = True,
|
||||
):
|
||||
"""Sample a token from the model.
|
||||
|
||||
|
@ -444,6 +449,7 @@ class Llama:
|
|||
mirostat_mode=llama_cpp.c_int(mirostat_mode),
|
||||
mirostat_tau=llama_cpp.c_float(mirostat_tau),
|
||||
mirostat_eta=llama_cpp.c_float(mirostat_eta),
|
||||
penalize_nl=penalize_nl,
|
||||
)
|
||||
|
||||
def generate(
|
||||
|
@ -1170,6 +1176,11 @@ class Llama:
|
|||
"""Return the beginning-of-sequence token."""
|
||||
return llama_cpp.llama_token_bos()
|
||||
|
||||
@staticmethod
|
||||
def token_nl() -> llama_cpp.llama_token:
|
||||
"""Return the newline token."""
|
||||
return llama_cpp.llama_token_nl()
|
||||
|
||||
@staticmethod
|
||||
def logits_to_logprobs(logits: List[float]) -> List[float]:
|
||||
exps = [math.exp(float(x)) for x in logits]
|
||||
|
|
Loading…
Reference in a new issue