feat: Add support for str type kv_overrides
This commit is contained in:
parent
c9b85bf098
commit
a411612b38
1 changed files with 8 additions and 1 deletions
|
@ -73,7 +73,7 @@ class Llama:
|
|||
vocab_only: bool = False,
|
||||
use_mmap: bool = True,
|
||||
use_mlock: bool = False,
|
||||
kv_overrides: Optional[Dict[str, Union[bool, int, float]]] = None,
|
||||
kv_overrides: Optional[Dict[str, Union[bool, int, float, str]]] = None,
|
||||
# Context Params
|
||||
seed: int = llama_cpp.LLAMA_DEFAULT_SEED,
|
||||
n_ctx: int = 512,
|
||||
|
@ -254,6 +254,13 @@ class Llama:
|
|||
elif isinstance(v, float):
|
||||
self._kv_overrides_array[i].tag = llama_cpp.LLAMA_KV_OVERRIDE_TYPE_FLOAT
|
||||
self._kv_overrides_array[i].value.float_value = v
|
||||
elif isinstance(v, str): # type: ignore
|
||||
v_bytes = v.encode("utf-8")
|
||||
if len(v_bytes) > 128: # TODO: Make this a constant
|
||||
raise ValueError(f"Value for {k} is too long: {v}")
|
||||
v_bytes = v_bytes.ljust(128, b"\0")
|
||||
self._kv_overrides_array[i].tag = llama_cpp.LLAMA_KV_OVERRIDE_TYPE_STR
|
||||
self._kv_overrides_array[i].value.str_value[:128] = v_bytes
|
||||
else:
|
||||
raise ValueError(f"Unknown value type for {k}: {v}")
|
||||
|
||||
|
|
Loading…
Add table
Reference in a new issue