Add validation for tensor_split size exceeding LLAMA_MAX_DEVICES (#820)
* Add validation for tensor_split size exceeding LLAMA_MAX_DEVICES * reword
This commit is contained in:
parent
f30aa20126
commit
b50166500e
1 changed files with 2 additions and 0 deletions
|
@ -308,6 +308,8 @@ class Llama:
|
||||||
self.tensor_split = tensor_split
|
self.tensor_split = tensor_split
|
||||||
self._p_tensor_split = None
|
self._p_tensor_split = None
|
||||||
if self.tensor_split is not None:
|
if self.tensor_split is not None:
|
||||||
|
if len(self.tensor_split) > llama_cpp.LLAMA_MAX_DEVICES:
|
||||||
|
raise ValueError(f"Attempt to split tensors that exceed maximum supported devices. Current LLAMA_MAX_DEVICES={llama_cpp.LLAMA_MAX_DEVICES}")
|
||||||
# Type conversion and expand the list to the length of LLAMA_MAX_DEVICES
|
# Type conversion and expand the list to the length of LLAMA_MAX_DEVICES
|
||||||
FloatArray = ctypes.c_float * llama_cpp.LLAMA_MAX_DEVICES
|
FloatArray = ctypes.c_float * llama_cpp.LLAMA_MAX_DEVICES
|
||||||
self._c_tensor_split = FloatArray(
|
self._c_tensor_split = FloatArray(
|
||||||
|
|
Loading…
Reference in a new issue