Add rms_eps_norm
This commit is contained in:
parent
e4431a6ade
commit
8cd64d4ac3
1 changed files with 19 additions and 4 deletions
|
@ -216,7 +216,6 @@ class Llama:
|
||||||
embedding: bool = False,
|
embedding: bool = False,
|
||||||
n_threads: Optional[int] = None,
|
n_threads: Optional[int] = None,
|
||||||
n_batch: int = 512,
|
n_batch: int = 512,
|
||||||
n_gqa: Optional[int] = None, # must be 8 for llama2 70b
|
|
||||||
last_n_tokens_size: int = 64,
|
last_n_tokens_size: int = 64,
|
||||||
lora_base: Optional[str] = None,
|
lora_base: Optional[str] = None,
|
||||||
lora_path: Optional[str] = None,
|
lora_path: Optional[str] = None,
|
||||||
|
@ -224,6 +223,8 @@ class Llama:
|
||||||
tensor_split: Optional[List[float]] = None,
|
tensor_split: Optional[List[float]] = None,
|
||||||
rope_freq_base: float = 10000.0,
|
rope_freq_base: float = 10000.0,
|
||||||
rope_freq_scale: float = 1.0,
|
rope_freq_scale: float = 1.0,
|
||||||
|
n_gqa: Optional[int] = None, # (TEMPORARY) must be 8 for llama2 70b
|
||||||
|
rms_eps_norm: Optional[float] = None, # (TEMPORARY)
|
||||||
verbose: bool = True,
|
verbose: bool = True,
|
||||||
):
|
):
|
||||||
"""Load a llama.cpp model from `model_path`.
|
"""Load a llama.cpp model from `model_path`.
|
||||||
|
@ -261,8 +262,6 @@ class Llama:
|
||||||
|
|
||||||
self.params = llama_cpp.llama_context_default_params()
|
self.params = llama_cpp.llama_context_default_params()
|
||||||
self.params.n_ctx = n_ctx
|
self.params.n_ctx = n_ctx
|
||||||
if n_gqa is not None:
|
|
||||||
self.params.n_gqa = n_gqa
|
|
||||||
self.params.n_gpu_layers = n_gpu_layers
|
self.params.n_gpu_layers = n_gpu_layers
|
||||||
self.params.seed = seed
|
self.params.seed = seed
|
||||||
self.params.f16_kv = f16_kv
|
self.params.f16_kv = f16_kv
|
||||||
|
@ -285,6 +284,12 @@ class Llama:
|
||||||
self.params.rope_freq_base = rope_freq_base
|
self.params.rope_freq_base = rope_freq_base
|
||||||
self.params.rope_freq_scale = rope_freq_scale
|
self.params.rope_freq_scale = rope_freq_scale
|
||||||
|
|
||||||
|
if n_gqa is not None:
|
||||||
|
self.params.n_gqa = n_gqa
|
||||||
|
|
||||||
|
if rms_eps_norm is not None:
|
||||||
|
self.params.rms_eps_norm = rms_eps_norm
|
||||||
|
|
||||||
self.last_n_tokens_size = last_n_tokens_size
|
self.last_n_tokens_size = last_n_tokens_size
|
||||||
self.n_batch = min(n_ctx, n_batch)
|
self.n_batch = min(n_ctx, n_batch)
|
||||||
|
|
||||||
|
@ -1526,6 +1531,10 @@ class Llama:
|
||||||
lora_base=self.lora_base,
|
lora_base=self.lora_base,
|
||||||
lora_path=self.lora_path,
|
lora_path=self.lora_path,
|
||||||
tensor_split=self.tensor_split,
|
tensor_split=self.tensor_split,
|
||||||
|
### TEMPORARY ###
|
||||||
|
n_gqa=self.params.n_gqa,
|
||||||
|
rms_eps_norm=self.params.rms_eps_norm,
|
||||||
|
### TEMPORARY ###
|
||||||
### DEPRECATED ###
|
### DEPRECATED ###
|
||||||
n_parts=self.n_parts,
|
n_parts=self.n_parts,
|
||||||
### DEPRECATED ###
|
### DEPRECATED ###
|
||||||
|
@ -1535,7 +1544,6 @@ class Llama:
|
||||||
self.__init__(
|
self.__init__(
|
||||||
model_path=state["model_path"],
|
model_path=state["model_path"],
|
||||||
n_ctx=state["n_ctx"],
|
n_ctx=state["n_ctx"],
|
||||||
n_parts=state["n_parts"],
|
|
||||||
n_gpu_layers=state["n_gpu_layers"],
|
n_gpu_layers=state["n_gpu_layers"],
|
||||||
seed=state["seed"],
|
seed=state["seed"],
|
||||||
f16_kv=state["f16_kv"],
|
f16_kv=state["f16_kv"],
|
||||||
|
@ -1551,7 +1559,14 @@ class Llama:
|
||||||
lora_base=state["lora_base"],
|
lora_base=state["lora_base"],
|
||||||
lora_path=state["lora_path"],
|
lora_path=state["lora_path"],
|
||||||
tensor_split=state["tensor_split"],
|
tensor_split=state["tensor_split"],
|
||||||
|
n_gqa=state["n_gqa"],
|
||||||
|
### TEMPORARY ###
|
||||||
|
rms_eps_norm=state["rms_eps_norm"],
|
||||||
verbose=state["verbose"],
|
verbose=state["verbose"],
|
||||||
|
### TEMPORARY ###
|
||||||
|
### DEPRECATED ###
|
||||||
|
n_parts=state["n_parts"],
|
||||||
|
### DEPRECATED ###
|
||||||
)
|
)
|
||||||
|
|
||||||
def save_state(self) -> LlamaState:
|
def save_state(self) -> LlamaState:
|
||||||
|
|
Loading…
Reference in a new issue