app: gracefully shut down ollama serve on windows (#3641)

* app: gracefully shut down `ollama serve` on windows

* fix linter errors

* bring back `HideWindow`

* remove creation flags

* restore `windows.CREATE_NEW_PROCESS_GROUP`
This commit is contained in:
Jeffrey Morgan 2024-04-14 18:33:25 -04:00 committed by GitHub
parent 9bee3b63b1
commit 7027f264fb
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
3 changed files with 118 additions and 7 deletions

View file

@ -9,7 +9,6 @@ import (
"os" "os"
"os/exec" "os/exec"
"path/filepath" "path/filepath"
"syscall"
"time" "time"
"github.com/ollama/ollama/api" "github.com/ollama/ollama/api"
@ -87,19 +86,29 @@ func SpawnServer(ctx context.Context, command string) (chan int, error) {
// Re-wire context done behavior to attempt a graceful shutdown of the server // Re-wire context done behavior to attempt a graceful shutdown of the server
cmd.Cancel = func() error { cmd.Cancel = func() error {
if cmd.Process != nil { if cmd.Process != nil {
cmd.Process.Signal(os.Interrupt) //nolint:errcheck err := terminate(cmd)
if err != nil {
slog.Warn("error trying to gracefully terminate server", "err", err)
return cmd.Process.Kill()
}
tick := time.NewTicker(10 * time.Millisecond) tick := time.NewTicker(10 * time.Millisecond)
defer tick.Stop() defer tick.Stop()
for { for {
select { select {
case <-tick.C: case <-tick.C:
// OS agnostic "is it still running" exited, err := isProcessExited(cmd.Process.Pid)
if proc, err := os.FindProcess(int(cmd.Process.Pid)); err != nil || errors.Is(proc.Signal(syscall.Signal(0)), os.ErrProcessDone) { if err != nil {
return nil //nolint:nilerr return err
}
if exited {
return nil
} }
case <-time.After(5 * time.Second): case <-time.After(5 * time.Second):
slog.Warn("graceful server shutdown timeout, killing", "pid", cmd.Process.Pid) slog.Warn("graceful server shutdown timeout, killing", "pid", cmd.Process.Pid)
cmd.Process.Kill() //nolint:errcheck return cmd.Process.Kill()
} }
} }
} }

View file

@ -4,9 +4,35 @@ package lifecycle
import ( import (
"context" "context"
"errors"
"fmt"
"os"
"os/exec" "os/exec"
"syscall"
) )
func getCmd(ctx context.Context, cmd string) *exec.Cmd { func getCmd(ctx context.Context, cmd string) *exec.Cmd {
return exec.CommandContext(ctx, cmd, "serve") return exec.CommandContext(ctx, cmd, "serve")
} }
func terminate(cmd *exec.Cmd) error {
return cmd.Process.Signal(os.Interrupt)
}
func isProcessExited(pid int) (bool, error) {
proc, err := os.FindProcess(pid)
if err != nil {
return false, fmt.Errorf("failed to find process: %v", err)
}
err = proc.Signal(syscall.Signal(0))
if err != nil {
if errors.Is(err, os.ErrProcessDone) || errors.Is(err, syscall.ESRCH) {
return true, nil
}
return false, fmt.Errorf("error signaling process: %v", err)
}
return false, nil
}

View file

@ -2,12 +2,88 @@ package lifecycle
import ( import (
"context" "context"
"fmt"
"os/exec" "os/exec"
"syscall" "syscall"
"golang.org/x/sys/windows"
) )
func getCmd(ctx context.Context, exePath string) *exec.Cmd { func getCmd(ctx context.Context, exePath string) *exec.Cmd {
cmd := exec.CommandContext(ctx, exePath, "serve") cmd := exec.CommandContext(ctx, exePath, "serve")
cmd.SysProcAttr = &syscall.SysProcAttr{HideWindow: true, CreationFlags: 0x08000000} cmd.SysProcAttr = &syscall.SysProcAttr{
HideWindow: true,
CreationFlags: windows.CREATE_NEW_PROCESS_GROUP,
}
return cmd return cmd
} }
func terminate(cmd *exec.Cmd) error {
dll, err := windows.LoadDLL("kernel32.dll")
if err != nil {
return err
}
defer dll.Release() // nolint: errcheck
pid := cmd.Process.Pid
f, err := dll.FindProc("AttachConsole")
if err != nil {
return err
}
r1, _, err := f.Call(uintptr(pid))
if r1 == 0 && err != syscall.ERROR_ACCESS_DENIED {
return err
}
f, err = dll.FindProc("SetConsoleCtrlHandler")
if err != nil {
return err
}
r1, _, err = f.Call(0, 1)
if r1 == 0 {
return err
}
f, err = dll.FindProc("GenerateConsoleCtrlEvent")
if err != nil {
return err
}
r1, _, err = f.Call(windows.CTRL_BREAK_EVENT, uintptr(pid))
if r1 == 0 {
return err
}
r1, _, err = f.Call(windows.CTRL_C_EVENT, uintptr(pid))
if r1 == 0 {
return err
}
return nil
}
const STILL_ACTIVE = 259
func isProcessExited(pid int) (bool, error) {
hProcess, err := windows.OpenProcess(windows.PROCESS_QUERY_INFORMATION, false, uint32(pid))
if err != nil {
return false, fmt.Errorf("failed to open process: %v", err)
}
defer windows.CloseHandle(hProcess) // nolint: errcheck
var exitCode uint32
err = windows.GetExitCodeProcess(hProcess, &exitCode)
if err != nil {
return false, fmt.Errorf("failed to get exit code: %v", err)
}
if exitCode == STILL_ACTIVE {
return false, nil
}
return true, nil
}