Use flash attention flag for now (#4580)
* put flash attention behind flag for now * add test * remove print * up timeout for sheduler tests
This commit is contained in:
parent
73630a7e85
commit
38255d2af1
4 changed files with 19 additions and 6 deletions
|
@ -200,20 +200,20 @@ func NewLlamaServer(gpus gpu.GpuInfoList, model string, ggml *GGML, adapters, pr
|
||||||
params = append(params, "--numa")
|
params = append(params, "--numa")
|
||||||
}
|
}
|
||||||
|
|
||||||
flashAttnSupported := true
|
flashAttnEnabled := envconfig.FlashAttention
|
||||||
|
|
||||||
// partial offloading does not support flash attention
|
// partial offloading does not support flash attention
|
||||||
if uint64(opts.NumGPU) < ggml.KV().BlockCount()+1 {
|
if uint64(opts.NumGPU) < ggml.KV().BlockCount()+1 {
|
||||||
flashAttnSupported = false
|
flashAttnEnabled = false
|
||||||
}
|
}
|
||||||
|
|
||||||
// only cuda (compute capability 7+) and metal support flash attention
|
// only cuda (compute capability 7+) and metal support flash attention
|
||||||
for _, g := range gpus {
|
for _, g := range gpus {
|
||||||
if g.Library != "metal" && (g.Library != "cuda" || g.DriverMajor < 7) {
|
if g.Library != "metal" && (g.Library != "cuda" || g.DriverMajor < 7) {
|
||||||
flashAttnSupported = false
|
flashAttnEnabled = false
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
if flashAttnSupported {
|
if flashAttnEnabled {
|
||||||
params = append(params, "--flash-attn")
|
params = append(params, "--flash-attn")
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -31,6 +31,8 @@ var (
|
||||||
RunnersDir string
|
RunnersDir string
|
||||||
// Set via OLLAMA_TMPDIR in the environment
|
// Set via OLLAMA_TMPDIR in the environment
|
||||||
TmpDir string
|
TmpDir string
|
||||||
|
// Experimental flash attention
|
||||||
|
FlashAttention bool
|
||||||
)
|
)
|
||||||
|
|
||||||
func AsMap() map[string]string {
|
func AsMap() map[string]string {
|
||||||
|
@ -45,6 +47,7 @@ func AsMap() map[string]string {
|
||||||
"OLLAMA_NUM_PARALLEL": fmt.Sprintf("%v", NumParallel),
|
"OLLAMA_NUM_PARALLEL": fmt.Sprintf("%v", NumParallel),
|
||||||
"OLLAMA_RUNNERS_DIR": fmt.Sprintf("%v", RunnersDir),
|
"OLLAMA_RUNNERS_DIR": fmt.Sprintf("%v", RunnersDir),
|
||||||
"OLLAMA_TMPDIR": fmt.Sprintf("%v", TmpDir),
|
"OLLAMA_TMPDIR": fmt.Sprintf("%v", TmpDir),
|
||||||
|
"OLLAMA_FLASH_ATTENTION": fmt.Sprintf("%v", FlashAttention),
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -78,6 +81,13 @@ func LoadConfig() {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if fa := clean("OLLAMA_FLASH_ATTENTION"); fa != "" {
|
||||||
|
d, err := strconv.ParseBool(fa)
|
||||||
|
if err == nil {
|
||||||
|
FlashAttention = d
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
RunnersDir = clean("OLLAMA_RUNNERS_DIR")
|
RunnersDir = clean("OLLAMA_RUNNERS_DIR")
|
||||||
if runtime.GOOS == "windows" && RunnersDir == "" {
|
if runtime.GOOS == "windows" && RunnersDir == "" {
|
||||||
// On Windows we do not carry the payloads inside the main executable
|
// On Windows we do not carry the payloads inside the main executable
|
||||||
|
|
|
@ -17,4 +17,7 @@ func TestConfig(t *testing.T) {
|
||||||
t.Setenv("OLLAMA_DEBUG", "1")
|
t.Setenv("OLLAMA_DEBUG", "1")
|
||||||
LoadConfig()
|
LoadConfig()
|
||||||
require.True(t, Debug)
|
require.True(t, Debug)
|
||||||
|
t.Setenv("OLLAMA_FLASH_ATTENTION", "1")
|
||||||
|
LoadConfig()
|
||||||
|
require.True(t, FlashAttention)
|
||||||
}
|
}
|
||||||
|
|
|
@ -151,7 +151,7 @@ func newScenario(t *testing.T, ctx context.Context, modelName string, estimatedV
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestRequests(t *testing.T) {
|
func TestRequests(t *testing.T) {
|
||||||
ctx, done := context.WithTimeout(context.Background(), 500*time.Millisecond)
|
ctx, done := context.WithTimeout(context.Background(), time.Second)
|
||||||
defer done()
|
defer done()
|
||||||
|
|
||||||
// Same model, same request
|
// Same model, same request
|
||||||
|
|
Loading…
Reference in a new issue