fix model reloading

ensure runtime model changes (template, system prompt, messages,
options) are captured on model updates without needing to reload the
server
This commit is contained in:
Michael Yang 2024-07-03 09:00:07 -07:00
parent 2c3fe1fd97
commit ac7a842e55
2 changed files with 23 additions and 21 deletions

View file

@ -679,7 +679,7 @@ type CompletionRequest struct {
Prompt string Prompt string
Format string Format string
Images []ImageData Images []ImageData
Options api.Options Options *api.Options
} }
type CompletionResponse struct { type CompletionResponse struct {

View file

@ -69,23 +69,25 @@ func modelOptions(model *Model, requestOpts map[string]interface{}) (api.Options
return opts, nil return opts, nil
} }
func (s *Server) scheduleRunner(ctx context.Context, name string, caps []Capability, requestOpts map[string]any, keepAlive *api.Duration) (*runnerRef, error) { // scheduleRunner schedules a runner after validating inputs such as capabilities and model options.
// It returns the allocated runner, model instance, and consolidated options if successful and error otherwise.
func (s *Server) scheduleRunner(ctx context.Context, name string, caps []Capability, requestOpts map[string]any, keepAlive *api.Duration) (llm.LlamaServer, *Model, *api.Options, error) {
if name == "" { if name == "" {
return nil, fmt.Errorf("model %w", errRequired) return nil, nil, nil, fmt.Errorf("model %w", errRequired)
} }
model, err := GetModel(name) model, err := GetModel(name)
if err != nil { if err != nil {
return nil, err return nil, nil, nil, err
} }
if err := model.CheckCapabilities(caps...); err != nil { if err := model.CheckCapabilities(caps...); err != nil {
return nil, fmt.Errorf("%s %w", name, err) return nil, nil, nil, fmt.Errorf("%s %w", name, err)
} }
opts, err := modelOptions(model, requestOpts) opts, err := modelOptions(model, requestOpts)
if err != nil { if err != nil {
return nil, err return nil, nil, nil, err
} }
runnerCh, errCh := s.sched.GetRunner(ctx, model, opts, keepAlive) runnerCh, errCh := s.sched.GetRunner(ctx, model, opts, keepAlive)
@ -93,10 +95,10 @@ func (s *Server) scheduleRunner(ctx context.Context, name string, caps []Capabil
select { select {
case runner = <-runnerCh: case runner = <-runnerCh:
case err = <-errCh: case err = <-errCh:
return nil, err return nil, nil, nil, err
} }
return runner, nil return runner.llama, model, &opts, nil
} }
func (s *Server) GenerateHandler(c *gin.Context) { func (s *Server) GenerateHandler(c *gin.Context) {
@ -118,7 +120,7 @@ func (s *Server) GenerateHandler(c *gin.Context) {
} }
caps := []Capability{CapabilityCompletion} caps := []Capability{CapabilityCompletion}
r, err := s.scheduleRunner(c.Request.Context(), req.Model, caps, req.Options, req.KeepAlive) r, m, opts, err := s.scheduleRunner(c.Request.Context(), req.Model, caps, req.Options, req.KeepAlive)
if errors.Is(err, errCapabilityCompletion) { if errors.Is(err, errCapabilityCompletion) {
c.JSON(http.StatusBadRequest, gin.H{"error": fmt.Sprintf("%q does not support generate", req.Model)}) c.JSON(http.StatusBadRequest, gin.H{"error": fmt.Sprintf("%q does not support generate", req.Model)})
return return
@ -147,8 +149,8 @@ func (s *Server) GenerateHandler(c *gin.Context) {
var msgs []api.Message var msgs []api.Message
if req.System != "" { if req.System != "" {
msgs = append(msgs, api.Message{Role: "system", Content: req.System}) msgs = append(msgs, api.Message{Role: "system", Content: req.System})
} else if r.model.System != "" { } else if m.System != "" {
msgs = append(msgs, api.Message{Role: "system", Content: r.model.System}) msgs = append(msgs, api.Message{Role: "system", Content: m.System})
} }
for _, i := range images { for _, i := range images {
@ -157,7 +159,7 @@ func (s *Server) GenerateHandler(c *gin.Context) {
msgs = append(msgs, api.Message{Role: "user", Content: req.Prompt}) msgs = append(msgs, api.Message{Role: "user", Content: req.Prompt})
tmpl := r.model.Template tmpl := m.Template
if req.Template != "" { if req.Template != "" {
tmpl, err = template.Parse(req.Template) tmpl, err = template.Parse(req.Template)
if err != nil { if err != nil {
@ -168,7 +170,7 @@ func (s *Server) GenerateHandler(c *gin.Context) {
var b bytes.Buffer var b bytes.Buffer
if req.Context != nil { if req.Context != nil {
s, err := r.llama.Detokenize(c.Request.Context(), req.Context) s, err := r.Detokenize(c.Request.Context(), req.Context)
if err != nil { if err != nil {
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()}) c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
return return
@ -190,11 +192,11 @@ func (s *Server) GenerateHandler(c *gin.Context) {
ch := make(chan any) ch := make(chan any)
go func() { go func() {
defer close(ch) defer close(ch)
if err := r.llama.Completion(c.Request.Context(), llm.CompletionRequest{ if err := r.Completion(c.Request.Context(), llm.CompletionRequest{
Prompt: prompt, Prompt: prompt,
Images: images, Images: images,
Format: req.Format, Format: req.Format,
Options: *r.Options, Options: opts,
}, func(r llm.CompletionResponse) { }, func(r llm.CompletionResponse) {
ch <- api.GenerateResponse{ ch <- api.GenerateResponse{
Model: req.Model, Model: req.Model,
@ -254,7 +256,7 @@ func (s *Server) EmbeddingsHandler(c *gin.Context) {
return return
} }
r, err := s.scheduleRunner(c.Request.Context(), req.Model, []Capability{}, req.Options, req.KeepAlive) r, _, _, err := s.scheduleRunner(c.Request.Context(), req.Model, []Capability{}, req.Options, req.KeepAlive)
if err != nil { if err != nil {
handleScheduleError(c, req.Model, err) handleScheduleError(c, req.Model, err)
return return
@ -266,7 +268,7 @@ func (s *Server) EmbeddingsHandler(c *gin.Context) {
return return
} }
embedding, err := r.llama.Embedding(c.Request.Context(), req.Prompt) embedding, err := r.Embedding(c.Request.Context(), req.Prompt)
if err != nil { if err != nil {
slog.Info(fmt.Sprintf("embedding generation failed: %v", err)) slog.Info(fmt.Sprintf("embedding generation failed: %v", err))
c.JSON(http.StatusInternalServerError, gin.H{"error": "failed to generate embedding"}) c.JSON(http.StatusInternalServerError, gin.H{"error": "failed to generate embedding"})
@ -1130,7 +1132,7 @@ func (s *Server) ChatHandler(c *gin.Context) {
} }
caps := []Capability{CapabilityCompletion} caps := []Capability{CapabilityCompletion}
r, err := s.scheduleRunner(c.Request.Context(), req.Model, caps, req.Options, req.KeepAlive) r, m, opts, err := s.scheduleRunner(c.Request.Context(), req.Model, caps, req.Options, req.KeepAlive)
if errors.Is(err, errCapabilityCompletion) { if errors.Is(err, errCapabilityCompletion) {
c.JSON(http.StatusBadRequest, gin.H{"error": fmt.Sprintf("%q does not support chat", req.Model)}) c.JSON(http.StatusBadRequest, gin.H{"error": fmt.Sprintf("%q does not support chat", req.Model)})
return return
@ -1150,7 +1152,7 @@ func (s *Server) ChatHandler(c *gin.Context) {
return return
} }
prompt, images, err := chatPrompt(c.Request.Context(), r.model, r.llama.Tokenize, r.Options, req.Messages) prompt, images, err := chatPrompt(c.Request.Context(), m, r.Tokenize, opts, req.Messages)
if err != nil { if err != nil {
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()}) c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
return return
@ -1161,11 +1163,11 @@ func (s *Server) ChatHandler(c *gin.Context) {
ch := make(chan any) ch := make(chan any)
go func() { go func() {
defer close(ch) defer close(ch)
if err := r.llama.Completion(c.Request.Context(), llm.CompletionRequest{ if err := r.Completion(c.Request.Context(), llm.CompletionRequest{
Prompt: prompt, Prompt: prompt,
Images: images, Images: images,
Format: req.Format, Format: req.Format,
Options: *r.Options, Options: opts,
}, func(r llm.CompletionResponse) { }, func(r llm.CompletionResponse) {
ch <- api.ChatResponse{ ch <- api.ChatResponse{
Model: req.Model, Model: req.Model,