From 11dd2bf3829896b00d7af1121d19e60c03385987 Mon Sep 17 00:00:00 2001 From: Andrei Betlen Date: Mon, 24 Jul 2023 14:09:24 -0400 Subject: [PATCH] Add temporary rms_norm_eps parameter --- llama_cpp/llama.py | 14 +++++++------- llama_cpp/server/app.py | 10 ++++++++++ 2 files changed, 17 insertions(+), 7 deletions(-) diff --git a/llama_cpp/llama.py b/llama_cpp/llama.py index 7ca7af0..9679b2e 100644 --- a/llama_cpp/llama.py +++ b/llama_cpp/llama.py @@ -224,7 +224,7 @@ class Llama: rope_freq_base: float = 10000.0, rope_freq_scale: float = 1.0, n_gqa: Optional[int] = None, # (TEMPORARY) must be 8 for llama2 70b - rms_eps_norm: Optional[float] = None, # (TEMPORARY) + rms_norm_eps: Optional[float] = None, # (TEMPORARY) verbose: bool = True, ): """Load a llama.cpp model from `model_path`. @@ -287,8 +287,8 @@ class Llama: if n_gqa is not None: self.params.n_gqa = n_gqa - if rms_eps_norm is not None: - self.params.rms_eps_norm = rms_eps_norm + if rms_norm_eps is not None: + self.params.rms_norm_eps = rms_norm_eps self.last_n_tokens_size = last_n_tokens_size self.n_batch = min(n_ctx, n_batch) @@ -1533,7 +1533,7 @@ class Llama: tensor_split=self.tensor_split, ### TEMPORARY ### n_gqa=self.params.n_gqa, - rms_eps_norm=self.params.rms_eps_norm, + rms_norm_eps=self.params.rms_norm_eps, ### TEMPORARY ### ### DEPRECATED ### n_parts=self.n_parts, @@ -1559,11 +1559,11 @@ class Llama: lora_base=state["lora_base"], lora_path=state["lora_path"], tensor_split=state["tensor_split"], - n_gqa=state["n_gqa"], - ### TEMPORARY ### - rms_eps_norm=state["rms_eps_norm"], verbose=state["verbose"], ### TEMPORARY ### + n_gqa=state["n_gqa"], + rms_norm_eps=state["rms_norm_eps"], + ### TEMPORARY ### ### DEPRECATED ### n_parts=state["n_parts"], ### DEPRECATED ### diff --git a/llama_cpp/server/app.py b/llama_cpp/server/app.py index ba68ba8..4afcfd5 100644 --- a/llama_cpp/server/app.py +++ b/llama_cpp/server/app.py @@ -95,6 +95,14 @@ class Settings(BaseSettings): default=True, description="Whether to interrupt requests when a new request is received.", ) + n_gqa: Optional[int] = Field( + default=None, + description="TEMPORARY: Set to 8 for Llama2 70B", + ) + rms_norm_eps: Optional[float] = Field( + default=None, + description="TEMPORARY", + ) class ErrorResponse(TypedDict): @@ -320,6 +328,8 @@ def create_app(settings: Optional[Settings] = None): last_n_tokens_size=settings.last_n_tokens_size, vocab_only=settings.vocab_only, verbose=settings.verbose, + n_gqa=settings.n_gqa, + rms_norm_eps=settings.rms_norm_eps, ) if settings.cache: if settings.cache_type == "disk":