From ac7a842e550721fbc00e36e416e7cf6606993149 Mon Sep 17 00:00:00 2001 From: Michael Yang Date: Wed, 3 Jul 2024 09:00:07 -0700 Subject: [PATCH] fix model reloading ensure runtime model changes (template, system prompt, messages, options) are captured on model updates without needing to reload the server --- llm/server.go | 2 +- server/routes.go | 42 ++++++++++++++++++++++-------------------- 2 files changed, 23 insertions(+), 21 deletions(-) diff --git a/llm/server.go b/llm/server.go index 206f9e39..229d61e4 100644 --- a/llm/server.go +++ b/llm/server.go @@ -679,7 +679,7 @@ type CompletionRequest struct { Prompt string Format string Images []ImageData - Options api.Options + Options *api.Options } type CompletionResponse struct { diff --git a/server/routes.go b/server/routes.go index 1a93e977..4059c7c5 100644 --- a/server/routes.go +++ b/server/routes.go @@ -69,23 +69,25 @@ func modelOptions(model *Model, requestOpts map[string]interface{}) (api.Options return opts, nil } -func (s *Server) scheduleRunner(ctx context.Context, name string, caps []Capability, requestOpts map[string]any, keepAlive *api.Duration) (*runnerRef, error) { +// scheduleRunner schedules a runner after validating inputs such as capabilities and model options. +// It returns the allocated runner, model instance, and consolidated options if successful and error otherwise. +func (s *Server) scheduleRunner(ctx context.Context, name string, caps []Capability, requestOpts map[string]any, keepAlive *api.Duration) (llm.LlamaServer, *Model, *api.Options, error) { if name == "" { - return nil, fmt.Errorf("model %w", errRequired) + return nil, nil, nil, fmt.Errorf("model %w", errRequired) } model, err := GetModel(name) if err != nil { - return nil, err + return nil, nil, nil, err } if err := model.CheckCapabilities(caps...); err != nil { - return nil, fmt.Errorf("%s %w", name, err) + return nil, nil, nil, fmt.Errorf("%s %w", name, err) } opts, err := modelOptions(model, requestOpts) if err != nil { - return nil, err + return nil, nil, nil, err } runnerCh, errCh := s.sched.GetRunner(ctx, model, opts, keepAlive) @@ -93,10 +95,10 @@ func (s *Server) scheduleRunner(ctx context.Context, name string, caps []Capabil select { case runner = <-runnerCh: case err = <-errCh: - return nil, err + return nil, nil, nil, err } - return runner, nil + return runner.llama, model, &opts, nil } func (s *Server) GenerateHandler(c *gin.Context) { @@ -118,7 +120,7 @@ func (s *Server) GenerateHandler(c *gin.Context) { } caps := []Capability{CapabilityCompletion} - r, err := s.scheduleRunner(c.Request.Context(), req.Model, caps, req.Options, req.KeepAlive) + r, m, opts, err := s.scheduleRunner(c.Request.Context(), req.Model, caps, req.Options, req.KeepAlive) if errors.Is(err, errCapabilityCompletion) { c.JSON(http.StatusBadRequest, gin.H{"error": fmt.Sprintf("%q does not support generate", req.Model)}) return @@ -147,8 +149,8 @@ func (s *Server) GenerateHandler(c *gin.Context) { var msgs []api.Message if req.System != "" { msgs = append(msgs, api.Message{Role: "system", Content: req.System}) - } else if r.model.System != "" { - msgs = append(msgs, api.Message{Role: "system", Content: r.model.System}) + } else if m.System != "" { + msgs = append(msgs, api.Message{Role: "system", Content: m.System}) } for _, i := range images { @@ -157,7 +159,7 @@ func (s *Server) GenerateHandler(c *gin.Context) { msgs = append(msgs, api.Message{Role: "user", Content: req.Prompt}) - tmpl := r.model.Template + tmpl := m.Template if req.Template != "" { tmpl, err = template.Parse(req.Template) if err != nil { @@ -168,7 +170,7 @@ func (s *Server) GenerateHandler(c *gin.Context) { var b bytes.Buffer if req.Context != nil { - s, err := r.llama.Detokenize(c.Request.Context(), req.Context) + s, err := r.Detokenize(c.Request.Context(), req.Context) if err != nil { c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()}) return @@ -190,11 +192,11 @@ func (s *Server) GenerateHandler(c *gin.Context) { ch := make(chan any) go func() { defer close(ch) - if err := r.llama.Completion(c.Request.Context(), llm.CompletionRequest{ + if err := r.Completion(c.Request.Context(), llm.CompletionRequest{ Prompt: prompt, Images: images, Format: req.Format, - Options: *r.Options, + Options: opts, }, func(r llm.CompletionResponse) { ch <- api.GenerateResponse{ Model: req.Model, @@ -254,7 +256,7 @@ func (s *Server) EmbeddingsHandler(c *gin.Context) { return } - r, err := s.scheduleRunner(c.Request.Context(), req.Model, []Capability{}, req.Options, req.KeepAlive) + r, _, _, err := s.scheduleRunner(c.Request.Context(), req.Model, []Capability{}, req.Options, req.KeepAlive) if err != nil { handleScheduleError(c, req.Model, err) return @@ -266,7 +268,7 @@ func (s *Server) EmbeddingsHandler(c *gin.Context) { return } - embedding, err := r.llama.Embedding(c.Request.Context(), req.Prompt) + embedding, err := r.Embedding(c.Request.Context(), req.Prompt) if err != nil { slog.Info(fmt.Sprintf("embedding generation failed: %v", err)) c.JSON(http.StatusInternalServerError, gin.H{"error": "failed to generate embedding"}) @@ -1130,7 +1132,7 @@ func (s *Server) ChatHandler(c *gin.Context) { } caps := []Capability{CapabilityCompletion} - r, err := s.scheduleRunner(c.Request.Context(), req.Model, caps, req.Options, req.KeepAlive) + r, m, opts, err := s.scheduleRunner(c.Request.Context(), req.Model, caps, req.Options, req.KeepAlive) if errors.Is(err, errCapabilityCompletion) { c.JSON(http.StatusBadRequest, gin.H{"error": fmt.Sprintf("%q does not support chat", req.Model)}) return @@ -1150,7 +1152,7 @@ func (s *Server) ChatHandler(c *gin.Context) { return } - prompt, images, err := chatPrompt(c.Request.Context(), r.model, r.llama.Tokenize, r.Options, req.Messages) + prompt, images, err := chatPrompt(c.Request.Context(), m, r.Tokenize, opts, req.Messages) if err != nil { c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()}) return @@ -1161,11 +1163,11 @@ func (s *Server) ChatHandler(c *gin.Context) { ch := make(chan any) go func() { defer close(ch) - if err := r.llama.Completion(c.Request.Context(), llm.CompletionRequest{ + if err := r.Completion(c.Request.Context(), llm.CompletionRequest{ Prompt: prompt, Images: images, Format: req.Format, - Options: *r.Options, + Options: opts, }, func(r llm.CompletionResponse) { ch <- api.ChatResponse{ Model: req.Model,