add image support to the chat api (#1490)

This commit is contained in:
Patrick Devine 2023-12-12 13:28:58 -08:00 committed by GitHub
parent 4251b342de
commit d9e60f634b
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
3 changed files with 14 additions and 10 deletions

View file

@ -59,6 +59,7 @@ type ChatRequest struct {
type Message struct {
Role string `json:"role"` // one of ["system", "user", "assistant"]
Content string `json:"content"`
Images []ImageData `json:"images, omitempty"`
}
type ChatResponse struct {

View file

@ -86,9 +86,10 @@ func (m *Model) Prompt(p PromptVars) (string, error) {
return prompt.String(), nil
}
func (m *Model) ChatPrompt(msgs []api.Message) (string, error) {
func (m *Model) ChatPrompt(msgs []api.Message) (string, []api.ImageData, error) {
// build the prompt from the list of messages
var prompt strings.Builder
var currentImages []api.ImageData
currentVars := PromptVars{
First: true,
}
@ -108,35 +109,36 @@ func (m *Model) ChatPrompt(msgs []api.Message) (string, error) {
case "system":
if currentVars.System != "" {
if err := writePrompt(); err != nil {
return "", err
return "", nil, err
}
}
currentVars.System = msg.Content
case "user":
if currentVars.Prompt != "" {
if err := writePrompt(); err != nil {
return "", err
return "", nil, err
}
}
currentVars.Prompt = msg.Content
currentImages = msg.Images
case "assistant":
currentVars.Response = msg.Content
if err := writePrompt(); err != nil {
return "", err
return "", nil, err
}
default:
return "", fmt.Errorf("invalid role: %s, role must be one of [system, user, assistant]", msg.Role)
return "", nil, fmt.Errorf("invalid role: %s, role must be one of [system, user, assistant]", msg.Role)
}
}
// Append the last set of vars if they are non-empty
if currentVars.Prompt != "" || currentVars.System != "" {
if err := writePrompt(); err != nil {
return "", err
return "", nil, err
}
}
return prompt.String(), nil
return prompt.String(), currentImages, nil
}
type ManifestV2 struct {

View file

@ -994,7 +994,7 @@ func ChatHandler(c *gin.Context) {
checkpointLoaded := time.Now()
prompt, err := model.ChatPrompt(req.Messages)
prompt, images, err := model.ChatPrompt(req.Messages)
if err != nil {
c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()})
return
@ -1037,6 +1037,7 @@ func ChatHandler(c *gin.Context) {
Format: req.Format,
CheckpointStart: checkpointStart,
CheckpointLoaded: checkpointLoaded,
Images: images,
}
if err := loaded.runner.Predict(c.Request.Context(), predictReq, fn); err != nil {
ch <- gin.H{"error": err.Error()}