From cd102e9da1a0e6159e5489f2cab23c207f4916a5 Mon Sep 17 00:00:00 2001 From: Andrei Betlen Date: Sun, 21 May 2023 19:18:56 -0400 Subject: [PATCH] Cache shared library function calls for static tokens --- llama_cpp/llama.py | 8 +++++--- 1 file changed, 5 insertions(+), 3 deletions(-) diff --git a/llama_cpp/llama.py b/llama_cpp/llama.py index 2d405b7..7a152fd 100644 --- a/llama_cpp/llama.py +++ b/llama_cpp/llama.py @@ -198,6 +198,8 @@ class Llama: sorted=sorted, ) self._candidates = candidates + self._token_nl = Llama.token_nl() + self._token_eos = Llama.token_eos() def tokenize(self, text: bytes, add_bos: bool = True) -> List[int]: """Tokenize a string. @@ -327,7 +329,7 @@ class Llama: else last_n_tokens_size ) logits = self.eval_logits[-1] - nl_logit = logits[Llama.token_nl()] + nl_logit = logits[self._token_nl] candidates = self._candidates for i, logit in enumerate(logits): candidates.data[i].id = llama_cpp.llama_token(i) @@ -351,7 +353,7 @@ class Llama: alpha_presence=presence_penalty, ) if not penalize_nl: - candidates.data[Llama.token_nl()].logit = llama_cpp.c_float(nl_logit) + candidates.data[self._token_nl].logit = llama_cpp.c_float(nl_logit) if temp.value == 0.0: return llama_cpp.llama_sample_token_greedy( ctx=self.ctx, @@ -688,7 +690,7 @@ class Llama: presence_penalty=presence_penalty, repeat_penalty=repeat_penalty, ): - if token == Llama.token_eos(): + if token == self._token_eos: text = self.detokenize(completion_tokens) finish_reason = "stop" break