send empty messages on last chat response (#1530)

This commit is contained in:
Bruce MacDonald 2023-12-18 14:23:38 -05:00 committed by GitHub
parent 3948c6ea06
commit d99fa6ce0a
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
2 changed files with 6 additions and 10 deletions

View file

@ -59,13 +59,13 @@ type ChatRequest struct {
type Message struct { type Message struct {
Role string `json:"role"` // one of ["system", "user", "assistant"] Role string `json:"role"` // one of ["system", "user", "assistant"]
Content string `json:"content"` Content string `json:"content"`
Images []ImageData `json:"images, omitempty"` Images []ImageData `json:"images,omitempty"`
} }
type ChatResponse struct { type ChatResponse struct {
Model string `json:"model"` Model string `json:"model"`
CreatedAt time.Time `json:"created_at"` CreatedAt time.Time `json:"created_at"`
Message *Message `json:"message,omitempty"` Message Message `json:"message"`
Done bool `json:"done"` Done bool `json:"done"`

View file

@ -1013,7 +1013,7 @@ func ChatHandler(c *gin.Context) {
// an empty request loads the model // an empty request loads the model
if len(req.Messages) == 0 { if len(req.Messages) == 0 {
c.JSON(http.StatusOK, api.ChatResponse{CreatedAt: time.Now().UTC(), Model: req.Model, Done: true}) c.JSON(http.StatusOK, api.ChatResponse{CreatedAt: time.Now().UTC(), Model: req.Model, Done: true, Message: api.Message{Role: "assistant"}})
return return
} }
@ -1038,6 +1038,7 @@ func ChatHandler(c *gin.Context) {
resp := api.ChatResponse{ resp := api.ChatResponse{
Model: req.Model, Model: req.Model,
CreatedAt: time.Now().UTC(), CreatedAt: time.Now().UTC(),
Message: api.Message{Role: "assistant", Content: r.Content},
Done: r.Done, Done: r.Done,
Metrics: api.Metrics{ Metrics: api.Metrics{
PromptEvalCount: r.PromptEvalCount, PromptEvalCount: r.PromptEvalCount,
@ -1050,8 +1051,6 @@ func ChatHandler(c *gin.Context) {
if r.Done { if r.Done {
resp.TotalDuration = time.Since(checkpointStart) resp.TotalDuration = time.Since(checkpointStart)
resp.LoadDuration = checkpointLoaded.Sub(checkpointStart) resp.LoadDuration = checkpointLoaded.Sub(checkpointStart)
} else {
resp.Message = &api.Message{Role: "assistant", Content: r.Content}
} }
ch <- resp ch <- resp
@ -1075,10 +1074,7 @@ func ChatHandler(c *gin.Context) {
for resp := range ch { for resp := range ch {
switch r := resp.(type) { switch r := resp.(type) {
case api.ChatResponse: case api.ChatResponse:
if r.Message != nil {
sb.WriteString(r.Message.Content) sb.WriteString(r.Message.Content)
}
final = r final = r
case gin.H: case gin.H:
if errorMsg, ok := r["error"].(string); ok { if errorMsg, ok := r["error"].(string); ok {
@ -1094,7 +1090,7 @@ func ChatHandler(c *gin.Context) {
} }
} }
final.Message = &api.Message{Role: "assistant", Content: sb.String()} final.Message = api.Message{Role: "assistant", Content: sb.String()}
c.JSON(http.StatusOK, final) c.JSON(http.StatusOK, final)
return return
} }