fix: relay request opts to loaded llm prediction (#1761)
This commit is contained in:
parent
05face44ef
commit
0b3118e0af
5 changed files with 106 additions and 71 deletions
|
@ -153,7 +153,7 @@ func newExtServer(server extServer, model string, adapters, projectors []string,
|
||||||
return server, nil
|
return server, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func predict(llm extServer, opts api.Options, ctx context.Context, predict PredictOpts, fn func(PredictResult)) error {
|
func predict(ctx context.Context, llm extServer, predict PredictOpts, fn func(PredictResult)) error {
|
||||||
resp := newExtServerResp(128)
|
resp := newExtServerResp(128)
|
||||||
defer freeExtServerResp(resp)
|
defer freeExtServerResp(resp)
|
||||||
var imageData []ImageData
|
var imageData []ImageData
|
||||||
|
@ -167,23 +167,23 @@ func predict(llm extServer, opts api.Options, ctx context.Context, predict Predi
|
||||||
request := map[string]any{
|
request := map[string]any{
|
||||||
"prompt": predict.Prompt,
|
"prompt": predict.Prompt,
|
||||||
"stream": true,
|
"stream": true,
|
||||||
"n_predict": opts.NumPredict,
|
"n_predict": predict.Options.NumPredict,
|
||||||
"n_keep": opts.NumKeep,
|
"n_keep": predict.Options.NumKeep,
|
||||||
"temperature": opts.Temperature,
|
"temperature": predict.Options.Temperature,
|
||||||
"top_k": opts.TopK,
|
"top_k": predict.Options.TopK,
|
||||||
"top_p": opts.TopP,
|
"top_p": predict.Options.TopP,
|
||||||
"tfs_z": opts.TFSZ,
|
"tfs_z": predict.Options.TFSZ,
|
||||||
"typical_p": opts.TypicalP,
|
"typical_p": predict.Options.TypicalP,
|
||||||
"repeat_last_n": opts.RepeatLastN,
|
"repeat_last_n": predict.Options.RepeatLastN,
|
||||||
"repeat_penalty": opts.RepeatPenalty,
|
"repeat_penalty": predict.Options.RepeatPenalty,
|
||||||
"presence_penalty": opts.PresencePenalty,
|
"presence_penalty": predict.Options.PresencePenalty,
|
||||||
"frequency_penalty": opts.FrequencyPenalty,
|
"frequency_penalty": predict.Options.FrequencyPenalty,
|
||||||
"mirostat": opts.Mirostat,
|
"mirostat": predict.Options.Mirostat,
|
||||||
"mirostat_tau": opts.MirostatTau,
|
"mirostat_tau": predict.Options.MirostatTau,
|
||||||
"mirostat_eta": opts.MirostatEta,
|
"mirostat_eta": predict.Options.MirostatEta,
|
||||||
"penalize_nl": opts.PenalizeNewline,
|
"penalize_nl": predict.Options.PenalizeNewline,
|
||||||
"seed": opts.Seed,
|
"seed": predict.Options.Seed,
|
||||||
"stop": opts.Stop,
|
"stop": predict.Options.Stop,
|
||||||
"image_data": imageData,
|
"image_data": imageData,
|
||||||
"cache_prompt": true,
|
"cache_prompt": true,
|
||||||
}
|
}
|
||||||
|
|
|
@ -60,7 +60,7 @@ func newDefaultExtServer(model string, adapters, projectors []string, numLayers
|
||||||
}
|
}
|
||||||
|
|
||||||
func (llm *llamaExtServer) Predict(ctx context.Context, pred PredictOpts, fn func(PredictResult)) error {
|
func (llm *llamaExtServer) Predict(ctx context.Context, pred PredictOpts, fn func(PredictResult)) error {
|
||||||
return predict(llm, llm.Options, ctx, pred, fn)
|
return predict(ctx, llm, pred, fn)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (llm *llamaExtServer) Encode(ctx context.Context, prompt string) ([]int, error) {
|
func (llm *llamaExtServer) Encode(ctx context.Context, prompt string) ([]int, error) {
|
||||||
|
|
|
@ -166,9 +166,10 @@ const maxRetries = 3
|
||||||
const retryDelay = 1 * time.Second
|
const retryDelay = 1 * time.Second
|
||||||
|
|
||||||
type PredictOpts struct {
|
type PredictOpts struct {
|
||||||
Prompt string
|
Prompt string
|
||||||
Format string
|
Format string
|
||||||
Images []api.ImageData
|
Images []api.ImageData
|
||||||
|
Options api.Options
|
||||||
}
|
}
|
||||||
|
|
||||||
type PredictResult struct {
|
type PredictResult struct {
|
||||||
|
|
|
@ -92,7 +92,7 @@ func newDynamicShimExtServer(library, model string, adapters, projectors []strin
|
||||||
}
|
}
|
||||||
|
|
||||||
func (llm *shimExtServer) Predict(ctx context.Context, pred PredictOpts, fn func(PredictResult)) error {
|
func (llm *shimExtServer) Predict(ctx context.Context, pred PredictOpts, fn func(PredictResult)) error {
|
||||||
return predict(llm, llm.options, ctx, pred, fn)
|
return predict(ctx, llm, pred, fn)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (llm *shimExtServer) Encode(ctx context.Context, prompt string) ([]int, error) {
|
func (llm *shimExtServer) Encode(ctx context.Context, prompt string) ([]int, error) {
|
||||||
|
|
130
server/routes.go
130
server/routes.go
|
@ -64,24 +64,9 @@ var loaded struct {
|
||||||
var defaultSessionDuration = 5 * time.Minute
|
var defaultSessionDuration = 5 * time.Minute
|
||||||
|
|
||||||
// load a model into memory if it is not already loaded, it is up to the caller to lock loaded.mu before calling this function
|
// load a model into memory if it is not already loaded, it is up to the caller to lock loaded.mu before calling this function
|
||||||
func load(c *gin.Context, modelName string, reqOpts map[string]interface{}, sessionDuration time.Duration) (*Model, error) {
|
func load(c *gin.Context, model *Model, opts api.Options, sessionDuration time.Duration) error {
|
||||||
model, err := GetModel(modelName)
|
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
|
|
||||||
workDir := c.GetString("workDir")
|
workDir := c.GetString("workDir")
|
||||||
|
|
||||||
opts := api.DefaultOptions()
|
|
||||||
if err := opts.FromMap(model.Options); err != nil {
|
|
||||||
log.Printf("could not load model options: %v", err)
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
|
|
||||||
if err := opts.FromMap(reqOpts); err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
|
|
||||||
needLoad := loaded.runner == nil || // is there a model loaded?
|
needLoad := loaded.runner == nil || // is there a model loaded?
|
||||||
loaded.ModelPath != model.ModelPath || // has the base model changed?
|
loaded.ModelPath != model.ModelPath || // has the base model changed?
|
||||||
!reflect.DeepEqual(loaded.AdapterPaths, model.AdapterPaths) || // have the adapters changed?
|
!reflect.DeepEqual(loaded.AdapterPaths, model.AdapterPaths) || // have the adapters changed?
|
||||||
|
@ -105,7 +90,7 @@ func load(c *gin.Context, modelName string, reqOpts map[string]interface{}, sess
|
||||||
err = fmt.Errorf("%v: this model may be incompatible with your version of Ollama. If you previously pulled this model, try updating it by running `ollama pull %s`", err, model.ShortName)
|
err = fmt.Errorf("%v: this model may be incompatible with your version of Ollama. If you previously pulled this model, try updating it by running `ollama pull %s`", err, model.ShortName)
|
||||||
}
|
}
|
||||||
|
|
||||||
return nil, err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
loaded.Model = model
|
loaded.Model = model
|
||||||
|
@ -135,7 +120,20 @@ func load(c *gin.Context, modelName string, reqOpts map[string]interface{}, sess
|
||||||
}
|
}
|
||||||
|
|
||||||
loaded.expireTimer.Reset(sessionDuration)
|
loaded.expireTimer.Reset(sessionDuration)
|
||||||
return model, nil
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func modelOptions(model *Model, requestOpts map[string]interface{}) (api.Options, error) {
|
||||||
|
opts := api.DefaultOptions()
|
||||||
|
if err := opts.FromMap(model.Options); err != nil {
|
||||||
|
return api.Options{}, err
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := opts.FromMap(requestOpts); err != nil {
|
||||||
|
return api.Options{}, err
|
||||||
|
}
|
||||||
|
|
||||||
|
return opts, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func GenerateHandler(c *gin.Context) {
|
func GenerateHandler(c *gin.Context) {
|
||||||
|
@ -168,18 +166,30 @@ func GenerateHandler(c *gin.Context) {
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
sessionDuration := defaultSessionDuration
|
model, err := GetModel(req.Model)
|
||||||
model, err := load(c, req.Model, req.Options, sessionDuration)
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
var pErr *fs.PathError
|
var pErr *fs.PathError
|
||||||
switch {
|
if errors.As(err, &pErr) {
|
||||||
case errors.As(err, &pErr):
|
|
||||||
c.JSON(http.StatusNotFound, gin.H{"error": fmt.Sprintf("model '%s' not found, try pulling it first", req.Model)})
|
c.JSON(http.StatusNotFound, gin.H{"error": fmt.Sprintf("model '%s' not found, try pulling it first", req.Model)})
|
||||||
case errors.Is(err, api.ErrInvalidOpts):
|
return
|
||||||
c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()})
|
|
||||||
default:
|
|
||||||
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
|
|
||||||
}
|
}
|
||||||
|
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
opts, err := modelOptions(model, req.Options)
|
||||||
|
if err != nil {
|
||||||
|
if errors.Is(err, api.ErrInvalidOpts) {
|
||||||
|
c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()})
|
||||||
|
return
|
||||||
|
}
|
||||||
|
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
sessionDuration := defaultSessionDuration
|
||||||
|
if err := load(c, model, opts, sessionDuration); err != nil {
|
||||||
|
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -287,9 +297,10 @@ func GenerateHandler(c *gin.Context) {
|
||||||
|
|
||||||
// Start prediction
|
// Start prediction
|
||||||
predictReq := llm.PredictOpts{
|
predictReq := llm.PredictOpts{
|
||||||
Prompt: prompt,
|
Prompt: prompt,
|
||||||
Format: req.Format,
|
Format: req.Format,
|
||||||
Images: req.Images,
|
Images: req.Images,
|
||||||
|
Options: opts,
|
||||||
}
|
}
|
||||||
if err := loaded.runner.Predict(c.Request.Context(), predictReq, fn); err != nil {
|
if err := loaded.runner.Predict(c.Request.Context(), predictReq, fn); err != nil {
|
||||||
ch <- gin.H{"error": err.Error()}
|
ch <- gin.H{"error": err.Error()}
|
||||||
|
@ -347,18 +358,29 @@ func EmbeddingHandler(c *gin.Context) {
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
sessionDuration := defaultSessionDuration
|
model, err := GetModel(req.Model)
|
||||||
_, err = load(c, req.Model, req.Options, sessionDuration)
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
var pErr *fs.PathError
|
var pErr *fs.PathError
|
||||||
switch {
|
if errors.As(err, &pErr) {
|
||||||
case errors.As(err, &pErr):
|
|
||||||
c.JSON(http.StatusNotFound, gin.H{"error": fmt.Sprintf("model '%s' not found, try pulling it first", req.Model)})
|
c.JSON(http.StatusNotFound, gin.H{"error": fmt.Sprintf("model '%s' not found, try pulling it first", req.Model)})
|
||||||
case errors.Is(err, api.ErrInvalidOpts):
|
return
|
||||||
c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()})
|
|
||||||
default:
|
|
||||||
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
|
|
||||||
}
|
}
|
||||||
|
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
opts, err := modelOptions(model, req.Options)
|
||||||
|
if err != nil {
|
||||||
|
if errors.Is(err, api.ErrInvalidOpts) {
|
||||||
|
c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()})
|
||||||
|
return
|
||||||
|
}
|
||||||
|
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
|
||||||
|
return
|
||||||
|
}
|
||||||
|
sessionDuration := defaultSessionDuration
|
||||||
|
if err := load(c, model, opts, sessionDuration); err != nil {
|
||||||
|
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -991,18 +1013,29 @@ func ChatHandler(c *gin.Context) {
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
sessionDuration := defaultSessionDuration
|
model, err := GetModel(req.Model)
|
||||||
model, err := load(c, req.Model, req.Options, sessionDuration)
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
var pErr *fs.PathError
|
var pErr *fs.PathError
|
||||||
switch {
|
if errors.As(err, &pErr) {
|
||||||
case errors.As(err, &pErr):
|
|
||||||
c.JSON(http.StatusNotFound, gin.H{"error": fmt.Sprintf("model '%s' not found, try pulling it first", req.Model)})
|
c.JSON(http.StatusNotFound, gin.H{"error": fmt.Sprintf("model '%s' not found, try pulling it first", req.Model)})
|
||||||
case errors.Is(err, api.ErrInvalidOpts):
|
return
|
||||||
c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()})
|
|
||||||
default:
|
|
||||||
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
|
|
||||||
}
|
}
|
||||||
|
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
opts, err := modelOptions(model, req.Options)
|
||||||
|
if err != nil {
|
||||||
|
if errors.Is(err, api.ErrInvalidOpts) {
|
||||||
|
c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()})
|
||||||
|
return
|
||||||
|
}
|
||||||
|
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
|
||||||
|
return
|
||||||
|
}
|
||||||
|
sessionDuration := defaultSessionDuration
|
||||||
|
if err := load(c, model, opts, sessionDuration); err != nil {
|
||||||
|
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -1053,9 +1086,10 @@ func ChatHandler(c *gin.Context) {
|
||||||
|
|
||||||
// Start prediction
|
// Start prediction
|
||||||
predictReq := llm.PredictOpts{
|
predictReq := llm.PredictOpts{
|
||||||
Prompt: prompt,
|
Prompt: prompt,
|
||||||
Format: req.Format,
|
Format: req.Format,
|
||||||
Images: images,
|
Images: images,
|
||||||
|
Options: opts,
|
||||||
}
|
}
|
||||||
if err := loaded.runner.Predict(c.Request.Context(), predictReq, fn); err != nil {
|
if err := loaded.runner.Predict(c.Request.Context(), predictReq, fn); err != nil {
|
||||||
ch <- gin.H{"error": err.Error()}
|
ch <- gin.H{"error": err.Error()}
|
||||||
|
|
Loading…
Reference in a new issue