diff --git a/llama_cpp/llama.py b/llama_cpp/llama.py index 481842b..96aac66 100644 --- a/llama_cpp/llama.py +++ b/llama_cpp/llama.py @@ -73,7 +73,7 @@ class Llama: vocab_only: bool = False, use_mmap: bool = True, use_mlock: bool = False, - kv_overrides: Optional[Dict[str, Union[bool, int, float]]] = None, + kv_overrides: Optional[Dict[str, Union[bool, int, float, str]]] = None, # Context Params seed: int = llama_cpp.LLAMA_DEFAULT_SEED, n_ctx: int = 512, @@ -254,6 +254,13 @@ class Llama: elif isinstance(v, float): self._kv_overrides_array[i].tag = llama_cpp.LLAMA_KV_OVERRIDE_TYPE_FLOAT self._kv_overrides_array[i].value.float_value = v + elif isinstance(v, str): # type: ignore + v_bytes = v.encode("utf-8") + if len(v_bytes) > 128: # TODO: Make this a constant + raise ValueError(f"Value for {k} is too long: {v}") + v_bytes = v_bytes.ljust(128, b"\0") + self._kv_overrides_array[i].tag = llama_cpp.LLAMA_KV_OVERRIDE_TYPE_STR + self._kv_overrides_array[i].value.str_value[:128] = v_bytes else: raise ValueError(f"Unknown value type for {k}: {v}")