* kv overrides another attempt * add sentinel element, simplify array population * ensure sentinel element is zeroed
This commit is contained in:
parent
7e63928bc9
commit
fe5d6ea648
1 changed files with 8 additions and 11 deletions
|
@ -194,16 +194,16 @@ class Llama:
|
|||
self.model_params.use_mmap = use_mmap if lora_path is None else False
|
||||
self.model_params.use_mlock = use_mlock
|
||||
|
||||
# kv_overrides is the original python dict
|
||||
self.kv_overrides = kv_overrides
|
||||
if kv_overrides is not None:
|
||||
n_overrides = len(kv_overrides)
|
||||
self._kv_overrides_array = llama_cpp.llama_model_kv_override * (n_overrides + 1)
|
||||
self._kv_overrides_array_keys = []
|
||||
|
||||
for k, v in kv_overrides.items():
|
||||
key_buf = ctypes.create_string_buffer(k.encode("utf-8"))
|
||||
self._kv_overrides_array_keys.append(key_buf)
|
||||
self._kv_overrides_array[i].key = key_buf
|
||||
# _kv_overrides_array is a ctypes.Array of llama_model_kv_override Structs
|
||||
kvo_array_len = len(kv_overrides) + 1 # for sentinel element
|
||||
self._kv_overrides_array = (llama_cpp.llama_model_kv_override * kvo_array_len)()
|
||||
|
||||
for i, (k, v) in enumerate(kv_overrides.items()):
|
||||
self._kv_overrides_array[i].key = k.encode('utf-8');
|
||||
if isinstance(v, int):
|
||||
self._kv_overrides_array[i].tag = llama_cpp.LLAMA_KV_OVERRIDE_INT
|
||||
self._kv_overrides_array[i].value.int_value = v
|
||||
|
@ -216,10 +216,7 @@ class Llama:
|
|||
else:
|
||||
raise ValueError(f"Unknown value type for {k}: {v}")
|
||||
|
||||
self._kv_overrides_array_sentinel_key = b'\0'
|
||||
|
||||
# null array sentinel
|
||||
self._kv_overrides_array[n_overrides].key = self._kv_overrides_array_sentinel_key
|
||||
self._kv_overrides_array[-1].key = b'\0' # ensure sentinel element is zeroed
|
||||
self.model_params.kv_overrides = self._kv_overrides_array
|
||||
|
||||
self.n_batch = min(n_ctx, n_batch) # ???
|
||||
|
|
Loading…
Add table
Reference in a new issue