add keep_alive to generate/chat/embedding api endpoints (#2146)
This commit is contained in:
parent
cc4915e262
commit
b5cf31b460
2 changed files with 48 additions and 20 deletions
12
api/types.go
12
api/types.go
|
@ -42,6 +42,7 @@ type GenerateRequest struct {
|
||||||
Stream *bool `json:"stream,omitempty"`
|
Stream *bool `json:"stream,omitempty"`
|
||||||
Raw bool `json:"raw,omitempty"`
|
Raw bool `json:"raw,omitempty"`
|
||||||
Format string `json:"format"`
|
Format string `json:"format"`
|
||||||
|
KeepAlive *Duration `json:"keep_alive,omitempty"`
|
||||||
Images []ImageData `json:"images,omitempty"`
|
Images []ImageData `json:"images,omitempty"`
|
||||||
|
|
||||||
Options map[string]interface{} `json:"options"`
|
Options map[string]interface{} `json:"options"`
|
||||||
|
@ -52,6 +53,7 @@ type ChatRequest struct {
|
||||||
Messages []Message `json:"messages"`
|
Messages []Message `json:"messages"`
|
||||||
Stream *bool `json:"stream,omitempty"`
|
Stream *bool `json:"stream,omitempty"`
|
||||||
Format string `json:"format"`
|
Format string `json:"format"`
|
||||||
|
KeepAlive *Duration `json:"keep_alive,omitempty"`
|
||||||
|
|
||||||
Options map[string]interface{} `json:"options"`
|
Options map[string]interface{} `json:"options"`
|
||||||
}
|
}
|
||||||
|
@ -128,6 +130,7 @@ type Runner struct {
|
||||||
type EmbeddingRequest struct {
|
type EmbeddingRequest struct {
|
||||||
Model string `json:"model"`
|
Model string `json:"model"`
|
||||||
Prompt string `json:"prompt"`
|
Prompt string `json:"prompt"`
|
||||||
|
KeepAlive *Duration `json:"keep_alive,omitempty"`
|
||||||
|
|
||||||
Options map[string]interface{} `json:"options"`
|
Options map[string]interface{} `json:"options"`
|
||||||
}
|
}
|
||||||
|
@ -413,14 +416,19 @@ func (d *Duration) UnmarshalJSON(b []byte) (err error) {
|
||||||
case float64:
|
case float64:
|
||||||
if t < 0 {
|
if t < 0 {
|
||||||
t = math.MaxFloat64
|
t = math.MaxFloat64
|
||||||
}
|
|
||||||
|
|
||||||
d.Duration = time.Duration(t)
|
d.Duration = time.Duration(t)
|
||||||
|
} else {
|
||||||
|
d.Duration = time.Duration(t * float64(time.Second))
|
||||||
|
}
|
||||||
case string:
|
case string:
|
||||||
d.Duration, err = time.ParseDuration(t)
|
d.Duration, err = time.ParseDuration(t)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
if d.Duration < 0 {
|
||||||
|
mf := math.MaxFloat64
|
||||||
|
d.Duration = time.Duration(mf)
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
return nil
|
return nil
|
||||||
|
|
|
@ -186,7 +186,13 @@ func GenerateHandler(c *gin.Context) {
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
sessionDuration := defaultSessionDuration
|
var sessionDuration time.Duration
|
||||||
|
if req.KeepAlive == nil {
|
||||||
|
sessionDuration = defaultSessionDuration
|
||||||
|
} else {
|
||||||
|
sessionDuration = req.KeepAlive.Duration
|
||||||
|
}
|
||||||
|
|
||||||
if err := load(c, model, opts, sessionDuration); err != nil {
|
if err := load(c, model, opts, sessionDuration); err != nil {
|
||||||
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
|
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
|
||||||
return
|
return
|
||||||
|
@ -378,7 +384,14 @@ func EmbeddingHandler(c *gin.Context) {
|
||||||
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
|
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
sessionDuration := defaultSessionDuration
|
|
||||||
|
var sessionDuration time.Duration
|
||||||
|
if req.KeepAlive == nil {
|
||||||
|
sessionDuration = defaultSessionDuration
|
||||||
|
} else {
|
||||||
|
sessionDuration = req.KeepAlive.Duration
|
||||||
|
}
|
||||||
|
|
||||||
if err := load(c, model, opts, sessionDuration); err != nil {
|
if err := load(c, model, opts, sessionDuration); err != nil {
|
||||||
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
|
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
|
||||||
return
|
return
|
||||||
|
@ -1074,7 +1087,14 @@ func ChatHandler(c *gin.Context) {
|
||||||
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
|
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
sessionDuration := defaultSessionDuration
|
|
||||||
|
var sessionDuration time.Duration
|
||||||
|
if req.KeepAlive == nil {
|
||||||
|
sessionDuration = defaultSessionDuration
|
||||||
|
} else {
|
||||||
|
sessionDuration = req.KeepAlive.Duration
|
||||||
|
}
|
||||||
|
|
||||||
if err := load(c, model, opts, sessionDuration); err != nil {
|
if err := load(c, model, opts, sessionDuration); err != nil {
|
||||||
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
|
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
|
||||||
return
|
return
|
||||||
|
|
Loading…
Reference in a new issue