fix: fix string value kv_overrides. Closes #1487

This commit is contained in:
Andrei Betlen 2024-05-29 02:02:22 -04:00
parent 10b7c50cd2
commit df45a4b3fe
2 changed files with 11 additions and 6 deletions

View file

@ -6,6 +6,7 @@ import uuid
import time
import json
import ctypes
import typing
import fnmatch
import multiprocessing
@ -249,13 +250,13 @@ class Llama:
self._kv_overrides_array[i].key = k.encode("utf-8")
if isinstance(v, bool):
self._kv_overrides_array[i].tag = llama_cpp.LLAMA_KV_OVERRIDE_TYPE_BOOL
self._kv_overrides_array[i].value.bool_value = v
self._kv_overrides_array[i].value.val_bool = v
elif isinstance(v, int):
self._kv_overrides_array[i].tag = llama_cpp.LLAMA_KV_OVERRIDE_TYPE_INT
self._kv_overrides_array[i].value.int_value = v
self._kv_overrides_array[i].value.val_i64 = v
elif isinstance(v, float):
self._kv_overrides_array[i].tag = llama_cpp.LLAMA_KV_OVERRIDE_TYPE_FLOAT
self._kv_overrides_array[i].value.float_value = v
self._kv_overrides_array[i].value.val_f64 = v
elif isinstance(v, str): # type: ignore
v_bytes = v.encode("utf-8")
if len(v_bytes) > 128: # TODO: Make this a constant
@ -263,10 +264,12 @@ class Llama:
v_bytes = v_bytes.ljust(128, b"\0")
self._kv_overrides_array[i].tag = llama_cpp.LLAMA_KV_OVERRIDE_TYPE_STR
# copy min(v_bytes, 128) to str_value
address = typing.cast(int, ctypes.addressof(self._kv_overrides_array[i].value) + llama_cpp.llama_model_kv_override_value.val_str.offset)
buffer_start = ctypes.cast(address, ctypes.POINTER(ctypes.c_char))
ctypes.memmove(
self._kv_overrides_array[i].value.str_value,
buffer_start,
v_bytes,
min(len(v_bytes), 128),
128,
)
else:
raise ValueError(f"Unknown value type for {k}: {v}")

View file

@ -183,7 +183,7 @@ class LlamaProxy:
num_pred_tokens=settings.draft_model_num_pred_tokens
)
kv_overrides: Optional[Dict[str, Union[bool, int, float]]] = None
kv_overrides: Optional[Dict[str, Union[bool, int, float, str]]] = None
if settings.kv_overrides is not None:
assert isinstance(settings.kv_overrides, list)
kv_overrides = {}
@ -197,6 +197,8 @@ class LlamaProxy:
kv_overrides[key] = int(value)
elif value_type == "float":
kv_overrides[key] = float(value)
elif value_type == "str":
kv_overrides[key] = value
else:
raise ValueError(f"Unknown value type {value_type}")