fix tests

This commit is contained in:
Michael Yang 2024-02-01 11:48:11 -08:00
parent d125510b4b
commit e49dc9f3d8
2 changed files with 28 additions and 8 deletions

View file

@ -238,18 +238,37 @@ func chatHistoryEqual(a, b ChatHistory) bool {
if len(a.Prompts) != len(b.Prompts) {
return false
}
if len(a.CurrentImages) != len(b.CurrentImages) {
return false
}
for i, v := range a.Prompts {
if v != b.Prompts[i] {
if v.First != b.Prompts[i].First {
return false
}
}
for i, v := range a.CurrentImages {
if !bytes.Equal(v, b.CurrentImages[i]) {
if v.Response != b.Prompts[i].Response {
return false
}
if v.Prompt != b.Prompts[i].Prompt {
return false
}
if v.System != b.Prompts[i].System {
return false
}
if len(v.Images) != len(b.Prompts[i].Images) {
return false
}
for j, img := range v.Images {
if img.ID != b.Prompts[i].Images[j].ID {
return false
}
if !bytes.Equal(img.Data, b.Prompts[i].Images[j].Data) {
return false
}
}
}
return a.LastSystem == b.LastSystem
}

View file

@ -455,7 +455,8 @@ func Test_ChatPrompt(t *testing.T) {
NumCtx: tt.numCtx,
},
}
got, err := trimmedPrompt(context.Background(), tt.chat, m)
// TODO: add tests for trimming images
got, _, err := trimmedPrompt(context.Background(), tt.chat, m)
if tt.wantErr != "" {
if err == nil {
t.Errorf("ChatPrompt() expected error, got nil")