Implement GGUF metadata KV overrides (#1011)
* Implement GGUF metadata overrides * whitespace fix * Fix kv overrides. * Fix pointer and pickle * Match llama.cpp kv_overrides cli argument --------- Co-authored-by: Andrei <abetlen@gmail.com>
This commit is contained in:
parent
7eff42c239
commit
76aafa6149
3 changed files with 55 additions and 1 deletions
|
@ -735,6 +735,7 @@ class Llama:
|
|||
vocab_only: bool = False,
|
||||
use_mmap: bool = True,
|
||||
use_mlock: bool = False,
|
||||
kv_overrides: Optional[Dict[str, Union[bool, int, float]]] = None,
|
||||
# Context Params
|
||||
seed: int = llama_cpp.LLAMA_DEFAULT_SEED,
|
||||
n_ctx: int = 512,
|
||||
|
@ -803,6 +804,7 @@ class Llama:
|
|||
vocab_only: Only load the vocabulary no weights.
|
||||
use_mmap: Use mmap if possible.
|
||||
use_mlock: Force the system to keep the model in RAM.
|
||||
kv_overrides: Key-value overrides for the model.
|
||||
seed: RNG seed, -1 for random
|
||||
n_ctx: Text context, 0 = from model
|
||||
n_batch: Prompt processing maximum batch size
|
||||
|
@ -866,6 +868,34 @@ class Llama:
|
|||
self.model_params.use_mmap = use_mmap if lora_path is None else False
|
||||
self.model_params.use_mlock = use_mlock
|
||||
|
||||
self.kv_overrides = kv_overrides
|
||||
if kv_overrides is not None:
|
||||
n_overrides = len(kv_overrides)
|
||||
self._kv_overrides_array = llama_cpp.llama_model_kv_override * (n_overrides + 1)
|
||||
self._kv_overrides_array_keys = []
|
||||
|
||||
for k, v in kv_overrides.items():
|
||||
key_buf = ctypes.create_string_buffer(k.encode("utf-8"))
|
||||
self._kv_overrides_array_keys.append(key_buf)
|
||||
self._kv_overrides_array[i].key = key_buf
|
||||
if isinstance(v, int):
|
||||
self._kv_overrides_array[i].tag = llama_cpp.LLAMA_KV_OVERRIDE_INT
|
||||
self._kv_overrides_array[i].value.int_value = v
|
||||
elif isinstance(v, float):
|
||||
self._kv_overrides_array[i].tag = llama_cpp.LLAMA_KV_OVERRIDE_FLOAT
|
||||
self._kv_overrides_array[i].value.float_value = v
|
||||
elif isinstance(v, bool):
|
||||
self._kv_overrides_array[i].tag = llama_cpp.LLAMA_KV_OVERRIDE_BOOL
|
||||
self._kv_overrides_array[i].value.bool_value = v
|
||||
else:
|
||||
raise ValueError(f"Unknown value type for {k}: {v}")
|
||||
|
||||
self._kv_overrides_array_sentinel_key = b'\0'
|
||||
|
||||
# null array sentinel
|
||||
self._kv_overrides_array[n_overrides].key = self._kv_overrides_array_sentinel_key
|
||||
self.model_params.kv_overrides = self._kv_overrides_array
|
||||
|
||||
self.n_batch = min(n_ctx, n_batch) # ???
|
||||
self.n_threads = n_threads or max(multiprocessing.cpu_count() // 2, 1)
|
||||
self.n_threads_batch = n_threads_batch or max(
|
||||
|
@ -2148,6 +2178,7 @@ class Llama:
|
|||
vocab_only=self.model_params.vocab_only,
|
||||
use_mmap=self.model_params.use_mmap,
|
||||
use_mlock=self.model_params.use_mlock,
|
||||
kv_overrides=self.kv_overrides,
|
||||
# Context Params
|
||||
seed=self.context_params.seed,
|
||||
n_ctx=self.context_params.n_ctx,
|
||||
|
@ -2190,6 +2221,7 @@ class Llama:
|
|||
vocab_only=state["vocab_only"],
|
||||
use_mmap=state["use_mmap"],
|
||||
use_mlock=state["use_mlock"],
|
||||
kv_overrides=state["kv_overrides"],
|
||||
# Context Params
|
||||
seed=state["seed"],
|
||||
n_ctx=state["n_ctx"],
|
||||
|
|
|
@ -1,6 +1,6 @@
|
|||
from __future__ import annotations
|
||||
|
||||
from typing import Optional, Union, List
|
||||
from typing import Dict, Optional, Union, List
|
||||
|
||||
import llama_cpp
|
||||
|
||||
|
@ -71,6 +71,23 @@ class LlamaProxy:
|
|||
chat_handler = llama_cpp.llama_chat_format.Llava15ChatHandler(
|
||||
clip_model_path=settings.clip_model_path, verbose=settings.verbose
|
||||
)
|
||||
|
||||
kv_overrides: Optional[Dict[str, Union[bool, int, float]]] = None
|
||||
if settings.kv_overrides is not None:
|
||||
assert isinstance(settings.kv_overrides, list)
|
||||
kv_overrides = {}
|
||||
for kv in settings.kv_overrides:
|
||||
key, value = kv.split("=")
|
||||
if ":" in value:
|
||||
value_type, value = value.split(":")
|
||||
if value_type == "bool":
|
||||
kv_overrides[key] = value.lower() in ["true", "1"]
|
||||
elif value_type == "int":
|
||||
kv_overrides[key] = int(value)
|
||||
elif value_type == "float":
|
||||
kv_overrides[key] = float(value)
|
||||
else:
|
||||
raise ValueError(f"Unknown value type {value_type}")
|
||||
|
||||
_model = llama_cpp.Llama(
|
||||
model_path=settings.model,
|
||||
|
@ -81,6 +98,7 @@ class LlamaProxy:
|
|||
vocab_only=settings.vocab_only,
|
||||
use_mmap=settings.use_mmap,
|
||||
use_mlock=settings.use_mlock,
|
||||
kv_overrides=kv_overrides,
|
||||
# Context Params
|
||||
seed=settings.seed,
|
||||
n_ctx=settings.n_ctx,
|
||||
|
|
|
@ -48,6 +48,10 @@ class ModelSettings(BaseSettings):
|
|||
default=llama_cpp.llama_mlock_supported(),
|
||||
description="Use mlock.",
|
||||
)
|
||||
kv_overrides: Optional[List[str]] = Field(
|
||||
default=None,
|
||||
description="List of model kv overrides in the format key=type:value where type is one of (bool, int, float). Valid true values are (true, TRUE, 1), otherwise false.",
|
||||
)
|
||||
# Context Params
|
||||
seed: int = Field(
|
||||
default=llama_cpp.LLAMA_DEFAULT_SEED, description="Random seed. -1 for random."
|
||||
|
|
Loading…
Reference in a new issue