feat: Add option to enable flash_attn to Lllama params and ModelSettings

This commit is contained in:
Andrei Betlen 2024-04-30 09:29:16 -04:00
parent 8c2b24d5aa
commit 22d77eefd2
2 changed files with 7 additions and 0 deletions

View file

@ -92,6 +92,7 @@ class Llama:
logits_all: bool = False, logits_all: bool = False,
embedding: bool = False, embedding: bool = False,
offload_kqv: bool = True, offload_kqv: bool = True,
flash_attn: bool = False,
# Sampling Params # Sampling Params
last_n_tokens_size: int = 64, last_n_tokens_size: int = 64,
# LoRA Params # LoRA Params
@ -168,6 +169,7 @@ class Llama:
logits_all: Return logits for all tokens, not just the last token. Must be True for completion to return logprobs. logits_all: Return logits for all tokens, not just the last token. Must be True for completion to return logprobs.
embedding: Embedding mode only. embedding: Embedding mode only.
offload_kqv: Offload K, Q, V to GPU. offload_kqv: Offload K, Q, V to GPU.
flash_attn: Use flash attention.
last_n_tokens_size: Maximum number of tokens to keep in the last_n_tokens deque. last_n_tokens_size: Maximum number of tokens to keep in the last_n_tokens deque.
lora_base: Optional path to base model, useful if using a quantized base model and you want to apply LoRA to an f16 model. lora_base: Optional path to base model, useful if using a quantized base model and you want to apply LoRA to an f16 model.
lora_path: Path to a LoRA file to apply to the model. lora_path: Path to a LoRA file to apply to the model.
@ -310,6 +312,7 @@ class Llama:
) # Must be set to True for speculative decoding ) # Must be set to True for speculative decoding
self.context_params.embeddings = embedding # TODO: Rename to embeddings self.context_params.embeddings = embedding # TODO: Rename to embeddings
self.context_params.offload_kqv = offload_kqv self.context_params.offload_kqv = offload_kqv
self.context_params.flash_attn = flash_attn
# KV cache quantization # KV cache quantization
if type_k is not None: if type_k is not None:
self.context_params.type_k = type_k self.context_params.type_k = type_k
@ -1774,6 +1777,7 @@ class Llama:
logits_all=self.context_params.logits_all, logits_all=self.context_params.logits_all,
embedding=self.context_params.embeddings, embedding=self.context_params.embeddings,
offload_kqv=self.context_params.offload_kqv, offload_kqv=self.context_params.offload_kqv,
flash_offload=self.context_params.flash_offload,
# Sampling Params # Sampling Params
last_n_tokens_size=self.last_n_tokens_size, last_n_tokens_size=self.last_n_tokens_size,
# LoRA Params # LoRA Params

View file

@ -96,6 +96,9 @@ class ModelSettings(BaseSettings):
offload_kqv: bool = Field( offload_kqv: bool = Field(
default=True, description="Whether to offload kqv to the GPU." default=True, description="Whether to offload kqv to the GPU."
) )
flash_attn: bool = Field(
default=False, description="Whether to use flash attention."
)
# Sampling Params # Sampling Params
last_n_tokens_size: int = Field( last_n_tokens_size: int = Field(
default=64, default=64,