add done_reason to the api (#4235)

This commit is contained in:
Bruce MacDonald 2024-05-09 13:30:14 -07:00 committed by GitHub
parent 1580ed4c06
commit cfa84b8470
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
4 changed files with 44 additions and 40 deletions

View file

@ -117,6 +117,7 @@ type ChatResponse struct {
Model string `json:"model"` Model string `json:"model"`
CreatedAt time.Time `json:"created_at"` CreatedAt time.Time `json:"created_at"`
Message Message `json:"message"` Message Message `json:"message"`
DoneReason string `json:"done_reason"`
Done bool `json:"done"` Done bool `json:"done"`
@ -309,6 +310,9 @@ type GenerateResponse struct {
// Done specifies if the response is complete. // Done specifies if the response is complete.
Done bool `json:"done"` Done bool `json:"done"`
// DoneReason is the reason the model stopped generating text.
DoneReason string `json:"done_reason"`
// Context is an encoding of the conversation used in this response; this // Context is an encoding of the conversation used in this response; this
// can be sent in the next request to keep a conversational memory. // can be sent in the next request to keep a conversational memory.
Context []int `json:"context,omitempty"` Context []int `json:"context,omitempty"`

View file

@ -580,6 +580,7 @@ type completion struct {
Model string `json:"model"` Model string `json:"model"`
Prompt string `json:"prompt"` Prompt string `json:"prompt"`
Stop bool `json:"stop"` Stop bool `json:"stop"`
StoppedLimit bool `json:"stopped_limit"`
Timings struct { Timings struct {
PredictedN int `json:"predicted_n"` PredictedN int `json:"predicted_n"`
@ -598,6 +599,7 @@ type CompletionRequest struct {
type CompletionResponse struct { type CompletionResponse struct {
Content string Content string
DoneReason string
Done bool Done bool
PromptEvalCount int PromptEvalCount int
PromptEvalDuration time.Duration PromptEvalDuration time.Duration
@ -739,8 +741,14 @@ func (s *llmServer) Completion(ctx context.Context, req CompletionRequest, fn fu
} }
if c.Stop { if c.Stop {
doneReason := "stop"
if c.StoppedLimit {
doneReason = "length"
}
fn(CompletionResponse{ fn(CompletionResponse{
Done: true, Done: true,
DoneReason: doneReason,
PromptEvalCount: c.Timings.PromptN, PromptEvalCount: c.Timings.PromptN,
PromptEvalDuration: parseDurationMs(c.Timings.PromptMS), PromptEvalDuration: parseDurationMs(c.Timings.PromptMS),
EvalCount: c.Timings.PredictedN, EvalCount: c.Timings.PredictedN,

View file

@ -109,13 +109,7 @@ func toChatCompletion(id string, r api.ChatResponse) ChatCompletion {
Choices: []Choice{{ Choices: []Choice{{
Index: 0, Index: 0,
Message: Message{Role: r.Message.Role, Content: r.Message.Content}, Message: Message{Role: r.Message.Role, Content: r.Message.Content},
FinishReason: func(done bool) *string { FinishReason: &r.DoneReason,
if done {
reason := "stop"
return &reason
}
return nil
}(r.Done),
}}, }},
Usage: Usage{ Usage: Usage{
// TODO: ollama returns 0 for prompt eval if the prompt was cached, but openai returns the actual count // TODO: ollama returns 0 for prompt eval if the prompt was cached, but openai returns the actual count
@ -137,13 +131,7 @@ func toChunk(id string, r api.ChatResponse) ChatCompletionChunk {
{ {
Index: 0, Index: 0,
Delta: Message{Role: "assistant", Content: r.Message.Content}, Delta: Message{Role: "assistant", Content: r.Message.Content},
FinishReason: func(done bool) *string { FinishReason: &r.DoneReason,
if done {
reason := "stop"
return &reason
}
return nil
}(r.Done),
}, },
}, },
} }

View file

@ -155,6 +155,7 @@ func (s *Server) GenerateHandler(c *gin.Context) {
CreatedAt: time.Now().UTC(), CreatedAt: time.Now().UTC(),
Model: req.Model, Model: req.Model,
Done: true, Done: true,
DoneReason: "load",
}) })
return return
} }
@ -226,6 +227,7 @@ func (s *Server) GenerateHandler(c *gin.Context) {
CreatedAt: time.Now().UTC(), CreatedAt: time.Now().UTC(),
Done: r.Done, Done: r.Done,
Response: r.Content, Response: r.Content,
DoneReason: r.DoneReason,
Metrics: api.Metrics{ Metrics: api.Metrics{
PromptEvalCount: r.PromptEvalCount, PromptEvalCount: r.PromptEvalCount,
PromptEvalDuration: r.PromptEvalDuration, PromptEvalDuration: r.PromptEvalDuration,
@ -1218,6 +1220,7 @@ func (s *Server) ChatHandler(c *gin.Context) {
CreatedAt: time.Now().UTC(), CreatedAt: time.Now().UTC(),
Model: req.Model, Model: req.Model,
Done: true, Done: true,
DoneReason: "load",
Message: api.Message{Role: "assistant"}, Message: api.Message{Role: "assistant"},
} }
c.JSON(http.StatusOK, resp) c.JSON(http.StatusOK, resp)
@ -1255,6 +1258,7 @@ func (s *Server) ChatHandler(c *gin.Context) {
CreatedAt: time.Now().UTC(), CreatedAt: time.Now().UTC(),
Message: api.Message{Role: "assistant", Content: r.Content}, Message: api.Message{Role: "assistant", Content: r.Content},
Done: r.Done, Done: r.Done,
DoneReason: r.DoneReason,
Metrics: api.Metrics{ Metrics: api.Metrics{
PromptEvalCount: r.PromptEvalCount, PromptEvalCount: r.PromptEvalCount,
PromptEvalDuration: r.PromptEvalDuration, PromptEvalDuration: r.PromptEvalDuration,