parent
a6cd8f6169
commit
ac33aa7d37
1 changed files with 54 additions and 5 deletions
|
@ -4,12 +4,45 @@ package integration
|
|||
|
||||
import (
|
||||
"context"
|
||||
"math"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/ollama/ollama/api"
|
||||
)
|
||||
|
||||
func floatsEqual32(a, b float32) bool {
|
||||
return math.Abs(float64(a-b)) <= 1e-4
|
||||
}
|
||||
|
||||
func floatsEqual64(a, b float64) bool {
|
||||
return math.Abs(a-b) <= 1e-4
|
||||
}
|
||||
|
||||
func TestAllMiniLMEmbeddings(t *testing.T) {
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 2*time.Minute)
|
||||
defer cancel()
|
||||
|
||||
req := api.EmbeddingRequest{
|
||||
Model: "all-minilm",
|
||||
Prompt: "why is the sky blue?",
|
||||
}
|
||||
|
||||
res, err := embeddingTestHelper(ctx, t, req)
|
||||
|
||||
if err != nil {
|
||||
t.Fatalf("error: %v", err)
|
||||
}
|
||||
|
||||
if len(res.Embedding) != 384 {
|
||||
t.Fatalf("expected 384 floats, got %d", len(res.Embedding))
|
||||
}
|
||||
|
||||
if !floatsEqual64(res.Embedding[0], 0.06642947345972061) {
|
||||
t.Fatalf("expected 0.06642947345972061, got %.16f", res.Embedding[0])
|
||||
}
|
||||
}
|
||||
|
||||
func TestAllMiniLMEmbed(t *testing.T) {
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 2*time.Minute)
|
||||
defer cancel()
|
||||
|
@ -33,8 +66,8 @@ func TestAllMiniLMEmbed(t *testing.T) {
|
|||
t.Fatalf("expected 384 floats, got %d", len(res.Embeddings[0]))
|
||||
}
|
||||
|
||||
if res.Embeddings[0][0] != 0.010071031 {
|
||||
t.Fatalf("expected 0.010071031, got %f", res.Embeddings[0][0])
|
||||
if !floatsEqual32(res.Embeddings[0][0], 0.010071031) {
|
||||
t.Fatalf("expected 0.010071031, got %.8f", res.Embeddings[0][0])
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -61,12 +94,12 @@ func TestAllMiniLMBatchEmbed(t *testing.T) {
|
|||
t.Fatalf("expected 384 floats, got %d", len(res.Embeddings[0]))
|
||||
}
|
||||
|
||||
if res.Embeddings[0][0] != 0.010071031 || res.Embeddings[1][0] != -0.009802706 {
|
||||
t.Fatalf("expected 0.010071031 and -0.009802706, got %f and %f", res.Embeddings[0][0], res.Embeddings[1][0])
|
||||
if !floatsEqual32(res.Embeddings[0][0], 0.010071031) || !floatsEqual32(res.Embeddings[1][0], -0.009802706) {
|
||||
t.Fatalf("expected 0.010071031 and -0.009802706, got %.8f and %.8f", res.Embeddings[0][0], res.Embeddings[1][0])
|
||||
}
|
||||
}
|
||||
|
||||
func TestAllMiniLmEmbedTruncate(t *testing.T) {
|
||||
func TestAllMiniLMEmbedTruncate(t *testing.T) {
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 2*time.Minute)
|
||||
defer cancel()
|
||||
|
||||
|
@ -135,6 +168,22 @@ func TestAllMiniLmEmbedTruncate(t *testing.T) {
|
|||
}
|
||||
}
|
||||
|
||||
func embeddingTestHelper(ctx context.Context, t *testing.T, req api.EmbeddingRequest) (*api.EmbeddingResponse, error) {
|
||||
client, _, cleanup := InitServerConnection(ctx, t)
|
||||
defer cleanup()
|
||||
if err := PullIfMissing(ctx, client, req.Model); err != nil {
|
||||
t.Fatalf("failed to pull model %s: %v", req.Model, err)
|
||||
}
|
||||
|
||||
response, err := client.Embeddings(ctx, &req)
|
||||
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return response, nil
|
||||
}
|
||||
|
||||
func embedTestHelper(ctx context.Context, t *testing.T, req api.EmbedRequest) (*api.EmbedResponse, error) {
|
||||
client, _, cleanup := InitServerConnection(ctx, t)
|
||||
defer cleanup()
|
||||
|
|
Loading…
Reference in a new issue