ollama/llama/runner/image_test.go
Jesse Gross c826e57475 runner.go: Better abstract vision model integration
-Update mllama to take the cross attention state as embeddings in
a batch, more similar to how Llava handles it. This improves
integration with the input cache.
-Pass locations in a prompt for embeddings using tags similar to Llava.
-Abstract interface to vision models so the main runner accesses Clip
and Mllama similarly

Co-authored-by: Michael Yang <mxyng@pm.me>
2024-10-30 14:53:43 -07:00

80 lines
2.4 KiB
Go

package main
import (
"reflect"
"testing"
)
func TestImageCache(t *testing.T) {
cache := ImageContext{images: make([]imageCache, 4)}
valA := [][]float32{{0.1, 0.2}, {0.3}}
valB := [][]float32{{0.4}, {0.5}, {0.6}}
valC := [][]float32{{0.7}}
valD := [][]float32{{0.8}}
valE := [][]float32{{0.9}}
// Empty cache
result, err := cache.findImage(0x5adb61d31933a946)
if err != errImageNotFound {
t.Errorf("found result in empty cache: result %v, err %v", result, err)
}
// Insert A
cache.addImage(0x5adb61d31933a946, valA)
result, err = cache.findImage(0x5adb61d31933a946)
if !reflect.DeepEqual(result, valA) {
t.Errorf("failed to find expected value: result %v, err %v", result, err)
}
// Insert B
cache.addImage(0x011551369a34a901, valB)
result, err = cache.findImage(0x5adb61d31933a946)
if !reflect.DeepEqual(result, valA) {
t.Errorf("failed to find expected value: result %v, err %v", result, err)
}
result, err = cache.findImage(0x011551369a34a901)
if !reflect.DeepEqual(result, valB) {
t.Errorf("failed to find expected value: result %v, err %v", result, err)
}
// Replace B with C
cache.addImage(0x011551369a34a901, valC)
result, err = cache.findImage(0x5adb61d31933a946)
if !reflect.DeepEqual(result, valA) {
t.Errorf("failed to find expected value: result %v, err %v", result, err)
}
result, err = cache.findImage(0x011551369a34a901)
if !reflect.DeepEqual(result, valC) {
t.Errorf("failed to find expected value: result %v, err %v", result, err)
}
// Evict A
cache.addImage(0x756b218a517e7353, valB)
cache.addImage(0x75e5e8d35d7e3967, valD)
cache.addImage(0xd96f7f268ca0646e, valE)
result, err = cache.findImage(0x5adb61d31933a946)
if reflect.DeepEqual(result, valA) {
t.Errorf("failed to find expected value: result %v, err %v", result, err)
}
result, err = cache.findImage(0x756b218a517e7353)
if !reflect.DeepEqual(result, valB) {
t.Errorf("failed to find expected value: result %v, err %v", result, err)
}
result, err = cache.findImage(0x011551369a34a901)
if !reflect.DeepEqual(result, valC) {
t.Errorf("failed to find expected value: result %v, err %v", result, err)
}
result, err = cache.findImage(0x75e5e8d35d7e3967)
if !reflect.DeepEqual(result, valD) {
t.Errorf("failed to find expected value: result %v, err %v", result, err)
}
result, err = cache.findImage(0xd96f7f268ca0646e)
if !reflect.DeepEqual(result, valE) {
t.Errorf("failed to find expected value: result %v, err %v", result, err)
}
}