Minor fix to tensor_split parameter

This commit is contained in:
Andrei Betlen 2023-07-14 16:40:53 -04:00
parent e6c67c8f7d
commit 25b3494e11

View file

@ -207,7 +207,6 @@ class Llama:
n_ctx: int = 512, n_ctx: int = 512,
n_parts: int = -1, n_parts: int = -1,
n_gpu_layers: int = 0, n_gpu_layers: int = 0,
tensor_split: list[float] = None,
seed: int = 1337, seed: int = 1337,
f16_kv: bool = True, f16_kv: bool = True,
logits_all: bool = False, logits_all: bool = False,
@ -221,6 +220,7 @@ class Llama:
lora_base: Optional[str] = None, lora_base: Optional[str] = None,
lora_path: Optional[str] = None, lora_path: Optional[str] = None,
low_vram: bool = False, low_vram: bool = False,
tensor_split: Optional[List[float]] = None,
verbose: bool = True, verbose: bool = True,
): ):
"""Load a llama.cpp model from `model_path`. """Load a llama.cpp model from `model_path`.
@ -241,6 +241,7 @@ class Llama:
last_n_tokens_size: Maximum number of tokens to keep in the last_n_tokens deque. last_n_tokens_size: Maximum number of tokens to keep in the last_n_tokens deque.
lora_base: Optional path to base model, useful if using a quantized base model and you want to apply LoRA to an f16 model. lora_base: Optional path to base model, useful if using a quantized base model and you want to apply LoRA to an f16 model.
lora_path: Path to a LoRA file to apply to the model. lora_path: Path to a LoRA file to apply to the model.
tensor_split: List of floats to split the model across multiple GPUs. If None, the model is not split.
verbose: Print verbose output to stderr. verbose: Print verbose output to stderr.
Raises: Raises:
@ -249,12 +250,6 @@ class Llama:
Returns: Returns:
A Llama instance. A Llama instance.
""" """
if tensor_split is None:
tensor_split = [0.0] * llama_cpp.LLAMA_MAX_DEVICES.value
#Type conversion and expand the list to the length of LLAMA_MAX_DEVICES
FloatArray = ctypes.c_float * llama_cpp.LLAMA_MAX_DEVICES.value
c_tensor_split = FloatArray(*tensor_split)
self.verbose = verbose self.verbose = verbose
self.model_path = model_path self.model_path = model_path
@ -262,7 +257,6 @@ class Llama:
self.params = llama_cpp.llama_context_default_params() self.params = llama_cpp.llama_context_default_params()
self.params.n_ctx = n_ctx self.params.n_ctx = n_ctx
self.params.n_gpu_layers = n_gpu_layers self.params.n_gpu_layers = n_gpu_layers
self.params.tensor_split = c_tensor_split
self.params.seed = seed self.params.seed = seed
self.params.f16_kv = f16_kv self.params.f16_kv = f16_kv
self.params.logits_all = logits_all self.params.logits_all = logits_all
@ -272,6 +266,15 @@ class Llama:
self.params.embedding = embedding self.params.embedding = embedding
self.params.low_vram = low_vram self.params.low_vram = low_vram
self.tensor_split = tensor_split
self._c_tensor_split = None
if self.tensor_split is not None:
#Type conversion and expand the list to the length of LLAMA_MAX_DEVICES
FloatArray = ctypes.c_float * llama_cpp.LLAMA_MAX_DEVICES.value
self._c_tensor_split = FloatArray(*tensor_split) # keep a reference to the array so it is not gc'd
self.params.tensor_split = self._c_tensor_split
self.last_n_tokens_size = last_n_tokens_size self.last_n_tokens_size = last_n_tokens_size
self.n_batch = min(n_ctx, n_batch) self.n_batch = min(n_ctx, n_batch)
@ -1499,7 +1502,6 @@ class Llama:
model_path=self.model_path, model_path=self.model_path,
n_ctx=self.params.n_ctx, n_ctx=self.params.n_ctx,
n_gpu_layers=self.params.n_gpu_layers, n_gpu_layers=self.params.n_gpu_layers,
tensor_split=self.params.tensor_split,
seed=self.params.seed, seed=self.params.seed,
f16_kv=self.params.f16_kv, f16_kv=self.params.f16_kv,
logits_all=self.params.logits_all, logits_all=self.params.logits_all,
@ -1513,6 +1515,7 @@ class Llama:
n_threads=self.n_threads, n_threads=self.n_threads,
lora_base=self.lora_base, lora_base=self.lora_base,
lora_path=self.lora_path, lora_path=self.lora_path,
tensor_split=self.tensor_split,
### DEPRECATED ### ### DEPRECATED ###
n_parts=self.n_parts, n_parts=self.n_parts,
### DEPRECATED ### ### DEPRECATED ###
@ -1524,7 +1527,6 @@ class Llama:
n_ctx=state["n_ctx"], n_ctx=state["n_ctx"],
n_parts=state["n_parts"], n_parts=state["n_parts"],
n_gpu_layers=state["n_gpu_layers"], n_gpu_layers=state["n_gpu_layers"],
tensor_split=state["tensor_split"],
seed=state["seed"], seed=state["seed"],
f16_kv=state["f16_kv"], f16_kv=state["f16_kv"],
logits_all=state["logits_all"], logits_all=state["logits_all"],
@ -1538,6 +1540,7 @@ class Llama:
last_n_tokens_size=state["last_n_tokens_size"], last_n_tokens_size=state["last_n_tokens_size"],
lora_base=state["lora_base"], lora_base=state["lora_base"],
lora_path=state["lora_path"], lora_path=state["lora_path"],
tensor_split=state["tensor_split"],
verbose=state["verbose"], verbose=state["verbose"],
) )