Use flash attention flag for now (#4580)

* put flash attention behind flag for now

* add test

* remove print

* up timeout for sheduler tests
This commit is contained in:
Jeffrey Morgan 2024-05-22 21:52:09 -07:00 committed by GitHub
parent 73630a7e85
commit 38255d2af1
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
4 changed files with 19 additions and 6 deletions

View file

@ -200,20 +200,20 @@ func NewLlamaServer(gpus gpu.GpuInfoList, model string, ggml *GGML, adapters, pr
params = append(params, "--numa") params = append(params, "--numa")
} }
flashAttnSupported := true flashAttnEnabled := envconfig.FlashAttention
// partial offloading does not support flash attention // partial offloading does not support flash attention
if uint64(opts.NumGPU) < ggml.KV().BlockCount()+1 { if uint64(opts.NumGPU) < ggml.KV().BlockCount()+1 {
flashAttnSupported = false flashAttnEnabled = false
} }
// only cuda (compute capability 7+) and metal support flash attention // only cuda (compute capability 7+) and metal support flash attention
for _, g := range gpus { for _, g := range gpus {
if g.Library != "metal" && (g.Library != "cuda" || g.DriverMajor < 7) { if g.Library != "metal" && (g.Library != "cuda" || g.DriverMajor < 7) {
flashAttnSupported = false flashAttnEnabled = false
} }
} }
if flashAttnSupported { if flashAttnEnabled {
params = append(params, "--flash-attn") params = append(params, "--flash-attn")
} }

View file

@ -31,6 +31,8 @@ var (
RunnersDir string RunnersDir string
// Set via OLLAMA_TMPDIR in the environment // Set via OLLAMA_TMPDIR in the environment
TmpDir string TmpDir string
// Experimental flash attention
FlashAttention bool
) )
func AsMap() map[string]string { func AsMap() map[string]string {
@ -45,6 +47,7 @@ func AsMap() map[string]string {
"OLLAMA_NUM_PARALLEL": fmt.Sprintf("%v", NumParallel), "OLLAMA_NUM_PARALLEL": fmt.Sprintf("%v", NumParallel),
"OLLAMA_RUNNERS_DIR": fmt.Sprintf("%v", RunnersDir), "OLLAMA_RUNNERS_DIR": fmt.Sprintf("%v", RunnersDir),
"OLLAMA_TMPDIR": fmt.Sprintf("%v", TmpDir), "OLLAMA_TMPDIR": fmt.Sprintf("%v", TmpDir),
"OLLAMA_FLASH_ATTENTION": fmt.Sprintf("%v", FlashAttention),
} }
} }
@ -78,6 +81,13 @@ func LoadConfig() {
} }
} }
if fa := clean("OLLAMA_FLASH_ATTENTION"); fa != "" {
d, err := strconv.ParseBool(fa)
if err == nil {
FlashAttention = d
}
}
RunnersDir = clean("OLLAMA_RUNNERS_DIR") RunnersDir = clean("OLLAMA_RUNNERS_DIR")
if runtime.GOOS == "windows" && RunnersDir == "" { if runtime.GOOS == "windows" && RunnersDir == "" {
// On Windows we do not carry the payloads inside the main executable // On Windows we do not carry the payloads inside the main executable

View file

@ -17,4 +17,7 @@ func TestConfig(t *testing.T) {
t.Setenv("OLLAMA_DEBUG", "1") t.Setenv("OLLAMA_DEBUG", "1")
LoadConfig() LoadConfig()
require.True(t, Debug) require.True(t, Debug)
t.Setenv("OLLAMA_FLASH_ATTENTION", "1")
LoadConfig()
require.True(t, FlashAttention)
} }

View file

@ -151,7 +151,7 @@ func newScenario(t *testing.T, ctx context.Context, modelName string, estimatedV
} }
func TestRequests(t *testing.T) { func TestRequests(t *testing.T) {
ctx, done := context.WithTimeout(context.Background(), 500*time.Millisecond) ctx, done := context.WithTimeout(context.Background(), time.Second)
defer done() defer done()
// Same model, same request // Same model, same request