restore model load duration on generate response (#1524)
* restore model load duration on generate response - set model load duration on generate and chat done response - calculate createAt time when response created * remove checkpoints predict opts * Update routes.go
This commit is contained in:
parent
31f0551dab
commit
6ee8c80199
2 changed files with 27 additions and 36 deletions
17
llm/llama.go
17
llm/llama.go
|
@ -548,17 +548,12 @@ const maxBufferSize = 512 * format.KiloByte
|
||||||
const maxRetries = 6
|
const maxRetries = 6
|
||||||
|
|
||||||
type PredictOpts struct {
|
type PredictOpts struct {
|
||||||
Prompt string
|
Prompt string
|
||||||
Format string
|
Format string
|
||||||
Images []api.ImageData
|
Images []api.ImageData
|
||||||
CheckpointStart time.Time
|
|
||||||
CheckpointLoaded time.Time
|
|
||||||
}
|
}
|
||||||
|
|
||||||
type PredictResult struct {
|
type PredictResult struct {
|
||||||
CreatedAt time.Time
|
|
||||||
TotalDuration time.Duration
|
|
||||||
LoadDuration time.Duration
|
|
||||||
Content string
|
Content string
|
||||||
Done bool
|
Done bool
|
||||||
PromptEvalCount int
|
PromptEvalCount int
|
||||||
|
@ -681,16 +676,12 @@ func (llm *llama) Predict(ctx context.Context, predict PredictOpts, fn func(Pred
|
||||||
|
|
||||||
if p.Content != "" {
|
if p.Content != "" {
|
||||||
fn(PredictResult{
|
fn(PredictResult{
|
||||||
CreatedAt: time.Now().UTC(),
|
Content: p.Content,
|
||||||
Content: p.Content,
|
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
if p.Stop {
|
if p.Stop {
|
||||||
fn(PredictResult{
|
fn(PredictResult{
|
||||||
CreatedAt: time.Now().UTC(),
|
|
||||||
TotalDuration: time.Since(predict.CheckpointStart),
|
|
||||||
|
|
||||||
Done: true,
|
Done: true,
|
||||||
PromptEvalCount: p.Timings.PromptN,
|
PromptEvalCount: p.Timings.PromptN,
|
||||||
PromptEvalDuration: parseDurationMs(p.Timings.PromptMS),
|
PromptEvalDuration: parseDurationMs(p.Timings.PromptMS),
|
||||||
|
|
|
@ -261,12 +261,10 @@ func GenerateHandler(c *gin.Context) {
|
||||||
|
|
||||||
resp := api.GenerateResponse{
|
resp := api.GenerateResponse{
|
||||||
Model: req.Model,
|
Model: req.Model,
|
||||||
CreatedAt: r.CreatedAt,
|
CreatedAt: time.Now().UTC(),
|
||||||
Done: r.Done,
|
Done: r.Done,
|
||||||
Response: r.Content,
|
Response: r.Content,
|
||||||
Metrics: api.Metrics{
|
Metrics: api.Metrics{
|
||||||
TotalDuration: r.TotalDuration,
|
|
||||||
LoadDuration: r.LoadDuration,
|
|
||||||
PromptEvalCount: r.PromptEvalCount,
|
PromptEvalCount: r.PromptEvalCount,
|
||||||
PromptEvalDuration: r.PromptEvalDuration,
|
PromptEvalDuration: r.PromptEvalDuration,
|
||||||
EvalCount: r.EvalCount,
|
EvalCount: r.EvalCount,
|
||||||
|
@ -274,13 +272,18 @@ func GenerateHandler(c *gin.Context) {
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
|
|
||||||
if r.Done && !req.Raw {
|
if r.Done {
|
||||||
embd, err := loaded.runner.Encode(c.Request.Context(), prompt+generated.String())
|
resp.TotalDuration = time.Since(checkpointStart)
|
||||||
if err != nil {
|
resp.LoadDuration = checkpointLoaded.Sub(checkpointStart)
|
||||||
ch <- gin.H{"error": err.Error()}
|
|
||||||
return
|
if !req.Raw {
|
||||||
|
embd, err := loaded.runner.Encode(c.Request.Context(), prompt+generated.String())
|
||||||
|
if err != nil {
|
||||||
|
ch <- gin.H{"error": err.Error()}
|
||||||
|
return
|
||||||
|
}
|
||||||
|
resp.Context = embd
|
||||||
}
|
}
|
||||||
resp.Context = embd
|
|
||||||
}
|
}
|
||||||
|
|
||||||
ch <- resp
|
ch <- resp
|
||||||
|
@ -288,11 +291,9 @@ func GenerateHandler(c *gin.Context) {
|
||||||
|
|
||||||
// Start prediction
|
// Start prediction
|
||||||
predictReq := llm.PredictOpts{
|
predictReq := llm.PredictOpts{
|
||||||
Prompt: prompt,
|
Prompt: prompt,
|
||||||
Format: req.Format,
|
Format: req.Format,
|
||||||
CheckpointStart: checkpointStart,
|
Images: req.Images,
|
||||||
CheckpointLoaded: checkpointLoaded,
|
|
||||||
Images: req.Images,
|
|
||||||
}
|
}
|
||||||
if err := loaded.runner.Predict(c.Request.Context(), predictReq, fn); err != nil {
|
if err := loaded.runner.Predict(c.Request.Context(), predictReq, fn); err != nil {
|
||||||
ch <- gin.H{"error": err.Error()}
|
ch <- gin.H{"error": err.Error()}
|
||||||
|
@ -1012,11 +1013,9 @@ func ChatHandler(c *gin.Context) {
|
||||||
|
|
||||||
resp := api.ChatResponse{
|
resp := api.ChatResponse{
|
||||||
Model: req.Model,
|
Model: req.Model,
|
||||||
CreatedAt: r.CreatedAt,
|
CreatedAt: time.Now().UTC(),
|
||||||
Done: r.Done,
|
Done: r.Done,
|
||||||
Metrics: api.Metrics{
|
Metrics: api.Metrics{
|
||||||
TotalDuration: r.TotalDuration,
|
|
||||||
LoadDuration: r.LoadDuration,
|
|
||||||
PromptEvalCount: r.PromptEvalCount,
|
PromptEvalCount: r.PromptEvalCount,
|
||||||
PromptEvalDuration: r.PromptEvalDuration,
|
PromptEvalDuration: r.PromptEvalDuration,
|
||||||
EvalCount: r.EvalCount,
|
EvalCount: r.EvalCount,
|
||||||
|
@ -1024,7 +1023,10 @@ func ChatHandler(c *gin.Context) {
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
|
|
||||||
if !r.Done {
|
if r.Done {
|
||||||
|
resp.TotalDuration = time.Since(checkpointStart)
|
||||||
|
resp.LoadDuration = checkpointLoaded.Sub(checkpointStart)
|
||||||
|
} else {
|
||||||
resp.Message = &api.Message{Role: "assistant", Content: r.Content}
|
resp.Message = &api.Message{Role: "assistant", Content: r.Content}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -1033,11 +1035,9 @@ func ChatHandler(c *gin.Context) {
|
||||||
|
|
||||||
// Start prediction
|
// Start prediction
|
||||||
predictReq := llm.PredictOpts{
|
predictReq := llm.PredictOpts{
|
||||||
Prompt: prompt,
|
Prompt: prompt,
|
||||||
Format: req.Format,
|
Format: req.Format,
|
||||||
CheckpointStart: checkpointStart,
|
Images: images,
|
||||||
CheckpointLoaded: checkpointLoaded,
|
|
||||||
Images: images,
|
|
||||||
}
|
}
|
||||||
if err := loaded.runner.Predict(c.Request.Context(), predictReq, fn); err != nil {
|
if err := loaded.runner.Predict(c.Request.Context(), predictReq, fn); err != nil {
|
||||||
ch <- gin.H{"error": err.Error()}
|
ch <- gin.H{"error": err.Error()}
|
||||||
|
|
Loading…
Add table
Reference in a new issue