Reload model if num_gpu changes (#3920)

* reload model if `num_gpu` changes

* dont reload on -1

* fix tests
This commit is contained in:
Jeffrey Morgan 2024-04-25 19:02:40 -04:00 committed by GitHub
parent 993cf8bf55
commit 00b0699c75
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
2 changed files with 15 additions and 6 deletions

View file

@ -421,16 +421,21 @@ func (runner *runnerRef) needsReload(ctx context.Context, req *LlmRequest) bool
slog.Debug("evaluating already loaded", "model", req.model.ModelPath)
runner.refMu.Lock()
defer runner.refMu.Unlock()
// Ignore the NumGPU settings for comparison
optsExisting := runner.Options.Runner
optsExisting.NumGPU = -1
optsNew := req.opts.Runner
optsNew.NumGPU = -1
timeout := 10 * time.Second
if runner.loading {
timeout = 2 * time.Minute // Initial load can take a long time for big models on slow systems...
}
ctx, cancel := context.WithTimeout(ctx, timeout) // BUG -
// Don't reload runner if num_gpu=-1 was provided
optsExisting := runner.Options.Runner
optsNew := req.opts.Runner
if optsNew.NumGPU < 0 {
optsExisting.NumGPU = -1
optsNew.NumGPU = -1
}
ctx, cancel := context.WithTimeout(ctx, timeout)
defer cancel()
if !reflect.DeepEqual(runner.adapters, req.model.AdapterPaths) || // have the adapters changed?
!reflect.DeepEqual(runner.projectors, req.model.ProjectorPaths) || // have the projectors changed?
@ -438,6 +443,7 @@ func (runner *runnerRef) needsReload(ctx context.Context, req *LlmRequest) bool
runner.llama.Ping(ctx) != nil {
return true
}
return false
}

View file

@ -490,6 +490,9 @@ func TestNeedsReload(t *testing.T) {
require.False(t, resp)
req.opts.NumGPU = 99
resp = runner.needsReload(ctx, req)
require.True(t, resp)
req.opts.NumGPU = -1
resp = runner.needsReload(ctx, req)
require.False(t, resp)
}