add image support to the chat api (#1490)
This commit is contained in:
parent
4251b342de
commit
d9e60f634b
3 changed files with 14 additions and 10 deletions
|
@ -57,8 +57,9 @@ type ChatRequest struct {
|
|||
}
|
||||
|
||||
type Message struct {
|
||||
Role string `json:"role"` // one of ["system", "user", "assistant"]
|
||||
Content string `json:"content"`
|
||||
Role string `json:"role"` // one of ["system", "user", "assistant"]
|
||||
Content string `json:"content"`
|
||||
Images []ImageData `json:"images, omitempty"`
|
||||
}
|
||||
|
||||
type ChatResponse struct {
|
||||
|
|
|
@ -86,9 +86,10 @@ func (m *Model) Prompt(p PromptVars) (string, error) {
|
|||
return prompt.String(), nil
|
||||
}
|
||||
|
||||
func (m *Model) ChatPrompt(msgs []api.Message) (string, error) {
|
||||
func (m *Model) ChatPrompt(msgs []api.Message) (string, []api.ImageData, error) {
|
||||
// build the prompt from the list of messages
|
||||
var prompt strings.Builder
|
||||
var currentImages []api.ImageData
|
||||
currentVars := PromptVars{
|
||||
First: true,
|
||||
}
|
||||
|
@ -108,35 +109,36 @@ func (m *Model) ChatPrompt(msgs []api.Message) (string, error) {
|
|||
case "system":
|
||||
if currentVars.System != "" {
|
||||
if err := writePrompt(); err != nil {
|
||||
return "", err
|
||||
return "", nil, err
|
||||
}
|
||||
}
|
||||
currentVars.System = msg.Content
|
||||
case "user":
|
||||
if currentVars.Prompt != "" {
|
||||
if err := writePrompt(); err != nil {
|
||||
return "", err
|
||||
return "", nil, err
|
||||
}
|
||||
}
|
||||
currentVars.Prompt = msg.Content
|
||||
currentImages = msg.Images
|
||||
case "assistant":
|
||||
currentVars.Response = msg.Content
|
||||
if err := writePrompt(); err != nil {
|
||||
return "", err
|
||||
return "", nil, err
|
||||
}
|
||||
default:
|
||||
return "", fmt.Errorf("invalid role: %s, role must be one of [system, user, assistant]", msg.Role)
|
||||
return "", nil, fmt.Errorf("invalid role: %s, role must be one of [system, user, assistant]", msg.Role)
|
||||
}
|
||||
}
|
||||
|
||||
// Append the last set of vars if they are non-empty
|
||||
if currentVars.Prompt != "" || currentVars.System != "" {
|
||||
if err := writePrompt(); err != nil {
|
||||
return "", err
|
||||
return "", nil, err
|
||||
}
|
||||
}
|
||||
|
||||
return prompt.String(), nil
|
||||
return prompt.String(), currentImages, nil
|
||||
}
|
||||
|
||||
type ManifestV2 struct {
|
||||
|
|
|
@ -994,7 +994,7 @@ func ChatHandler(c *gin.Context) {
|
|||
|
||||
checkpointLoaded := time.Now()
|
||||
|
||||
prompt, err := model.ChatPrompt(req.Messages)
|
||||
prompt, images, err := model.ChatPrompt(req.Messages)
|
||||
if err != nil {
|
||||
c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()})
|
||||
return
|
||||
|
@ -1037,6 +1037,7 @@ func ChatHandler(c *gin.Context) {
|
|||
Format: req.Format,
|
||||
CheckpointStart: checkpointStart,
|
||||
CheckpointLoaded: checkpointLoaded,
|
||||
Images: images,
|
||||
}
|
||||
if err := loaded.runner.Predict(c.Request.Context(), predictReq, fn); err != nil {
|
||||
ch <- gin.H{"error": err.Error()}
|
||||
|
|
Loading…
Reference in a new issue