fix: fix string value kv_overrides. Closes #1487
This commit is contained in:
parent
10b7c50cd2
commit
df45a4b3fe
2 changed files with 11 additions and 6 deletions
|
@ -6,6 +6,7 @@ import uuid
|
||||||
import time
|
import time
|
||||||
import json
|
import json
|
||||||
import ctypes
|
import ctypes
|
||||||
|
import typing
|
||||||
import fnmatch
|
import fnmatch
|
||||||
import multiprocessing
|
import multiprocessing
|
||||||
|
|
||||||
|
@ -249,13 +250,13 @@ class Llama:
|
||||||
self._kv_overrides_array[i].key = k.encode("utf-8")
|
self._kv_overrides_array[i].key = k.encode("utf-8")
|
||||||
if isinstance(v, bool):
|
if isinstance(v, bool):
|
||||||
self._kv_overrides_array[i].tag = llama_cpp.LLAMA_KV_OVERRIDE_TYPE_BOOL
|
self._kv_overrides_array[i].tag = llama_cpp.LLAMA_KV_OVERRIDE_TYPE_BOOL
|
||||||
self._kv_overrides_array[i].value.bool_value = v
|
self._kv_overrides_array[i].value.val_bool = v
|
||||||
elif isinstance(v, int):
|
elif isinstance(v, int):
|
||||||
self._kv_overrides_array[i].tag = llama_cpp.LLAMA_KV_OVERRIDE_TYPE_INT
|
self._kv_overrides_array[i].tag = llama_cpp.LLAMA_KV_OVERRIDE_TYPE_INT
|
||||||
self._kv_overrides_array[i].value.int_value = v
|
self._kv_overrides_array[i].value.val_i64 = v
|
||||||
elif isinstance(v, float):
|
elif isinstance(v, float):
|
||||||
self._kv_overrides_array[i].tag = llama_cpp.LLAMA_KV_OVERRIDE_TYPE_FLOAT
|
self._kv_overrides_array[i].tag = llama_cpp.LLAMA_KV_OVERRIDE_TYPE_FLOAT
|
||||||
self._kv_overrides_array[i].value.float_value = v
|
self._kv_overrides_array[i].value.val_f64 = v
|
||||||
elif isinstance(v, str): # type: ignore
|
elif isinstance(v, str): # type: ignore
|
||||||
v_bytes = v.encode("utf-8")
|
v_bytes = v.encode("utf-8")
|
||||||
if len(v_bytes) > 128: # TODO: Make this a constant
|
if len(v_bytes) > 128: # TODO: Make this a constant
|
||||||
|
@ -263,10 +264,12 @@ class Llama:
|
||||||
v_bytes = v_bytes.ljust(128, b"\0")
|
v_bytes = v_bytes.ljust(128, b"\0")
|
||||||
self._kv_overrides_array[i].tag = llama_cpp.LLAMA_KV_OVERRIDE_TYPE_STR
|
self._kv_overrides_array[i].tag = llama_cpp.LLAMA_KV_OVERRIDE_TYPE_STR
|
||||||
# copy min(v_bytes, 128) to str_value
|
# copy min(v_bytes, 128) to str_value
|
||||||
|
address = typing.cast(int, ctypes.addressof(self._kv_overrides_array[i].value) + llama_cpp.llama_model_kv_override_value.val_str.offset)
|
||||||
|
buffer_start = ctypes.cast(address, ctypes.POINTER(ctypes.c_char))
|
||||||
ctypes.memmove(
|
ctypes.memmove(
|
||||||
self._kv_overrides_array[i].value.str_value,
|
buffer_start,
|
||||||
v_bytes,
|
v_bytes,
|
||||||
min(len(v_bytes), 128),
|
128,
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
raise ValueError(f"Unknown value type for {k}: {v}")
|
raise ValueError(f"Unknown value type for {k}: {v}")
|
||||||
|
|
|
@ -183,7 +183,7 @@ class LlamaProxy:
|
||||||
num_pred_tokens=settings.draft_model_num_pred_tokens
|
num_pred_tokens=settings.draft_model_num_pred_tokens
|
||||||
)
|
)
|
||||||
|
|
||||||
kv_overrides: Optional[Dict[str, Union[bool, int, float]]] = None
|
kv_overrides: Optional[Dict[str, Union[bool, int, float, str]]] = None
|
||||||
if settings.kv_overrides is not None:
|
if settings.kv_overrides is not None:
|
||||||
assert isinstance(settings.kv_overrides, list)
|
assert isinstance(settings.kv_overrides, list)
|
||||||
kv_overrides = {}
|
kv_overrides = {}
|
||||||
|
@ -197,6 +197,8 @@ class LlamaProxy:
|
||||||
kv_overrides[key] = int(value)
|
kv_overrides[key] = int(value)
|
||||||
elif value_type == "float":
|
elif value_type == "float":
|
||||||
kv_overrides[key] = float(value)
|
kv_overrides[key] = float(value)
|
||||||
|
elif value_type == "str":
|
||||||
|
kv_overrides[key] = value
|
||||||
else:
|
else:
|
||||||
raise ValueError(f"Unknown value type {value_type}")
|
raise ValueError(f"Unknown value type {value_type}")
|
||||||
|
|
||||||
|
|
Loading…
Add table
Reference in a new issue