diff --git a/server/routes.go b/server/routes.go index 90cdfcd5..130423b7 100644 --- a/server/routes.go +++ b/server/routes.go @@ -1036,7 +1036,8 @@ func Serve(ln net.Listener) error { } ctx, done := context.WithCancel(context.Background()) - sched := InitScheduler(ctx) + schedCtx, schedDone := context.WithCancel(ctx) + sched := InitScheduler(schedCtx) s := &Server{addr: ln.Addr(), sched: sched} r := s.GenerateRoutes() @@ -1051,24 +1052,31 @@ func Serve(ln net.Listener) error { go func() { <-signals srvr.Close() - done() + schedDone() sched.unloadAllRunners() gpu.Cleanup() - os.Exit(0) + done() }() if err := llm.Init(); err != nil { return fmt.Errorf("unable to initialize llm library %w", err) } - s.sched.Run(ctx) + s.sched.Run(schedCtx) // At startup we retrieve GPU information so we can get log messages before loading a model // This will log warnings to the log in case we have problems with detected GPUs gpus := gpu.GetGPUInfo() gpus.LogDetails() - return srvr.Serve(ln) + err = srvr.Serve(ln) + // If server is closed from the signal handler, wait for the ctx to be done + // otherwise error out quickly + if !errors.Is(err, http.ErrServerClosed) { + return err + } + <-ctx.Done() + return err } func waitForStream(c *gin.Context, ch chan interface{}) {