feat: add support for flash_attn (#4120)
* feat: enable flash attention if supported * feat: enable flash attention if supported * feat: enable flash attention if supported * feat: add flash_attn support
This commit is contained in:
parent
ccdf0b2a44
commit
e15307fdf4
2 changed files with 28 additions and 3 deletions
14
llm/ext_server/server.cpp
vendored
14
llm/ext_server/server.cpp
vendored
|
@ -2104,6 +2104,7 @@ static void server_print_usage(const char *argv0, const gpt_params ¶ms,
|
||||||
printf(" --embedding enable embedding vector output (default: %s)\n", params.embedding ? "enabled" : "disabled");
|
printf(" --embedding enable embedding vector output (default: %s)\n", params.embedding ? "enabled" : "disabled");
|
||||||
printf(" -np N, --parallel N number of slots for process requests (default: %d)\n", params.n_parallel);
|
printf(" -np N, --parallel N number of slots for process requests (default: %d)\n", params.n_parallel);
|
||||||
printf(" -cb, --cont-batching enable continuous batching (a.k.a dynamic batching) (default: disabled)\n");
|
printf(" -cb, --cont-batching enable continuous batching (a.k.a dynamic batching) (default: disabled)\n");
|
||||||
|
printf(" -fa, --flash-attn enable Flash Attention (default: %s)\n", params.flash_attn ? "enabled" : "disabled");
|
||||||
printf(" -spf FNAME, --system-prompt-file FNAME\n");
|
printf(" -spf FNAME, --system-prompt-file FNAME\n");
|
||||||
printf(" set a file to load a system prompt (initial prompt of all slots), this is useful for chat applications.\n");
|
printf(" set a file to load a system prompt (initial prompt of all slots), this is useful for chat applications.\n");
|
||||||
printf(" -ctk TYPE, --cache-type-k TYPE\n");
|
printf(" -ctk TYPE, --cache-type-k TYPE\n");
|
||||||
|
@ -2501,7 +2502,8 @@ static void server_params_parse(int argc, char **argv, server_params &sparams,
|
||||||
{
|
{
|
||||||
params.use_mmap = false;
|
params.use_mmap = false;
|
||||||
}
|
}
|
||||||
else if (arg == "--numa") {
|
else if (arg == "--numa")
|
||||||
|
{
|
||||||
if (++i >= argc) {
|
if (++i >= argc) {
|
||||||
invalid_param = true;
|
invalid_param = true;
|
||||||
break;
|
break;
|
||||||
|
@ -2521,6 +2523,10 @@ static void server_params_parse(int argc, char **argv, server_params &sparams,
|
||||||
{
|
{
|
||||||
params.cont_batching = true;
|
params.cont_batching = true;
|
||||||
}
|
}
|
||||||
|
else if (arg == "-fa" || arg == "--flash-attn")
|
||||||
|
{
|
||||||
|
params.flash_attn = true;
|
||||||
|
}
|
||||||
else if (arg == "-np" || arg == "--parallel")
|
else if (arg == "-np" || arg == "--parallel")
|
||||||
{
|
{
|
||||||
if (++i >= argc)
|
if (++i >= argc)
|
||||||
|
@ -2529,7 +2535,8 @@ static void server_params_parse(int argc, char **argv, server_params &sparams,
|
||||||
break;
|
break;
|
||||||
}
|
}
|
||||||
params.n_parallel = std::stoi(argv[i]);
|
params.n_parallel = std::stoi(argv[i]);
|
||||||
} else if (arg == "-n" || arg == "--n-predict")
|
}
|
||||||
|
else if (arg == "-n" || arg == "--n-predict")
|
||||||
{
|
{
|
||||||
if (++i >= argc)
|
if (++i >= argc)
|
||||||
{
|
{
|
||||||
|
@ -2537,7 +2544,8 @@ static void server_params_parse(int argc, char **argv, server_params &sparams,
|
||||||
break;
|
break;
|
||||||
}
|
}
|
||||||
params.n_predict = std::stoi(argv[i]);
|
params.n_predict = std::stoi(argv[i]);
|
||||||
} else if (arg == "-spf" || arg == "--system-prompt-file")
|
}
|
||||||
|
else if (arg == "-spf" || arg == "--system-prompt-file")
|
||||||
{
|
{
|
||||||
if (++i >= argc)
|
if (++i >= argc)
|
||||||
{
|
{
|
||||||
|
|
|
@ -200,6 +200,23 @@ func NewLlamaServer(gpus gpu.GpuInfoList, model string, ggml *GGML, adapters, pr
|
||||||
params = append(params, "--numa")
|
params = append(params, "--numa")
|
||||||
}
|
}
|
||||||
|
|
||||||
|
flashAttnSupported := true
|
||||||
|
|
||||||
|
// partial offloading does not support flash attention
|
||||||
|
if uint64(opts.NumGPU) < ggml.KV().BlockCount() + 1 {
|
||||||
|
flashAttnSupported = false
|
||||||
|
}
|
||||||
|
|
||||||
|
// only cuda (compute capability 7+) and metal support flash attention
|
||||||
|
for _, g := range gpus {
|
||||||
|
if g.Library != "metal" && (g.Library != "cuda" || g.DriverMajor < 7) {
|
||||||
|
flashAttnSupported = false
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if flashAttnSupported {
|
||||||
|
params = append(params, "--flash-attn")
|
||||||
|
}
|
||||||
|
|
||||||
numParallel := envconfig.NumParallel
|
numParallel := envconfig.NumParallel
|
||||||
|
|
||||||
// TODO (jmorganca): multimodal models don't support parallel yet
|
// TODO (jmorganca): multimodal models don't support parallel yet
|
||||||
|
|
Loading…
Reference in a new issue