Use flash attention flag for now (#4580)

* put flash attention behind flag for now

* add test

* remove print

* up timeout for sheduler tests
This commit is contained in:
Jeffrey Morgan 2024-05-22 21:52:09 -07:00 committed by GitHub
parent 73630a7e85
commit 38255d2af1
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
4 changed files with 19 additions and 6 deletions

View file

@ -200,20 +200,20 @@ func NewLlamaServer(gpus gpu.GpuInfoList, model string, ggml *GGML, adapters, pr
params = append(params, "--numa")
}
flashAttnSupported := true
flashAttnEnabled := envconfig.FlashAttention
// partial offloading does not support flash attention
if uint64(opts.NumGPU) < ggml.KV().BlockCount()+1 {
flashAttnSupported = false
flashAttnEnabled = false
}
// only cuda (compute capability 7+) and metal support flash attention
for _, g := range gpus {
if g.Library != "metal" && (g.Library != "cuda" || g.DriverMajor < 7) {
flashAttnSupported = false
flashAttnEnabled = false
}
}
if flashAttnSupported {
if flashAttnEnabled {
params = append(params, "--flash-attn")
}

View file

@ -31,6 +31,8 @@ var (
RunnersDir string
// Set via OLLAMA_TMPDIR in the environment
TmpDir string
// Experimental flash attention
FlashAttention bool
)
func AsMap() map[string]string {
@ -45,6 +47,7 @@ func AsMap() map[string]string {
"OLLAMA_NUM_PARALLEL": fmt.Sprintf("%v", NumParallel),
"OLLAMA_RUNNERS_DIR": fmt.Sprintf("%v", RunnersDir),
"OLLAMA_TMPDIR": fmt.Sprintf("%v", TmpDir),
"OLLAMA_FLASH_ATTENTION": fmt.Sprintf("%v", FlashAttention),
}
}
@ -78,6 +81,13 @@ func LoadConfig() {
}
}
if fa := clean("OLLAMA_FLASH_ATTENTION"); fa != "" {
d, err := strconv.ParseBool(fa)
if err == nil {
FlashAttention = d
}
}
RunnersDir = clean("OLLAMA_RUNNERS_DIR")
if runtime.GOOS == "windows" && RunnersDir == "" {
// On Windows we do not carry the payloads inside the main executable

View file

@ -17,4 +17,7 @@ func TestConfig(t *testing.T) {
t.Setenv("OLLAMA_DEBUG", "1")
LoadConfig()
require.True(t, Debug)
t.Setenv("OLLAMA_FLASH_ATTENTION", "1")
LoadConfig()
require.True(t, FlashAttention)
}

View file

@ -151,7 +151,7 @@ func newScenario(t *testing.T, ctx context.Context, modelName string, estimatedV
}
func TestRequests(t *testing.T) {
ctx, done := context.WithTimeout(context.Background(), 500*time.Millisecond)
ctx, done := context.WithTimeout(context.Background(), time.Second)
defer done()
// Same model, same request