add keep_alive to generate/chat/embedding api endpoints (#2146)

This commit is contained in:
Patrick Devine 2024-01-26 14:28:02 -08:00 committed by GitHub
parent cc4915e262
commit b5cf31b460
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
2 changed files with 48 additions and 20 deletions

View file

@ -42,6 +42,7 @@ type GenerateRequest struct {
Stream *bool `json:"stream,omitempty"` Stream *bool `json:"stream,omitempty"`
Raw bool `json:"raw,omitempty"` Raw bool `json:"raw,omitempty"`
Format string `json:"format"` Format string `json:"format"`
KeepAlive *Duration `json:"keep_alive,omitempty"`
Images []ImageData `json:"images,omitempty"` Images []ImageData `json:"images,omitempty"`
Options map[string]interface{} `json:"options"` Options map[string]interface{} `json:"options"`
@ -52,6 +53,7 @@ type ChatRequest struct {
Messages []Message `json:"messages"` Messages []Message `json:"messages"`
Stream *bool `json:"stream,omitempty"` Stream *bool `json:"stream,omitempty"`
Format string `json:"format"` Format string `json:"format"`
KeepAlive *Duration `json:"keep_alive,omitempty"`
Options map[string]interface{} `json:"options"` Options map[string]interface{} `json:"options"`
} }
@ -128,6 +130,7 @@ type Runner struct {
type EmbeddingRequest struct { type EmbeddingRequest struct {
Model string `json:"model"` Model string `json:"model"`
Prompt string `json:"prompt"` Prompt string `json:"prompt"`
KeepAlive *Duration `json:"keep_alive,omitempty"`
Options map[string]interface{} `json:"options"` Options map[string]interface{} `json:"options"`
} }
@ -413,14 +416,19 @@ func (d *Duration) UnmarshalJSON(b []byte) (err error) {
case float64: case float64:
if t < 0 { if t < 0 {
t = math.MaxFloat64 t = math.MaxFloat64
}
d.Duration = time.Duration(t) d.Duration = time.Duration(t)
} else {
d.Duration = time.Duration(t * float64(time.Second))
}
case string: case string:
d.Duration, err = time.ParseDuration(t) d.Duration, err = time.ParseDuration(t)
if err != nil { if err != nil {
return err return err
} }
if d.Duration < 0 {
mf := math.MaxFloat64
d.Duration = time.Duration(mf)
}
} }
return nil return nil

View file

@ -186,7 +186,13 @@ func GenerateHandler(c *gin.Context) {
return return
} }
sessionDuration := defaultSessionDuration var sessionDuration time.Duration
if req.KeepAlive == nil {
sessionDuration = defaultSessionDuration
} else {
sessionDuration = req.KeepAlive.Duration
}
if err := load(c, model, opts, sessionDuration); err != nil { if err := load(c, model, opts, sessionDuration); err != nil {
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()}) c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
return return
@ -378,7 +384,14 @@ func EmbeddingHandler(c *gin.Context) {
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()}) c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
return return
} }
sessionDuration := defaultSessionDuration
var sessionDuration time.Duration
if req.KeepAlive == nil {
sessionDuration = defaultSessionDuration
} else {
sessionDuration = req.KeepAlive.Duration
}
if err := load(c, model, opts, sessionDuration); err != nil { if err := load(c, model, opts, sessionDuration); err != nil {
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()}) c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
return return
@ -1074,7 +1087,14 @@ func ChatHandler(c *gin.Context) {
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()}) c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
return return
} }
sessionDuration := defaultSessionDuration
var sessionDuration time.Duration
if req.KeepAlive == nil {
sessionDuration = defaultSessionDuration
} else {
sessionDuration = req.KeepAlive.Duration
}
if err := load(c, model, opts, sessionDuration); err != nil { if err := load(c, model, opts, sessionDuration); err != nil {
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()}) c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
return return