From 0b3118e0afe1a4658264081979b04aab9fda82d6 Mon Sep 17 00:00:00 2001 From: Bruce MacDonald Date: Wed, 3 Jan 2024 12:01:42 -0500 Subject: [PATCH] fix: relay request opts to loaded llm prediction (#1761) --- llm/ext_server_common.go | 36 +++++------ llm/ext_server_default.go | 2 +- llm/llama.go | 7 +- llm/shim_ext_server.go | 2 +- server/routes.go | 130 ++++++++++++++++++++++++-------------- 5 files changed, 106 insertions(+), 71 deletions(-) diff --git a/llm/ext_server_common.go b/llm/ext_server_common.go index 8e5f34e3..470df412 100644 --- a/llm/ext_server_common.go +++ b/llm/ext_server_common.go @@ -153,7 +153,7 @@ func newExtServer(server extServer, model string, adapters, projectors []string, return server, nil } -func predict(llm extServer, opts api.Options, ctx context.Context, predict PredictOpts, fn func(PredictResult)) error { +func predict(ctx context.Context, llm extServer, predict PredictOpts, fn func(PredictResult)) error { resp := newExtServerResp(128) defer freeExtServerResp(resp) var imageData []ImageData @@ -167,23 +167,23 @@ func predict(llm extServer, opts api.Options, ctx context.Context, predict Predi request := map[string]any{ "prompt": predict.Prompt, "stream": true, - "n_predict": opts.NumPredict, - "n_keep": opts.NumKeep, - "temperature": opts.Temperature, - "top_k": opts.TopK, - "top_p": opts.TopP, - "tfs_z": opts.TFSZ, - "typical_p": opts.TypicalP, - "repeat_last_n": opts.RepeatLastN, - "repeat_penalty": opts.RepeatPenalty, - "presence_penalty": opts.PresencePenalty, - "frequency_penalty": opts.FrequencyPenalty, - "mirostat": opts.Mirostat, - "mirostat_tau": opts.MirostatTau, - "mirostat_eta": opts.MirostatEta, - "penalize_nl": opts.PenalizeNewline, - "seed": opts.Seed, - "stop": opts.Stop, + "n_predict": predict.Options.NumPredict, + "n_keep": predict.Options.NumKeep, + "temperature": predict.Options.Temperature, + "top_k": predict.Options.TopK, + "top_p": predict.Options.TopP, + "tfs_z": predict.Options.TFSZ, + "typical_p": predict.Options.TypicalP, + "repeat_last_n": predict.Options.RepeatLastN, + "repeat_penalty": predict.Options.RepeatPenalty, + "presence_penalty": predict.Options.PresencePenalty, + "frequency_penalty": predict.Options.FrequencyPenalty, + "mirostat": predict.Options.Mirostat, + "mirostat_tau": predict.Options.MirostatTau, + "mirostat_eta": predict.Options.MirostatEta, + "penalize_nl": predict.Options.PenalizeNewline, + "seed": predict.Options.Seed, + "stop": predict.Options.Stop, "image_data": imageData, "cache_prompt": true, } diff --git a/llm/ext_server_default.go b/llm/ext_server_default.go index 797eb851..80e00081 100644 --- a/llm/ext_server_default.go +++ b/llm/ext_server_default.go @@ -60,7 +60,7 @@ func newDefaultExtServer(model string, adapters, projectors []string, numLayers } func (llm *llamaExtServer) Predict(ctx context.Context, pred PredictOpts, fn func(PredictResult)) error { - return predict(llm, llm.Options, ctx, pred, fn) + return predict(ctx, llm, pred, fn) } func (llm *llamaExtServer) Encode(ctx context.Context, prompt string) ([]int, error) { diff --git a/llm/llama.go b/llm/llama.go index f8e22960..89616cf0 100644 --- a/llm/llama.go +++ b/llm/llama.go @@ -166,9 +166,10 @@ const maxRetries = 3 const retryDelay = 1 * time.Second type PredictOpts struct { - Prompt string - Format string - Images []api.ImageData + Prompt string + Format string + Images []api.ImageData + Options api.Options } type PredictResult struct { diff --git a/llm/shim_ext_server.go b/llm/shim_ext_server.go index 146456d3..96192731 100644 --- a/llm/shim_ext_server.go +++ b/llm/shim_ext_server.go @@ -92,7 +92,7 @@ func newDynamicShimExtServer(library, model string, adapters, projectors []strin } func (llm *shimExtServer) Predict(ctx context.Context, pred PredictOpts, fn func(PredictResult)) error { - return predict(llm, llm.options, ctx, pred, fn) + return predict(ctx, llm, pred, fn) } func (llm *shimExtServer) Encode(ctx context.Context, prompt string) ([]int, error) { diff --git a/server/routes.go b/server/routes.go index 123f964b..d4485afc 100644 --- a/server/routes.go +++ b/server/routes.go @@ -64,24 +64,9 @@ var loaded struct { var defaultSessionDuration = 5 * time.Minute // load a model into memory if it is not already loaded, it is up to the caller to lock loaded.mu before calling this function -func load(c *gin.Context, modelName string, reqOpts map[string]interface{}, sessionDuration time.Duration) (*Model, error) { - model, err := GetModel(modelName) - if err != nil { - return nil, err - } - +func load(c *gin.Context, model *Model, opts api.Options, sessionDuration time.Duration) error { workDir := c.GetString("workDir") - opts := api.DefaultOptions() - if err := opts.FromMap(model.Options); err != nil { - log.Printf("could not load model options: %v", err) - return nil, err - } - - if err := opts.FromMap(reqOpts); err != nil { - return nil, err - } - needLoad := loaded.runner == nil || // is there a model loaded? loaded.ModelPath != model.ModelPath || // has the base model changed? !reflect.DeepEqual(loaded.AdapterPaths, model.AdapterPaths) || // have the adapters changed? @@ -105,7 +90,7 @@ func load(c *gin.Context, modelName string, reqOpts map[string]interface{}, sess err = fmt.Errorf("%v: this model may be incompatible with your version of Ollama. If you previously pulled this model, try updating it by running `ollama pull %s`", err, model.ShortName) } - return nil, err + return err } loaded.Model = model @@ -135,7 +120,20 @@ func load(c *gin.Context, modelName string, reqOpts map[string]interface{}, sess } loaded.expireTimer.Reset(sessionDuration) - return model, nil + return nil +} + +func modelOptions(model *Model, requestOpts map[string]interface{}) (api.Options, error) { + opts := api.DefaultOptions() + if err := opts.FromMap(model.Options); err != nil { + return api.Options{}, err + } + + if err := opts.FromMap(requestOpts); err != nil { + return api.Options{}, err + } + + return opts, nil } func GenerateHandler(c *gin.Context) { @@ -168,18 +166,30 @@ func GenerateHandler(c *gin.Context) { return } - sessionDuration := defaultSessionDuration - model, err := load(c, req.Model, req.Options, sessionDuration) + model, err := GetModel(req.Model) if err != nil { var pErr *fs.PathError - switch { - case errors.As(err, &pErr): + if errors.As(err, &pErr) { c.JSON(http.StatusNotFound, gin.H{"error": fmt.Sprintf("model '%s' not found, try pulling it first", req.Model)}) - case errors.Is(err, api.ErrInvalidOpts): - c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()}) - default: - c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()}) + return } + c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()}) + return + } + + opts, err := modelOptions(model, req.Options) + if err != nil { + if errors.Is(err, api.ErrInvalidOpts) { + c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()}) + return + } + c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()}) + return + } + + sessionDuration := defaultSessionDuration + if err := load(c, model, opts, sessionDuration); err != nil { + c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()}) return } @@ -287,9 +297,10 @@ func GenerateHandler(c *gin.Context) { // Start prediction predictReq := llm.PredictOpts{ - Prompt: prompt, - Format: req.Format, - Images: req.Images, + Prompt: prompt, + Format: req.Format, + Images: req.Images, + Options: opts, } if err := loaded.runner.Predict(c.Request.Context(), predictReq, fn); err != nil { ch <- gin.H{"error": err.Error()} @@ -347,18 +358,29 @@ func EmbeddingHandler(c *gin.Context) { return } - sessionDuration := defaultSessionDuration - _, err = load(c, req.Model, req.Options, sessionDuration) + model, err := GetModel(req.Model) if err != nil { var pErr *fs.PathError - switch { - case errors.As(err, &pErr): + if errors.As(err, &pErr) { c.JSON(http.StatusNotFound, gin.H{"error": fmt.Sprintf("model '%s' not found, try pulling it first", req.Model)}) - case errors.Is(err, api.ErrInvalidOpts): - c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()}) - default: - c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()}) + return } + c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()}) + return + } + + opts, err := modelOptions(model, req.Options) + if err != nil { + if errors.Is(err, api.ErrInvalidOpts) { + c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()}) + return + } + c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()}) + return + } + sessionDuration := defaultSessionDuration + if err := load(c, model, opts, sessionDuration); err != nil { + c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()}) return } @@ -991,18 +1013,29 @@ func ChatHandler(c *gin.Context) { return } - sessionDuration := defaultSessionDuration - model, err := load(c, req.Model, req.Options, sessionDuration) + model, err := GetModel(req.Model) if err != nil { var pErr *fs.PathError - switch { - case errors.As(err, &pErr): + if errors.As(err, &pErr) { c.JSON(http.StatusNotFound, gin.H{"error": fmt.Sprintf("model '%s' not found, try pulling it first", req.Model)}) - case errors.Is(err, api.ErrInvalidOpts): - c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()}) - default: - c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()}) + return } + c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()}) + return + } + + opts, err := modelOptions(model, req.Options) + if err != nil { + if errors.Is(err, api.ErrInvalidOpts) { + c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()}) + return + } + c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()}) + return + } + sessionDuration := defaultSessionDuration + if err := load(c, model, opts, sessionDuration); err != nil { + c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()}) return } @@ -1053,9 +1086,10 @@ func ChatHandler(c *gin.Context) { // Start prediction predictReq := llm.PredictOpts{ - Prompt: prompt, - Format: req.Format, - Images: images, + Prompt: prompt, + Format: req.Format, + Images: images, + Options: opts, } if err := loaded.runner.Predict(c.Request.Context(), predictReq, fn); err != nil { ch <- gin.H{"error": err.Error()}