Merge pull request #5447 from dhiltgen/fix_keepalive

Only set default keep_alive on initial model load
This commit is contained in:
Daniel Hiltgen 2024-07-03 15:34:38 -07:00 committed by GitHub
commit 8072e205ff
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
5 changed files with 70 additions and 71 deletions

View file

@ -4,12 +4,14 @@ import (
"errors"
"fmt"
"log/slog"
"math"
"net"
"os"
"path/filepath"
"runtime"
"strconv"
"strings"
"time"
)
type OllamaHost struct {
@ -34,7 +36,7 @@ var (
// Set via OLLAMA_HOST in the environment
Host *OllamaHost
// Set via OLLAMA_KEEP_ALIVE in the environment
KeepAlive string
KeepAlive time.Duration
// Set via OLLAMA_LLM_LIBRARY in the environment
LLMLibrary string
// Set via OLLAMA_MAX_LOADED_MODELS in the environment
@ -132,6 +134,7 @@ func init() {
NumParallel = 0 // Autoselect
MaxRunners = 0 // Autoselect
MaxQueuedRequests = 512
KeepAlive = 5 * time.Minute
LoadConfig()
}
@ -266,7 +269,10 @@ func LoadConfig() {
}
}
KeepAlive = clean("OLLAMA_KEEP_ALIVE")
ka := clean("OLLAMA_KEEP_ALIVE")
if ka != "" {
loadKeepAlive(ka)
}
var err error
ModelsDir, err = getModelsDir()
@ -344,3 +350,24 @@ func getOllamaHost() (*OllamaHost, error) {
Port: port,
}, nil
}
func loadKeepAlive(ka string) {
v, err := strconv.Atoi(ka)
if err != nil {
d, err := time.ParseDuration(ka)
if err == nil {
if d < 0 {
KeepAlive = time.Duration(math.MaxInt64)
} else {
KeepAlive = d
}
}
} else {
d := time.Duration(v) * time.Second
if d < 0 {
KeepAlive = time.Duration(math.MaxInt64)
} else {
KeepAlive = d
}
}
}

View file

@ -2,8 +2,10 @@ package envconfig
import (
"fmt"
"math"
"net"
"testing"
"time"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
@ -23,6 +25,21 @@ func TestConfig(t *testing.T) {
t.Setenv("OLLAMA_FLASH_ATTENTION", "1")
LoadConfig()
require.True(t, FlashAttention)
t.Setenv("OLLAMA_KEEP_ALIVE", "")
LoadConfig()
require.Equal(t, 5*time.Minute, KeepAlive)
t.Setenv("OLLAMA_KEEP_ALIVE", "3")
LoadConfig()
require.Equal(t, 3*time.Second, KeepAlive)
t.Setenv("OLLAMA_KEEP_ALIVE", "1h")
LoadConfig()
require.Equal(t, 1*time.Hour, KeepAlive)
t.Setenv("OLLAMA_KEEP_ALIVE", "-1s")
LoadConfig()
require.Equal(t, time.Duration(math.MaxInt64), KeepAlive)
t.Setenv("OLLAMA_KEEP_ALIVE", "-1")
LoadConfig()
require.Equal(t, time.Duration(math.MaxInt64), KeepAlive)
}
func TestClientFromEnvironment(t *testing.T) {

View file

@ -9,7 +9,6 @@ import (
"io"
"io/fs"
"log/slog"
"math"
"net"
"net/http"
"net/netip"
@ -17,7 +16,6 @@ import (
"os/signal"
"path/filepath"
"slices"
"strconv"
"strings"
"syscall"
"time"
@ -56,8 +54,6 @@ func init() {
gin.SetMode(mode)
}
var defaultSessionDuration = 5 * time.Minute
func modelOptions(model *Model, requestOpts map[string]interface{}) (api.Options, error) {
opts := api.DefaultOptions()
if err := opts.FromMap(model.Options); err != nil {
@ -133,14 +129,7 @@ func (s *Server) GenerateHandler(c *gin.Context) {
return
}
var sessionDuration time.Duration
if req.KeepAlive == nil {
sessionDuration = getDefaultSessionDuration()
} else {
sessionDuration = req.KeepAlive.Duration
}
rCh, eCh := s.sched.GetRunner(c.Request.Context(), model, opts, sessionDuration)
rCh, eCh := s.sched.GetRunner(c.Request.Context(), model, opts, req.KeepAlive)
var runner *runnerRef
select {
case runner = <-rCh:
@ -320,32 +309,6 @@ func (s *Server) GenerateHandler(c *gin.Context) {
streamResponse(c, ch)
}
func getDefaultSessionDuration() time.Duration {
if envconfig.KeepAlive != "" {
v, err := strconv.Atoi(envconfig.KeepAlive)
if err != nil {
d, err := time.ParseDuration(envconfig.KeepAlive)
if err != nil {
return defaultSessionDuration
}
if d < 0 {
return time.Duration(math.MaxInt64)
}
return d
}
d := time.Duration(v) * time.Second
if d < 0 {
return time.Duration(math.MaxInt64)
}
return d
}
return defaultSessionDuration
}
func (s *Server) EmbeddingsHandler(c *gin.Context) {
var req api.EmbeddingRequest
err := c.ShouldBindJSON(&req)
@ -380,14 +343,7 @@ func (s *Server) EmbeddingsHandler(c *gin.Context) {
return
}
var sessionDuration time.Duration
if req.KeepAlive == nil {
sessionDuration = getDefaultSessionDuration()
} else {
sessionDuration = req.KeepAlive.Duration
}
rCh, eCh := s.sched.GetRunner(c.Request.Context(), model, opts, sessionDuration)
rCh, eCh := s.sched.GetRunner(c.Request.Context(), model, opts, req.KeepAlive)
var runner *runnerRef
select {
case runner = <-rCh:
@ -1318,14 +1274,7 @@ func (s *Server) ChatHandler(c *gin.Context) {
return
}
var sessionDuration time.Duration
if req.KeepAlive == nil {
sessionDuration = getDefaultSessionDuration()
} else {
sessionDuration = req.KeepAlive.Duration
}
rCh, eCh := s.sched.GetRunner(c.Request.Context(), model, opts, sessionDuration)
rCh, eCh := s.sched.GetRunner(c.Request.Context(), model, opts, req.KeepAlive)
var runner *runnerRef
select {
case runner = <-rCh:

View file

@ -24,7 +24,7 @@ type LlmRequest struct {
model *Model
opts api.Options
origNumCtx int // Track the initial ctx request
sessionDuration time.Duration
sessionDuration *api.Duration
successCh chan *runnerRef
errCh chan error
schedAttempts uint
@ -75,7 +75,7 @@ func InitScheduler(ctx context.Context) *Scheduler {
}
// context must be canceled to decrement ref count and release the runner
func (s *Scheduler) GetRunner(c context.Context, model *Model, opts api.Options, sessionDuration time.Duration) (chan *runnerRef, chan error) {
func (s *Scheduler) GetRunner(c context.Context, model *Model, opts api.Options, sessionDuration *api.Duration) (chan *runnerRef, chan error) {
if opts.NumCtx < 4 {
opts.NumCtx = 4
}
@ -389,7 +389,9 @@ func (pending *LlmRequest) useLoadedRunner(runner *runnerRef, finished chan *Llm
runner.expireTimer.Stop()
runner.expireTimer = nil
}
runner.sessionDuration = pending.sessionDuration
if pending.sessionDuration != nil {
runner.sessionDuration = pending.sessionDuration.Duration
}
pending.successCh <- runner
go func() {
<-pending.ctx.Done()
@ -402,6 +404,10 @@ func (s *Scheduler) load(req *LlmRequest, ggml *llm.GGML, gpus gpu.GpuInfoList,
if numParallel < 1 {
numParallel = 1
}
sessionDuration := envconfig.KeepAlive
if req.sessionDuration != nil {
sessionDuration = req.sessionDuration.Duration
}
llama, err := s.newServerFn(gpus, req.model.ModelPath, ggml, req.model.AdapterPaths, req.model.ProjectorPaths, req.opts, numParallel)
if err != nil {
// some older models are not compatible with newer versions of llama.cpp
@ -419,7 +425,7 @@ func (s *Scheduler) load(req *LlmRequest, ggml *llm.GGML, gpus gpu.GpuInfoList,
modelPath: req.model.ModelPath,
llama: llama,
Options: &req.opts,
sessionDuration: req.sessionDuration,
sessionDuration: sessionDuration,
gpus: gpus,
estimatedVRAM: llama.EstimatedVRAM(),
estimatedTotal: llama.EstimatedTotal(),

View file

@ -44,7 +44,7 @@ func TestLoad(t *testing.T) {
opts: api.DefaultOptions(),
successCh: make(chan *runnerRef, 1),
errCh: make(chan error, 1),
sessionDuration: 2,
sessionDuration: &api.Duration{Duration: 2 * time.Second},
}
// Fail to load model first
s.newServerFn = func(gpus gpu.GpuInfoList, model string, ggml *llm.GGML, adapters []string, projectors []string, opts api.Options, numParallel int) (llm.LlamaServer, error) {
@ -142,7 +142,7 @@ func newScenario(t *testing.T, ctx context.Context, modelName string, estimatedV
ctx: scenario.ctx,
model: model,
opts: api.DefaultOptions(),
sessionDuration: 5 * time.Millisecond,
sessionDuration: &api.Duration{Duration: 5 * time.Millisecond},
successCh: make(chan *runnerRef, 1),
errCh: make(chan error, 1),
}
@ -156,18 +156,18 @@ func TestRequests(t *testing.T) {
// Same model, same request
scenario1a := newScenario(t, ctx, "ollama-model-1", 10)
scenario1a.req.sessionDuration = 5 * time.Millisecond
scenario1a.req.sessionDuration = &api.Duration{Duration: 5 * time.Millisecond}
scenario1b := newScenario(t, ctx, "ollama-model-1", 11)
scenario1b.req.model = scenario1a.req.model
scenario1b.ggml = scenario1a.ggml
scenario1b.req.sessionDuration = 0
scenario1b.req.sessionDuration = &api.Duration{Duration: 0}
// simple reload of same model
scenario2a := newScenario(t, ctx, "ollama-model-1", 20)
tmpModel := *scenario1a.req.model
scenario2a.req.model = &tmpModel
scenario2a.ggml = scenario1a.ggml
scenario2a.req.sessionDuration = 5 * time.Millisecond
scenario2a.req.sessionDuration = &api.Duration{Duration: 5 * time.Millisecond}
// Multiple loaded models
scenario3a := newScenario(t, ctx, "ollama-model-3a", 1*format.GigaByte)
@ -318,11 +318,11 @@ func TestGetRunner(t *testing.T) {
defer done()
scenario1a := newScenario(t, ctx, "ollama-model-1a", 10)
scenario1a.req.sessionDuration = 0
scenario1a.req.sessionDuration = &api.Duration{Duration: 0}
scenario1b := newScenario(t, ctx, "ollama-model-1b", 10)
scenario1b.req.sessionDuration = 0
scenario1b.req.sessionDuration = &api.Duration{Duration: 0}
scenario1c := newScenario(t, ctx, "ollama-model-1c", 10)
scenario1c.req.sessionDuration = 0
scenario1c.req.sessionDuration = &api.Duration{Duration: 0}
envconfig.MaxQueuedRequests = 1
s := InitScheduler(ctx)
s.getGpuFn = func() gpu.GpuInfoList {
@ -402,7 +402,7 @@ func TestPrematureExpired(t *testing.T) {
case <-ctx.Done():
t.Fatal("timeout")
}
time.Sleep(scenario1a.req.sessionDuration)
time.Sleep(scenario1a.req.sessionDuration.Duration)
scenario1a.ctxDone()
time.Sleep(20 * time.Millisecond)
require.LessOrEqual(t, len(s.finishedReqCh), 1)
@ -423,7 +423,7 @@ func TestUseLoadedRunner(t *testing.T) {
ctx: ctx,
opts: api.DefaultOptions(),
successCh: make(chan *runnerRef, 1),
sessionDuration: 2,
sessionDuration: &api.Duration{Duration: 2},
}
finished := make(chan *LlmRequest)
llm1 := &mockLlm{estimatedVRAMByGPU: map[string]uint64{}}
@ -614,7 +614,7 @@ func TestAlreadyCanceled(t *testing.T) {
dctx, done2 := context.WithCancel(ctx)
done2()
scenario1a := newScenario(t, dctx, "ollama-model-1", 10)
scenario1a.req.sessionDuration = 0
scenario1a.req.sessionDuration = &api.Duration{Duration: 0}
s := InitScheduler(ctx)
slog.Info("scenario1a")
s.pendingReqCh <- scenario1a.req