Terminate subprocess if receiving SIGINT or SIGTERM signals while model is loading (#3653)

* terminate subprocess if receiving `SIGINT` or `SIGTERM` signals while model is loading

* use `unload` in signal handler
This commit is contained in:
Jeffrey Morgan 2024-04-15 12:09:32 -04:00 committed by GitHub
parent 7027f264fb
commit a0b8a32eb4
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
2 changed files with 23 additions and 32 deletions

View file

@ -17,7 +17,6 @@ import (
"os/exec" "os/exec"
"path/filepath" "path/filepath"
"runtime" "runtime"
"slices"
"strconv" "strconv"
"strings" "strings"
"time" "time"
@ -36,10 +35,6 @@ type LlamaServer struct {
options api.Options options api.Options
} }
var cpuOnlyFamilies = []string{
"mamba",
}
func NewLlamaServer(model string, adapters, projectors []string, opts api.Options) (*LlamaServer, error) { func NewLlamaServer(model string, adapters, projectors []string, opts api.Options) (*LlamaServer, error) {
f, err := os.Open(model) f, err := os.Open(model)
if err != nil { if err != nil {
@ -91,7 +86,7 @@ func NewLlamaServer(model string, adapters, projectors []string, opts api.Option
memoryRequiredPartial := memoryMinimum + graphPartialOffload memoryRequiredPartial := memoryMinimum + graphPartialOffload
if info.Library != "metal" { if info.Library != "metal" {
if memoryRequiredPartial > memoryAvailable || slices.Contains(cpuOnlyFamilies, ggml.KV().Architecture()) { if memoryRequiredPartial > memoryAvailable {
info.Library = "cpu" info.Library = "cpu"
} }
} }
@ -277,12 +272,6 @@ func NewLlamaServer(model string, adapters, projectors []string, opts api.Option
_ = s.cmd.Wait() _ = s.cmd.Wait()
}() }()
if err = s.waitUntilRunning(); err != nil {
slog.Error("error starting llama server", "server", servers[i], "error", err)
s.Close()
finalErr = err
continue
}
return s, nil return s, nil
} }
@ -383,7 +372,7 @@ func (s *LlamaServer) Ping(ctx context.Context) error {
return nil return nil
} }
func (s *LlamaServer) waitUntilRunning() error { func (s *LlamaServer) WaitUntilRunning() error {
start := time.Now() start := time.Now()
// TODO we need to wire up a better way to detect hangs during model load and startup of the server // TODO we need to wire up a better way to detect hangs during model load and startup of the server
expiresAt := time.Now().Add(10 * time.Minute) // be generous with timeout, large models can take a while to load expiresAt := time.Now().Add(10 * time.Minute) // be generous with timeout, large models can take a while to load

View file

@ -68,6 +68,18 @@ var loaded struct {
var defaultSessionDuration = 5 * time.Minute var defaultSessionDuration = 5 * time.Minute
func unload() {
if loaded.llama != nil {
loaded.llama.Close()
}
loaded.llama = nil
loaded.model = ""
loaded.adapters = nil
loaded.projectors = nil
loaded.Options = nil
}
// load a model into memory if it is not already loaded, it is up to the caller to lock loaded.mu before calling this function // load a model into memory if it is not already loaded, it is up to the caller to lock loaded.mu before calling this function
func load(c *gin.Context, model *Model, opts api.Options, sessionDuration time.Duration) error { func load(c *gin.Context, model *Model, opts api.Options, sessionDuration time.Duration) error {
ctx, cancel := context.WithTimeout(c, 10*time.Second) ctx, cancel := context.WithTimeout(c, 10*time.Second)
@ -83,12 +95,7 @@ func load(c *gin.Context, model *Model, opts api.Options, sessionDuration time.D
if needLoad { if needLoad {
if loaded.llama != nil { if loaded.llama != nil {
slog.Info("changing loaded model") slog.Info("changing loaded model")
loaded.llama.Close() unload()
loaded.llama = nil
loaded.model = ""
loaded.adapters = nil
loaded.projectors = nil
loaded.Options = nil
} }
llama, err := llm.NewLlamaServer(model.ModelPath, model.AdapterPaths, model.ProjectorPaths, opts) llama, err := llm.NewLlamaServer(model.ModelPath, model.AdapterPaths, model.ProjectorPaths, opts)
@ -108,22 +115,19 @@ func load(c *gin.Context, model *Model, opts api.Options, sessionDuration time.D
loaded.projectors = model.ProjectorPaths loaded.projectors = model.ProjectorPaths
loaded.llama = llama loaded.llama = llama
loaded.Options = &opts loaded.Options = &opts
if err = llama.WaitUntilRunning(); err != nil {
slog.Error("error loading llama server", "error", err)
unload()
return err
}
} }
if loaded.expireTimer == nil { if loaded.expireTimer == nil {
loaded.expireTimer = time.AfterFunc(sessionDuration, func() { loaded.expireTimer = time.AfterFunc(sessionDuration, func() {
loaded.mu.Lock() loaded.mu.Lock()
defer loaded.mu.Unlock() defer loaded.mu.Unlock()
unload()
if loaded.llama != nil {
loaded.llama.Close()
}
loaded.llama = nil
loaded.model = ""
loaded.adapters = nil
loaded.projectors = nil
loaded.Options = nil
}) })
} }
@ -1146,9 +1150,7 @@ func Serve(ln net.Listener) error {
signal.Notify(signals, syscall.SIGINT, syscall.SIGTERM) signal.Notify(signals, syscall.SIGINT, syscall.SIGTERM)
go func() { go func() {
<-signals <-signals
if loaded.llama != nil { unload()
loaded.llama.Close()
}
gpu.Cleanup() gpu.Cleanup()
os.Exit(0) os.Exit(0)
}() }()