Change tensor_split from array to pointer
This commit is contained in:
parent
c7c700b0d4
commit
426dbfe3f4
1 changed files with 4 additions and 5 deletions
|
@ -273,13 +273,12 @@ class Llama:
|
||||||
self.params.low_vram = low_vram
|
self.params.low_vram = low_vram
|
||||||
|
|
||||||
self.tensor_split = tensor_split
|
self.tensor_split = tensor_split
|
||||||
self._c_tensor_split = None
|
self._p_tensor_split = None
|
||||||
|
|
||||||
if self.tensor_split is not None:
|
if self.tensor_split is not None:
|
||||||
#Type conversion and expand the list to the length of LLAMA_MAX_DEVICES
|
FloatArray = (ctypes.c_float * len(self.tensor_split))(*self.tensor_split)
|
||||||
FloatArray = ctypes.c_float * llama_cpp.LLAMA_MAX_DEVICES.value
|
self._p_tensor_split = ctypes.POINTER(ctypes.c_float)(FloatArray) # keep a reference to the array so it is not gc'd
|
||||||
self._c_tensor_split = FloatArray(*tensor_split) # keep a reference to the array so it is not gc'd
|
self.params.tensor_split = self._p_tensor_split
|
||||||
self.params.tensor_split = self._c_tensor_split
|
|
||||||
|
|
||||||
self.params.rope_freq_base = rope_freq_base
|
self.params.rope_freq_base = rope_freq_base
|
||||||
self.params.rope_freq_scale = rope_freq_scale
|
self.params.rope_freq_scale = rope_freq_scale
|
||||||
|
|
Loading…
Reference in a new issue