diff --git a/integration/embed_test.go b/integration/embed_test.go index aeafa57b..61b36fa2 100644 --- a/integration/embed_test.go +++ b/integration/embed_test.go @@ -4,12 +4,45 @@ package integration import ( "context" + "math" "testing" "time" "github.com/ollama/ollama/api" ) +func floatsEqual32(a, b float32) bool { + return math.Abs(float64(a-b)) <= 1e-4 +} + +func floatsEqual64(a, b float64) bool { + return math.Abs(a-b) <= 1e-4 +} + +func TestAllMiniLMEmbeddings(t *testing.T) { + ctx, cancel := context.WithTimeout(context.Background(), 2*time.Minute) + defer cancel() + + req := api.EmbeddingRequest{ + Model: "all-minilm", + Prompt: "why is the sky blue?", + } + + res, err := embeddingTestHelper(ctx, t, req) + + if err != nil { + t.Fatalf("error: %v", err) + } + + if len(res.Embedding) != 384 { + t.Fatalf("expected 384 floats, got %d", len(res.Embedding)) + } + + if !floatsEqual64(res.Embedding[0], 0.06642947345972061) { + t.Fatalf("expected 0.06642947345972061, got %.16f", res.Embedding[0]) + } +} + func TestAllMiniLMEmbed(t *testing.T) { ctx, cancel := context.WithTimeout(context.Background(), 2*time.Minute) defer cancel() @@ -33,8 +66,8 @@ func TestAllMiniLMEmbed(t *testing.T) { t.Fatalf("expected 384 floats, got %d", len(res.Embeddings[0])) } - if res.Embeddings[0][0] != 0.010071031 { - t.Fatalf("expected 0.010071031, got %f", res.Embeddings[0][0]) + if !floatsEqual32(res.Embeddings[0][0], 0.010071031) { + t.Fatalf("expected 0.010071031, got %.8f", res.Embeddings[0][0]) } } @@ -61,12 +94,12 @@ func TestAllMiniLMBatchEmbed(t *testing.T) { t.Fatalf("expected 384 floats, got %d", len(res.Embeddings[0])) } - if res.Embeddings[0][0] != 0.010071031 || res.Embeddings[1][0] != -0.009802706 { - t.Fatalf("expected 0.010071031 and -0.009802706, got %f and %f", res.Embeddings[0][0], res.Embeddings[1][0]) + if !floatsEqual32(res.Embeddings[0][0], 0.010071031) || !floatsEqual32(res.Embeddings[1][0], -0.009802706) { + t.Fatalf("expected 0.010071031 and -0.009802706, got %.8f and %.8f", res.Embeddings[0][0], res.Embeddings[1][0]) } } -func TestAllMiniLmEmbedTruncate(t *testing.T) { +func TestAllMiniLMEmbedTruncate(t *testing.T) { ctx, cancel := context.WithTimeout(context.Background(), 2*time.Minute) defer cancel() @@ -135,6 +168,22 @@ func TestAllMiniLmEmbedTruncate(t *testing.T) { } } +func embeddingTestHelper(ctx context.Context, t *testing.T, req api.EmbeddingRequest) (*api.EmbeddingResponse, error) { + client, _, cleanup := InitServerConnection(ctx, t) + defer cleanup() + if err := PullIfMissing(ctx, client, req.Model); err != nil { + t.Fatalf("failed to pull model %s: %v", req.Model, err) + } + + response, err := client.Embeddings(ctx, &req) + + if err != nil { + return nil, err + } + + return response, nil +} + func embedTestHelper(ctx context.Context, t *testing.T, req api.EmbedRequest) (*api.EmbedResponse, error) { client, _, cleanup := InitServerConnection(ctx, t) defer cleanup()