From 03e2947b03beb3954e128783b401ced6143c275d Mon Sep 17 00:00:00 2001 From: Andrei Betlen Date: Sun, 21 May 2023 18:36:34 -0400 Subject: [PATCH] Fix unnecessary memory allocation while sampling --- llama_cpp/llama.py | 54 ++++++++++++++++++++++++++++------------------ 1 file changed, 33 insertions(+), 21 deletions(-) diff --git a/llama_cpp/llama.py b/llama_cpp/llama.py index 332a882..7943084 100644 --- a/llama_cpp/llama.py +++ b/llama_cpp/llama.py @@ -176,6 +176,28 @@ class Llama: if self.verbose: print(llama_cpp.llama_print_system_info().decode("utf-8"), file=sys.stderr) + + + n_vocab = self.n_vocab() + n_ctx = self.n_ctx() + data = (llama_cpp.llama_token_data * n_vocab)( + *[ + llama_cpp.llama_token_data( + id=llama_cpp.llama_token(i), + logit=llama_cpp.c_float(0.0), + p=llama_cpp.c_float(0.0), + ) + for i in range(n_vocab) + ] + ) + size = llama_cpp.c_size_t(n_vocab) + sorted = False + candidates = llama_cpp.llama_token_data_array( + data=data, + size=size, + sorted=sorted, + ) + self._candidates = candidates def tokenize(self, text: bytes, add_bos: bool = True) -> List[int]: """Tokenize a string. @@ -296,8 +318,8 @@ class Llama: ): assert self.ctx is not None assert len(self.eval_logits) > 0 - n_vocab = int(llama_cpp.llama_n_vocab(self.ctx)) - n_ctx = int(llama_cpp.llama_n_ctx(self.ctx)) + n_vocab = self.n_vocab() + n_ctx = self.n_ctx() top_k = llama_cpp.c_int(n_vocab) if top_k.value <= 0 else top_k last_n_tokens_size = ( llama_cpp.c_int(n_ctx) @@ -305,24 +327,14 @@ class Llama: else last_n_tokens_size ) logits = self.eval_logits[-1] - nl_logit = logits[int(Llama.token_nl())] - data = (llama_cpp.llama_token_data * n_vocab)( - *[ - llama_cpp.llama_token_data( - id=llama_cpp.llama_token(i), - logit=logits[i], - p=llama_cpp.c_float(0.0), - ) - for i in range(n_vocab) - ] - ) - size = llama_cpp.c_size_t(n_vocab) - sorted = False - candidates = llama_cpp.llama_token_data_array( - data=data, - size=size, - sorted=sorted, - ) + nl_logit = logits[Llama.token_nl()] + candidates = self._candidates + for i, logit in enumerate(logits): + candidates.data[i].id = llama_cpp.llama_token(i) + candidates.data[i].logit = llama_cpp.c_float(logit) + candidates.data[i].p = llama_cpp.c_float(0.0) + candidates.sorted = llama_cpp.c_bool(False) + candidates.size = llama_cpp.c_size_t(n_vocab) llama_cpp.llama_sample_repetition_penalty( ctx=self.ctx, last_tokens_data=last_n_tokens_data, @@ -339,7 +351,7 @@ class Llama: alpha_presence=presence_penalty, ) if not penalize_nl: - candidates.data[int(Llama.token_nl())].logit = nl_logit + candidates.data[Llama.token_nl()].logit = nl_logit if temp.value == 0.0: return llama_cpp.llama_sample_token_greedy( ctx=self.ctx,