From 22d77eefd2edaf0148f53374d0cac74d0e25d06e Mon Sep 17 00:00:00 2001 From: Andrei Betlen Date: Tue, 30 Apr 2024 09:29:16 -0400 Subject: [PATCH] feat: Add option to enable `flash_attn` to Lllama params and ModelSettings --- llama_cpp/llama.py | 4 ++++ llama_cpp/server/settings.py | 3 +++ 2 files changed, 7 insertions(+) diff --git a/llama_cpp/llama.py b/llama_cpp/llama.py index 96aac66..172f4c6 100644 --- a/llama_cpp/llama.py +++ b/llama_cpp/llama.py @@ -92,6 +92,7 @@ class Llama: logits_all: bool = False, embedding: bool = False, offload_kqv: bool = True, + flash_attn: bool = False, # Sampling Params last_n_tokens_size: int = 64, # LoRA Params @@ -168,6 +169,7 @@ class Llama: logits_all: Return logits for all tokens, not just the last token. Must be True for completion to return logprobs. embedding: Embedding mode only. offload_kqv: Offload K, Q, V to GPU. + flash_attn: Use flash attention. last_n_tokens_size: Maximum number of tokens to keep in the last_n_tokens deque. lora_base: Optional path to base model, useful if using a quantized base model and you want to apply LoRA to an f16 model. lora_path: Path to a LoRA file to apply to the model. @@ -310,6 +312,7 @@ class Llama: ) # Must be set to True for speculative decoding self.context_params.embeddings = embedding # TODO: Rename to embeddings self.context_params.offload_kqv = offload_kqv + self.context_params.flash_attn = flash_attn # KV cache quantization if type_k is not None: self.context_params.type_k = type_k @@ -1774,6 +1777,7 @@ class Llama: logits_all=self.context_params.logits_all, embedding=self.context_params.embeddings, offload_kqv=self.context_params.offload_kqv, + flash_offload=self.context_params.flash_offload, # Sampling Params last_n_tokens_size=self.last_n_tokens_size, # LoRA Params diff --git a/llama_cpp/server/settings.py b/llama_cpp/server/settings.py index 0c858f9..ed05a88 100644 --- a/llama_cpp/server/settings.py +++ b/llama_cpp/server/settings.py @@ -96,6 +96,9 @@ class ModelSettings(BaseSettings): offload_kqv: bool = Field( default=True, description="Whether to offload kqv to the GPU." ) + flash_attn: bool = Field( + default=False, description="Whether to use flash attention." + ) # Sampling Params last_n_tokens_size: int = Field( default=64,