Implement penalize_nl
This commit is contained in:
parent
f11e2a781c
commit
d28b753ed2
1 changed files with 11 additions and 0 deletions
|
@ -291,6 +291,7 @@ class Llama:
|
||||||
mirostat_mode: llama_cpp.c_int,
|
mirostat_mode: llama_cpp.c_int,
|
||||||
mirostat_tau: llama_cpp.c_float,
|
mirostat_tau: llama_cpp.c_float,
|
||||||
mirostat_eta: llama_cpp.c_float,
|
mirostat_eta: llama_cpp.c_float,
|
||||||
|
penalize_nl: bool = True,
|
||||||
):
|
):
|
||||||
assert self.ctx is not None
|
assert self.ctx is not None
|
||||||
assert len(self.eval_logits) > 0
|
assert len(self.eval_logits) > 0
|
||||||
|
@ -299,6 +300,7 @@ class Llama:
|
||||||
top_k = llama_cpp.c_int(n_vocab) if top_k.value <= 0 else top_k
|
top_k = llama_cpp.c_int(n_vocab) if top_k.value <= 0 else top_k
|
||||||
last_n_tokens_size = llama_cpp.c_int(n_ctx) if last_n_tokens_size.value < 0 else last_n_tokens_size
|
last_n_tokens_size = llama_cpp.c_int(n_ctx) if last_n_tokens_size.value < 0 else last_n_tokens_size
|
||||||
logits = self.eval_logits[-1]
|
logits = self.eval_logits[-1]
|
||||||
|
nl_logit = logits[llama_cpp.llama_token_nl().value]
|
||||||
data = (llama_cpp.llama_token_data * n_vocab)(
|
data = (llama_cpp.llama_token_data * n_vocab)(
|
||||||
*[
|
*[
|
||||||
llama_cpp.llama_token_data(
|
llama_cpp.llama_token_data(
|
||||||
|
@ -331,6 +333,8 @@ class Llama:
|
||||||
alpha_frequency=frequency_penalty,
|
alpha_frequency=frequency_penalty,
|
||||||
alpha_presence=presence_penalty,
|
alpha_presence=presence_penalty,
|
||||||
)
|
)
|
||||||
|
if not penalize_nl:
|
||||||
|
candidates.data[llama_cpp.llama_token_nl().value].logit = nl_logit
|
||||||
if temp.value == 0.0:
|
if temp.value == 0.0:
|
||||||
return llama_cpp.llama_sample_token_greedy(
|
return llama_cpp.llama_sample_token_greedy(
|
||||||
ctx=self.ctx,
|
ctx=self.ctx,
|
||||||
|
@ -413,6 +417,7 @@ class Llama:
|
||||||
mirostat_mode: int = 0,
|
mirostat_mode: int = 0,
|
||||||
mirostat_eta: float = 0.1,
|
mirostat_eta: float = 0.1,
|
||||||
mirostat_tau: float = 5.0,
|
mirostat_tau: float = 5.0,
|
||||||
|
penalize_nl: bool = True,
|
||||||
):
|
):
|
||||||
"""Sample a token from the model.
|
"""Sample a token from the model.
|
||||||
|
|
||||||
|
@ -444,6 +449,7 @@ class Llama:
|
||||||
mirostat_mode=llama_cpp.c_int(mirostat_mode),
|
mirostat_mode=llama_cpp.c_int(mirostat_mode),
|
||||||
mirostat_tau=llama_cpp.c_float(mirostat_tau),
|
mirostat_tau=llama_cpp.c_float(mirostat_tau),
|
||||||
mirostat_eta=llama_cpp.c_float(mirostat_eta),
|
mirostat_eta=llama_cpp.c_float(mirostat_eta),
|
||||||
|
penalize_nl=penalize_nl,
|
||||||
)
|
)
|
||||||
|
|
||||||
def generate(
|
def generate(
|
||||||
|
@ -1170,6 +1176,11 @@ class Llama:
|
||||||
"""Return the beginning-of-sequence token."""
|
"""Return the beginning-of-sequence token."""
|
||||||
return llama_cpp.llama_token_bos()
|
return llama_cpp.llama_token_bos()
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def token_nl() -> llama_cpp.llama_token:
|
||||||
|
"""Return the newline token."""
|
||||||
|
return llama_cpp.llama_token_nl()
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def logits_to_logprobs(logits: List[float]) -> List[float]:
|
def logits_to_logprobs(logits: List[float]) -> List[float]:
|
||||||
exps = [math.exp(float(x)) for x in logits]
|
exps = [math.exp(float(x)) for x in logits]
|
||||||
|
|
Loading…
Reference in a new issue