fix: relay request opts to loaded llm prediction (#1761)

This commit is contained in:
Bruce MacDonald 2024-01-03 12:01:42 -05:00 committed by GitHub
parent 05face44ef
commit 0b3118e0af
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
5 changed files with 106 additions and 71 deletions

View file

@ -153,7 +153,7 @@ func newExtServer(server extServer, model string, adapters, projectors []string,
return server, nil return server, nil
} }
func predict(llm extServer, opts api.Options, ctx context.Context, predict PredictOpts, fn func(PredictResult)) error { func predict(ctx context.Context, llm extServer, predict PredictOpts, fn func(PredictResult)) error {
resp := newExtServerResp(128) resp := newExtServerResp(128)
defer freeExtServerResp(resp) defer freeExtServerResp(resp)
var imageData []ImageData var imageData []ImageData
@ -167,23 +167,23 @@ func predict(llm extServer, opts api.Options, ctx context.Context, predict Predi
request := map[string]any{ request := map[string]any{
"prompt": predict.Prompt, "prompt": predict.Prompt,
"stream": true, "stream": true,
"n_predict": opts.NumPredict, "n_predict": predict.Options.NumPredict,
"n_keep": opts.NumKeep, "n_keep": predict.Options.NumKeep,
"temperature": opts.Temperature, "temperature": predict.Options.Temperature,
"top_k": opts.TopK, "top_k": predict.Options.TopK,
"top_p": opts.TopP, "top_p": predict.Options.TopP,
"tfs_z": opts.TFSZ, "tfs_z": predict.Options.TFSZ,
"typical_p": opts.TypicalP, "typical_p": predict.Options.TypicalP,
"repeat_last_n": opts.RepeatLastN, "repeat_last_n": predict.Options.RepeatLastN,
"repeat_penalty": opts.RepeatPenalty, "repeat_penalty": predict.Options.RepeatPenalty,
"presence_penalty": opts.PresencePenalty, "presence_penalty": predict.Options.PresencePenalty,
"frequency_penalty": opts.FrequencyPenalty, "frequency_penalty": predict.Options.FrequencyPenalty,
"mirostat": opts.Mirostat, "mirostat": predict.Options.Mirostat,
"mirostat_tau": opts.MirostatTau, "mirostat_tau": predict.Options.MirostatTau,
"mirostat_eta": opts.MirostatEta, "mirostat_eta": predict.Options.MirostatEta,
"penalize_nl": opts.PenalizeNewline, "penalize_nl": predict.Options.PenalizeNewline,
"seed": opts.Seed, "seed": predict.Options.Seed,
"stop": opts.Stop, "stop": predict.Options.Stop,
"image_data": imageData, "image_data": imageData,
"cache_prompt": true, "cache_prompt": true,
} }

View file

@ -60,7 +60,7 @@ func newDefaultExtServer(model string, adapters, projectors []string, numLayers
} }
func (llm *llamaExtServer) Predict(ctx context.Context, pred PredictOpts, fn func(PredictResult)) error { func (llm *llamaExtServer) Predict(ctx context.Context, pred PredictOpts, fn func(PredictResult)) error {
return predict(llm, llm.Options, ctx, pred, fn) return predict(ctx, llm, pred, fn)
} }
func (llm *llamaExtServer) Encode(ctx context.Context, prompt string) ([]int, error) { func (llm *llamaExtServer) Encode(ctx context.Context, prompt string) ([]int, error) {

View file

@ -166,9 +166,10 @@ const maxRetries = 3
const retryDelay = 1 * time.Second const retryDelay = 1 * time.Second
type PredictOpts struct { type PredictOpts struct {
Prompt string Prompt string
Format string Format string
Images []api.ImageData Images []api.ImageData
Options api.Options
} }
type PredictResult struct { type PredictResult struct {

View file

@ -92,7 +92,7 @@ func newDynamicShimExtServer(library, model string, adapters, projectors []strin
} }
func (llm *shimExtServer) Predict(ctx context.Context, pred PredictOpts, fn func(PredictResult)) error { func (llm *shimExtServer) Predict(ctx context.Context, pred PredictOpts, fn func(PredictResult)) error {
return predict(llm, llm.options, ctx, pred, fn) return predict(ctx, llm, pred, fn)
} }
func (llm *shimExtServer) Encode(ctx context.Context, prompt string) ([]int, error) { func (llm *shimExtServer) Encode(ctx context.Context, prompt string) ([]int, error) {

View file

@ -64,24 +64,9 @@ var loaded struct {
var defaultSessionDuration = 5 * time.Minute var defaultSessionDuration = 5 * time.Minute
// load a model into memory if it is not already loaded, it is up to the caller to lock loaded.mu before calling this function // load a model into memory if it is not already loaded, it is up to the caller to lock loaded.mu before calling this function
func load(c *gin.Context, modelName string, reqOpts map[string]interface{}, sessionDuration time.Duration) (*Model, error) { func load(c *gin.Context, model *Model, opts api.Options, sessionDuration time.Duration) error {
model, err := GetModel(modelName)
if err != nil {
return nil, err
}
workDir := c.GetString("workDir") workDir := c.GetString("workDir")
opts := api.DefaultOptions()
if err := opts.FromMap(model.Options); err != nil {
log.Printf("could not load model options: %v", err)
return nil, err
}
if err := opts.FromMap(reqOpts); err != nil {
return nil, err
}
needLoad := loaded.runner == nil || // is there a model loaded? needLoad := loaded.runner == nil || // is there a model loaded?
loaded.ModelPath != model.ModelPath || // has the base model changed? loaded.ModelPath != model.ModelPath || // has the base model changed?
!reflect.DeepEqual(loaded.AdapterPaths, model.AdapterPaths) || // have the adapters changed? !reflect.DeepEqual(loaded.AdapterPaths, model.AdapterPaths) || // have the adapters changed?
@ -105,7 +90,7 @@ func load(c *gin.Context, modelName string, reqOpts map[string]interface{}, sess
err = fmt.Errorf("%v: this model may be incompatible with your version of Ollama. If you previously pulled this model, try updating it by running `ollama pull %s`", err, model.ShortName) err = fmt.Errorf("%v: this model may be incompatible with your version of Ollama. If you previously pulled this model, try updating it by running `ollama pull %s`", err, model.ShortName)
} }
return nil, err return err
} }
loaded.Model = model loaded.Model = model
@ -135,7 +120,20 @@ func load(c *gin.Context, modelName string, reqOpts map[string]interface{}, sess
} }
loaded.expireTimer.Reset(sessionDuration) loaded.expireTimer.Reset(sessionDuration)
return model, nil return nil
}
func modelOptions(model *Model, requestOpts map[string]interface{}) (api.Options, error) {
opts := api.DefaultOptions()
if err := opts.FromMap(model.Options); err != nil {
return api.Options{}, err
}
if err := opts.FromMap(requestOpts); err != nil {
return api.Options{}, err
}
return opts, nil
} }
func GenerateHandler(c *gin.Context) { func GenerateHandler(c *gin.Context) {
@ -168,18 +166,30 @@ func GenerateHandler(c *gin.Context) {
return return
} }
sessionDuration := defaultSessionDuration model, err := GetModel(req.Model)
model, err := load(c, req.Model, req.Options, sessionDuration)
if err != nil { if err != nil {
var pErr *fs.PathError var pErr *fs.PathError
switch { if errors.As(err, &pErr) {
case errors.As(err, &pErr):
c.JSON(http.StatusNotFound, gin.H{"error": fmt.Sprintf("model '%s' not found, try pulling it first", req.Model)}) c.JSON(http.StatusNotFound, gin.H{"error": fmt.Sprintf("model '%s' not found, try pulling it first", req.Model)})
case errors.Is(err, api.ErrInvalidOpts): return
c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()})
default:
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
} }
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
return
}
opts, err := modelOptions(model, req.Options)
if err != nil {
if errors.Is(err, api.ErrInvalidOpts) {
c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()})
return
}
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
return
}
sessionDuration := defaultSessionDuration
if err := load(c, model, opts, sessionDuration); err != nil {
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
return return
} }
@ -287,9 +297,10 @@ func GenerateHandler(c *gin.Context) {
// Start prediction // Start prediction
predictReq := llm.PredictOpts{ predictReq := llm.PredictOpts{
Prompt: prompt, Prompt: prompt,
Format: req.Format, Format: req.Format,
Images: req.Images, Images: req.Images,
Options: opts,
} }
if err := loaded.runner.Predict(c.Request.Context(), predictReq, fn); err != nil { if err := loaded.runner.Predict(c.Request.Context(), predictReq, fn); err != nil {
ch <- gin.H{"error": err.Error()} ch <- gin.H{"error": err.Error()}
@ -347,18 +358,29 @@ func EmbeddingHandler(c *gin.Context) {
return return
} }
sessionDuration := defaultSessionDuration model, err := GetModel(req.Model)
_, err = load(c, req.Model, req.Options, sessionDuration)
if err != nil { if err != nil {
var pErr *fs.PathError var pErr *fs.PathError
switch { if errors.As(err, &pErr) {
case errors.As(err, &pErr):
c.JSON(http.StatusNotFound, gin.H{"error": fmt.Sprintf("model '%s' not found, try pulling it first", req.Model)}) c.JSON(http.StatusNotFound, gin.H{"error": fmt.Sprintf("model '%s' not found, try pulling it first", req.Model)})
case errors.Is(err, api.ErrInvalidOpts): return
c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()})
default:
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
} }
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
return
}
opts, err := modelOptions(model, req.Options)
if err != nil {
if errors.Is(err, api.ErrInvalidOpts) {
c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()})
return
}
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
return
}
sessionDuration := defaultSessionDuration
if err := load(c, model, opts, sessionDuration); err != nil {
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
return return
} }
@ -991,18 +1013,29 @@ func ChatHandler(c *gin.Context) {
return return
} }
sessionDuration := defaultSessionDuration model, err := GetModel(req.Model)
model, err := load(c, req.Model, req.Options, sessionDuration)
if err != nil { if err != nil {
var pErr *fs.PathError var pErr *fs.PathError
switch { if errors.As(err, &pErr) {
case errors.As(err, &pErr):
c.JSON(http.StatusNotFound, gin.H{"error": fmt.Sprintf("model '%s' not found, try pulling it first", req.Model)}) c.JSON(http.StatusNotFound, gin.H{"error": fmt.Sprintf("model '%s' not found, try pulling it first", req.Model)})
case errors.Is(err, api.ErrInvalidOpts): return
c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()})
default:
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
} }
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
return
}
opts, err := modelOptions(model, req.Options)
if err != nil {
if errors.Is(err, api.ErrInvalidOpts) {
c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()})
return
}
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
return
}
sessionDuration := defaultSessionDuration
if err := load(c, model, opts, sessionDuration); err != nil {
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
return return
} }
@ -1053,9 +1086,10 @@ func ChatHandler(c *gin.Context) {
// Start prediction // Start prediction
predictReq := llm.PredictOpts{ predictReq := llm.PredictOpts{
Prompt: prompt, Prompt: prompt,
Format: req.Format, Format: req.Format,
Images: images, Images: images,
Options: opts,
} }
if err := loaded.runner.Predict(c.Request.Context(), predictReq, fn); err != nil { if err := loaded.runner.Predict(c.Request.Context(), predictReq, fn); err != nil {
ch <- gin.H{"error": err.Error()} ch <- gin.H{"error": err.Error()}