//go:build integration package integration import ( "bytes" "context" "errors" "fmt" "io" "log/slog" "math/rand" "net" "net/http" "net/url" "os" "path/filepath" "runtime" "strconv" "strings" "sync" "testing" "time" "github.com/ollama/ollama/api" "github.com/ollama/ollama/app/lifecycle" "github.com/stretchr/testify/require" ) func Init() { lifecycle.InitLogging() } func FindPort() string { port := 0 if a, err := net.ResolveTCPAddr("tcp", "localhost:0"); err == nil { var l *net.TCPListener if l, err = net.ListenTCP("tcp", a); err == nil { port = l.Addr().(*net.TCPAddr).Port l.Close() } } if port == 0 { port = rand.Intn(65535-49152) + 49152 // get a random port in the ephemeral range } return strconv.Itoa(port) } func GetTestEndpoint() (*api.Client, string) { defaultPort := "11434" ollamaHost := os.Getenv("OLLAMA_HOST") scheme, hostport, ok := strings.Cut(ollamaHost, "://") if !ok { scheme, hostport = "http", ollamaHost } // trim trailing slashes hostport = strings.TrimRight(hostport, "/") host, port, err := net.SplitHostPort(hostport) if err != nil { host, port = "127.0.0.1", defaultPort if ip := net.ParseIP(strings.Trim(hostport, "[]")); ip != nil { host = ip.String() } else if hostport != "" { host = hostport } } if os.Getenv("OLLAMA_TEST_EXISTING") == "" && port == defaultPort { port = FindPort() } slog.Info("server connection", "host", host, "port", port) return api.NewClient( &url.URL{ Scheme: scheme, Host: net.JoinHostPort(host, port), }, http.DefaultClient), fmt.Sprintf("%s:%s", host, port) } var serverMutex sync.Mutex var serverReady bool func startServer(ctx context.Context, ollamaHost string) error { // Make sure the server has been built CLIName, err := filepath.Abs("../ollama") if err != nil { return err } if runtime.GOOS == "windows" { CLIName += ".exe" } _, err = os.Stat(CLIName) if err != nil { return fmt.Errorf("CLI missing, did you forget to build first? %w", err) } serverMutex.Lock() defer serverMutex.Unlock() if serverReady { return nil } if tmp := os.Getenv("OLLAMA_HOST"); tmp != ollamaHost { slog.Info("setting env", "OLLAMA_HOST", ollamaHost) os.Setenv("OLLAMA_HOST", ollamaHost) } slog.Info("starting server", "url", ollamaHost) done, err := lifecycle.SpawnServer(ctx, "../ollama") if err != nil { return fmt.Errorf("failed to start server: %w", err) } go func() { <-ctx.Done() serverMutex.Lock() defer serverMutex.Unlock() exitCode := <-done if exitCode > 0 { slog.Warn("server failure", "exit", exitCode) } serverReady = false }() // TODO wait only long enough for the server to be responsive... time.Sleep(500 * time.Millisecond) serverReady = true return nil } func PullIfMissing(ctx context.Context, client *api.Client, modelName string) error { slog.Info("checking status of model", "model", modelName) showReq := &api.ShowRequest{Name: modelName} showCtx, cancel := context.WithDeadlineCause( ctx, time.Now().Add(5*time.Second), fmt.Errorf("show for existing model %s took too long", modelName), ) defer cancel() _, err := client.Show(showCtx, showReq) var statusError api.StatusError switch { case errors.As(err, &statusError) && statusError.StatusCode == http.StatusNotFound: break case err != nil: return err default: slog.Info("model already present", "model", modelName) return nil } slog.Info("model missing", "model", modelName) stallDuration := 30 * time.Second // This includes checksum verification, which can take a while on larger models stallTimer := time.NewTimer(stallDuration) fn := func(resp api.ProgressResponse) error { // fmt.Print(".") if !stallTimer.Reset(stallDuration) { return fmt.Errorf("stall was detected, aborting status reporting") } return nil } stream := true pullReq := &api.PullRequest{Name: modelName, Stream: &stream} var pullError error done := make(chan int) go func() { pullError = client.Pull(ctx, pullReq, fn) done <- 0 }() select { case <-stallTimer.C: return fmt.Errorf("download stalled") case <-done: return pullError } } var serverProcMutex sync.Mutex // Returns an Client, the testEndpoint, and a cleanup function, fails the test on errors // Starts the server if needed func InitServerConnection(ctx context.Context, t *testing.T) (*api.Client, string, func()) { client, testEndpoint := GetTestEndpoint() if os.Getenv("OLLAMA_TEST_EXISTING") == "" { serverProcMutex.Lock() fp, err := os.CreateTemp("", "ollama-server-*.log") if err != nil { t.Fatalf("failed to generate log file: %s", err) } lifecycle.ServerLogFile = fp.Name() fp.Close() require.NoError(t, startServer(ctx, testEndpoint)) } return client, testEndpoint, func() { if os.Getenv("OLLAMA_TEST_EXISTING") == "" { defer serverProcMutex.Unlock() if t.Failed() { fp, err := os.Open(lifecycle.ServerLogFile) if err != nil { slog.Error("failed to open server log", "logfile", lifecycle.ServerLogFile, "error", err) return } data, err := io.ReadAll(fp) if err != nil { slog.Error("failed to read server log", "logfile", lifecycle.ServerLogFile, "error", err) return } slog.Warn("SERVER LOG FOLLOWS") os.Stderr.Write(data) slog.Warn("END OF SERVER") } err := os.Remove(lifecycle.ServerLogFile) if err != nil && !os.IsNotExist(err) { slog.Warn("failed to cleanup", "logfile", lifecycle.ServerLogFile, "error", err) } } } } func GenerateTestHelper(ctx context.Context, t *testing.T, genReq api.GenerateRequest, anyResp []string) { client, _, cleanup := InitServerConnection(ctx, t) defer cleanup() require.NoError(t, PullIfMissing(ctx, client, genReq.Model)) DoGenerate(ctx, t, client, genReq, anyResp, 30*time.Second, 10*time.Second) } func DoGenerate(ctx context.Context, t *testing.T, client *api.Client, genReq api.GenerateRequest, anyResp []string, initialTimeout, streamTimeout time.Duration) { stallTimer := time.NewTimer(initialTimeout) var buf bytes.Buffer fn := func(response api.GenerateResponse) error { // fmt.Print(".") buf.Write([]byte(response.Response)) if !stallTimer.Reset(streamTimeout) { return fmt.Errorf("stall was detected while streaming response, aborting") } return nil } stream := true genReq.Stream = &stream done := make(chan int) var genErr error go func() { genErr = client.Generate(ctx, &genReq, fn) done <- 0 }() select { case <-stallTimer.C: if buf.Len() == 0 { t.Errorf("generate never started. Timed out after :%s", initialTimeout.String()) } else { t.Errorf("generate stalled. Response so far:%s", buf.String()) } case <-done: require.NoError(t, genErr, "failed with %s request prompt %s ", genReq.Model, genReq.Prompt) // Verify the response contains the expected data response := buf.String() atLeastOne := false for _, resp := range anyResp { if strings.Contains(strings.ToLower(response), resp) { atLeastOne = true break } } require.True(t, atLeastOne, "none of %v found in %s", anyResp, response) slog.Info("test pass", "model", genReq.Model, "prompt", genReq.Prompt, "contains", anyResp, "response", response) case <-ctx.Done(): t.Error("outer test context done while waiting for generate") } } // Generate a set of requests // By default each request uses orca-mini as the model func GenerateRequests() ([]api.GenerateRequest, [][]string) { return []api.GenerateRequest{ { Model: "orca-mini", Prompt: "why is the ocean blue?", Stream: &stream, Options: map[string]interface{}{ "seed": 42, "temperature": 0.0, }, }, { Model: "orca-mini", Prompt: "why is the color of dirt brown?", Stream: &stream, Options: map[string]interface{}{ "seed": 42, "temperature": 0.0, }, }, { Model: "orca-mini", Prompt: "what is the origin of the us thanksgiving holiday?", Stream: &stream, Options: map[string]interface{}{ "seed": 42, "temperature": 0.0, }, }, { Model: "orca-mini", Prompt: "what is the origin of independence day?", Stream: &stream, Options: map[string]interface{}{ "seed": 42, "temperature": 0.0, }, }, { Model: "orca-mini", Prompt: "what is the composition of air?", Stream: &stream, Options: map[string]interface{}{ "seed": 42, "temperature": 0.0, }, }, }, [][]string{ []string{"sunlight"}, []string{"soil", "organic", "earth", "black", "tan"}, []string{"england", "english", "massachusetts", "pilgrims"}, []string{"fourth", "july", "declaration", "independence"}, []string{"nitrogen", "oxygen", "carbon", "dioxide"}, } }