From 5cae1040e35326dc514552bb9423c3718a37bf44 Mon Sep 17 00:00:00 2001 From: Linghan Zhong Date: Fri, 24 May 2024 00:49:44 -0500 Subject: [PATCH] feat: Improve Llama.eval performance by avoiding list conversion (#1476) Co-authored-by: Andrei --- llama_cpp/llama.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/llama_cpp/llama.py b/llama_cpp/llama.py index 043fb2a..6dad650 100644 --- a/llama_cpp/llama.py +++ b/llama_cpp/llama.py @@ -562,12 +562,12 @@ class Llama: if self.context_params.logits_all: rows = n_tokens cols = self._n_vocab - logits = self._ctx.get_logits()[: rows * cols] + logits = np.ctypeslib.as_array(self._ctx.get_logits(), shape=(rows * cols, )) self.scores[n_past : n_past + n_tokens, :].reshape(-1)[: :] = logits else: rows = 1 cols = self._n_vocab - logits = self._ctx.get_logits()[: rows * cols] + logits = np.ctypeslib.as_array(self._ctx.get_logits(), shape=(rows * cols, )) self.scores[n_past + n_tokens - 1, :].reshape(-1)[: :] = logits # Update n_tokens self.n_tokens += n_tokens