diff --git a/llama_cpp/llama.py b/llama_cpp/llama.py index 7c819b0..6443b6d 100644 --- a/llama_cpp/llama.py +++ b/llama_cpp/llama.py @@ -735,6 +735,7 @@ class Llama: vocab_only: bool = False, use_mmap: bool = True, use_mlock: bool = False, + kv_overrides: Optional[Dict[str, Union[bool, int, float]]] = None, # Context Params seed: int = llama_cpp.LLAMA_DEFAULT_SEED, n_ctx: int = 512, @@ -803,6 +804,7 @@ class Llama: vocab_only: Only load the vocabulary no weights. use_mmap: Use mmap if possible. use_mlock: Force the system to keep the model in RAM. + kv_overrides: Key-value overrides for the model. seed: RNG seed, -1 for random n_ctx: Text context, 0 = from model n_batch: Prompt processing maximum batch size @@ -866,6 +868,34 @@ class Llama: self.model_params.use_mmap = use_mmap if lora_path is None else False self.model_params.use_mlock = use_mlock + self.kv_overrides = kv_overrides + if kv_overrides is not None: + n_overrides = len(kv_overrides) + self._kv_overrides_array = llama_cpp.llama_model_kv_override * (n_overrides + 1) + self._kv_overrides_array_keys = [] + + for k, v in kv_overrides.items(): + key_buf = ctypes.create_string_buffer(k.encode("utf-8")) + self._kv_overrides_array_keys.append(key_buf) + self._kv_overrides_array[i].key = key_buf + if isinstance(v, int): + self._kv_overrides_array[i].tag = llama_cpp.LLAMA_KV_OVERRIDE_INT + self._kv_overrides_array[i].value.int_value = v + elif isinstance(v, float): + self._kv_overrides_array[i].tag = llama_cpp.LLAMA_KV_OVERRIDE_FLOAT + self._kv_overrides_array[i].value.float_value = v + elif isinstance(v, bool): + self._kv_overrides_array[i].tag = llama_cpp.LLAMA_KV_OVERRIDE_BOOL + self._kv_overrides_array[i].value.bool_value = v + else: + raise ValueError(f"Unknown value type for {k}: {v}") + + self._kv_overrides_array_sentinel_key = b'\0' + + # null array sentinel + self._kv_overrides_array[n_overrides].key = self._kv_overrides_array_sentinel_key + self.model_params.kv_overrides = self._kv_overrides_array + self.n_batch = min(n_ctx, n_batch) # ??? self.n_threads = n_threads or max(multiprocessing.cpu_count() // 2, 1) self.n_threads_batch = n_threads_batch or max( @@ -2148,6 +2178,7 @@ class Llama: vocab_only=self.model_params.vocab_only, use_mmap=self.model_params.use_mmap, use_mlock=self.model_params.use_mlock, + kv_overrides=self.kv_overrides, # Context Params seed=self.context_params.seed, n_ctx=self.context_params.n_ctx, @@ -2190,6 +2221,7 @@ class Llama: vocab_only=state["vocab_only"], use_mmap=state["use_mmap"], use_mlock=state["use_mlock"], + kv_overrides=state["kv_overrides"], # Context Params seed=state["seed"], n_ctx=state["n_ctx"], diff --git a/llama_cpp/server/model.py b/llama_cpp/server/model.py index b9373b7..f9be323 100644 --- a/llama_cpp/server/model.py +++ b/llama_cpp/server/model.py @@ -1,6 +1,6 @@ from __future__ import annotations -from typing import Optional, Union, List +from typing import Dict, Optional, Union, List import llama_cpp @@ -71,6 +71,23 @@ class LlamaProxy: chat_handler = llama_cpp.llama_chat_format.Llava15ChatHandler( clip_model_path=settings.clip_model_path, verbose=settings.verbose ) + + kv_overrides: Optional[Dict[str, Union[bool, int, float]]] = None + if settings.kv_overrides is not None: + assert isinstance(settings.kv_overrides, list) + kv_overrides = {} + for kv in settings.kv_overrides: + key, value = kv.split("=") + if ":" in value: + value_type, value = value.split(":") + if value_type == "bool": + kv_overrides[key] = value.lower() in ["true", "1"] + elif value_type == "int": + kv_overrides[key] = int(value) + elif value_type == "float": + kv_overrides[key] = float(value) + else: + raise ValueError(f"Unknown value type {value_type}") _model = llama_cpp.Llama( model_path=settings.model, @@ -81,6 +98,7 @@ class LlamaProxy: vocab_only=settings.vocab_only, use_mmap=settings.use_mmap, use_mlock=settings.use_mlock, + kv_overrides=kv_overrides, # Context Params seed=settings.seed, n_ctx=settings.n_ctx, diff --git a/llama_cpp/server/settings.py b/llama_cpp/server/settings.py index 346b463..3195d1d 100644 --- a/llama_cpp/server/settings.py +++ b/llama_cpp/server/settings.py @@ -48,6 +48,10 @@ class ModelSettings(BaseSettings): default=llama_cpp.llama_mlock_supported(), description="Use mlock.", ) + kv_overrides: Optional[List[str]] = Field( + default=None, + description="List of model kv overrides in the format key=type:value where type is one of (bool, int, float). Valid true values are (true, TRUE, 1), otherwise false.", + ) # Context Params seed: int = Field( default=llama_cpp.LLAMA_DEFAULT_SEED, description="Random seed. -1 for random."