Cache shared library function calls for static tokens

This commit is contained in:
Andrei Betlen 2023-05-21 19:18:56 -04:00
parent b895511cca
commit cd102e9da1

View file

@ -198,6 +198,8 @@ class Llama:
sorted=sorted, sorted=sorted,
) )
self._candidates = candidates self._candidates = candidates
self._token_nl = Llama.token_nl()
self._token_eos = Llama.token_eos()
def tokenize(self, text: bytes, add_bos: bool = True) -> List[int]: def tokenize(self, text: bytes, add_bos: bool = True) -> List[int]:
"""Tokenize a string. """Tokenize a string.
@ -327,7 +329,7 @@ class Llama:
else last_n_tokens_size else last_n_tokens_size
) )
logits = self.eval_logits[-1] logits = self.eval_logits[-1]
nl_logit = logits[Llama.token_nl()] nl_logit = logits[self._token_nl]
candidates = self._candidates candidates = self._candidates
for i, logit in enumerate(logits): for i, logit in enumerate(logits):
candidates.data[i].id = llama_cpp.llama_token(i) candidates.data[i].id = llama_cpp.llama_token(i)
@ -351,7 +353,7 @@ class Llama:
alpha_presence=presence_penalty, alpha_presence=presence_penalty,
) )
if not penalize_nl: if not penalize_nl:
candidates.data[Llama.token_nl()].logit = llama_cpp.c_float(nl_logit) candidates.data[self._token_nl].logit = llama_cpp.c_float(nl_logit)
if temp.value == 0.0: if temp.value == 0.0:
return llama_cpp.llama_sample_token_greedy( return llama_cpp.llama_sample_token_greedy(
ctx=self.ctx, ctx=self.ctx,
@ -688,7 +690,7 @@ class Llama:
presence_penalty=presence_penalty, presence_penalty=presence_penalty,
repeat_penalty=repeat_penalty, repeat_penalty=repeat_penalty,
): ):
if token == Llama.token_eos(): if token == self._token_eos:
text = self.detokenize(completion_tokens) text = self.detokenize(completion_tokens)
finish_reason = "stop" finish_reason = "stop"
break break