feat: Add support for str type kv_overrides

This commit is contained in:
Andrei Betlen 2024-04-27 23:42:19 -04:00
parent c9b85bf098
commit a411612b38

View file

@ -73,7 +73,7 @@ class Llama:
vocab_only: bool = False,
use_mmap: bool = True,
use_mlock: bool = False,
kv_overrides: Optional[Dict[str, Union[bool, int, float]]] = None,
kv_overrides: Optional[Dict[str, Union[bool, int, float, str]]] = None,
# Context Params
seed: int = llama_cpp.LLAMA_DEFAULT_SEED,
n_ctx: int = 512,
@ -254,6 +254,13 @@ class Llama:
elif isinstance(v, float):
self._kv_overrides_array[i].tag = llama_cpp.LLAMA_KV_OVERRIDE_TYPE_FLOAT
self._kv_overrides_array[i].value.float_value = v
elif isinstance(v, str): # type: ignore
v_bytes = v.encode("utf-8")
if len(v_bytes) > 128: # TODO: Make this a constant
raise ValueError(f"Value for {k} is too long: {v}")
v_bytes = v_bytes.ljust(128, b"\0")
self._kv_overrides_array[i].tag = llama_cpp.LLAMA_KV_OVERRIDE_TYPE_STR
self._kv_overrides_array[i].value.str_value[:128] = v_bytes
else:
raise ValueError(f"Unknown value type for {k}: {v}")