diff --git a/app/lifecycle/server.go b/app/lifecycle/server.go index 0ce90df9..8680e7bc 100644 --- a/app/lifecycle/server.go +++ b/app/lifecycle/server.go @@ -9,7 +9,6 @@ import ( "os" "os/exec" "path/filepath" - "syscall" "time" "github.com/ollama/ollama/api" @@ -87,19 +86,29 @@ func SpawnServer(ctx context.Context, command string) (chan int, error) { // Re-wire context done behavior to attempt a graceful shutdown of the server cmd.Cancel = func() error { if cmd.Process != nil { - cmd.Process.Signal(os.Interrupt) //nolint:errcheck + err := terminate(cmd) + if err != nil { + slog.Warn("error trying to gracefully terminate server", "err", err) + return cmd.Process.Kill() + } + tick := time.NewTicker(10 * time.Millisecond) defer tick.Stop() + for { select { case <-tick.C: - // OS agnostic "is it still running" - if proc, err := os.FindProcess(int(cmd.Process.Pid)); err != nil || errors.Is(proc.Signal(syscall.Signal(0)), os.ErrProcessDone) { - return nil //nolint:nilerr + exited, err := isProcessExited(cmd.Process.Pid) + if err != nil { + return err + } + + if exited { + return nil } case <-time.After(5 * time.Second): slog.Warn("graceful server shutdown timeout, killing", "pid", cmd.Process.Pid) - cmd.Process.Kill() //nolint:errcheck + return cmd.Process.Kill() } } } diff --git a/app/lifecycle/server_unix.go b/app/lifecycle/server_unix.go index c35f8b5b..70573913 100644 --- a/app/lifecycle/server_unix.go +++ b/app/lifecycle/server_unix.go @@ -4,9 +4,35 @@ package lifecycle import ( "context" + "errors" + "fmt" + "os" "os/exec" + "syscall" ) func getCmd(ctx context.Context, cmd string) *exec.Cmd { return exec.CommandContext(ctx, cmd, "serve") } + +func terminate(cmd *exec.Cmd) error { + return cmd.Process.Signal(os.Interrupt) +} + +func isProcessExited(pid int) (bool, error) { + proc, err := os.FindProcess(pid) + if err != nil { + return false, fmt.Errorf("failed to find process: %v", err) + } + + err = proc.Signal(syscall.Signal(0)) + if err != nil { + if errors.Is(err, os.ErrProcessDone) || errors.Is(err, syscall.ESRCH) { + return true, nil + } + + return false, fmt.Errorf("error signaling process: %v", err) + } + + return false, nil +} diff --git a/app/lifecycle/server_windows.go b/app/lifecycle/server_windows.go index 3044e526..cd4244ff 100644 --- a/app/lifecycle/server_windows.go +++ b/app/lifecycle/server_windows.go @@ -2,12 +2,88 @@ package lifecycle import ( "context" + "fmt" "os/exec" "syscall" + + "golang.org/x/sys/windows" ) func getCmd(ctx context.Context, exePath string) *exec.Cmd { cmd := exec.CommandContext(ctx, exePath, "serve") - cmd.SysProcAttr = &syscall.SysProcAttr{HideWindow: true, CreationFlags: 0x08000000} + cmd.SysProcAttr = &syscall.SysProcAttr{ + HideWindow: true, + CreationFlags: windows.CREATE_NEW_PROCESS_GROUP, + } + return cmd } + +func terminate(cmd *exec.Cmd) error { + dll, err := windows.LoadDLL("kernel32.dll") + if err != nil { + return err + } + defer dll.Release() // nolint: errcheck + + pid := cmd.Process.Pid + + f, err := dll.FindProc("AttachConsole") + if err != nil { + return err + } + + r1, _, err := f.Call(uintptr(pid)) + if r1 == 0 && err != syscall.ERROR_ACCESS_DENIED { + return err + } + + f, err = dll.FindProc("SetConsoleCtrlHandler") + if err != nil { + return err + } + + r1, _, err = f.Call(0, 1) + if r1 == 0 { + return err + } + + f, err = dll.FindProc("GenerateConsoleCtrlEvent") + if err != nil { + return err + } + + r1, _, err = f.Call(windows.CTRL_BREAK_EVENT, uintptr(pid)) + if r1 == 0 { + return err + } + + r1, _, err = f.Call(windows.CTRL_C_EVENT, uintptr(pid)) + if r1 == 0 { + return err + } + + return nil +} + +const STILL_ACTIVE = 259 + +func isProcessExited(pid int) (bool, error) { + hProcess, err := windows.OpenProcess(windows.PROCESS_QUERY_INFORMATION, false, uint32(pid)) + if err != nil { + return false, fmt.Errorf("failed to open process: %v", err) + } + defer windows.CloseHandle(hProcess) // nolint: errcheck + + var exitCode uint32 + err = windows.GetExitCodeProcess(hProcess, &exitCode) + if err != nil { + return false, fmt.Errorf("failed to get exit code: %v", err) + } + + if exitCode == STILL_ACTIVE { + return false, nil + } + + return true, nil +}