This commit is contained in:
Andrei Betlen 2024-04-03 00:55:21 -04:00
commit f96de6d920

View file

@ -730,12 +730,14 @@ class _LlamaSamplingContext:
if len(self.prev) > 0: if len(self.prev) > 0:
nl_token = ctx_main.model.token_nl() nl_token = ctx_main.model.token_nl()
nl_logit = logits_array[nl_token] nl_logit = logits_array[nl_token]
if self.params.penalty_last_n > 0: last_tokens = self.prev[-self.params.penalty_last_n:]
last_tokens_size = min(len(last_tokens), self.params.penalty_last_n)
if last_tokens_size > 0:
last_tokens_p = (llama_cpp.llama_token * len(last_tokens))(*last_tokens)
ctx_main.sample_repetition_penalties( ctx_main.sample_repetition_penalties(
token_data_array, token_data_array,
# TODO: Only create this once last_tokens_p,
(llama_cpp.llama_token * len(self.prev))(*self.prev), last_tokens_size,
self.params.penalty_last_n,
self.params.penalty_repeat, self.params.penalty_repeat,
self.params.penalty_freq, self.params.penalty_freq,
self.params.penalty_present, self.params.penalty_present,