diff --git a/llm/server.go b/llm/server.go index 707f0b8b..4c1f9634 100644 --- a/llm/server.go +++ b/llm/server.go @@ -17,7 +17,6 @@ import ( "os/exec" "path/filepath" "runtime" - "slices" "strconv" "strings" "time" @@ -36,10 +35,6 @@ type LlamaServer struct { options api.Options } -var cpuOnlyFamilies = []string{ - "mamba", -} - func NewLlamaServer(model string, adapters, projectors []string, opts api.Options) (*LlamaServer, error) { f, err := os.Open(model) if err != nil { @@ -91,7 +86,7 @@ func NewLlamaServer(model string, adapters, projectors []string, opts api.Option memoryRequiredPartial := memoryMinimum + graphPartialOffload if info.Library != "metal" { - if memoryRequiredPartial > memoryAvailable || slices.Contains(cpuOnlyFamilies, ggml.KV().Architecture()) { + if memoryRequiredPartial > memoryAvailable { info.Library = "cpu" } } @@ -277,12 +272,6 @@ func NewLlamaServer(model string, adapters, projectors []string, opts api.Option _ = s.cmd.Wait() }() - if err = s.waitUntilRunning(); err != nil { - slog.Error("error starting llama server", "server", servers[i], "error", err) - s.Close() - finalErr = err - continue - } return s, nil } @@ -383,7 +372,7 @@ func (s *LlamaServer) Ping(ctx context.Context) error { return nil } -func (s *LlamaServer) waitUntilRunning() error { +func (s *LlamaServer) WaitUntilRunning() error { start := time.Now() // TODO we need to wire up a better way to detect hangs during model load and startup of the server expiresAt := time.Now().Add(10 * time.Minute) // be generous with timeout, large models can take a while to load diff --git a/server/routes.go b/server/routes.go index d1e7f4cd..b0d36b14 100644 --- a/server/routes.go +++ b/server/routes.go @@ -68,6 +68,18 @@ var loaded struct { var defaultSessionDuration = 5 * time.Minute +func unload() { + if loaded.llama != nil { + loaded.llama.Close() + } + + loaded.llama = nil + loaded.model = "" + loaded.adapters = nil + loaded.projectors = nil + loaded.Options = nil +} + // load a model into memory if it is not already loaded, it is up to the caller to lock loaded.mu before calling this function func load(c *gin.Context, model *Model, opts api.Options, sessionDuration time.Duration) error { ctx, cancel := context.WithTimeout(c, 10*time.Second) @@ -83,12 +95,7 @@ func load(c *gin.Context, model *Model, opts api.Options, sessionDuration time.D if needLoad { if loaded.llama != nil { slog.Info("changing loaded model") - loaded.llama.Close() - loaded.llama = nil - loaded.model = "" - loaded.adapters = nil - loaded.projectors = nil - loaded.Options = nil + unload() } llama, err := llm.NewLlamaServer(model.ModelPath, model.AdapterPaths, model.ProjectorPaths, opts) @@ -108,22 +115,19 @@ func load(c *gin.Context, model *Model, opts api.Options, sessionDuration time.D loaded.projectors = model.ProjectorPaths loaded.llama = llama loaded.Options = &opts + + if err = llama.WaitUntilRunning(); err != nil { + slog.Error("error loading llama server", "error", err) + unload() + return err + } } if loaded.expireTimer == nil { loaded.expireTimer = time.AfterFunc(sessionDuration, func() { loaded.mu.Lock() defer loaded.mu.Unlock() - - if loaded.llama != nil { - loaded.llama.Close() - } - - loaded.llama = nil - loaded.model = "" - loaded.adapters = nil - loaded.projectors = nil - loaded.Options = nil + unload() }) } @@ -1146,9 +1150,7 @@ func Serve(ln net.Listener) error { signal.Notify(signals, syscall.SIGINT, syscall.SIGTERM) go func() { <-signals - if loaded.llama != nil { - loaded.llama.Close() - } + unload() gpu.Cleanup() os.Exit(0) }()