diff --git a/llama_cpp/llama.py b/llama_cpp/llama.py index 3d15800..74739cb 100644 --- a/llama_cpp/llama.py +++ b/llama_cpp/llama.py @@ -205,15 +205,15 @@ class Llama: for i, (k, v) in enumerate(kv_overrides.items()): self._kv_overrides_array[i].key = k.encode("utf-8") - if isinstance(v, int): + if isinstance(v, bool): + self._kv_overrides_array[i].tag = llama_cpp.LLAMA_KV_OVERRIDE_BOOL + self._kv_overrides_array[i].value.bool_value = v + elif isinstance(v, int): self._kv_overrides_array[i].tag = llama_cpp.LLAMA_KV_OVERRIDE_INT self._kv_overrides_array[i].value.int_value = v elif isinstance(v, float): self._kv_overrides_array[i].tag = llama_cpp.LLAMA_KV_OVERRIDE_FLOAT self._kv_overrides_array[i].value.float_value = v - elif isinstance(v, bool): - self._kv_overrides_array[i].tag = llama_cpp.LLAMA_KV_OVERRIDE_BOOL - self._kv_overrides_array[i].value.bool_value = v else: raise ValueError(f"Unknown value type for {k}: {v}")