prepend image tags (#2789)

instead of appending image tags, prepend them - this generally produces better results
This commit is contained in:
Michael Yang 2024-02-29 11:30:14 -08:00 committed by GitHub
parent fa2f2b3563
commit 0e19476b56
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
3 changed files with 21 additions and 18 deletions

View file

@ -121,13 +121,15 @@ func ChatPrompt(tmpl string, messages []api.Message, window int, encode func(str
p = prompt{} p = prompt{}
} }
p.Prompt = msg.Content var sb strings.Builder
for range msg.Images { for range msg.Images {
p.Prompt += fmt.Sprintf(" [img-%d]", imgId) fmt.Fprintf(&sb, "[img-%d] ", imgId)
p.images = append(p.images, imgId) p.images = append(p.images, imgId)
imgId += 1 imgId += 1
} }
sb.WriteString(msg.Content)
p.Prompt = sb.String()
case "assistant": case "assistant":
if p.Response != "" { if p.Response != "" {
prompts = append(prompts, p) prompts = append(prompts, p)

View file

@ -155,7 +155,7 @@ func TestChatPrompt(t *testing.T) {
{Role: "user", Content: "Hello", Images: []api.ImageData{[]byte("base64")}}, {Role: "user", Content: "Hello", Images: []api.ImageData{[]byte("base64")}},
}, },
window: 1024, window: 1024,
want: "You are a Wizard. Hello [img-0]", want: "You are a Wizard. [img-0] Hello",
}, },
{ {
name: "images truncated", name: "images truncated",
@ -165,7 +165,7 @@ func TestChatPrompt(t *testing.T) {
{Role: "user", Content: "Hello", Images: []api.ImageData{[]byte("img1"), []byte("img2")}}, {Role: "user", Content: "Hello", Images: []api.ImageData{[]byte("img1"), []byte("img2")}},
}, },
window: 1024, window: 1024,
want: "You are a Wizard. Hello [img-1]", want: "You are a Wizard. [img-0] [img-1] Hello",
}, },
{ {
name: "empty list", name: "empty list",
@ -198,7 +198,7 @@ func TestChatPrompt(t *testing.T) {
} }
if got != tc.want { if got != tc.want {
t.Errorf("got = %v, want %v", got, tc.want) t.Errorf("got: %q, want: %q", got, tc.want)
} }
}) })
} }

View file

@ -250,6 +250,19 @@ func GenerateHandler(c *gin.Context) {
slog.Debug("generate handler", "system", req.System) slog.Debug("generate handler", "system", req.System)
var sb strings.Builder var sb strings.Builder
for i := range req.Images {
fmt.Fprintf(&sb, "[img-%d] ", i)
}
sb.WriteString(req.Prompt)
p, err := Prompt(req.Template, req.System, sb.String(), "", true)
if err != nil {
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
return
}
sb.Reset()
if req.Context != nil { if req.Context != nil {
prev, err := loaded.runner.Decode(c.Request.Context(), req.Context) prev, err := loaded.runner.Decode(c.Request.Context(), req.Context)
if err != nil { if err != nil {
@ -260,18 +273,6 @@ func GenerateHandler(c *gin.Context) {
sb.WriteString(prev) sb.WriteString(prev)
} }
// write image tags
// TODO: limit the number of images to fit in the context similar to the chat endpoint
for i := range req.Images {
req.Prompt += fmt.Sprintf(" [img-%d]", i)
}
p, err := Prompt(req.Template, req.System, req.Prompt, "", true)
if err != nil {
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
return
}
sb.WriteString(p) sb.WriteString(p)
prompt = sb.String() prompt = sb.String()