Use flash attention flag for now (#4580)
* put flash attention behind flag for now * add test * remove print * up timeout for sheduler tests
This commit is contained in:
parent
73630a7e85
commit
38255d2af1
4 changed files with 19 additions and 6 deletions
|
@ -200,20 +200,20 @@ func NewLlamaServer(gpus gpu.GpuInfoList, model string, ggml *GGML, adapters, pr
|
|||
params = append(params, "--numa")
|
||||
}
|
||||
|
||||
flashAttnSupported := true
|
||||
flashAttnEnabled := envconfig.FlashAttention
|
||||
|
||||
// partial offloading does not support flash attention
|
||||
if uint64(opts.NumGPU) < ggml.KV().BlockCount()+1 {
|
||||
flashAttnSupported = false
|
||||
flashAttnEnabled = false
|
||||
}
|
||||
|
||||
// only cuda (compute capability 7+) and metal support flash attention
|
||||
for _, g := range gpus {
|
||||
if g.Library != "metal" && (g.Library != "cuda" || g.DriverMajor < 7) {
|
||||
flashAttnSupported = false
|
||||
flashAttnEnabled = false
|
||||
}
|
||||
}
|
||||
if flashAttnSupported {
|
||||
if flashAttnEnabled {
|
||||
params = append(params, "--flash-attn")
|
||||
}
|
||||
|
||||
|
|
|
@ -31,6 +31,8 @@ var (
|
|||
RunnersDir string
|
||||
// Set via OLLAMA_TMPDIR in the environment
|
||||
TmpDir string
|
||||
// Experimental flash attention
|
||||
FlashAttention bool
|
||||
)
|
||||
|
||||
func AsMap() map[string]string {
|
||||
|
@ -45,6 +47,7 @@ func AsMap() map[string]string {
|
|||
"OLLAMA_NUM_PARALLEL": fmt.Sprintf("%v", NumParallel),
|
||||
"OLLAMA_RUNNERS_DIR": fmt.Sprintf("%v", RunnersDir),
|
||||
"OLLAMA_TMPDIR": fmt.Sprintf("%v", TmpDir),
|
||||
"OLLAMA_FLASH_ATTENTION": fmt.Sprintf("%v", FlashAttention),
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -78,6 +81,13 @@ func LoadConfig() {
|
|||
}
|
||||
}
|
||||
|
||||
if fa := clean("OLLAMA_FLASH_ATTENTION"); fa != "" {
|
||||
d, err := strconv.ParseBool(fa)
|
||||
if err == nil {
|
||||
FlashAttention = d
|
||||
}
|
||||
}
|
||||
|
||||
RunnersDir = clean("OLLAMA_RUNNERS_DIR")
|
||||
if runtime.GOOS == "windows" && RunnersDir == "" {
|
||||
// On Windows we do not carry the payloads inside the main executable
|
||||
|
|
|
@ -17,4 +17,7 @@ func TestConfig(t *testing.T) {
|
|||
t.Setenv("OLLAMA_DEBUG", "1")
|
||||
LoadConfig()
|
||||
require.True(t, Debug)
|
||||
t.Setenv("OLLAMA_FLASH_ATTENTION", "1")
|
||||
LoadConfig()
|
||||
require.True(t, FlashAttention)
|
||||
}
|
||||
|
|
|
@ -151,7 +151,7 @@ func newScenario(t *testing.T, ctx context.Context, modelName string, estimatedV
|
|||
}
|
||||
|
||||
func TestRequests(t *testing.T) {
|
||||
ctx, done := context.WithTimeout(context.Background(), 500*time.Millisecond)
|
||||
ctx, done := context.WithTimeout(context.Background(), time.Second)
|
||||
defer done()
|
||||
|
||||
// Same model, same request
|
||||
|
|
Loading…
Reference in a new issue