add image support to the chat api (#1490)

This commit is contained in:
Patrick Devine 2023-12-12 13:28:58 -08:00 committed by GitHub
parent 4251b342de
commit d9e60f634b
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
3 changed files with 14 additions and 10 deletions

View file

@ -59,6 +59,7 @@ type ChatRequest struct {
type Message struct { type Message struct {
Role string `json:"role"` // one of ["system", "user", "assistant"] Role string `json:"role"` // one of ["system", "user", "assistant"]
Content string `json:"content"` Content string `json:"content"`
Images []ImageData `json:"images, omitempty"`
} }
type ChatResponse struct { type ChatResponse struct {

View file

@ -86,9 +86,10 @@ func (m *Model) Prompt(p PromptVars) (string, error) {
return prompt.String(), nil return prompt.String(), nil
} }
func (m *Model) ChatPrompt(msgs []api.Message) (string, error) { func (m *Model) ChatPrompt(msgs []api.Message) (string, []api.ImageData, error) {
// build the prompt from the list of messages // build the prompt from the list of messages
var prompt strings.Builder var prompt strings.Builder
var currentImages []api.ImageData
currentVars := PromptVars{ currentVars := PromptVars{
First: true, First: true,
} }
@ -108,35 +109,36 @@ func (m *Model) ChatPrompt(msgs []api.Message) (string, error) {
case "system": case "system":
if currentVars.System != "" { if currentVars.System != "" {
if err := writePrompt(); err != nil { if err := writePrompt(); err != nil {
return "", err return "", nil, err
} }
} }
currentVars.System = msg.Content currentVars.System = msg.Content
case "user": case "user":
if currentVars.Prompt != "" { if currentVars.Prompt != "" {
if err := writePrompt(); err != nil { if err := writePrompt(); err != nil {
return "", err return "", nil, err
} }
} }
currentVars.Prompt = msg.Content currentVars.Prompt = msg.Content
currentImages = msg.Images
case "assistant": case "assistant":
currentVars.Response = msg.Content currentVars.Response = msg.Content
if err := writePrompt(); err != nil { if err := writePrompt(); err != nil {
return "", err return "", nil, err
} }
default: default:
return "", fmt.Errorf("invalid role: %s, role must be one of [system, user, assistant]", msg.Role) return "", nil, fmt.Errorf("invalid role: %s, role must be one of [system, user, assistant]", msg.Role)
} }
} }
// Append the last set of vars if they are non-empty // Append the last set of vars if they are non-empty
if currentVars.Prompt != "" || currentVars.System != "" { if currentVars.Prompt != "" || currentVars.System != "" {
if err := writePrompt(); err != nil { if err := writePrompt(); err != nil {
return "", err return "", nil, err
} }
} }
return prompt.String(), nil return prompt.String(), currentImages, nil
} }
type ManifestV2 struct { type ManifestV2 struct {

View file

@ -994,7 +994,7 @@ func ChatHandler(c *gin.Context) {
checkpointLoaded := time.Now() checkpointLoaded := time.Now()
prompt, err := model.ChatPrompt(req.Messages) prompt, images, err := model.ChatPrompt(req.Messages)
if err != nil { if err != nil {
c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()}) c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()})
return return
@ -1037,6 +1037,7 @@ func ChatHandler(c *gin.Context) {
Format: req.Format, Format: req.Format,
CheckpointStart: checkpointStart, CheckpointStart: checkpointStart,
CheckpointLoaded: checkpointLoaded, CheckpointLoaded: checkpointLoaded,
Images: images,
} }
if err := loaded.runner.Predict(c.Request.Context(), predictReq, fn); err != nil { if err := loaded.runner.Predict(c.Request.Context(), predictReq, fn); err != nil {
ch <- gin.H{"error": err.Error()} ch <- gin.H{"error": err.Error()}