prepend image tags (#2789)
instead of appending image tags, prepend them - this generally produces better results
This commit is contained in:
parent
fa2f2b3563
commit
0e19476b56
3 changed files with 21 additions and 18 deletions
|
@ -121,13 +121,15 @@ func ChatPrompt(tmpl string, messages []api.Message, window int, encode func(str
|
||||||
p = prompt{}
|
p = prompt{}
|
||||||
}
|
}
|
||||||
|
|
||||||
p.Prompt = msg.Content
|
var sb strings.Builder
|
||||||
|
|
||||||
for range msg.Images {
|
for range msg.Images {
|
||||||
p.Prompt += fmt.Sprintf(" [img-%d]", imgId)
|
fmt.Fprintf(&sb, "[img-%d] ", imgId)
|
||||||
p.images = append(p.images, imgId)
|
p.images = append(p.images, imgId)
|
||||||
imgId += 1
|
imgId += 1
|
||||||
}
|
}
|
||||||
|
|
||||||
|
sb.WriteString(msg.Content)
|
||||||
|
p.Prompt = sb.String()
|
||||||
case "assistant":
|
case "assistant":
|
||||||
if p.Response != "" {
|
if p.Response != "" {
|
||||||
prompts = append(prompts, p)
|
prompts = append(prompts, p)
|
||||||
|
|
|
@ -155,7 +155,7 @@ func TestChatPrompt(t *testing.T) {
|
||||||
{Role: "user", Content: "Hello", Images: []api.ImageData{[]byte("base64")}},
|
{Role: "user", Content: "Hello", Images: []api.ImageData{[]byte("base64")}},
|
||||||
},
|
},
|
||||||
window: 1024,
|
window: 1024,
|
||||||
want: "You are a Wizard. Hello [img-0]",
|
want: "You are a Wizard. [img-0] Hello",
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
name: "images truncated",
|
name: "images truncated",
|
||||||
|
@ -165,7 +165,7 @@ func TestChatPrompt(t *testing.T) {
|
||||||
{Role: "user", Content: "Hello", Images: []api.ImageData{[]byte("img1"), []byte("img2")}},
|
{Role: "user", Content: "Hello", Images: []api.ImageData{[]byte("img1"), []byte("img2")}},
|
||||||
},
|
},
|
||||||
window: 1024,
|
window: 1024,
|
||||||
want: "You are a Wizard. Hello [img-1]",
|
want: "You are a Wizard. [img-0] [img-1] Hello",
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
name: "empty list",
|
name: "empty list",
|
||||||
|
@ -198,7 +198,7 @@ func TestChatPrompt(t *testing.T) {
|
||||||
}
|
}
|
||||||
|
|
||||||
if got != tc.want {
|
if got != tc.want {
|
||||||
t.Errorf("got = %v, want %v", got, tc.want)
|
t.Errorf("got: %q, want: %q", got, tc.want)
|
||||||
}
|
}
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
|
@ -250,6 +250,19 @@ func GenerateHandler(c *gin.Context) {
|
||||||
slog.Debug("generate handler", "system", req.System)
|
slog.Debug("generate handler", "system", req.System)
|
||||||
|
|
||||||
var sb strings.Builder
|
var sb strings.Builder
|
||||||
|
for i := range req.Images {
|
||||||
|
fmt.Fprintf(&sb, "[img-%d] ", i)
|
||||||
|
}
|
||||||
|
|
||||||
|
sb.WriteString(req.Prompt)
|
||||||
|
|
||||||
|
p, err := Prompt(req.Template, req.System, sb.String(), "", true)
|
||||||
|
if err != nil {
|
||||||
|
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
sb.Reset()
|
||||||
if req.Context != nil {
|
if req.Context != nil {
|
||||||
prev, err := loaded.runner.Decode(c.Request.Context(), req.Context)
|
prev, err := loaded.runner.Decode(c.Request.Context(), req.Context)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
|
@ -260,18 +273,6 @@ func GenerateHandler(c *gin.Context) {
|
||||||
sb.WriteString(prev)
|
sb.WriteString(prev)
|
||||||
}
|
}
|
||||||
|
|
||||||
// write image tags
|
|
||||||
// TODO: limit the number of images to fit in the context similar to the chat endpoint
|
|
||||||
for i := range req.Images {
|
|
||||||
req.Prompt += fmt.Sprintf(" [img-%d]", i)
|
|
||||||
}
|
|
||||||
|
|
||||||
p, err := Prompt(req.Template, req.System, req.Prompt, "", true)
|
|
||||||
if err != nil {
|
|
||||||
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
sb.WriteString(p)
|
sb.WriteString(p)
|
||||||
|
|
||||||
prompt = sb.String()
|
prompt = sb.String()
|
||||||
|
|
Loading…
Reference in a new issue