feat: Add option to enable flash_attn
to Lllama params and ModelSettings
This commit is contained in:
parent
8c2b24d5aa
commit
22d77eefd2
2 changed files with 7 additions and 0 deletions
|
@ -92,6 +92,7 @@ class Llama:
|
||||||
logits_all: bool = False,
|
logits_all: bool = False,
|
||||||
embedding: bool = False,
|
embedding: bool = False,
|
||||||
offload_kqv: bool = True,
|
offload_kqv: bool = True,
|
||||||
|
flash_attn: bool = False,
|
||||||
# Sampling Params
|
# Sampling Params
|
||||||
last_n_tokens_size: int = 64,
|
last_n_tokens_size: int = 64,
|
||||||
# LoRA Params
|
# LoRA Params
|
||||||
|
@ -168,6 +169,7 @@ class Llama:
|
||||||
logits_all: Return logits for all tokens, not just the last token. Must be True for completion to return logprobs.
|
logits_all: Return logits for all tokens, not just the last token. Must be True for completion to return logprobs.
|
||||||
embedding: Embedding mode only.
|
embedding: Embedding mode only.
|
||||||
offload_kqv: Offload K, Q, V to GPU.
|
offload_kqv: Offload K, Q, V to GPU.
|
||||||
|
flash_attn: Use flash attention.
|
||||||
last_n_tokens_size: Maximum number of tokens to keep in the last_n_tokens deque.
|
last_n_tokens_size: Maximum number of tokens to keep in the last_n_tokens deque.
|
||||||
lora_base: Optional path to base model, useful if using a quantized base model and you want to apply LoRA to an f16 model.
|
lora_base: Optional path to base model, useful if using a quantized base model and you want to apply LoRA to an f16 model.
|
||||||
lora_path: Path to a LoRA file to apply to the model.
|
lora_path: Path to a LoRA file to apply to the model.
|
||||||
|
@ -310,6 +312,7 @@ class Llama:
|
||||||
) # Must be set to True for speculative decoding
|
) # Must be set to True for speculative decoding
|
||||||
self.context_params.embeddings = embedding # TODO: Rename to embeddings
|
self.context_params.embeddings = embedding # TODO: Rename to embeddings
|
||||||
self.context_params.offload_kqv = offload_kqv
|
self.context_params.offload_kqv = offload_kqv
|
||||||
|
self.context_params.flash_attn = flash_attn
|
||||||
# KV cache quantization
|
# KV cache quantization
|
||||||
if type_k is not None:
|
if type_k is not None:
|
||||||
self.context_params.type_k = type_k
|
self.context_params.type_k = type_k
|
||||||
|
@ -1774,6 +1777,7 @@ class Llama:
|
||||||
logits_all=self.context_params.logits_all,
|
logits_all=self.context_params.logits_all,
|
||||||
embedding=self.context_params.embeddings,
|
embedding=self.context_params.embeddings,
|
||||||
offload_kqv=self.context_params.offload_kqv,
|
offload_kqv=self.context_params.offload_kqv,
|
||||||
|
flash_offload=self.context_params.flash_offload,
|
||||||
# Sampling Params
|
# Sampling Params
|
||||||
last_n_tokens_size=self.last_n_tokens_size,
|
last_n_tokens_size=self.last_n_tokens_size,
|
||||||
# LoRA Params
|
# LoRA Params
|
||||||
|
|
|
@ -96,6 +96,9 @@ class ModelSettings(BaseSettings):
|
||||||
offload_kqv: bool = Field(
|
offload_kqv: bool = Field(
|
||||||
default=True, description="Whether to offload kqv to the GPU."
|
default=True, description="Whether to offload kqv to the GPU."
|
||||||
)
|
)
|
||||||
|
flash_attn: bool = Field(
|
||||||
|
default=False, description="Whether to use flash attention."
|
||||||
|
)
|
||||||
# Sampling Params
|
# Sampling Params
|
||||||
last_n_tokens_size: int = Field(
|
last_n_tokens_size: int = Field(
|
||||||
default=64,
|
default=64,
|
||||||
|
|
Loading…
Reference in a new issue