server: fix context
, load_duration
and total_duration
fields (#5676)
* server: fix `contet`, `load_duration` and `total_duration` fields * Update server/routes.go
This commit is contained in:
parent
ef98803d63
commit
1ed0aa8fea
1 changed files with 46 additions and 10 deletions
|
@ -102,6 +102,7 @@ func (s *Server) scheduleRunner(ctx context.Context, name string, caps []Capabil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *Server) GenerateHandler(c *gin.Context) {
|
func (s *Server) GenerateHandler(c *gin.Context) {
|
||||||
|
checkpointStart := time.Now()
|
||||||
var req api.GenerateRequest
|
var req api.GenerateRequest
|
||||||
if err := c.ShouldBindJSON(&req); errors.Is(err, io.EOF) {
|
if err := c.ShouldBindJSON(&req); errors.Is(err, io.EOF) {
|
||||||
c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": "missing request body"})
|
c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": "missing request body"})
|
||||||
|
@ -129,6 +130,8 @@ func (s *Server) GenerateHandler(c *gin.Context) {
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
|
checkpointLoaded := time.Now()
|
||||||
|
|
||||||
if req.Prompt == "" {
|
if req.Prompt == "" {
|
||||||
c.JSON(http.StatusOK, api.GenerateResponse{
|
c.JSON(http.StatusOK, api.GenerateResponse{
|
||||||
Model: req.Model,
|
Model: req.Model,
|
||||||
|
@ -191,26 +194,48 @@ func (s *Server) GenerateHandler(c *gin.Context) {
|
||||||
|
|
||||||
ch := make(chan any)
|
ch := make(chan any)
|
||||||
go func() {
|
go func() {
|
||||||
|
// TODO (jmorganca): avoid building the response twice both here and below
|
||||||
|
var sb strings.Builder
|
||||||
defer close(ch)
|
defer close(ch)
|
||||||
if err := r.Completion(c.Request.Context(), llm.CompletionRequest{
|
if err := r.Completion(c.Request.Context(), llm.CompletionRequest{
|
||||||
Prompt: prompt,
|
Prompt: prompt,
|
||||||
Images: images,
|
Images: images,
|
||||||
Format: req.Format,
|
Format: req.Format,
|
||||||
Options: opts,
|
Options: opts,
|
||||||
}, func(r llm.CompletionResponse) {
|
}, func(cr llm.CompletionResponse) {
|
||||||
ch <- api.GenerateResponse{
|
res := api.GenerateResponse{
|
||||||
Model: req.Model,
|
Model: req.Model,
|
||||||
CreatedAt: time.Now().UTC(),
|
CreatedAt: time.Now().UTC(),
|
||||||
Response: r.Content,
|
Response: cr.Content,
|
||||||
Done: r.Done,
|
Done: cr.Done,
|
||||||
DoneReason: r.DoneReason,
|
DoneReason: cr.DoneReason,
|
||||||
Metrics: api.Metrics{
|
Metrics: api.Metrics{
|
||||||
PromptEvalCount: r.PromptEvalCount,
|
PromptEvalCount: cr.PromptEvalCount,
|
||||||
PromptEvalDuration: r.PromptEvalDuration,
|
PromptEvalDuration: cr.PromptEvalDuration,
|
||||||
EvalCount: r.EvalCount,
|
EvalCount: cr.EvalCount,
|
||||||
EvalDuration: r.EvalDuration,
|
EvalDuration: cr.EvalDuration,
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if _, err := sb.WriteString(cr.Content); err != nil {
|
||||||
|
ch <- gin.H{"error": err.Error()}
|
||||||
|
}
|
||||||
|
|
||||||
|
if cr.Done {
|
||||||
|
res.TotalDuration = time.Since(checkpointStart)
|
||||||
|
res.LoadDuration = checkpointLoaded.Sub(checkpointStart)
|
||||||
|
|
||||||
|
if !req.Raw {
|
||||||
|
tokens, err := r.Tokenize(c.Request.Context(), prompt+sb.String())
|
||||||
|
if err != nil {
|
||||||
|
ch <- gin.H{"error": err.Error()}
|
||||||
|
return
|
||||||
|
}
|
||||||
|
res.Context = append(req.Context, tokens...)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
ch <- res
|
||||||
}); err != nil {
|
}); err != nil {
|
||||||
ch <- gin.H{"error": err.Error()}
|
ch <- gin.H{"error": err.Error()}
|
||||||
}
|
}
|
||||||
|
@ -1122,6 +1147,8 @@ func (s *Server) ProcessHandler(c *gin.Context) {
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *Server) ChatHandler(c *gin.Context) {
|
func (s *Server) ChatHandler(c *gin.Context) {
|
||||||
|
checkpointStart := time.Now()
|
||||||
|
|
||||||
var req api.ChatRequest
|
var req api.ChatRequest
|
||||||
if err := c.ShouldBindJSON(&req); errors.Is(err, io.EOF) {
|
if err := c.ShouldBindJSON(&req); errors.Is(err, io.EOF) {
|
||||||
c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": "missing request body"})
|
c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": "missing request body"})
|
||||||
|
@ -1141,6 +1168,8 @@ func (s *Server) ChatHandler(c *gin.Context) {
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
|
checkpointLoaded := time.Now()
|
||||||
|
|
||||||
if len(req.Messages) == 0 {
|
if len(req.Messages) == 0 {
|
||||||
c.JSON(http.StatusOK, api.ChatResponse{
|
c.JSON(http.StatusOK, api.ChatResponse{
|
||||||
Model: req.Model,
|
Model: req.Model,
|
||||||
|
@ -1169,7 +1198,7 @@ func (s *Server) ChatHandler(c *gin.Context) {
|
||||||
Format: req.Format,
|
Format: req.Format,
|
||||||
Options: opts,
|
Options: opts,
|
||||||
}, func(r llm.CompletionResponse) {
|
}, func(r llm.CompletionResponse) {
|
||||||
ch <- api.ChatResponse{
|
res := api.ChatResponse{
|
||||||
Model: req.Model,
|
Model: req.Model,
|
||||||
CreatedAt: time.Now().UTC(),
|
CreatedAt: time.Now().UTC(),
|
||||||
Message: api.Message{Role: "assistant", Content: r.Content},
|
Message: api.Message{Role: "assistant", Content: r.Content},
|
||||||
|
@ -1182,6 +1211,13 @@ func (s *Server) ChatHandler(c *gin.Context) {
|
||||||
EvalDuration: r.EvalDuration,
|
EvalDuration: r.EvalDuration,
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if r.Done {
|
||||||
|
res.TotalDuration = time.Since(checkpointStart)
|
||||||
|
res.LoadDuration = checkpointLoaded.Sub(checkpointStart)
|
||||||
|
}
|
||||||
|
|
||||||
|
ch <- res
|
||||||
}); err != nil {
|
}); err != nil {
|
||||||
ch <- gin.H{"error": err.Error()}
|
ch <- gin.H{"error": err.Error()}
|
||||||
}
|
}
|
||||||
|
|
Loading…
Reference in a new issue