From b5cf31b4606a1faa083bd713ea9233bcf46ee570 Mon Sep 17 00:00:00 2001 From: Patrick Devine Date: Fri, 26 Jan 2024 14:28:02 -0800 Subject: [PATCH] add keep_alive to generate/chat/embedding api endpoints (#2146) --- api/types.go | 42 +++++++++++++++++++++++++----------------- server/routes.go | 26 +++++++++++++++++++++++--- 2 files changed, 48 insertions(+), 20 deletions(-) diff --git a/api/types.go b/api/types.go index 585daf6c..609c4a8a 100644 --- a/api/types.go +++ b/api/types.go @@ -34,24 +34,26 @@ func (e StatusError) Error() string { type ImageData []byte type GenerateRequest struct { - Model string `json:"model"` - Prompt string `json:"prompt"` - System string `json:"system"` - Template string `json:"template"` - Context []int `json:"context,omitempty"` - Stream *bool `json:"stream,omitempty"` - Raw bool `json:"raw,omitempty"` - Format string `json:"format"` - Images []ImageData `json:"images,omitempty"` + Model string `json:"model"` + Prompt string `json:"prompt"` + System string `json:"system"` + Template string `json:"template"` + Context []int `json:"context,omitempty"` + Stream *bool `json:"stream,omitempty"` + Raw bool `json:"raw,omitempty"` + Format string `json:"format"` + KeepAlive *Duration `json:"keep_alive,omitempty"` + Images []ImageData `json:"images,omitempty"` Options map[string]interface{} `json:"options"` } type ChatRequest struct { - Model string `json:"model"` - Messages []Message `json:"messages"` - Stream *bool `json:"stream,omitempty"` - Format string `json:"format"` + Model string `json:"model"` + Messages []Message `json:"messages"` + Stream *bool `json:"stream,omitempty"` + Format string `json:"format"` + KeepAlive *Duration `json:"keep_alive,omitempty"` Options map[string]interface{} `json:"options"` } @@ -126,8 +128,9 @@ type Runner struct { } type EmbeddingRequest struct { - Model string `json:"model"` - Prompt string `json:"prompt"` + Model string `json:"model"` + Prompt string `json:"prompt"` + KeepAlive *Duration `json:"keep_alive,omitempty"` Options map[string]interface{} `json:"options"` } @@ -413,14 +416,19 @@ func (d *Duration) UnmarshalJSON(b []byte) (err error) { case float64: if t < 0 { t = math.MaxFloat64 + d.Duration = time.Duration(t) + } else { + d.Duration = time.Duration(t * float64(time.Second)) } - - d.Duration = time.Duration(t) case string: d.Duration, err = time.ParseDuration(t) if err != nil { return err } + if d.Duration < 0 { + mf := math.MaxFloat64 + d.Duration = time.Duration(mf) + } } return nil diff --git a/server/routes.go b/server/routes.go index 141f05d4..56c275c9 100644 --- a/server/routes.go +++ b/server/routes.go @@ -186,7 +186,13 @@ func GenerateHandler(c *gin.Context) { return } - sessionDuration := defaultSessionDuration + var sessionDuration time.Duration + if req.KeepAlive == nil { + sessionDuration = defaultSessionDuration + } else { + sessionDuration = req.KeepAlive.Duration + } + if err := load(c, model, opts, sessionDuration); err != nil { c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()}) return @@ -378,7 +384,14 @@ func EmbeddingHandler(c *gin.Context) { c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()}) return } - sessionDuration := defaultSessionDuration + + var sessionDuration time.Duration + if req.KeepAlive == nil { + sessionDuration = defaultSessionDuration + } else { + sessionDuration = req.KeepAlive.Duration + } + if err := load(c, model, opts, sessionDuration); err != nil { c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()}) return @@ -1074,7 +1087,14 @@ func ChatHandler(c *gin.Context) { c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()}) return } - sessionDuration := defaultSessionDuration + + var sessionDuration time.Duration + if req.KeepAlive == nil { + sessionDuration = defaultSessionDuration + } else { + sessionDuration = req.KeepAlive.Duration + } + if err := load(c, model, opts, sessionDuration); err != nil { c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()}) return