feat: add support for flash_attn (#4120)

* feat: enable flash attention if supported

* feat: enable flash attention if supported

* feat: enable flash attention if supported

* feat: add flash_attn support
This commit is contained in:
Sam 2024-05-21 06:36:03 +10:00 committed by GitHub
parent ccdf0b2a44
commit e15307fdf4
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
2 changed files with 28 additions and 3 deletions

View file

@ -2104,6 +2104,7 @@ static void server_print_usage(const char *argv0, const gpt_params &params,
printf(" --embedding enable embedding vector output (default: %s)\n", params.embedding ? "enabled" : "disabled"); printf(" --embedding enable embedding vector output (default: %s)\n", params.embedding ? "enabled" : "disabled");
printf(" -np N, --parallel N number of slots for process requests (default: %d)\n", params.n_parallel); printf(" -np N, --parallel N number of slots for process requests (default: %d)\n", params.n_parallel);
printf(" -cb, --cont-batching enable continuous batching (a.k.a dynamic batching) (default: disabled)\n"); printf(" -cb, --cont-batching enable continuous batching (a.k.a dynamic batching) (default: disabled)\n");
printf(" -fa, --flash-attn enable Flash Attention (default: %s)\n", params.flash_attn ? "enabled" : "disabled");
printf(" -spf FNAME, --system-prompt-file FNAME\n"); printf(" -spf FNAME, --system-prompt-file FNAME\n");
printf(" set a file to load a system prompt (initial prompt of all slots), this is useful for chat applications.\n"); printf(" set a file to load a system prompt (initial prompt of all slots), this is useful for chat applications.\n");
printf(" -ctk TYPE, --cache-type-k TYPE\n"); printf(" -ctk TYPE, --cache-type-k TYPE\n");
@ -2501,7 +2502,8 @@ static void server_params_parse(int argc, char **argv, server_params &sparams,
{ {
params.use_mmap = false; params.use_mmap = false;
} }
else if (arg == "--numa") { else if (arg == "--numa")
{
if (++i >= argc) { if (++i >= argc) {
invalid_param = true; invalid_param = true;
break; break;
@ -2521,6 +2523,10 @@ static void server_params_parse(int argc, char **argv, server_params &sparams,
{ {
params.cont_batching = true; params.cont_batching = true;
} }
else if (arg == "-fa" || arg == "--flash-attn")
{
params.flash_attn = true;
}
else if (arg == "-np" || arg == "--parallel") else if (arg == "-np" || arg == "--parallel")
{ {
if (++i >= argc) if (++i >= argc)
@ -2529,7 +2535,8 @@ static void server_params_parse(int argc, char **argv, server_params &sparams,
break; break;
} }
params.n_parallel = std::stoi(argv[i]); params.n_parallel = std::stoi(argv[i]);
} else if (arg == "-n" || arg == "--n-predict") }
else if (arg == "-n" || arg == "--n-predict")
{ {
if (++i >= argc) if (++i >= argc)
{ {
@ -2537,7 +2544,8 @@ static void server_params_parse(int argc, char **argv, server_params &sparams,
break; break;
} }
params.n_predict = std::stoi(argv[i]); params.n_predict = std::stoi(argv[i]);
} else if (arg == "-spf" || arg == "--system-prompt-file") }
else if (arg == "-spf" || arg == "--system-prompt-file")
{ {
if (++i >= argc) if (++i >= argc)
{ {

View file

@ -200,6 +200,23 @@ func NewLlamaServer(gpus gpu.GpuInfoList, model string, ggml *GGML, adapters, pr
params = append(params, "--numa") params = append(params, "--numa")
} }
flashAttnSupported := true
// partial offloading does not support flash attention
if uint64(opts.NumGPU) < ggml.KV().BlockCount() + 1 {
flashAttnSupported = false
}
// only cuda (compute capability 7+) and metal support flash attention
for _, g := range gpus {
if g.Library != "metal" && (g.Library != "cuda" || g.DriverMajor < 7) {
flashAttnSupported = false
}
}
if flashAttnSupported {
params = append(params, "--flash-attn")
}
numParallel := envconfig.NumParallel numParallel := envconfig.NumParallel
// TODO (jmorganca): multimodal models don't support parallel yet // TODO (jmorganca): multimodal models don't support parallel yet