add "stop" command (#6739)
This commit is contained in:
parent
034392624c
commit
abed273de3
5 changed files with 172 additions and 25 deletions
56
cmd/cmd.go
56
cmd/cmd.go
|
@ -346,6 +346,39 @@ func (w *progressWriter) Write(p []byte) (n int, err error) {
|
||||||
return len(p), nil
|
return len(p), nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func loadOrUnloadModel(cmd *cobra.Command, opts *runOptions) error {
|
||||||
|
p := progress.NewProgress(os.Stderr)
|
||||||
|
defer p.StopAndClear()
|
||||||
|
|
||||||
|
spinner := progress.NewSpinner("")
|
||||||
|
p.Add("", spinner)
|
||||||
|
|
||||||
|
client, err := api.ClientFromEnvironment()
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
req := &api.GenerateRequest{
|
||||||
|
Model: opts.Model,
|
||||||
|
KeepAlive: opts.KeepAlive,
|
||||||
|
}
|
||||||
|
|
||||||
|
return client.Generate(cmd.Context(), req, func(api.GenerateResponse) error { return nil })
|
||||||
|
}
|
||||||
|
|
||||||
|
func StopHandler(cmd *cobra.Command, args []string) error {
|
||||||
|
opts := &runOptions{
|
||||||
|
Model: args[0],
|
||||||
|
KeepAlive: &api.Duration{Duration: 0},
|
||||||
|
}
|
||||||
|
if err := loadOrUnloadModel(cmd, opts); err != nil {
|
||||||
|
if strings.Contains(err.Error(), "not found") {
|
||||||
|
return fmt.Errorf("couldn't find model \"%s\" to stop", args[0])
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
func RunHandler(cmd *cobra.Command, args []string) error {
|
func RunHandler(cmd *cobra.Command, args []string) error {
|
||||||
interactive := true
|
interactive := true
|
||||||
|
|
||||||
|
@ -424,7 +457,7 @@ func RunHandler(cmd *cobra.Command, args []string) error {
|
||||||
opts.ParentModel = info.Details.ParentModel
|
opts.ParentModel = info.Details.ParentModel
|
||||||
|
|
||||||
if interactive {
|
if interactive {
|
||||||
if err := loadModel(cmd, &opts); err != nil {
|
if err := loadOrUnloadModel(cmd, &opts); err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -615,7 +648,15 @@ func ListRunningHandler(cmd *cobra.Command, args []string) error {
|
||||||
cpuPercent := math.Round(float64(sizeCPU) / float64(m.Size) * 100)
|
cpuPercent := math.Round(float64(sizeCPU) / float64(m.Size) * 100)
|
||||||
procStr = fmt.Sprintf("%d%%/%d%% CPU/GPU", int(cpuPercent), int(100-cpuPercent))
|
procStr = fmt.Sprintf("%d%%/%d%% CPU/GPU", int(cpuPercent), int(100-cpuPercent))
|
||||||
}
|
}
|
||||||
data = append(data, []string{m.Name, m.Digest[:12], format.HumanBytes(m.Size), procStr, format.HumanTime(m.ExpiresAt, "Never")})
|
|
||||||
|
var until string
|
||||||
|
delta := time.Since(m.ExpiresAt)
|
||||||
|
if delta > 0 {
|
||||||
|
until = "Stopping..."
|
||||||
|
} else {
|
||||||
|
until = format.HumanTime(m.ExpiresAt, "Never")
|
||||||
|
}
|
||||||
|
data = append(data, []string{m.Name, m.Digest[:12], format.HumanBytes(m.Size), procStr, until})
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -1294,6 +1335,15 @@ func NewCLI() *cobra.Command {
|
||||||
runCmd.Flags().Bool("insecure", false, "Use an insecure registry")
|
runCmd.Flags().Bool("insecure", false, "Use an insecure registry")
|
||||||
runCmd.Flags().Bool("nowordwrap", false, "Don't wrap words to the next line automatically")
|
runCmd.Flags().Bool("nowordwrap", false, "Don't wrap words to the next line automatically")
|
||||||
runCmd.Flags().String("format", "", "Response format (e.g. json)")
|
runCmd.Flags().String("format", "", "Response format (e.g. json)")
|
||||||
|
|
||||||
|
stopCmd := &cobra.Command{
|
||||||
|
Use: "stop MODEL",
|
||||||
|
Short: "Stop a running model",
|
||||||
|
Args: cobra.ExactArgs(1),
|
||||||
|
PreRunE: checkServerHeartbeat,
|
||||||
|
RunE: StopHandler,
|
||||||
|
}
|
||||||
|
|
||||||
serveCmd := &cobra.Command{
|
serveCmd := &cobra.Command{
|
||||||
Use: "serve",
|
Use: "serve",
|
||||||
Aliases: []string{"start"},
|
Aliases: []string{"start"},
|
||||||
|
@ -1361,6 +1411,7 @@ func NewCLI() *cobra.Command {
|
||||||
createCmd,
|
createCmd,
|
||||||
showCmd,
|
showCmd,
|
||||||
runCmd,
|
runCmd,
|
||||||
|
stopCmd,
|
||||||
pullCmd,
|
pullCmd,
|
||||||
pushCmd,
|
pushCmd,
|
||||||
listCmd,
|
listCmd,
|
||||||
|
@ -1400,6 +1451,7 @@ func NewCLI() *cobra.Command {
|
||||||
createCmd,
|
createCmd,
|
||||||
showCmd,
|
showCmd,
|
||||||
runCmd,
|
runCmd,
|
||||||
|
stopCmd,
|
||||||
pullCmd,
|
pullCmd,
|
||||||
pushCmd,
|
pushCmd,
|
||||||
listCmd,
|
listCmd,
|
||||||
|
|
|
@ -18,7 +18,6 @@ import (
|
||||||
"github.com/ollama/ollama/api"
|
"github.com/ollama/ollama/api"
|
||||||
"github.com/ollama/ollama/envconfig"
|
"github.com/ollama/ollama/envconfig"
|
||||||
"github.com/ollama/ollama/parser"
|
"github.com/ollama/ollama/parser"
|
||||||
"github.com/ollama/ollama/progress"
|
|
||||||
"github.com/ollama/ollama/readline"
|
"github.com/ollama/ollama/readline"
|
||||||
"github.com/ollama/ollama/types/errtypes"
|
"github.com/ollama/ollama/types/errtypes"
|
||||||
)
|
)
|
||||||
|
@ -31,26 +30,6 @@ const (
|
||||||
MultilineSystem
|
MultilineSystem
|
||||||
)
|
)
|
||||||
|
|
||||||
func loadModel(cmd *cobra.Command, opts *runOptions) error {
|
|
||||||
p := progress.NewProgress(os.Stderr)
|
|
||||||
defer p.StopAndClear()
|
|
||||||
|
|
||||||
spinner := progress.NewSpinner("")
|
|
||||||
p.Add("", spinner)
|
|
||||||
|
|
||||||
client, err := api.ClientFromEnvironment()
|
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
|
|
||||||
chatReq := &api.ChatRequest{
|
|
||||||
Model: opts.Model,
|
|
||||||
KeepAlive: opts.KeepAlive,
|
|
||||||
}
|
|
||||||
|
|
||||||
return client.Chat(cmd.Context(), chatReq, func(api.ChatResponse) error { return nil })
|
|
||||||
}
|
|
||||||
|
|
||||||
func generateInteractive(cmd *cobra.Command, opts runOptions) error {
|
func generateInteractive(cmd *cobra.Command, opts runOptions) error {
|
||||||
usage := func() {
|
usage := func() {
|
||||||
fmt.Fprintln(os.Stderr, "Available Commands:")
|
fmt.Fprintln(os.Stderr, "Available Commands:")
|
||||||
|
@ -217,7 +196,7 @@ func generateInteractive(cmd *cobra.Command, opts runOptions) error {
|
||||||
opts.Model = args[1]
|
opts.Model = args[1]
|
||||||
opts.Messages = []api.Message{}
|
opts.Messages = []api.Message{}
|
||||||
fmt.Printf("Loading model '%s'\n", opts.Model)
|
fmt.Printf("Loading model '%s'\n", opts.Model)
|
||||||
if err := loadModel(cmd, &opts); err != nil {
|
if err := loadOrUnloadModel(cmd, &opts); err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
continue
|
continue
|
||||||
|
|
|
@ -117,6 +117,32 @@ func (s *Server) GenerateHandler(c *gin.Context) {
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// expire the runner
|
||||||
|
if req.Prompt == "" && req.KeepAlive != nil && int(req.KeepAlive.Seconds()) == 0 {
|
||||||
|
model, err := GetModel(req.Model)
|
||||||
|
if err != nil {
|
||||||
|
switch {
|
||||||
|
case os.IsNotExist(err):
|
||||||
|
c.JSON(http.StatusNotFound, gin.H{"error": fmt.Sprintf("model '%s' not found", req.Model)})
|
||||||
|
case err.Error() == "invalid model name":
|
||||||
|
c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()})
|
||||||
|
default:
|
||||||
|
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
|
||||||
|
}
|
||||||
|
return
|
||||||
|
}
|
||||||
|
s.sched.expireRunner(model)
|
||||||
|
|
||||||
|
c.JSON(http.StatusOK, api.GenerateResponse{
|
||||||
|
Model: req.Model,
|
||||||
|
CreatedAt: time.Now().UTC(),
|
||||||
|
Response: "",
|
||||||
|
Done: true,
|
||||||
|
DoneReason: "unload",
|
||||||
|
})
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
if req.Format != "" && req.Format != "json" {
|
if req.Format != "" && req.Format != "json" {
|
||||||
c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": "format must be empty or \"json\""})
|
c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": "format must be empty or \"json\""})
|
||||||
return
|
return
|
||||||
|
@ -1322,6 +1348,32 @@ func (s *Server) ChatHandler(c *gin.Context) {
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// expire the runner
|
||||||
|
if len(req.Messages) == 0 && req.KeepAlive != nil && int(req.KeepAlive.Seconds()) == 0 {
|
||||||
|
model, err := GetModel(req.Model)
|
||||||
|
if err != nil {
|
||||||
|
switch {
|
||||||
|
case os.IsNotExist(err):
|
||||||
|
c.JSON(http.StatusNotFound, gin.H{"error": fmt.Sprintf("model '%s' not found", req.Model)})
|
||||||
|
case err.Error() == "invalid model name":
|
||||||
|
c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()})
|
||||||
|
default:
|
||||||
|
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
|
||||||
|
}
|
||||||
|
return
|
||||||
|
}
|
||||||
|
s.sched.expireRunner(model)
|
||||||
|
|
||||||
|
c.JSON(http.StatusOK, api.ChatResponse{
|
||||||
|
Model: req.Model,
|
||||||
|
CreatedAt: time.Now().UTC(),
|
||||||
|
Message: api.Message{Role: "assistant"},
|
||||||
|
Done: true,
|
||||||
|
DoneReason: "unload",
|
||||||
|
})
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
caps := []Capability{CapabilityCompletion}
|
caps := []Capability{CapabilityCompletion}
|
||||||
if len(req.Tools) > 0 {
|
if len(req.Tools) > 0 {
|
||||||
caps = append(caps, CapabilityTools)
|
caps = append(caps, CapabilityTools)
|
||||||
|
|
|
@ -360,7 +360,6 @@ func (s *Scheduler) processCompleted(ctx context.Context) {
|
||||||
slog.Debug("runner expired event received", "modelPath", runner.modelPath)
|
slog.Debug("runner expired event received", "modelPath", runner.modelPath)
|
||||||
runner.refMu.Lock()
|
runner.refMu.Lock()
|
||||||
if runner.refCount > 0 {
|
if runner.refCount > 0 {
|
||||||
// Shouldn't happen, but safeguard to ensure no leaked runners
|
|
||||||
slog.Debug("expired event with positive ref count, retrying", "modelPath", runner.modelPath, "refCount", runner.refCount)
|
slog.Debug("expired event with positive ref count, retrying", "modelPath", runner.modelPath, "refCount", runner.refCount)
|
||||||
go func(runner *runnerRef) {
|
go func(runner *runnerRef) {
|
||||||
// We can't unload yet, but want to as soon as the current request completes
|
// We can't unload yet, but want to as soon as the current request completes
|
||||||
|
@ -802,6 +801,25 @@ func (s *Scheduler) unloadAllRunners() {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (s *Scheduler) expireRunner(model *Model) {
|
||||||
|
s.loadedMu.Lock()
|
||||||
|
defer s.loadedMu.Unlock()
|
||||||
|
runner, ok := s.loaded[model.ModelPath]
|
||||||
|
if ok {
|
||||||
|
runner.refMu.Lock()
|
||||||
|
runner.expiresAt = time.Now()
|
||||||
|
if runner.expireTimer != nil {
|
||||||
|
runner.expireTimer.Stop()
|
||||||
|
runner.expireTimer = nil
|
||||||
|
}
|
||||||
|
runner.sessionDuration = 0
|
||||||
|
if runner.refCount <= 0 {
|
||||||
|
s.expiredCh <- runner
|
||||||
|
}
|
||||||
|
runner.refMu.Unlock()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
// If other runners are loaded, make sure the pending request will fit in system memory
|
// If other runners are loaded, make sure the pending request will fit in system memory
|
||||||
// If not, pick a runner to unload, else return nil and the request can be loaded
|
// If not, pick a runner to unload, else return nil and the request can be loaded
|
||||||
func (s *Scheduler) maybeFindCPURunnerToUnload(req *LlmRequest, ggml *llm.GGML, gpus gpu.GpuInfoList) *runnerRef {
|
func (s *Scheduler) maybeFindCPURunnerToUnload(req *LlmRequest, ggml *llm.GGML, gpus gpu.GpuInfoList) *runnerRef {
|
||||||
|
|
|
@ -406,6 +406,52 @@ func TestGetRunner(t *testing.T) {
|
||||||
b.ctxDone()
|
b.ctxDone()
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestExpireRunner(t *testing.T) {
|
||||||
|
ctx, done := context.WithTimeout(context.Background(), 20*time.Millisecond)
|
||||||
|
defer done()
|
||||||
|
s := InitScheduler(ctx)
|
||||||
|
req := &LlmRequest{
|
||||||
|
ctx: ctx,
|
||||||
|
model: &Model{ModelPath: "foo"},
|
||||||
|
opts: api.DefaultOptions(),
|
||||||
|
successCh: make(chan *runnerRef, 1),
|
||||||
|
errCh: make(chan error, 1),
|
||||||
|
sessionDuration: &api.Duration{Duration: 2 * time.Minute},
|
||||||
|
}
|
||||||
|
|
||||||
|
var ggml *llm.GGML
|
||||||
|
gpus := gpu.GpuInfoList{}
|
||||||
|
server := &mockLlm{estimatedVRAM: 10, estimatedVRAMByGPU: map[string]uint64{}}
|
||||||
|
s.newServerFn = func(gpus gpu.GpuInfoList, model string, ggml *llm.GGML, adapters []string, projectors []string, opts api.Options, numParallel int) (llm.LlamaServer, error) {
|
||||||
|
return server, nil
|
||||||
|
}
|
||||||
|
s.load(req, ggml, gpus, 0)
|
||||||
|
|
||||||
|
select {
|
||||||
|
case err := <-req.errCh:
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("expected no errors when loading, got '%s'", err.Error())
|
||||||
|
}
|
||||||
|
case resp := <-req.successCh:
|
||||||
|
s.loadedMu.Lock()
|
||||||
|
if resp.refCount != uint(1) || len(s.loaded) != 1 {
|
||||||
|
t.Fatalf("expected a model to be loaded")
|
||||||
|
}
|
||||||
|
s.loadedMu.Unlock()
|
||||||
|
}
|
||||||
|
|
||||||
|
s.expireRunner(&Model{ModelPath: "foo"})
|
||||||
|
|
||||||
|
s.finishedReqCh <- req
|
||||||
|
s.processCompleted(ctx)
|
||||||
|
|
||||||
|
s.loadedMu.Lock()
|
||||||
|
if len(s.loaded) != 0 {
|
||||||
|
t.Fatalf("expected model to be unloaded")
|
||||||
|
}
|
||||||
|
s.loadedMu.Unlock()
|
||||||
|
}
|
||||||
|
|
||||||
// TODO - add one scenario that triggers the bogus finished event with positive ref count
|
// TODO - add one scenario that triggers the bogus finished event with positive ref count
|
||||||
func TestPrematureExpired(t *testing.T) {
|
func TestPrematureExpired(t *testing.T) {
|
||||||
ctx, done := context.WithTimeout(context.Background(), 500*time.Millisecond)
|
ctx, done := context.WithTimeout(context.Background(), 500*time.Millisecond)
|
||||||
|
|
Loading…
Reference in a new issue