From fab064ded91209f7f1e3fe5ff5b247db891c446a Mon Sep 17 00:00:00 2001 From: Andrei Betlen Date: Tue, 23 May 2023 17:56:21 -0400 Subject: [PATCH] Remove unnecessary ffi calls --- llama_cpp/llama.py | 38 ++++++++++++++++++-------------------- 1 file changed, 18 insertions(+), 20 deletions(-) diff --git a/llama_cpp/llama.py b/llama_cpp/llama.py index 916fe07..43fa9c7 100644 --- a/llama_cpp/llama.py +++ b/llama_cpp/llama.py @@ -177,19 +177,19 @@ class Llama: if self.verbose: print(llama_cpp.llama_print_system_info().decode("utf-8"), file=sys.stderr) - n_vocab = self.n_vocab() - n_ctx = self.n_ctx() - data = (llama_cpp.llama_token_data * n_vocab)( + self._n_vocab = self.n_vocab() + self._n_ctx = self.n_ctx() + data = (llama_cpp.llama_token_data * self._n_vocab)( *[ llama_cpp.llama_token_data( id=llama_cpp.llama_token(i), logit=llama_cpp.c_float(0.0), p=llama_cpp.c_float(0.0), ) - for i in range(n_vocab) + for i in range(self._n_vocab) ] ) - size = llama_cpp.c_size_t(n_vocab) + size = llama_cpp.c_size_t(self._n_vocab) sorted = False candidates = llama_cpp.llama_token_data_array( data=data, @@ -213,8 +213,8 @@ class Llama: A list of tokens. """ assert self.ctx is not None - n_ctx = llama_cpp.llama_n_ctx(self.ctx) - tokens = (llama_cpp.llama_token * int(n_ctx))() + n_ctx = self._n_ctx + tokens = (llama_cpp.llama_token * n_ctx)() n_tokens = llama_cpp.llama_tokenize( self.ctx, text, @@ -222,9 +222,9 @@ class Llama: llama_cpp.c_int(n_ctx), llama_cpp.c_bool(add_bos), ) - if int(n_tokens) < 0: + if n_tokens < 0: n_tokens = abs(n_tokens) - tokens = (llama_cpp.llama_token * int(n_tokens))() + tokens = (llama_cpp.llama_token * n_tokens)() n_tokens = llama_cpp.llama_tokenize( self.ctx, text, @@ -275,7 +275,7 @@ class Llama: tokens: The list of tokens to evaluate. """ assert self.ctx is not None - n_ctx = int(llama_cpp.llama_n_ctx(self.ctx)) + n_ctx = self._n_ctx for i in range(0, len(tokens), self.n_batch): batch = tokens[i : min(len(tokens), i + self.n_batch)] n_past = min(n_ctx - len(batch), len(self.eval_tokens)) @@ -287,18 +287,16 @@ class Llama: n_past=llama_cpp.c_int(n_past), n_threads=llama_cpp.c_int(self.n_threads), ) - if int(return_code) != 0: + if return_code != 0: raise RuntimeError(f"llama_eval returned {return_code}") # Save tokens self.eval_tokens.extend(batch) # Save logits rows = n_tokens if self.params.logits_all else 1 - n_vocab = llama_cpp.llama_n_vocab(self.ctx) - cols = int(n_vocab) + n_vocab = self._n_vocab + cols = n_vocab logits_view = llama_cpp.llama_get_logits(self.ctx) - logits: List[List[float]] = [ - [logits_view[i * cols + j] for j in range(cols)] for i in range(rows) - ] + logits = [logits_view[i * cols : (i + 1) * cols] for i in range(rows)] self.eval_logits.extend(logits) def _sample( @@ -319,8 +317,8 @@ class Llama: ): assert self.ctx is not None assert len(self.eval_logits) > 0 - n_vocab = self.n_vocab() - n_ctx = self.n_ctx() + n_vocab = self._n_vocab + n_ctx = self._n_ctx top_k = llama_cpp.c_int(n_vocab) if top_k.value <= 0 else top_k last_n_tokens_size = ( llama_cpp.c_int(n_ctx) @@ -654,9 +652,9 @@ class Llama: if self.verbose: llama_cpp.llama_reset_timings(self.ctx) - if len(prompt_tokens) + max_tokens > int(llama_cpp.llama_n_ctx(self.ctx)): + if len(prompt_tokens) + max_tokens > self._n_ctx: raise ValueError( - f"Requested tokens exceed context window of {llama_cpp.llama_n_ctx(self.ctx)}" + f"Requested tokens exceed context window of {self._n_ctx}" ) if stop != []: