From 426dbfe3f4518114a9a8d8ceb80146c89e56aee3 Mon Sep 17 00:00:00 2001 From: Shouyi Wang Date: Tue, 25 Jul 2023 18:29:59 +1000 Subject: [PATCH] Change tensor_split from array to pointer --- llama_cpp/llama.py | 9 ++++----- 1 file changed, 4 insertions(+), 5 deletions(-) diff --git a/llama_cpp/llama.py b/llama_cpp/llama.py index 9679b2e..66c76c9 100644 --- a/llama_cpp/llama.py +++ b/llama_cpp/llama.py @@ -273,13 +273,12 @@ class Llama: self.params.low_vram = low_vram self.tensor_split = tensor_split - self._c_tensor_split = None + self._p_tensor_split = None if self.tensor_split is not None: - #Type conversion and expand the list to the length of LLAMA_MAX_DEVICES - FloatArray = ctypes.c_float * llama_cpp.LLAMA_MAX_DEVICES.value - self._c_tensor_split = FloatArray(*tensor_split) # keep a reference to the array so it is not gc'd - self.params.tensor_split = self._c_tensor_split + FloatArray = (ctypes.c_float * len(self.tensor_split))(*self.tensor_split) + self._p_tensor_split = ctypes.POINTER(ctypes.c_float)(FloatArray) # keep a reference to the array so it is not gc'd + self.params.tensor_split = self._p_tensor_split self.params.rope_freq_base = rope_freq_base self.params.rope_freq_scale = rope_freq_scale