diff --git a/llama_cpp/llama.py b/llama_cpp/llama.py index 44363a8..7e17d36 100644 --- a/llama_cpp/llama.py +++ b/llama_cpp/llama.py @@ -291,6 +291,7 @@ class Llama: mirostat_mode: llama_cpp.c_int, mirostat_tau: llama_cpp.c_float, mirostat_eta: llama_cpp.c_float, + penalize_nl: bool = True, ): assert self.ctx is not None assert len(self.eval_logits) > 0 @@ -299,6 +300,7 @@ class Llama: top_k = llama_cpp.c_int(n_vocab) if top_k.value <= 0 else top_k last_n_tokens_size = llama_cpp.c_int(n_ctx) if last_n_tokens_size.value < 0 else last_n_tokens_size logits = self.eval_logits[-1] + nl_logit = logits[llama_cpp.llama_token_nl().value] data = (llama_cpp.llama_token_data * n_vocab)( *[ llama_cpp.llama_token_data( @@ -331,6 +333,8 @@ class Llama: alpha_frequency=frequency_penalty, alpha_presence=presence_penalty, ) + if not penalize_nl: + candidates.data[llama_cpp.llama_token_nl().value].logit = nl_logit if temp.value == 0.0: return llama_cpp.llama_sample_token_greedy( ctx=self.ctx, @@ -413,6 +417,7 @@ class Llama: mirostat_mode: int = 0, mirostat_eta: float = 0.1, mirostat_tau: float = 5.0, + penalize_nl: bool = True, ): """Sample a token from the model. @@ -444,6 +449,7 @@ class Llama: mirostat_mode=llama_cpp.c_int(mirostat_mode), mirostat_tau=llama_cpp.c_float(mirostat_tau), mirostat_eta=llama_cpp.c_float(mirostat_eta), + penalize_nl=penalize_nl, ) def generate( @@ -1170,6 +1176,11 @@ class Llama: """Return the beginning-of-sequence token.""" return llama_cpp.llama_token_bos() + @staticmethod + def token_nl() -> llama_cpp.llama_token: + """Return the newline token.""" + return llama_cpp.llama_token_nl() + @staticmethod def logits_to_logprobs(logits: List[float]) -> List[float]: exps = [math.exp(float(x)) for x in logits]