diff --git a/llm/llama.go b/llm/llama.go index 72e67389..adaa4c57 100644 --- a/llm/llama.go +++ b/llm/llama.go @@ -548,17 +548,12 @@ const maxBufferSize = 512 * format.KiloByte const maxRetries = 6 type PredictOpts struct { - Prompt string - Format string - Images []api.ImageData - CheckpointStart time.Time - CheckpointLoaded time.Time + Prompt string + Format string + Images []api.ImageData } type PredictResult struct { - CreatedAt time.Time - TotalDuration time.Duration - LoadDuration time.Duration Content string Done bool PromptEvalCount int @@ -681,16 +676,12 @@ func (llm *llama) Predict(ctx context.Context, predict PredictOpts, fn func(Pred if p.Content != "" { fn(PredictResult{ - CreatedAt: time.Now().UTC(), - Content: p.Content, + Content: p.Content, }) } if p.Stop { fn(PredictResult{ - CreatedAt: time.Now().UTC(), - TotalDuration: time.Since(predict.CheckpointStart), - Done: true, PromptEvalCount: p.Timings.PromptN, PromptEvalDuration: parseDurationMs(p.Timings.PromptMS), diff --git a/server/routes.go b/server/routes.go index 6df7d2e4..71c27b89 100644 --- a/server/routes.go +++ b/server/routes.go @@ -261,12 +261,10 @@ func GenerateHandler(c *gin.Context) { resp := api.GenerateResponse{ Model: req.Model, - CreatedAt: r.CreatedAt, + CreatedAt: time.Now().UTC(), Done: r.Done, Response: r.Content, Metrics: api.Metrics{ - TotalDuration: r.TotalDuration, - LoadDuration: r.LoadDuration, PromptEvalCount: r.PromptEvalCount, PromptEvalDuration: r.PromptEvalDuration, EvalCount: r.EvalCount, @@ -274,13 +272,18 @@ func GenerateHandler(c *gin.Context) { }, } - if r.Done && !req.Raw { - embd, err := loaded.runner.Encode(c.Request.Context(), prompt+generated.String()) - if err != nil { - ch <- gin.H{"error": err.Error()} - return + if r.Done { + resp.TotalDuration = time.Since(checkpointStart) + resp.LoadDuration = checkpointLoaded.Sub(checkpointStart) + + if !req.Raw { + embd, err := loaded.runner.Encode(c.Request.Context(), prompt+generated.String()) + if err != nil { + ch <- gin.H{"error": err.Error()} + return + } + resp.Context = embd } - resp.Context = embd } ch <- resp @@ -288,11 +291,9 @@ func GenerateHandler(c *gin.Context) { // Start prediction predictReq := llm.PredictOpts{ - Prompt: prompt, - Format: req.Format, - CheckpointStart: checkpointStart, - CheckpointLoaded: checkpointLoaded, - Images: req.Images, + Prompt: prompt, + Format: req.Format, + Images: req.Images, } if err := loaded.runner.Predict(c.Request.Context(), predictReq, fn); err != nil { ch <- gin.H{"error": err.Error()} @@ -1012,11 +1013,9 @@ func ChatHandler(c *gin.Context) { resp := api.ChatResponse{ Model: req.Model, - CreatedAt: r.CreatedAt, + CreatedAt: time.Now().UTC(), Done: r.Done, Metrics: api.Metrics{ - TotalDuration: r.TotalDuration, - LoadDuration: r.LoadDuration, PromptEvalCount: r.PromptEvalCount, PromptEvalDuration: r.PromptEvalDuration, EvalCount: r.EvalCount, @@ -1024,7 +1023,10 @@ func ChatHandler(c *gin.Context) { }, } - if !r.Done { + if r.Done { + resp.TotalDuration = time.Since(checkpointStart) + resp.LoadDuration = checkpointLoaded.Sub(checkpointStart) + } else { resp.Message = &api.Message{Role: "assistant", Content: r.Content} } @@ -1033,11 +1035,9 @@ func ChatHandler(c *gin.Context) { // Start prediction predictReq := llm.PredictOpts{ - Prompt: prompt, - Format: req.Format, - CheckpointStart: checkpointStart, - CheckpointLoaded: checkpointLoaded, - Images: images, + Prompt: prompt, + Format: req.Format, + Images: images, } if err := loaded.runner.Predict(c.Request.Context(), predictReq, fn); err != nil { ch <- gin.H{"error": err.Error()}