Only set default keep_alive on initial model load
This change fixes the handling of keep_alive so that if client request omits the setting, we only set this on initial load. Once the model is loaded, if new requests leave this unset, we'll keep whatever keep_alive was there.
This commit is contained in:
parent
ccd7785859
commit
955f2a4e03
5 changed files with 70 additions and 71 deletions
|
@ -4,12 +4,14 @@ import (
|
||||||
"errors"
|
"errors"
|
||||||
"fmt"
|
"fmt"
|
||||||
"log/slog"
|
"log/slog"
|
||||||
|
"math"
|
||||||
"net"
|
"net"
|
||||||
"os"
|
"os"
|
||||||
"path/filepath"
|
"path/filepath"
|
||||||
"runtime"
|
"runtime"
|
||||||
"strconv"
|
"strconv"
|
||||||
"strings"
|
"strings"
|
||||||
|
"time"
|
||||||
)
|
)
|
||||||
|
|
||||||
type OllamaHost struct {
|
type OllamaHost struct {
|
||||||
|
@ -34,7 +36,7 @@ var (
|
||||||
// Set via OLLAMA_HOST in the environment
|
// Set via OLLAMA_HOST in the environment
|
||||||
Host *OllamaHost
|
Host *OllamaHost
|
||||||
// Set via OLLAMA_KEEP_ALIVE in the environment
|
// Set via OLLAMA_KEEP_ALIVE in the environment
|
||||||
KeepAlive string
|
KeepAlive time.Duration
|
||||||
// Set via OLLAMA_LLM_LIBRARY in the environment
|
// Set via OLLAMA_LLM_LIBRARY in the environment
|
||||||
LLMLibrary string
|
LLMLibrary string
|
||||||
// Set via OLLAMA_MAX_LOADED_MODELS in the environment
|
// Set via OLLAMA_MAX_LOADED_MODELS in the environment
|
||||||
|
@ -132,6 +134,7 @@ func init() {
|
||||||
NumParallel = 0 // Autoselect
|
NumParallel = 0 // Autoselect
|
||||||
MaxRunners = 0 // Autoselect
|
MaxRunners = 0 // Autoselect
|
||||||
MaxQueuedRequests = 512
|
MaxQueuedRequests = 512
|
||||||
|
KeepAlive = 5 * time.Minute
|
||||||
|
|
||||||
LoadConfig()
|
LoadConfig()
|
||||||
}
|
}
|
||||||
|
@ -266,7 +269,10 @@ func LoadConfig() {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
KeepAlive = clean("OLLAMA_KEEP_ALIVE")
|
ka := clean("OLLAMA_KEEP_ALIVE")
|
||||||
|
if ka != "" {
|
||||||
|
loadKeepAlive(ka)
|
||||||
|
}
|
||||||
|
|
||||||
var err error
|
var err error
|
||||||
ModelsDir, err = getModelsDir()
|
ModelsDir, err = getModelsDir()
|
||||||
|
@ -344,3 +350,24 @@ func getOllamaHost() (*OllamaHost, error) {
|
||||||
Port: port,
|
Port: port,
|
||||||
}, nil
|
}, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func loadKeepAlive(ka string) {
|
||||||
|
v, err := strconv.Atoi(ka)
|
||||||
|
if err != nil {
|
||||||
|
d, err := time.ParseDuration(ka)
|
||||||
|
if err == nil {
|
||||||
|
if d < 0 {
|
||||||
|
KeepAlive = time.Duration(math.MaxInt64)
|
||||||
|
} else {
|
||||||
|
KeepAlive = d
|
||||||
|
}
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
d := time.Duration(v) * time.Second
|
||||||
|
if d < 0 {
|
||||||
|
KeepAlive = time.Duration(math.MaxInt64)
|
||||||
|
} else {
|
||||||
|
KeepAlive = d
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
|
@ -2,8 +2,10 @@ package envconfig
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"fmt"
|
"fmt"
|
||||||
|
"math"
|
||||||
"net"
|
"net"
|
||||||
"testing"
|
"testing"
|
||||||
|
"time"
|
||||||
|
|
||||||
"github.com/stretchr/testify/assert"
|
"github.com/stretchr/testify/assert"
|
||||||
"github.com/stretchr/testify/require"
|
"github.com/stretchr/testify/require"
|
||||||
|
@ -23,6 +25,21 @@ func TestConfig(t *testing.T) {
|
||||||
t.Setenv("OLLAMA_FLASH_ATTENTION", "1")
|
t.Setenv("OLLAMA_FLASH_ATTENTION", "1")
|
||||||
LoadConfig()
|
LoadConfig()
|
||||||
require.True(t, FlashAttention)
|
require.True(t, FlashAttention)
|
||||||
|
t.Setenv("OLLAMA_KEEP_ALIVE", "")
|
||||||
|
LoadConfig()
|
||||||
|
require.Equal(t, 5*time.Minute, KeepAlive)
|
||||||
|
t.Setenv("OLLAMA_KEEP_ALIVE", "3")
|
||||||
|
LoadConfig()
|
||||||
|
require.Equal(t, 3*time.Second, KeepAlive)
|
||||||
|
t.Setenv("OLLAMA_KEEP_ALIVE", "1h")
|
||||||
|
LoadConfig()
|
||||||
|
require.Equal(t, 1*time.Hour, KeepAlive)
|
||||||
|
t.Setenv("OLLAMA_KEEP_ALIVE", "-1s")
|
||||||
|
LoadConfig()
|
||||||
|
require.Equal(t, time.Duration(math.MaxInt64), KeepAlive)
|
||||||
|
t.Setenv("OLLAMA_KEEP_ALIVE", "-1")
|
||||||
|
LoadConfig()
|
||||||
|
require.Equal(t, time.Duration(math.MaxInt64), KeepAlive)
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestClientFromEnvironment(t *testing.T) {
|
func TestClientFromEnvironment(t *testing.T) {
|
||||||
|
|
|
@ -9,7 +9,6 @@ import (
|
||||||
"io"
|
"io"
|
||||||
"io/fs"
|
"io/fs"
|
||||||
"log/slog"
|
"log/slog"
|
||||||
"math"
|
|
||||||
"net"
|
"net"
|
||||||
"net/http"
|
"net/http"
|
||||||
"net/netip"
|
"net/netip"
|
||||||
|
@ -17,7 +16,6 @@ import (
|
||||||
"os/signal"
|
"os/signal"
|
||||||
"path/filepath"
|
"path/filepath"
|
||||||
"slices"
|
"slices"
|
||||||
"strconv"
|
|
||||||
"strings"
|
"strings"
|
||||||
"syscall"
|
"syscall"
|
||||||
"time"
|
"time"
|
||||||
|
@ -56,8 +54,6 @@ func init() {
|
||||||
gin.SetMode(mode)
|
gin.SetMode(mode)
|
||||||
}
|
}
|
||||||
|
|
||||||
var defaultSessionDuration = 5 * time.Minute
|
|
||||||
|
|
||||||
func modelOptions(model *Model, requestOpts map[string]interface{}) (api.Options, error) {
|
func modelOptions(model *Model, requestOpts map[string]interface{}) (api.Options, error) {
|
||||||
opts := api.DefaultOptions()
|
opts := api.DefaultOptions()
|
||||||
if err := opts.FromMap(model.Options); err != nil {
|
if err := opts.FromMap(model.Options); err != nil {
|
||||||
|
@ -133,14 +129,7 @@ func (s *Server) GenerateHandler(c *gin.Context) {
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
var sessionDuration time.Duration
|
rCh, eCh := s.sched.GetRunner(c.Request.Context(), model, opts, req.KeepAlive)
|
||||||
if req.KeepAlive == nil {
|
|
||||||
sessionDuration = getDefaultSessionDuration()
|
|
||||||
} else {
|
|
||||||
sessionDuration = req.KeepAlive.Duration
|
|
||||||
}
|
|
||||||
|
|
||||||
rCh, eCh := s.sched.GetRunner(c.Request.Context(), model, opts, sessionDuration)
|
|
||||||
var runner *runnerRef
|
var runner *runnerRef
|
||||||
select {
|
select {
|
||||||
case runner = <-rCh:
|
case runner = <-rCh:
|
||||||
|
@ -320,32 +309,6 @@ func (s *Server) GenerateHandler(c *gin.Context) {
|
||||||
streamResponse(c, ch)
|
streamResponse(c, ch)
|
||||||
}
|
}
|
||||||
|
|
||||||
func getDefaultSessionDuration() time.Duration {
|
|
||||||
if envconfig.KeepAlive != "" {
|
|
||||||
v, err := strconv.Atoi(envconfig.KeepAlive)
|
|
||||||
if err != nil {
|
|
||||||
d, err := time.ParseDuration(envconfig.KeepAlive)
|
|
||||||
if err != nil {
|
|
||||||
return defaultSessionDuration
|
|
||||||
}
|
|
||||||
|
|
||||||
if d < 0 {
|
|
||||||
return time.Duration(math.MaxInt64)
|
|
||||||
}
|
|
||||||
|
|
||||||
return d
|
|
||||||
}
|
|
||||||
|
|
||||||
d := time.Duration(v) * time.Second
|
|
||||||
if d < 0 {
|
|
||||||
return time.Duration(math.MaxInt64)
|
|
||||||
}
|
|
||||||
return d
|
|
||||||
}
|
|
||||||
|
|
||||||
return defaultSessionDuration
|
|
||||||
}
|
|
||||||
|
|
||||||
func (s *Server) EmbeddingsHandler(c *gin.Context) {
|
func (s *Server) EmbeddingsHandler(c *gin.Context) {
|
||||||
var req api.EmbeddingRequest
|
var req api.EmbeddingRequest
|
||||||
err := c.ShouldBindJSON(&req)
|
err := c.ShouldBindJSON(&req)
|
||||||
|
@ -380,14 +343,7 @@ func (s *Server) EmbeddingsHandler(c *gin.Context) {
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
var sessionDuration time.Duration
|
rCh, eCh := s.sched.GetRunner(c.Request.Context(), model, opts, req.KeepAlive)
|
||||||
if req.KeepAlive == nil {
|
|
||||||
sessionDuration = getDefaultSessionDuration()
|
|
||||||
} else {
|
|
||||||
sessionDuration = req.KeepAlive.Duration
|
|
||||||
}
|
|
||||||
|
|
||||||
rCh, eCh := s.sched.GetRunner(c.Request.Context(), model, opts, sessionDuration)
|
|
||||||
var runner *runnerRef
|
var runner *runnerRef
|
||||||
select {
|
select {
|
||||||
case runner = <-rCh:
|
case runner = <-rCh:
|
||||||
|
@ -1318,14 +1274,7 @@ func (s *Server) ChatHandler(c *gin.Context) {
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
var sessionDuration time.Duration
|
rCh, eCh := s.sched.GetRunner(c.Request.Context(), model, opts, req.KeepAlive)
|
||||||
if req.KeepAlive == nil {
|
|
||||||
sessionDuration = getDefaultSessionDuration()
|
|
||||||
} else {
|
|
||||||
sessionDuration = req.KeepAlive.Duration
|
|
||||||
}
|
|
||||||
|
|
||||||
rCh, eCh := s.sched.GetRunner(c.Request.Context(), model, opts, sessionDuration)
|
|
||||||
var runner *runnerRef
|
var runner *runnerRef
|
||||||
select {
|
select {
|
||||||
case runner = <-rCh:
|
case runner = <-rCh:
|
||||||
|
|
|
@ -24,7 +24,7 @@ type LlmRequest struct {
|
||||||
model *Model
|
model *Model
|
||||||
opts api.Options
|
opts api.Options
|
||||||
origNumCtx int // Track the initial ctx request
|
origNumCtx int // Track the initial ctx request
|
||||||
sessionDuration time.Duration
|
sessionDuration *api.Duration
|
||||||
successCh chan *runnerRef
|
successCh chan *runnerRef
|
||||||
errCh chan error
|
errCh chan error
|
||||||
schedAttempts uint
|
schedAttempts uint
|
||||||
|
@ -75,7 +75,7 @@ func InitScheduler(ctx context.Context) *Scheduler {
|
||||||
}
|
}
|
||||||
|
|
||||||
// context must be canceled to decrement ref count and release the runner
|
// context must be canceled to decrement ref count and release the runner
|
||||||
func (s *Scheduler) GetRunner(c context.Context, model *Model, opts api.Options, sessionDuration time.Duration) (chan *runnerRef, chan error) {
|
func (s *Scheduler) GetRunner(c context.Context, model *Model, opts api.Options, sessionDuration *api.Duration) (chan *runnerRef, chan error) {
|
||||||
if opts.NumCtx < 4 {
|
if opts.NumCtx < 4 {
|
||||||
opts.NumCtx = 4
|
opts.NumCtx = 4
|
||||||
}
|
}
|
||||||
|
@ -389,7 +389,9 @@ func (pending *LlmRequest) useLoadedRunner(runner *runnerRef, finished chan *Llm
|
||||||
runner.expireTimer.Stop()
|
runner.expireTimer.Stop()
|
||||||
runner.expireTimer = nil
|
runner.expireTimer = nil
|
||||||
}
|
}
|
||||||
runner.sessionDuration = pending.sessionDuration
|
if pending.sessionDuration != nil {
|
||||||
|
runner.sessionDuration = pending.sessionDuration.Duration
|
||||||
|
}
|
||||||
pending.successCh <- runner
|
pending.successCh <- runner
|
||||||
go func() {
|
go func() {
|
||||||
<-pending.ctx.Done()
|
<-pending.ctx.Done()
|
||||||
|
@ -402,6 +404,10 @@ func (s *Scheduler) load(req *LlmRequest, ggml *llm.GGML, gpus gpu.GpuInfoList,
|
||||||
if numParallel < 1 {
|
if numParallel < 1 {
|
||||||
numParallel = 1
|
numParallel = 1
|
||||||
}
|
}
|
||||||
|
sessionDuration := envconfig.KeepAlive
|
||||||
|
if req.sessionDuration != nil {
|
||||||
|
sessionDuration = req.sessionDuration.Duration
|
||||||
|
}
|
||||||
llama, err := s.newServerFn(gpus, req.model.ModelPath, ggml, req.model.AdapterPaths, req.model.ProjectorPaths, req.opts, numParallel)
|
llama, err := s.newServerFn(gpus, req.model.ModelPath, ggml, req.model.AdapterPaths, req.model.ProjectorPaths, req.opts, numParallel)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
// some older models are not compatible with newer versions of llama.cpp
|
// some older models are not compatible with newer versions of llama.cpp
|
||||||
|
@ -419,7 +425,7 @@ func (s *Scheduler) load(req *LlmRequest, ggml *llm.GGML, gpus gpu.GpuInfoList,
|
||||||
modelPath: req.model.ModelPath,
|
modelPath: req.model.ModelPath,
|
||||||
llama: llama,
|
llama: llama,
|
||||||
Options: &req.opts,
|
Options: &req.opts,
|
||||||
sessionDuration: req.sessionDuration,
|
sessionDuration: sessionDuration,
|
||||||
gpus: gpus,
|
gpus: gpus,
|
||||||
estimatedVRAM: llama.EstimatedVRAM(),
|
estimatedVRAM: llama.EstimatedVRAM(),
|
||||||
estimatedTotal: llama.EstimatedTotal(),
|
estimatedTotal: llama.EstimatedTotal(),
|
||||||
|
|
|
@ -44,7 +44,7 @@ func TestLoad(t *testing.T) {
|
||||||
opts: api.DefaultOptions(),
|
opts: api.DefaultOptions(),
|
||||||
successCh: make(chan *runnerRef, 1),
|
successCh: make(chan *runnerRef, 1),
|
||||||
errCh: make(chan error, 1),
|
errCh: make(chan error, 1),
|
||||||
sessionDuration: 2,
|
sessionDuration: &api.Duration{Duration: 2 * time.Second},
|
||||||
}
|
}
|
||||||
// Fail to load model first
|
// Fail to load model first
|
||||||
s.newServerFn = func(gpus gpu.GpuInfoList, model string, ggml *llm.GGML, adapters []string, projectors []string, opts api.Options, numParallel int) (llm.LlamaServer, error) {
|
s.newServerFn = func(gpus gpu.GpuInfoList, model string, ggml *llm.GGML, adapters []string, projectors []string, opts api.Options, numParallel int) (llm.LlamaServer, error) {
|
||||||
|
@ -142,7 +142,7 @@ func newScenario(t *testing.T, ctx context.Context, modelName string, estimatedV
|
||||||
ctx: scenario.ctx,
|
ctx: scenario.ctx,
|
||||||
model: model,
|
model: model,
|
||||||
opts: api.DefaultOptions(),
|
opts: api.DefaultOptions(),
|
||||||
sessionDuration: 5 * time.Millisecond,
|
sessionDuration: &api.Duration{Duration: 5 * time.Millisecond},
|
||||||
successCh: make(chan *runnerRef, 1),
|
successCh: make(chan *runnerRef, 1),
|
||||||
errCh: make(chan error, 1),
|
errCh: make(chan error, 1),
|
||||||
}
|
}
|
||||||
|
@ -156,18 +156,18 @@ func TestRequests(t *testing.T) {
|
||||||
|
|
||||||
// Same model, same request
|
// Same model, same request
|
||||||
scenario1a := newScenario(t, ctx, "ollama-model-1", 10)
|
scenario1a := newScenario(t, ctx, "ollama-model-1", 10)
|
||||||
scenario1a.req.sessionDuration = 5 * time.Millisecond
|
scenario1a.req.sessionDuration = &api.Duration{Duration: 5 * time.Millisecond}
|
||||||
scenario1b := newScenario(t, ctx, "ollama-model-1", 11)
|
scenario1b := newScenario(t, ctx, "ollama-model-1", 11)
|
||||||
scenario1b.req.model = scenario1a.req.model
|
scenario1b.req.model = scenario1a.req.model
|
||||||
scenario1b.ggml = scenario1a.ggml
|
scenario1b.ggml = scenario1a.ggml
|
||||||
scenario1b.req.sessionDuration = 0
|
scenario1b.req.sessionDuration = &api.Duration{Duration: 0}
|
||||||
|
|
||||||
// simple reload of same model
|
// simple reload of same model
|
||||||
scenario2a := newScenario(t, ctx, "ollama-model-1", 20)
|
scenario2a := newScenario(t, ctx, "ollama-model-1", 20)
|
||||||
tmpModel := *scenario1a.req.model
|
tmpModel := *scenario1a.req.model
|
||||||
scenario2a.req.model = &tmpModel
|
scenario2a.req.model = &tmpModel
|
||||||
scenario2a.ggml = scenario1a.ggml
|
scenario2a.ggml = scenario1a.ggml
|
||||||
scenario2a.req.sessionDuration = 5 * time.Millisecond
|
scenario2a.req.sessionDuration = &api.Duration{Duration: 5 * time.Millisecond}
|
||||||
|
|
||||||
// Multiple loaded models
|
// Multiple loaded models
|
||||||
scenario3a := newScenario(t, ctx, "ollama-model-3a", 1*format.GigaByte)
|
scenario3a := newScenario(t, ctx, "ollama-model-3a", 1*format.GigaByte)
|
||||||
|
@ -318,11 +318,11 @@ func TestGetRunner(t *testing.T) {
|
||||||
defer done()
|
defer done()
|
||||||
|
|
||||||
scenario1a := newScenario(t, ctx, "ollama-model-1a", 10)
|
scenario1a := newScenario(t, ctx, "ollama-model-1a", 10)
|
||||||
scenario1a.req.sessionDuration = 0
|
scenario1a.req.sessionDuration = &api.Duration{Duration: 0}
|
||||||
scenario1b := newScenario(t, ctx, "ollama-model-1b", 10)
|
scenario1b := newScenario(t, ctx, "ollama-model-1b", 10)
|
||||||
scenario1b.req.sessionDuration = 0
|
scenario1b.req.sessionDuration = &api.Duration{Duration: 0}
|
||||||
scenario1c := newScenario(t, ctx, "ollama-model-1c", 10)
|
scenario1c := newScenario(t, ctx, "ollama-model-1c", 10)
|
||||||
scenario1c.req.sessionDuration = 0
|
scenario1c.req.sessionDuration = &api.Duration{Duration: 0}
|
||||||
envconfig.MaxQueuedRequests = 1
|
envconfig.MaxQueuedRequests = 1
|
||||||
s := InitScheduler(ctx)
|
s := InitScheduler(ctx)
|
||||||
s.getGpuFn = func() gpu.GpuInfoList {
|
s.getGpuFn = func() gpu.GpuInfoList {
|
||||||
|
@ -402,7 +402,7 @@ func TestPrematureExpired(t *testing.T) {
|
||||||
case <-ctx.Done():
|
case <-ctx.Done():
|
||||||
t.Fatal("timeout")
|
t.Fatal("timeout")
|
||||||
}
|
}
|
||||||
time.Sleep(scenario1a.req.sessionDuration)
|
time.Sleep(scenario1a.req.sessionDuration.Duration)
|
||||||
scenario1a.ctxDone()
|
scenario1a.ctxDone()
|
||||||
time.Sleep(20 * time.Millisecond)
|
time.Sleep(20 * time.Millisecond)
|
||||||
require.LessOrEqual(t, len(s.finishedReqCh), 1)
|
require.LessOrEqual(t, len(s.finishedReqCh), 1)
|
||||||
|
@ -423,7 +423,7 @@ func TestUseLoadedRunner(t *testing.T) {
|
||||||
ctx: ctx,
|
ctx: ctx,
|
||||||
opts: api.DefaultOptions(),
|
opts: api.DefaultOptions(),
|
||||||
successCh: make(chan *runnerRef, 1),
|
successCh: make(chan *runnerRef, 1),
|
||||||
sessionDuration: 2,
|
sessionDuration: &api.Duration{Duration: 2},
|
||||||
}
|
}
|
||||||
finished := make(chan *LlmRequest)
|
finished := make(chan *LlmRequest)
|
||||||
llm1 := &mockLlm{estimatedVRAMByGPU: map[string]uint64{}}
|
llm1 := &mockLlm{estimatedVRAMByGPU: map[string]uint64{}}
|
||||||
|
@ -614,7 +614,7 @@ func TestAlreadyCanceled(t *testing.T) {
|
||||||
dctx, done2 := context.WithCancel(ctx)
|
dctx, done2 := context.WithCancel(ctx)
|
||||||
done2()
|
done2()
|
||||||
scenario1a := newScenario(t, dctx, "ollama-model-1", 10)
|
scenario1a := newScenario(t, dctx, "ollama-model-1", 10)
|
||||||
scenario1a.req.sessionDuration = 0
|
scenario1a.req.sessionDuration = &api.Duration{Duration: 0}
|
||||||
s := InitScheduler(ctx)
|
s := InitScheduler(ctx)
|
||||||
slog.Info("scenario1a")
|
slog.Info("scenario1a")
|
||||||
s.pendingReqCh <- scenario1a.req
|
s.pendingReqCh <- scenario1a.req
|
||||||
|
|
Loading…
Reference in a new issue