feat: Add support for str type kv_overrides
This commit is contained in:
parent
c9b85bf098
commit
a411612b38
1 changed files with 8 additions and 1 deletions
|
@ -73,7 +73,7 @@ class Llama:
|
||||||
vocab_only: bool = False,
|
vocab_only: bool = False,
|
||||||
use_mmap: bool = True,
|
use_mmap: bool = True,
|
||||||
use_mlock: bool = False,
|
use_mlock: bool = False,
|
||||||
kv_overrides: Optional[Dict[str, Union[bool, int, float]]] = None,
|
kv_overrides: Optional[Dict[str, Union[bool, int, float, str]]] = None,
|
||||||
# Context Params
|
# Context Params
|
||||||
seed: int = llama_cpp.LLAMA_DEFAULT_SEED,
|
seed: int = llama_cpp.LLAMA_DEFAULT_SEED,
|
||||||
n_ctx: int = 512,
|
n_ctx: int = 512,
|
||||||
|
@ -254,6 +254,13 @@ class Llama:
|
||||||
elif isinstance(v, float):
|
elif isinstance(v, float):
|
||||||
self._kv_overrides_array[i].tag = llama_cpp.LLAMA_KV_OVERRIDE_TYPE_FLOAT
|
self._kv_overrides_array[i].tag = llama_cpp.LLAMA_KV_OVERRIDE_TYPE_FLOAT
|
||||||
self._kv_overrides_array[i].value.float_value = v
|
self._kv_overrides_array[i].value.float_value = v
|
||||||
|
elif isinstance(v, str): # type: ignore
|
||||||
|
v_bytes = v.encode("utf-8")
|
||||||
|
if len(v_bytes) > 128: # TODO: Make this a constant
|
||||||
|
raise ValueError(f"Value for {k} is too long: {v}")
|
||||||
|
v_bytes = v_bytes.ljust(128, b"\0")
|
||||||
|
self._kv_overrides_array[i].tag = llama_cpp.LLAMA_KV_OVERRIDE_TYPE_STR
|
||||||
|
self._kv_overrides_array[i].value.str_value[:128] = v_bytes
|
||||||
else:
|
else:
|
||||||
raise ValueError(f"Unknown value type for {k}: {v}")
|
raise ValueError(f"Unknown value type for {k}: {v}")
|
||||||
|
|
||||||
|
|
Loading…
Reference in a new issue