fix: GGUF metadata KV overrides, re #1011 (#1116)

* kv overrides another attempt

* add sentinel element, simplify array population

* ensure sentinel element is zeroed
This commit is contained in:
Phil H 2024-01-24 03:00:38 +00:00 committed by GitHub
parent 7e63928bc9
commit fe5d6ea648
No known key found for this signature in database
GPG key ID: B5690EEEBB952194

View file

@ -194,16 +194,16 @@ class Llama:
self.model_params.use_mmap = use_mmap if lora_path is None else False self.model_params.use_mmap = use_mmap if lora_path is None else False
self.model_params.use_mlock = use_mlock self.model_params.use_mlock = use_mlock
# kv_overrides is the original python dict
self.kv_overrides = kv_overrides self.kv_overrides = kv_overrides
if kv_overrides is not None: if kv_overrides is not None:
n_overrides = len(kv_overrides)
self._kv_overrides_array = llama_cpp.llama_model_kv_override * (n_overrides + 1)
self._kv_overrides_array_keys = []
for k, v in kv_overrides.items(): # _kv_overrides_array is a ctypes.Array of llama_model_kv_override Structs
key_buf = ctypes.create_string_buffer(k.encode("utf-8")) kvo_array_len = len(kv_overrides) + 1 # for sentinel element
self._kv_overrides_array_keys.append(key_buf) self._kv_overrides_array = (llama_cpp.llama_model_kv_override * kvo_array_len)()
self._kv_overrides_array[i].key = key_buf
for i, (k, v) in enumerate(kv_overrides.items()):
self._kv_overrides_array[i].key = k.encode('utf-8');
if isinstance(v, int): if isinstance(v, int):
self._kv_overrides_array[i].tag = llama_cpp.LLAMA_KV_OVERRIDE_INT self._kv_overrides_array[i].tag = llama_cpp.LLAMA_KV_OVERRIDE_INT
self._kv_overrides_array[i].value.int_value = v self._kv_overrides_array[i].value.int_value = v
@ -216,10 +216,7 @@ class Llama:
else: else:
raise ValueError(f"Unknown value type for {k}: {v}") raise ValueError(f"Unknown value type for {k}: {v}")
self._kv_overrides_array_sentinel_key = b'\0' self._kv_overrides_array[-1].key = b'\0' # ensure sentinel element is zeroed
# null array sentinel
self._kv_overrides_array[n_overrides].key = self._kv_overrides_array_sentinel_key
self.model_params.kv_overrides = self._kv_overrides_array self.model_params.kv_overrides = self._kv_overrides_array
self.n_batch = min(n_ctx, n_batch) # ??? self.n_batch = min(n_ctx, n_batch) # ???