From fe6f3b48f74e06b6e2e5377adcb622602f2acc8f Mon Sep 17 00:00:00 2001 From: Bruce MacDonald Date: Thu, 19 Oct 2023 10:39:58 -0400 Subject: [PATCH] do not reload the running llm when runtime params change (#840) - only reload the running llm if the model has changed, or the options for loading the running model have changed - rename loaded llm to runner to differentiate from loaded model image - remove logic which keeps the first system prompt in the generation context --- api/types.go | 52 ++++++++++++++-------------- server/images.go | 12 +++---- server/routes.go | 88 +++++++++++++++++++----------------------------- 3 files changed, 66 insertions(+), 86 deletions(-) diff --git a/api/types.go b/api/types.go index 8ea7a425..b0bd5d93 100644 --- a/api/types.go +++ b/api/types.go @@ -161,15 +161,10 @@ func (r *GenerateResponse) Summary() { } } -type Options struct { - Seed int `json:"seed,omitempty"` - - // Backend options - UseNUMA bool `json:"numa,omitempty"` - - // Model options +// Runner options which must be set when the model is loaded into memory +type Runner struct { + UseNUMA bool `json:"numa,omitempty"` NumCtx int `json:"num_ctx,omitempty"` - NumKeep int `json:"num_keep,omitempty"` NumBatch int `json:"num_batch,omitempty"` NumGQA int `json:"num_gqa,omitempty"` NumGPU int `json:"num_gpu,omitempty"` @@ -183,8 +178,15 @@ type Options struct { EmbeddingOnly bool `json:"embedding_only,omitempty"` RopeFrequencyBase float32 `json:"rope_frequency_base,omitempty"` RopeFrequencyScale float32 `json:"rope_frequency_scale,omitempty"` + NumThread int `json:"num_thread,omitempty"` +} - // Predict options +type Options struct { + Runner + + // Predict options used at runtime + NumKeep int `json:"num_keep,omitempty"` + Seed int `json:"seed,omitempty"` NumPredict int `json:"num_predict,omitempty"` TopK int `json:"top_k,omitempty"` TopP float32 `json:"top_p,omitempty"` @@ -200,8 +202,6 @@ type Options struct { MirostatEta float32 `json:"mirostat_eta,omitempty"` PenalizeNewline bool `json:"penalize_newline,omitempty"` Stop []string `json:"stop,omitempty"` - - NumThread int `json:"num_thread,omitempty"` } var ErrInvalidOpts = fmt.Errorf("invalid options") @@ -309,20 +309,22 @@ func DefaultOptions() Options { PenalizeNewline: true, Seed: -1, - // options set when the model is loaded - NumCtx: 2048, - RopeFrequencyBase: 10000.0, - RopeFrequencyScale: 1.0, - NumBatch: 512, - NumGPU: -1, // -1 here indicates that NumGPU should be set dynamically - NumGQA: 1, - NumThread: 0, // let the runtime decide - LowVRAM: false, - F16KV: true, - UseMLock: false, - UseMMap: true, - UseNUMA: false, - EmbeddingOnly: true, + Runner: Runner{ + // options set when the model is loaded + NumCtx: 2048, + RopeFrequencyBase: 10000.0, + RopeFrequencyScale: 1.0, + NumBatch: 512, + NumGPU: -1, // -1 here indicates that NumGPU should be set dynamically + NumGQA: 1, + NumThread: 0, // let the runtime decide + LowVRAM: false, + F16KV: true, + UseMLock: false, + UseMMap: true, + UseNUMA: false, + EmbeddingOnly: true, + }, } } diff --git a/server/images.go b/server/images.go index 8fa9e5af..5514c643 100644 --- a/server/images.go +++ b/server/images.go @@ -45,7 +45,6 @@ type Model struct { System string License []string Digest string - ConfigDigest string Options map[string]interface{} } @@ -166,12 +165,11 @@ func GetModel(name string) (*Model, error) { } model := &Model{ - Name: mp.GetFullTagname(), - ShortName: mp.GetShortTagname(), - Digest: digest, - ConfigDigest: manifest.Config.Digest, - Template: "{{ .Prompt }}", - License: []string{}, + Name: mp.GetFullTagname(), + ShortName: mp.GetShortTagname(), + Digest: digest, + Template: "{{ .Prompt }}", + License: []string{}, } for _, layer := range manifest.Layers { diff --git a/server/routes.go b/server/routes.go index c02ca3ed..3cee381e 100644 --- a/server/routes.go +++ b/server/routes.go @@ -46,13 +46,13 @@ func init() { var loaded struct { mu sync.Mutex - llm llm.LLM + runner llm.LLM expireAt time.Time expireTimer *time.Timer - digest string - options api.Options + *Model + *api.Options } var defaultSessionDuration = 5 * time.Minute @@ -70,59 +70,39 @@ func load(ctx context.Context, workDir string, model *Model, reqOpts map[string] } // check if the loaded model is still running in a subprocess, in case something unexpected happened - if loaded.llm != nil { - if err := loaded.llm.Ping(ctx); err != nil { + if loaded.runner != nil { + if err := loaded.runner.Ping(ctx); err != nil { log.Print("loaded llm process not responding, closing now") // the subprocess is no longer running, so close it - loaded.llm.Close() - loaded.llm = nil - loaded.digest = "" + loaded.runner.Close() + loaded.runner = nil + loaded.Model = nil + loaded.Options = nil } } - if model.Digest != loaded.digest || !reflect.DeepEqual(loaded.options, opts) { - if loaded.llm != nil { + needLoad := loaded.runner == nil || // is there a model loaded? + loaded.ModelPath != model.ModelPath || // has the base model changed? + !reflect.DeepEqual(loaded.AdapterPaths, model.AdapterPaths) || // have the adapters changed? + !reflect.DeepEqual(loaded.Options.Runner, opts.Runner) // have the runner options changed? + + if needLoad { + if loaded.runner != nil { log.Println("changing loaded model") - loaded.llm.Close() - loaded.llm = nil - loaded.digest = "" + loaded.runner.Close() + loaded.runner = nil + loaded.Model = nil + loaded.Options = nil } - llmModel, err := llm.New(workDir, model.ModelPath, model.AdapterPaths, opts) + llmRunner, err := llm.New(workDir, model.ModelPath, model.AdapterPaths, opts) if err != nil { return err } - // set cache values before modifying opts - loaded.llm = llmModel - loaded.digest = model.Digest - loaded.options = opts - - if opts.NumKeep < 0 { - promptWithSystem, err := model.Prompt(api.GenerateRequest{}) - if err != nil { - return err - } - - promptNoSystem, err := model.Prompt(api.GenerateRequest{Context: []int{0}}) - if err != nil { - return err - } - - tokensWithSystem, err := llmModel.Encode(ctx, promptWithSystem) - if err != nil { - return err - } - - tokensNoSystem, err := llmModel.Encode(ctx, promptNoSystem) - if err != nil { - return err - } - - opts.NumKeep = len(tokensWithSystem) - len(tokensNoSystem) - - llmModel.SetOptions(opts) - } + loaded.Model = model + loaded.runner = llmRunner + loaded.Options = &opts } loaded.expireAt = time.Now().Add(sessionDuration) @@ -136,13 +116,13 @@ func load(ctx context.Context, workDir string, model *Model, reqOpts map[string] return } - if loaded.llm == nil { - return + if loaded.runner != nil { + loaded.runner.Close() } - loaded.llm.Close() - loaded.llm = nil - loaded.digest = "" + loaded.runner = nil + loaded.Model = nil + loaded.Options = nil }) } @@ -215,7 +195,7 @@ func GenerateHandler(c *gin.Context) { if req.Prompt == "" && req.Template == "" && req.System == "" { ch <- api.GenerateResponse{Model: req.Model, Done: true} } else { - if err := loaded.llm.Predict(c.Request.Context(), req.Context, prompt, fn); err != nil { + if err := loaded.runner.Predict(c.Request.Context(), req.Context, prompt, fn); err != nil { ch <- gin.H{"error": err.Error()} } } @@ -263,12 +243,12 @@ func EmbeddingHandler(c *gin.Context) { return } - if !loaded.options.EmbeddingOnly { + if !loaded.Options.EmbeddingOnly { c.JSON(http.StatusBadRequest, gin.H{"error": "embedding option must be set to true"}) return } - embedding, err := loaded.llm.Embedding(c.Request.Context(), req.Prompt) + embedding, err := loaded.runner.Embedding(c.Request.Context(), req.Prompt) if err != nil { log.Printf("embedding generation failed: %v", err) c.JSON(http.StatusInternalServerError, gin.H{"error": "failed to generate embedding"}) @@ -599,8 +579,8 @@ func Serve(ln net.Listener, allowOrigins []string) error { signal.Notify(signals, syscall.SIGINT, syscall.SIGTERM) go func() { <-signals - if loaded.llm != nil { - loaded.llm.Close() + if loaded.runner != nil { + loaded.runner.Close() } os.RemoveAll(workDir) os.Exit(0)