diff --git a/llama_cpp/llama.py b/llama_cpp/llama.py index c5869ed..849e775 100644 --- a/llama_cpp/llama.py +++ b/llama_cpp/llama.py @@ -207,7 +207,6 @@ class Llama: n_ctx: int = 512, n_parts: int = -1, n_gpu_layers: int = 0, - tensor_split: list[float] = None, seed: int = 1337, f16_kv: bool = True, logits_all: bool = False, @@ -221,6 +220,7 @@ class Llama: lora_base: Optional[str] = None, lora_path: Optional[str] = None, low_vram: bool = False, + tensor_split: Optional[List[float]] = None, verbose: bool = True, ): """Load a llama.cpp model from `model_path`. @@ -241,6 +241,7 @@ class Llama: last_n_tokens_size: Maximum number of tokens to keep in the last_n_tokens deque. lora_base: Optional path to base model, useful if using a quantized base model and you want to apply LoRA to an f16 model. lora_path: Path to a LoRA file to apply to the model. + tensor_split: List of floats to split the model across multiple GPUs. If None, the model is not split. verbose: Print verbose output to stderr. Raises: @@ -249,12 +250,6 @@ class Llama: Returns: A Llama instance. """ - if tensor_split is None: - tensor_split = [0.0] * llama_cpp.LLAMA_MAX_DEVICES.value - - #Type conversion and expand the list to the length of LLAMA_MAX_DEVICES - FloatArray = ctypes.c_float * llama_cpp.LLAMA_MAX_DEVICES.value - c_tensor_split = FloatArray(*tensor_split) self.verbose = verbose self.model_path = model_path @@ -262,7 +257,6 @@ class Llama: self.params = llama_cpp.llama_context_default_params() self.params.n_ctx = n_ctx self.params.n_gpu_layers = n_gpu_layers - self.params.tensor_split = c_tensor_split self.params.seed = seed self.params.f16_kv = f16_kv self.params.logits_all = logits_all @@ -272,6 +266,15 @@ class Llama: self.params.embedding = embedding self.params.low_vram = low_vram + self.tensor_split = tensor_split + self._c_tensor_split = None + + if self.tensor_split is not None: + #Type conversion and expand the list to the length of LLAMA_MAX_DEVICES + FloatArray = ctypes.c_float * llama_cpp.LLAMA_MAX_DEVICES.value + self._c_tensor_split = FloatArray(*tensor_split) # keep a reference to the array so it is not gc'd + self.params.tensor_split = self._c_tensor_split + self.last_n_tokens_size = last_n_tokens_size self.n_batch = min(n_ctx, n_batch) @@ -1499,7 +1502,6 @@ class Llama: model_path=self.model_path, n_ctx=self.params.n_ctx, n_gpu_layers=self.params.n_gpu_layers, - tensor_split=self.params.tensor_split, seed=self.params.seed, f16_kv=self.params.f16_kv, logits_all=self.params.logits_all, @@ -1513,6 +1515,7 @@ class Llama: n_threads=self.n_threads, lora_base=self.lora_base, lora_path=self.lora_path, + tensor_split=self.tensor_split, ### DEPRECATED ### n_parts=self.n_parts, ### DEPRECATED ### @@ -1524,7 +1527,6 @@ class Llama: n_ctx=state["n_ctx"], n_parts=state["n_parts"], n_gpu_layers=state["n_gpu_layers"], - tensor_split=state["tensor_split"], seed=state["seed"], f16_kv=state["f16_kv"], logits_all=state["logits_all"], @@ -1538,6 +1540,7 @@ class Llama: last_n_tokens_size=state["last_n_tokens_size"], lora_base=state["lora_base"], lora_path=state["lora_path"], + tensor_split=state["tensor_split"], verbose=state["verbose"], )