Only set default keep_alive on initial model load

This change fixes the handling of keep_alive so that if client
request omits the setting, we only set this on initial load.  Once
the model is loaded, if new requests leave this unset, we'll keep
whatever keep_alive was there.
This commit is contained in:
Daniel Hiltgen 2024-07-02 15:12:43 -07:00
parent ccd7785859
commit 955f2a4e03
5 changed files with 70 additions and 71 deletions

View file

@ -4,12 +4,14 @@ import (
"errors" "errors"
"fmt" "fmt"
"log/slog" "log/slog"
"math"
"net" "net"
"os" "os"
"path/filepath" "path/filepath"
"runtime" "runtime"
"strconv" "strconv"
"strings" "strings"
"time"
) )
type OllamaHost struct { type OllamaHost struct {
@ -34,7 +36,7 @@ var (
// Set via OLLAMA_HOST in the environment // Set via OLLAMA_HOST in the environment
Host *OllamaHost Host *OllamaHost
// Set via OLLAMA_KEEP_ALIVE in the environment // Set via OLLAMA_KEEP_ALIVE in the environment
KeepAlive string KeepAlive time.Duration
// Set via OLLAMA_LLM_LIBRARY in the environment // Set via OLLAMA_LLM_LIBRARY in the environment
LLMLibrary string LLMLibrary string
// Set via OLLAMA_MAX_LOADED_MODELS in the environment // Set via OLLAMA_MAX_LOADED_MODELS in the environment
@ -132,6 +134,7 @@ func init() {
NumParallel = 0 // Autoselect NumParallel = 0 // Autoselect
MaxRunners = 0 // Autoselect MaxRunners = 0 // Autoselect
MaxQueuedRequests = 512 MaxQueuedRequests = 512
KeepAlive = 5 * time.Minute
LoadConfig() LoadConfig()
} }
@ -266,7 +269,10 @@ func LoadConfig() {
} }
} }
KeepAlive = clean("OLLAMA_KEEP_ALIVE") ka := clean("OLLAMA_KEEP_ALIVE")
if ka != "" {
loadKeepAlive(ka)
}
var err error var err error
ModelsDir, err = getModelsDir() ModelsDir, err = getModelsDir()
@ -344,3 +350,24 @@ func getOllamaHost() (*OllamaHost, error) {
Port: port, Port: port,
}, nil }, nil
} }
func loadKeepAlive(ka string) {
v, err := strconv.Atoi(ka)
if err != nil {
d, err := time.ParseDuration(ka)
if err == nil {
if d < 0 {
KeepAlive = time.Duration(math.MaxInt64)
} else {
KeepAlive = d
}
}
} else {
d := time.Duration(v) * time.Second
if d < 0 {
KeepAlive = time.Duration(math.MaxInt64)
} else {
KeepAlive = d
}
}
}

View file

@ -2,8 +2,10 @@ package envconfig
import ( import (
"fmt" "fmt"
"math"
"net" "net"
"testing" "testing"
"time"
"github.com/stretchr/testify/assert" "github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require" "github.com/stretchr/testify/require"
@ -23,6 +25,21 @@ func TestConfig(t *testing.T) {
t.Setenv("OLLAMA_FLASH_ATTENTION", "1") t.Setenv("OLLAMA_FLASH_ATTENTION", "1")
LoadConfig() LoadConfig()
require.True(t, FlashAttention) require.True(t, FlashAttention)
t.Setenv("OLLAMA_KEEP_ALIVE", "")
LoadConfig()
require.Equal(t, 5*time.Minute, KeepAlive)
t.Setenv("OLLAMA_KEEP_ALIVE", "3")
LoadConfig()
require.Equal(t, 3*time.Second, KeepAlive)
t.Setenv("OLLAMA_KEEP_ALIVE", "1h")
LoadConfig()
require.Equal(t, 1*time.Hour, KeepAlive)
t.Setenv("OLLAMA_KEEP_ALIVE", "-1s")
LoadConfig()
require.Equal(t, time.Duration(math.MaxInt64), KeepAlive)
t.Setenv("OLLAMA_KEEP_ALIVE", "-1")
LoadConfig()
require.Equal(t, time.Duration(math.MaxInt64), KeepAlive)
} }
func TestClientFromEnvironment(t *testing.T) { func TestClientFromEnvironment(t *testing.T) {

View file

@ -9,7 +9,6 @@ import (
"io" "io"
"io/fs" "io/fs"
"log/slog" "log/slog"
"math"
"net" "net"
"net/http" "net/http"
"net/netip" "net/netip"
@ -17,7 +16,6 @@ import (
"os/signal" "os/signal"
"path/filepath" "path/filepath"
"slices" "slices"
"strconv"
"strings" "strings"
"syscall" "syscall"
"time" "time"
@ -56,8 +54,6 @@ func init() {
gin.SetMode(mode) gin.SetMode(mode)
} }
var defaultSessionDuration = 5 * time.Minute
func modelOptions(model *Model, requestOpts map[string]interface{}) (api.Options, error) { func modelOptions(model *Model, requestOpts map[string]interface{}) (api.Options, error) {
opts := api.DefaultOptions() opts := api.DefaultOptions()
if err := opts.FromMap(model.Options); err != nil { if err := opts.FromMap(model.Options); err != nil {
@ -133,14 +129,7 @@ func (s *Server) GenerateHandler(c *gin.Context) {
return return
} }
var sessionDuration time.Duration rCh, eCh := s.sched.GetRunner(c.Request.Context(), model, opts, req.KeepAlive)
if req.KeepAlive == nil {
sessionDuration = getDefaultSessionDuration()
} else {
sessionDuration = req.KeepAlive.Duration
}
rCh, eCh := s.sched.GetRunner(c.Request.Context(), model, opts, sessionDuration)
var runner *runnerRef var runner *runnerRef
select { select {
case runner = <-rCh: case runner = <-rCh:
@ -320,32 +309,6 @@ func (s *Server) GenerateHandler(c *gin.Context) {
streamResponse(c, ch) streamResponse(c, ch)
} }
func getDefaultSessionDuration() time.Duration {
if envconfig.KeepAlive != "" {
v, err := strconv.Atoi(envconfig.KeepAlive)
if err != nil {
d, err := time.ParseDuration(envconfig.KeepAlive)
if err != nil {
return defaultSessionDuration
}
if d < 0 {
return time.Duration(math.MaxInt64)
}
return d
}
d := time.Duration(v) * time.Second
if d < 0 {
return time.Duration(math.MaxInt64)
}
return d
}
return defaultSessionDuration
}
func (s *Server) EmbeddingsHandler(c *gin.Context) { func (s *Server) EmbeddingsHandler(c *gin.Context) {
var req api.EmbeddingRequest var req api.EmbeddingRequest
err := c.ShouldBindJSON(&req) err := c.ShouldBindJSON(&req)
@ -380,14 +343,7 @@ func (s *Server) EmbeddingsHandler(c *gin.Context) {
return return
} }
var sessionDuration time.Duration rCh, eCh := s.sched.GetRunner(c.Request.Context(), model, opts, req.KeepAlive)
if req.KeepAlive == nil {
sessionDuration = getDefaultSessionDuration()
} else {
sessionDuration = req.KeepAlive.Duration
}
rCh, eCh := s.sched.GetRunner(c.Request.Context(), model, opts, sessionDuration)
var runner *runnerRef var runner *runnerRef
select { select {
case runner = <-rCh: case runner = <-rCh:
@ -1318,14 +1274,7 @@ func (s *Server) ChatHandler(c *gin.Context) {
return return
} }
var sessionDuration time.Duration rCh, eCh := s.sched.GetRunner(c.Request.Context(), model, opts, req.KeepAlive)
if req.KeepAlive == nil {
sessionDuration = getDefaultSessionDuration()
} else {
sessionDuration = req.KeepAlive.Duration
}
rCh, eCh := s.sched.GetRunner(c.Request.Context(), model, opts, sessionDuration)
var runner *runnerRef var runner *runnerRef
select { select {
case runner = <-rCh: case runner = <-rCh:

View file

@ -24,7 +24,7 @@ type LlmRequest struct {
model *Model model *Model
opts api.Options opts api.Options
origNumCtx int // Track the initial ctx request origNumCtx int // Track the initial ctx request
sessionDuration time.Duration sessionDuration *api.Duration
successCh chan *runnerRef successCh chan *runnerRef
errCh chan error errCh chan error
schedAttempts uint schedAttempts uint
@ -75,7 +75,7 @@ func InitScheduler(ctx context.Context) *Scheduler {
} }
// context must be canceled to decrement ref count and release the runner // context must be canceled to decrement ref count and release the runner
func (s *Scheduler) GetRunner(c context.Context, model *Model, opts api.Options, sessionDuration time.Duration) (chan *runnerRef, chan error) { func (s *Scheduler) GetRunner(c context.Context, model *Model, opts api.Options, sessionDuration *api.Duration) (chan *runnerRef, chan error) {
if opts.NumCtx < 4 { if opts.NumCtx < 4 {
opts.NumCtx = 4 opts.NumCtx = 4
} }
@ -389,7 +389,9 @@ func (pending *LlmRequest) useLoadedRunner(runner *runnerRef, finished chan *Llm
runner.expireTimer.Stop() runner.expireTimer.Stop()
runner.expireTimer = nil runner.expireTimer = nil
} }
runner.sessionDuration = pending.sessionDuration if pending.sessionDuration != nil {
runner.sessionDuration = pending.sessionDuration.Duration
}
pending.successCh <- runner pending.successCh <- runner
go func() { go func() {
<-pending.ctx.Done() <-pending.ctx.Done()
@ -402,6 +404,10 @@ func (s *Scheduler) load(req *LlmRequest, ggml *llm.GGML, gpus gpu.GpuInfoList,
if numParallel < 1 { if numParallel < 1 {
numParallel = 1 numParallel = 1
} }
sessionDuration := envconfig.KeepAlive
if req.sessionDuration != nil {
sessionDuration = req.sessionDuration.Duration
}
llama, err := s.newServerFn(gpus, req.model.ModelPath, ggml, req.model.AdapterPaths, req.model.ProjectorPaths, req.opts, numParallel) llama, err := s.newServerFn(gpus, req.model.ModelPath, ggml, req.model.AdapterPaths, req.model.ProjectorPaths, req.opts, numParallel)
if err != nil { if err != nil {
// some older models are not compatible with newer versions of llama.cpp // some older models are not compatible with newer versions of llama.cpp
@ -419,7 +425,7 @@ func (s *Scheduler) load(req *LlmRequest, ggml *llm.GGML, gpus gpu.GpuInfoList,
modelPath: req.model.ModelPath, modelPath: req.model.ModelPath,
llama: llama, llama: llama,
Options: &req.opts, Options: &req.opts,
sessionDuration: req.sessionDuration, sessionDuration: sessionDuration,
gpus: gpus, gpus: gpus,
estimatedVRAM: llama.EstimatedVRAM(), estimatedVRAM: llama.EstimatedVRAM(),
estimatedTotal: llama.EstimatedTotal(), estimatedTotal: llama.EstimatedTotal(),

View file

@ -44,7 +44,7 @@ func TestLoad(t *testing.T) {
opts: api.DefaultOptions(), opts: api.DefaultOptions(),
successCh: make(chan *runnerRef, 1), successCh: make(chan *runnerRef, 1),
errCh: make(chan error, 1), errCh: make(chan error, 1),
sessionDuration: 2, sessionDuration: &api.Duration{Duration: 2 * time.Second},
} }
// Fail to load model first // Fail to load model first
s.newServerFn = func(gpus gpu.GpuInfoList, model string, ggml *llm.GGML, adapters []string, projectors []string, opts api.Options, numParallel int) (llm.LlamaServer, error) { s.newServerFn = func(gpus gpu.GpuInfoList, model string, ggml *llm.GGML, adapters []string, projectors []string, opts api.Options, numParallel int) (llm.LlamaServer, error) {
@ -142,7 +142,7 @@ func newScenario(t *testing.T, ctx context.Context, modelName string, estimatedV
ctx: scenario.ctx, ctx: scenario.ctx,
model: model, model: model,
opts: api.DefaultOptions(), opts: api.DefaultOptions(),
sessionDuration: 5 * time.Millisecond, sessionDuration: &api.Duration{Duration: 5 * time.Millisecond},
successCh: make(chan *runnerRef, 1), successCh: make(chan *runnerRef, 1),
errCh: make(chan error, 1), errCh: make(chan error, 1),
} }
@ -156,18 +156,18 @@ func TestRequests(t *testing.T) {
// Same model, same request // Same model, same request
scenario1a := newScenario(t, ctx, "ollama-model-1", 10) scenario1a := newScenario(t, ctx, "ollama-model-1", 10)
scenario1a.req.sessionDuration = 5 * time.Millisecond scenario1a.req.sessionDuration = &api.Duration{Duration: 5 * time.Millisecond}
scenario1b := newScenario(t, ctx, "ollama-model-1", 11) scenario1b := newScenario(t, ctx, "ollama-model-1", 11)
scenario1b.req.model = scenario1a.req.model scenario1b.req.model = scenario1a.req.model
scenario1b.ggml = scenario1a.ggml scenario1b.ggml = scenario1a.ggml
scenario1b.req.sessionDuration = 0 scenario1b.req.sessionDuration = &api.Duration{Duration: 0}
// simple reload of same model // simple reload of same model
scenario2a := newScenario(t, ctx, "ollama-model-1", 20) scenario2a := newScenario(t, ctx, "ollama-model-1", 20)
tmpModel := *scenario1a.req.model tmpModel := *scenario1a.req.model
scenario2a.req.model = &tmpModel scenario2a.req.model = &tmpModel
scenario2a.ggml = scenario1a.ggml scenario2a.ggml = scenario1a.ggml
scenario2a.req.sessionDuration = 5 * time.Millisecond scenario2a.req.sessionDuration = &api.Duration{Duration: 5 * time.Millisecond}
// Multiple loaded models // Multiple loaded models
scenario3a := newScenario(t, ctx, "ollama-model-3a", 1*format.GigaByte) scenario3a := newScenario(t, ctx, "ollama-model-3a", 1*format.GigaByte)
@ -318,11 +318,11 @@ func TestGetRunner(t *testing.T) {
defer done() defer done()
scenario1a := newScenario(t, ctx, "ollama-model-1a", 10) scenario1a := newScenario(t, ctx, "ollama-model-1a", 10)
scenario1a.req.sessionDuration = 0 scenario1a.req.sessionDuration = &api.Duration{Duration: 0}
scenario1b := newScenario(t, ctx, "ollama-model-1b", 10) scenario1b := newScenario(t, ctx, "ollama-model-1b", 10)
scenario1b.req.sessionDuration = 0 scenario1b.req.sessionDuration = &api.Duration{Duration: 0}
scenario1c := newScenario(t, ctx, "ollama-model-1c", 10) scenario1c := newScenario(t, ctx, "ollama-model-1c", 10)
scenario1c.req.sessionDuration = 0 scenario1c.req.sessionDuration = &api.Duration{Duration: 0}
envconfig.MaxQueuedRequests = 1 envconfig.MaxQueuedRequests = 1
s := InitScheduler(ctx) s := InitScheduler(ctx)
s.getGpuFn = func() gpu.GpuInfoList { s.getGpuFn = func() gpu.GpuInfoList {
@ -402,7 +402,7 @@ func TestPrematureExpired(t *testing.T) {
case <-ctx.Done(): case <-ctx.Done():
t.Fatal("timeout") t.Fatal("timeout")
} }
time.Sleep(scenario1a.req.sessionDuration) time.Sleep(scenario1a.req.sessionDuration.Duration)
scenario1a.ctxDone() scenario1a.ctxDone()
time.Sleep(20 * time.Millisecond) time.Sleep(20 * time.Millisecond)
require.LessOrEqual(t, len(s.finishedReqCh), 1) require.LessOrEqual(t, len(s.finishedReqCh), 1)
@ -423,7 +423,7 @@ func TestUseLoadedRunner(t *testing.T) {
ctx: ctx, ctx: ctx,
opts: api.DefaultOptions(), opts: api.DefaultOptions(),
successCh: make(chan *runnerRef, 1), successCh: make(chan *runnerRef, 1),
sessionDuration: 2, sessionDuration: &api.Duration{Duration: 2},
} }
finished := make(chan *LlmRequest) finished := make(chan *LlmRequest)
llm1 := &mockLlm{estimatedVRAMByGPU: map[string]uint64{}} llm1 := &mockLlm{estimatedVRAMByGPU: map[string]uint64{}}
@ -614,7 +614,7 @@ func TestAlreadyCanceled(t *testing.T) {
dctx, done2 := context.WithCancel(ctx) dctx, done2 := context.WithCancel(ctx)
done2() done2()
scenario1a := newScenario(t, dctx, "ollama-model-1", 10) scenario1a := newScenario(t, dctx, "ollama-model-1", 10)
scenario1a.req.sessionDuration = 0 scenario1a.req.sessionDuration = &api.Duration{Duration: 0}
s := InitScheduler(ctx) s := InitScheduler(ctx)
slog.Info("scenario1a") slog.Info("scenario1a")
s.pendingReqCh <- scenario1a.req s.pendingReqCh <- scenario1a.req