Prevent loading models larger than total memory

Users may not realize the siny new model they're trying to load
fits on their disk, but can't load into system+GPU memory.  Today
we crash, but with this fix, we'll give them a better error message
before even trying to load it.
This commit is contained in:
Daniel Hiltgen 2024-07-03 14:47:42 -07:00
parent 3b5a4a77f3
commit 3c75113e37
2 changed files with 38 additions and 0 deletions

View file

@ -139,6 +139,11 @@ func (s *Scheduler) processPending(ctx context.Context) {
} }
for { for {
cpus := s.getCpuFn()
var systemMem gpu.GpuInfo
if len(cpus) > 0 {
systemMem = cpus[0]
}
var runnerToExpire *runnerRef var runnerToExpire *runnerRef
s.loadedMu.Lock() s.loadedMu.Lock()
runner := s.loaded[pending.model.ModelPath] runner := s.loaded[pending.model.ModelPath]
@ -192,6 +197,27 @@ func (s *Scheduler) processPending(ctx context.Context) {
break break
} }
// Block attempting to load a model larger than system memory + GPU memory
estimate := llm.EstimateGPULayers(gpus, ggml, pending.model.ProjectorPaths, pending.opts)
maxSize := systemMem.FreeMemory
for _, gpu := range gpus {
if gpu.Library == "cpu" {
continue
}
if loadedCount == 0 {
// If no other models are loaded, set the limit based on what's available
maxSize += gpu.FreeMemory
} else {
// Other models could be unloaded, favor total memory for limit
maxSize += gpu.TotalMemory
}
}
if estimate.TotalSize > maxSize {
slog.Warn("model request too large for system", "requested", format.HumanBytes2(estimate.TotalSize), "system", format.HumanBytes2(maxSize))
pending.errCh <- fmt.Errorf("requested model (%s) is too large for this system (%s)", format.HumanBytes2(estimate.TotalSize), format.HumanBytes2(maxSize))
break
}
// Evaluate if the model will fit in the available system memory, or if we should unload a model first // Evaluate if the model will fit in the available system memory, or if we should unload a model first
if len(gpus) == 1 && gpus[0].Library == "cpu" { if len(gpus) == 1 && gpus[0].Library == "cpu" {
// simplifying assumption of defaultParallel when in CPU mode // simplifying assumption of defaultParallel when in CPU mode

View file

@ -199,6 +199,8 @@ func TestRequests(t *testing.T) {
require.Equal(t, resp.llama, scenario1a.srv) require.Equal(t, resp.llama, scenario1a.srv)
require.Empty(t, s.pendingReqCh) require.Empty(t, s.pendingReqCh)
require.Empty(t, scenario1a.req.errCh) require.Empty(t, scenario1a.req.errCh)
case err := <-scenario1a.req.errCh:
t.Fatal(err.Error())
case <-ctx.Done(): case <-ctx.Done():
t.Fatal("timeout") t.Fatal("timeout")
} }
@ -212,6 +214,8 @@ func TestRequests(t *testing.T) {
require.Equal(t, resp.llama, scenario1a.srv) require.Equal(t, resp.llama, scenario1a.srv)
require.Empty(t, s.pendingReqCh) require.Empty(t, s.pendingReqCh)
require.Empty(t, scenario1b.req.errCh) require.Empty(t, scenario1b.req.errCh)
case err := <-scenario1b.req.errCh:
t.Fatal(err.Error())
case <-ctx.Done(): case <-ctx.Done():
t.Fatal("timeout") t.Fatal("timeout")
} }
@ -230,6 +234,8 @@ func TestRequests(t *testing.T) {
require.Equal(t, resp.llama, scenario2a.srv) require.Equal(t, resp.llama, scenario2a.srv)
require.Empty(t, s.pendingReqCh) require.Empty(t, s.pendingReqCh)
require.Empty(t, scenario2a.req.errCh) require.Empty(t, scenario2a.req.errCh)
case err := <-scenario2a.req.errCh:
t.Fatal(err.Error())
case <-ctx.Done(): case <-ctx.Done():
t.Fatal("timeout") t.Fatal("timeout")
} }
@ -246,6 +252,8 @@ func TestRequests(t *testing.T) {
require.Equal(t, resp.llama, scenario3a.srv) require.Equal(t, resp.llama, scenario3a.srv)
require.Empty(t, s.pendingReqCh) require.Empty(t, s.pendingReqCh)
require.Empty(t, scenario3a.req.errCh) require.Empty(t, scenario3a.req.errCh)
case err := <-scenario3a.req.errCh:
t.Fatal(err.Error())
case <-ctx.Done(): case <-ctx.Done():
t.Fatal("timeout") t.Fatal("timeout")
} }
@ -262,6 +270,8 @@ func TestRequests(t *testing.T) {
require.Equal(t, resp.llama, scenario3b.srv) require.Equal(t, resp.llama, scenario3b.srv)
require.Empty(t, s.pendingReqCh) require.Empty(t, s.pendingReqCh)
require.Empty(t, scenario3b.req.errCh) require.Empty(t, scenario3b.req.errCh)
case err := <-scenario3b.req.errCh:
t.Fatal(err.Error())
case <-ctx.Done(): case <-ctx.Done():
t.Fatal("timeout") t.Fatal("timeout")
} }
@ -278,6 +288,8 @@ func TestRequests(t *testing.T) {
require.Equal(t, resp.llama, scenario3c.srv) require.Equal(t, resp.llama, scenario3c.srv)
require.Empty(t, s.pendingReqCh) require.Empty(t, s.pendingReqCh)
require.Empty(t, scenario3c.req.errCh) require.Empty(t, scenario3c.req.errCh)
case err := <-scenario3c.req.errCh:
t.Fatal(err.Error())
case <-ctx.Done(): case <-ctx.Done():
t.Fatal("timeout") t.Fatal("timeout")
} }