restore model load duration on generate response (#1524)

* restore model load duration on generate response

- set model load duration on generate and chat done response
- calculate createAt time when response created

* remove checkpoints predict opts

* Update routes.go
This commit is contained in:
Bruce MacDonald 2023-12-14 12:15:50 -05:00 committed by GitHub
parent 31f0551dab
commit 6ee8c80199
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
2 changed files with 27 additions and 36 deletions

View file

@ -548,17 +548,12 @@ const maxBufferSize = 512 * format.KiloByte
const maxRetries = 6
type PredictOpts struct {
Prompt string
Format string
Images []api.ImageData
CheckpointStart time.Time
CheckpointLoaded time.Time
Prompt string
Format string
Images []api.ImageData
}
type PredictResult struct {
CreatedAt time.Time
TotalDuration time.Duration
LoadDuration time.Duration
Content string
Done bool
PromptEvalCount int
@ -681,16 +676,12 @@ func (llm *llama) Predict(ctx context.Context, predict PredictOpts, fn func(Pred
if p.Content != "" {
fn(PredictResult{
CreatedAt: time.Now().UTC(),
Content: p.Content,
Content: p.Content,
})
}
if p.Stop {
fn(PredictResult{
CreatedAt: time.Now().UTC(),
TotalDuration: time.Since(predict.CheckpointStart),
Done: true,
PromptEvalCount: p.Timings.PromptN,
PromptEvalDuration: parseDurationMs(p.Timings.PromptMS),

View file

@ -261,12 +261,10 @@ func GenerateHandler(c *gin.Context) {
resp := api.GenerateResponse{
Model: req.Model,
CreatedAt: r.CreatedAt,
CreatedAt: time.Now().UTC(),
Done: r.Done,
Response: r.Content,
Metrics: api.Metrics{
TotalDuration: r.TotalDuration,
LoadDuration: r.LoadDuration,
PromptEvalCount: r.PromptEvalCount,
PromptEvalDuration: r.PromptEvalDuration,
EvalCount: r.EvalCount,
@ -274,13 +272,18 @@ func GenerateHandler(c *gin.Context) {
},
}
if r.Done && !req.Raw {
embd, err := loaded.runner.Encode(c.Request.Context(), prompt+generated.String())
if err != nil {
ch <- gin.H{"error": err.Error()}
return
if r.Done {
resp.TotalDuration = time.Since(checkpointStart)
resp.LoadDuration = checkpointLoaded.Sub(checkpointStart)
if !req.Raw {
embd, err := loaded.runner.Encode(c.Request.Context(), prompt+generated.String())
if err != nil {
ch <- gin.H{"error": err.Error()}
return
}
resp.Context = embd
}
resp.Context = embd
}
ch <- resp
@ -288,11 +291,9 @@ func GenerateHandler(c *gin.Context) {
// Start prediction
predictReq := llm.PredictOpts{
Prompt: prompt,
Format: req.Format,
CheckpointStart: checkpointStart,
CheckpointLoaded: checkpointLoaded,
Images: req.Images,
Prompt: prompt,
Format: req.Format,
Images: req.Images,
}
if err := loaded.runner.Predict(c.Request.Context(), predictReq, fn); err != nil {
ch <- gin.H{"error": err.Error()}
@ -1012,11 +1013,9 @@ func ChatHandler(c *gin.Context) {
resp := api.ChatResponse{
Model: req.Model,
CreatedAt: r.CreatedAt,
CreatedAt: time.Now().UTC(),
Done: r.Done,
Metrics: api.Metrics{
TotalDuration: r.TotalDuration,
LoadDuration: r.LoadDuration,
PromptEvalCount: r.PromptEvalCount,
PromptEvalDuration: r.PromptEvalDuration,
EvalCount: r.EvalCount,
@ -1024,7 +1023,10 @@ func ChatHandler(c *gin.Context) {
},
}
if !r.Done {
if r.Done {
resp.TotalDuration = time.Since(checkpointStart)
resp.LoadDuration = checkpointLoaded.Sub(checkpointStart)
} else {
resp.Message = &api.Message{Role: "assistant", Content: r.Content}
}
@ -1033,11 +1035,9 @@ func ChatHandler(c *gin.Context) {
// Start prediction
predictReq := llm.PredictOpts{
Prompt: prompt,
Format: req.Format,
CheckpointStart: checkpointStart,
CheckpointLoaded: checkpointLoaded,
Images: images,
Prompt: prompt,
Format: req.Format,
Images: images,
}
if err := loaded.runner.Predict(c.Request.Context(), predictReq, fn); err != nil {
ch <- gin.H{"error": err.Error()}