add done_reason to the api (#4235)

This commit is contained in:
Bruce MacDonald 2024-05-09 13:30:14 -07:00 committed by GitHub
parent 1580ed4c06
commit cfa84b8470
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
4 changed files with 44 additions and 40 deletions

View file

@ -114,9 +114,10 @@ type Message struct {
// ChatResponse is the response returned by [Client.Chat]. Its fields are
// similar to [GenerateResponse].
type ChatResponse struct {
Model string `json:"model"`
CreatedAt time.Time `json:"created_at"`
Message Message `json:"message"`
Model string `json:"model"`
CreatedAt time.Time `json:"created_at"`
Message Message `json:"message"`
DoneReason string `json:"done_reason"`
Done bool `json:"done"`
@ -309,6 +310,9 @@ type GenerateResponse struct {
// Done specifies if the response is complete.
Done bool `json:"done"`
// DoneReason is the reason the model stopped generating text.
DoneReason string `json:"done_reason"`
// Context is an encoding of the conversation used in this response; this
// can be sent in the next request to keep a conversational memory.
Context []int `json:"context,omitempty"`

View file

@ -576,10 +576,11 @@ type ImageData struct {
}
type completion struct {
Content string `json:"content"`
Model string `json:"model"`
Prompt string `json:"prompt"`
Stop bool `json:"stop"`
Content string `json:"content"`
Model string `json:"model"`
Prompt string `json:"prompt"`
Stop bool `json:"stop"`
StoppedLimit bool `json:"stopped_limit"`
Timings struct {
PredictedN int `json:"predicted_n"`
@ -598,6 +599,7 @@ type CompletionRequest struct {
type CompletionResponse struct {
Content string
DoneReason string
Done bool
PromptEvalCount int
PromptEvalDuration time.Duration
@ -739,8 +741,14 @@ func (s *llmServer) Completion(ctx context.Context, req CompletionRequest, fn fu
}
if c.Stop {
doneReason := "stop"
if c.StoppedLimit {
doneReason = "length"
}
fn(CompletionResponse{
Done: true,
DoneReason: doneReason,
PromptEvalCount: c.Timings.PromptN,
PromptEvalDuration: parseDurationMs(c.Timings.PromptMS),
EvalCount: c.Timings.PredictedN,

View file

@ -107,15 +107,9 @@ func toChatCompletion(id string, r api.ChatResponse) ChatCompletion {
Model: r.Model,
SystemFingerprint: "fp_ollama",
Choices: []Choice{{
Index: 0,
Message: Message{Role: r.Message.Role, Content: r.Message.Content},
FinishReason: func(done bool) *string {
if done {
reason := "stop"
return &reason
}
return nil
}(r.Done),
Index: 0,
Message: Message{Role: r.Message.Role, Content: r.Message.Content},
FinishReason: &r.DoneReason,
}},
Usage: Usage{
// TODO: ollama returns 0 for prompt eval if the prompt was cached, but openai returns the actual count
@ -135,15 +129,9 @@ func toChunk(id string, r api.ChatResponse) ChatCompletionChunk {
SystemFingerprint: "fp_ollama",
Choices: []ChunkChoice{
{
Index: 0,
Delta: Message{Role: "assistant", Content: r.Message.Content},
FinishReason: func(done bool) *string {
if done {
reason := "stop"
return &reason
}
return nil
}(r.Done),
Index: 0,
Delta: Message{Role: "assistant", Content: r.Message.Content},
FinishReason: &r.DoneReason,
},
},
}

View file

@ -152,9 +152,10 @@ func (s *Server) GenerateHandler(c *gin.Context) {
// of `raw` mode so we need to check for it too
if req.Prompt == "" && req.Template == "" && req.System == "" {
c.JSON(http.StatusOK, api.GenerateResponse{
CreatedAt: time.Now().UTC(),
Model: req.Model,
Done: true,
CreatedAt: time.Now().UTC(),
Model: req.Model,
Done: true,
DoneReason: "load",
})
return
}
@ -222,10 +223,11 @@ func (s *Server) GenerateHandler(c *gin.Context) {
}
resp := api.GenerateResponse{
Model: req.Model,
CreatedAt: time.Now().UTC(),
Done: r.Done,
Response: r.Content,
Model: req.Model,
CreatedAt: time.Now().UTC(),
Done: r.Done,
Response: r.Content,
DoneReason: r.DoneReason,
Metrics: api.Metrics{
PromptEvalCount: r.PromptEvalCount,
PromptEvalDuration: r.PromptEvalDuration,
@ -1215,10 +1217,11 @@ func (s *Server) ChatHandler(c *gin.Context) {
// an empty request loads the model
if len(req.Messages) == 0 || prompt == "" {
resp := api.ChatResponse{
CreatedAt: time.Now().UTC(),
Model: req.Model,
Done: true,
Message: api.Message{Role: "assistant"},
CreatedAt: time.Now().UTC(),
Model: req.Model,
Done: true,
DoneReason: "load",
Message: api.Message{Role: "assistant"},
}
c.JSON(http.StatusOK, resp)
return
@ -1251,10 +1254,11 @@ func (s *Server) ChatHandler(c *gin.Context) {
fn := func(r llm.CompletionResponse) {
resp := api.ChatResponse{
Model: req.Model,
CreatedAt: time.Now().UTC(),
Message: api.Message{Role: "assistant", Content: r.Content},
Done: r.Done,
Model: req.Model,
CreatedAt: time.Now().UTC(),
Message: api.Message{Role: "assistant", Content: r.Content},
Done: r.Done,
DoneReason: r.DoneReason,
Metrics: api.Metrics{
PromptEvalCount: r.PromptEvalCount,
PromptEvalDuration: r.PromptEvalDuration,