add image support to the chat api (#1490)
This commit is contained in:
parent
4251b342de
commit
d9e60f634b
3 changed files with 14 additions and 10 deletions
|
@ -57,8 +57,9 @@ type ChatRequest struct {
|
||||||
}
|
}
|
||||||
|
|
||||||
type Message struct {
|
type Message struct {
|
||||||
Role string `json:"role"` // one of ["system", "user", "assistant"]
|
Role string `json:"role"` // one of ["system", "user", "assistant"]
|
||||||
Content string `json:"content"`
|
Content string `json:"content"`
|
||||||
|
Images []ImageData `json:"images, omitempty"`
|
||||||
}
|
}
|
||||||
|
|
||||||
type ChatResponse struct {
|
type ChatResponse struct {
|
||||||
|
|
|
@ -86,9 +86,10 @@ func (m *Model) Prompt(p PromptVars) (string, error) {
|
||||||
return prompt.String(), nil
|
return prompt.String(), nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (m *Model) ChatPrompt(msgs []api.Message) (string, error) {
|
func (m *Model) ChatPrompt(msgs []api.Message) (string, []api.ImageData, error) {
|
||||||
// build the prompt from the list of messages
|
// build the prompt from the list of messages
|
||||||
var prompt strings.Builder
|
var prompt strings.Builder
|
||||||
|
var currentImages []api.ImageData
|
||||||
currentVars := PromptVars{
|
currentVars := PromptVars{
|
||||||
First: true,
|
First: true,
|
||||||
}
|
}
|
||||||
|
@ -108,35 +109,36 @@ func (m *Model) ChatPrompt(msgs []api.Message) (string, error) {
|
||||||
case "system":
|
case "system":
|
||||||
if currentVars.System != "" {
|
if currentVars.System != "" {
|
||||||
if err := writePrompt(); err != nil {
|
if err := writePrompt(); err != nil {
|
||||||
return "", err
|
return "", nil, err
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
currentVars.System = msg.Content
|
currentVars.System = msg.Content
|
||||||
case "user":
|
case "user":
|
||||||
if currentVars.Prompt != "" {
|
if currentVars.Prompt != "" {
|
||||||
if err := writePrompt(); err != nil {
|
if err := writePrompt(); err != nil {
|
||||||
return "", err
|
return "", nil, err
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
currentVars.Prompt = msg.Content
|
currentVars.Prompt = msg.Content
|
||||||
|
currentImages = msg.Images
|
||||||
case "assistant":
|
case "assistant":
|
||||||
currentVars.Response = msg.Content
|
currentVars.Response = msg.Content
|
||||||
if err := writePrompt(); err != nil {
|
if err := writePrompt(); err != nil {
|
||||||
return "", err
|
return "", nil, err
|
||||||
}
|
}
|
||||||
default:
|
default:
|
||||||
return "", fmt.Errorf("invalid role: %s, role must be one of [system, user, assistant]", msg.Role)
|
return "", nil, fmt.Errorf("invalid role: %s, role must be one of [system, user, assistant]", msg.Role)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// Append the last set of vars if they are non-empty
|
// Append the last set of vars if they are non-empty
|
||||||
if currentVars.Prompt != "" || currentVars.System != "" {
|
if currentVars.Prompt != "" || currentVars.System != "" {
|
||||||
if err := writePrompt(); err != nil {
|
if err := writePrompt(); err != nil {
|
||||||
return "", err
|
return "", nil, err
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
return prompt.String(), nil
|
return prompt.String(), currentImages, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
type ManifestV2 struct {
|
type ManifestV2 struct {
|
||||||
|
|
|
@ -994,7 +994,7 @@ func ChatHandler(c *gin.Context) {
|
||||||
|
|
||||||
checkpointLoaded := time.Now()
|
checkpointLoaded := time.Now()
|
||||||
|
|
||||||
prompt, err := model.ChatPrompt(req.Messages)
|
prompt, images, err := model.ChatPrompt(req.Messages)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()})
|
c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()})
|
||||||
return
|
return
|
||||||
|
@ -1037,6 +1037,7 @@ func ChatHandler(c *gin.Context) {
|
||||||
Format: req.Format,
|
Format: req.Format,
|
||||||
CheckpointStart: checkpointStart,
|
CheckpointStart: checkpointStart,
|
||||||
CheckpointLoaded: checkpointLoaded,
|
CheckpointLoaded: checkpointLoaded,
|
||||||
|
Images: images,
|
||||||
}
|
}
|
||||||
if err := loaded.runner.Predict(c.Request.Context(), predictReq, fn); err != nil {
|
if err := loaded.runner.Predict(c.Request.Context(), predictReq, fn); err != nil {
|
||||||
ch <- gin.H{"error": err.Error()}
|
ch <- gin.H{"error": err.Error()}
|
||||||
|
|
Loading…
Reference in a new issue