Reload model if num_gpu changes (#3920)

* reload model if `num_gpu` changes

* dont reload on -1

* fix tests
This commit is contained in:
Jeffrey Morgan 2024-04-25 19:02:40 -04:00 committed by GitHub
parent 993cf8bf55
commit 00b0699c75
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
2 changed files with 15 additions and 6 deletions

View file

@ -421,16 +421,21 @@ func (runner *runnerRef) needsReload(ctx context.Context, req *LlmRequest) bool
slog.Debug("evaluating already loaded", "model", req.model.ModelPath) slog.Debug("evaluating already loaded", "model", req.model.ModelPath)
runner.refMu.Lock() runner.refMu.Lock()
defer runner.refMu.Unlock() defer runner.refMu.Unlock()
// Ignore the NumGPU settings for comparison
optsExisting := runner.Options.Runner
optsExisting.NumGPU = -1
optsNew := req.opts.Runner
optsNew.NumGPU = -1
timeout := 10 * time.Second timeout := 10 * time.Second
if runner.loading { if runner.loading {
timeout = 2 * time.Minute // Initial load can take a long time for big models on slow systems... timeout = 2 * time.Minute // Initial load can take a long time for big models on slow systems...
} }
ctx, cancel := context.WithTimeout(ctx, timeout) // BUG -
// Don't reload runner if num_gpu=-1 was provided
optsExisting := runner.Options.Runner
optsNew := req.opts.Runner
if optsNew.NumGPU < 0 {
optsExisting.NumGPU = -1
optsNew.NumGPU = -1
}
ctx, cancel := context.WithTimeout(ctx, timeout)
defer cancel() defer cancel()
if !reflect.DeepEqual(runner.adapters, req.model.AdapterPaths) || // have the adapters changed? if !reflect.DeepEqual(runner.adapters, req.model.AdapterPaths) || // have the adapters changed?
!reflect.DeepEqual(runner.projectors, req.model.ProjectorPaths) || // have the projectors changed? !reflect.DeepEqual(runner.projectors, req.model.ProjectorPaths) || // have the projectors changed?
@ -438,6 +443,7 @@ func (runner *runnerRef) needsReload(ctx context.Context, req *LlmRequest) bool
runner.llama.Ping(ctx) != nil { runner.llama.Ping(ctx) != nil {
return true return true
} }
return false return false
} }

View file

@ -490,6 +490,9 @@ func TestNeedsReload(t *testing.T) {
require.False(t, resp) require.False(t, resp)
req.opts.NumGPU = 99 req.opts.NumGPU = 99
resp = runner.needsReload(ctx, req) resp = runner.needsReload(ctx, req)
require.True(t, resp)
req.opts.NumGPU = -1
resp = runner.needsReload(ctx, req)
require.False(t, resp) require.False(t, resp)
} }