diff --git a/llama_cpp/llama.py b/llama_cpp/llama.py index 9679b2e..66c76c9 100644 --- a/llama_cpp/llama.py +++ b/llama_cpp/llama.py @@ -273,13 +273,12 @@ class Llama: self.params.low_vram = low_vram self.tensor_split = tensor_split - self._c_tensor_split = None + self._p_tensor_split = None if self.tensor_split is not None: - #Type conversion and expand the list to the length of LLAMA_MAX_DEVICES - FloatArray = ctypes.c_float * llama_cpp.LLAMA_MAX_DEVICES.value - self._c_tensor_split = FloatArray(*tensor_split) # keep a reference to the array so it is not gc'd - self.params.tensor_split = self._c_tensor_split + FloatArray = (ctypes.c_float * len(self.tensor_split))(*self.tensor_split) + self._p_tensor_split = ctypes.POINTER(ctypes.c_float)(FloatArray) # keep a reference to the array so it is not gc'd + self.params.tensor_split = self._p_tensor_split self.params.rope_freq_base = rope_freq_base self.params.rope_freq_scale = rope_freq_scale