diff --git a/llama_cpp/llama.py b/llama_cpp/llama.py index 251d064..7ca7af0 100644 --- a/llama_cpp/llama.py +++ b/llama_cpp/llama.py @@ -216,7 +216,6 @@ class Llama: embedding: bool = False, n_threads: Optional[int] = None, n_batch: int = 512, - n_gqa: Optional[int] = None, # must be 8 for llama2 70b last_n_tokens_size: int = 64, lora_base: Optional[str] = None, lora_path: Optional[str] = None, @@ -224,6 +223,8 @@ class Llama: tensor_split: Optional[List[float]] = None, rope_freq_base: float = 10000.0, rope_freq_scale: float = 1.0, + n_gqa: Optional[int] = None, # (TEMPORARY) must be 8 for llama2 70b + rms_eps_norm: Optional[float] = None, # (TEMPORARY) verbose: bool = True, ): """Load a llama.cpp model from `model_path`. @@ -261,8 +262,6 @@ class Llama: self.params = llama_cpp.llama_context_default_params() self.params.n_ctx = n_ctx - if n_gqa is not None: - self.params.n_gqa = n_gqa self.params.n_gpu_layers = n_gpu_layers self.params.seed = seed self.params.f16_kv = f16_kv @@ -285,6 +284,12 @@ class Llama: self.params.rope_freq_base = rope_freq_base self.params.rope_freq_scale = rope_freq_scale + if n_gqa is not None: + self.params.n_gqa = n_gqa + + if rms_eps_norm is not None: + self.params.rms_eps_norm = rms_eps_norm + self.last_n_tokens_size = last_n_tokens_size self.n_batch = min(n_ctx, n_batch) @@ -1526,6 +1531,10 @@ class Llama: lora_base=self.lora_base, lora_path=self.lora_path, tensor_split=self.tensor_split, + ### TEMPORARY ### + n_gqa=self.params.n_gqa, + rms_eps_norm=self.params.rms_eps_norm, + ### TEMPORARY ### ### DEPRECATED ### n_parts=self.n_parts, ### DEPRECATED ### @@ -1535,7 +1544,6 @@ class Llama: self.__init__( model_path=state["model_path"], n_ctx=state["n_ctx"], - n_parts=state["n_parts"], n_gpu_layers=state["n_gpu_layers"], seed=state["seed"], f16_kv=state["f16_kv"], @@ -1551,7 +1559,14 @@ class Llama: lora_base=state["lora_base"], lora_path=state["lora_path"], tensor_split=state["tensor_split"], + n_gqa=state["n_gqa"], + ### TEMPORARY ### + rms_eps_norm=state["rms_eps_norm"], verbose=state["verbose"], + ### TEMPORARY ### + ### DEPRECATED ### + n_parts=state["n_parts"], + ### DEPRECATED ### ) def save_state(self) -> LlamaState: