From 949b6c01e074c6f7712d7da37079218b3192b102 Mon Sep 17 00:00:00 2001 From: Daniel Hiltgen Date: Sat, 23 Mar 2024 14:24:18 +0100 Subject: [PATCH] Revamp go based integration tests This uplevels the integration tests to run the server which can allow testing an existing server, or a remote server. --- integration/README.md | 11 ++ integration/basic_test.go | 28 ++++ {server => integration}/llm_image_test.go | 33 ++-- integration/llm_test.go | 73 +++++++++ integration/utils_test.go | 190 ++++++++++++++++++++++ scripts/setup_integration_tests.sh | 41 ----- server/llm_test.go | 123 -------------- server/llm_utils_test.go | 75 --------- 8 files changed, 313 insertions(+), 261 deletions(-) create mode 100644 integration/README.md create mode 100644 integration/basic_test.go rename {server => integration}/llm_image_test.go (98%) create mode 100644 integration/llm_test.go create mode 100644 integration/utils_test.go delete mode 100755 scripts/setup_integration_tests.sh delete mode 100644 server/llm_test.go delete mode 100644 server/llm_utils_test.go diff --git a/integration/README.md b/integration/README.md new file mode 100644 index 00000000..e2bdd6b2 --- /dev/null +++ b/integration/README.md @@ -0,0 +1,11 @@ +# Integration Tests + +This directory contains integration tests to exercise Ollama end-to-end to verify behavior + +By default, these tests are disabled so `go test ./...` will exercise only unit tests. To run integration tests you must pass the integration tag. `go test -tags=integration ./...` + + +The integration tests have 2 modes of operating. + +1. By default, they will start the server on a random port, run the tests, and then shutdown the server. +2. If `OLLAMA_TEST_EXISTING` is set to a non-empty string, the tests will run against an existing running server, which can be remote diff --git a/integration/basic_test.go b/integration/basic_test.go new file mode 100644 index 00000000..3cd5c354 --- /dev/null +++ b/integration/basic_test.go @@ -0,0 +1,28 @@ +//go:build integration + +package integration + +import ( + "context" + "net/http" + "testing" + "time" + + "github.com/jmorganca/ollama/api" +) + +func TestOrcaMiniBlueSky(t *testing.T) { + ctx, cancel := context.WithTimeout(context.Background(), 1*time.Minute) + defer cancel() + // Set up the test data + req := api.GenerateRequest{ + Model: "orca-mini", + Prompt: "why is the sky blue?", + Stream: &stream, + Options: map[string]interface{}{ + "temperature": 0, + "seed": 123, + }, + } + GenerateTestHelper(ctx, t, &http.Client{}, req, []string{"rayleigh"}) +} diff --git a/server/llm_image_test.go b/integration/llm_image_test.go similarity index 98% rename from server/llm_image_test.go rename to integration/llm_image_test.go index 6ca3577b..e7b6754c 100644 --- a/server/llm_image_test.go +++ b/integration/llm_image_test.go @@ -1,49 +1,38 @@ //go:build integration -package server +package integration import ( "context" "encoding/base64" - "log" - "os" - "strings" + "net/http" "testing" "time" "github.com/jmorganca/ollama/api" - "github.com/jmorganca/ollama/llm" - "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" ) func TestIntegrationMultimodal(t *testing.T) { - SkipIFNoTestData(t) image, err := base64.StdEncoding.DecodeString(imageEncoding) require.NoError(t, err) req := api.GenerateRequest{ - Model: "llava:7b", - Prompt: "what does the text in this image say?", - Options: map[string]interface{}{}, + Model: "llava:7b", + Prompt: "what does the text in this image say?", + Stream: &stream, + Options: map[string]interface{}{ + "seed": 42, + "temperature": 0.0, + }, Images: []api.ImageData{ image, }, } + resp := "the ollamas" - workDir, err := os.MkdirTemp("", "ollama") - require.NoError(t, err) - defer os.RemoveAll(workDir) - require.NoError(t, llm.Init(workDir)) ctx, cancel := context.WithTimeout(context.Background(), time.Second*60) defer cancel() - opts := api.DefaultOptions() - opts.Seed = 42 - opts.Temperature = 0.0 - model, llmRunner := PrepareModelForPrompts(t, req.Model, opts) - defer llmRunner.Close() - response := OneShotPromptResponse(t, ctx, req, model, llmRunner) - log.Print(response) - assert.Contains(t, strings.ToLower(response), resp) + GenerateTestHelper(ctx, t, &http.Client{}, req, []string{resp}) } const imageEncoding = `iVBORw0KGgoAAAANSUhEUgAAANIAAAB4CAYAAACHHqzKAAAAAXNSR0IArs4c6QAAAIRlWElmTU0AKgAAAAgABQESAAMAAAABAAEAAAEaAAUAAAABAAAASgEb diff --git a/integration/llm_test.go b/integration/llm_test.go new file mode 100644 index 00000000..805c573d --- /dev/null +++ b/integration/llm_test.go @@ -0,0 +1,73 @@ +//go:build integration + +package integration + +import ( + "context" + "net/http" + "sync" + "testing" + "time" + + "github.com/jmorganca/ollama/api" +) + +// TODO - this would ideally be in the llm package, but that would require some refactoring of interfaces in the server +// package to avoid circular dependencies + +// WARNING - these tests will fail on mac if you don't manually copy ggml-metal.metal to this dir (./server) +// +// TODO - Fix this ^^ + +var ( + stream = false + req = [2]api.GenerateRequest{ + { + Model: "orca-mini", + Prompt: "why is the ocean blue?", + Stream: &stream, + Options: map[string]interface{}{ + "seed": 42, + "temperature": 0.0, + }, + }, { + Model: "orca-mini", + Prompt: "what is the origin of the us thanksgiving holiday?", + Stream: &stream, + Options: map[string]interface{}{ + "seed": 42, + "temperature": 0.0, + }, + }, + } + resp = [2]string{ + "scattering", + "united states thanksgiving", + } +) + +func TestIntegrationSimpleOrcaMini(t *testing.T) { + ctx, cancel := context.WithTimeout(context.Background(), time.Second*60) + defer cancel() + GenerateTestHelper(ctx, t, &http.Client{}, req[0], []string{resp[0]}) +} + +// TODO +// The server always loads a new runner and closes the old one, which forces serial execution +// At present this test case fails with concurrency problems. Eventually we should try to +// get true concurrency working with n_parallel support in the backend +func TestIntegrationConcurrentPredictOrcaMini(t *testing.T) { + var wg sync.WaitGroup + wg.Add(len(req)) + ctx, cancel := context.WithTimeout(context.Background(), time.Second*60) + defer cancel() + for i := 0; i < len(req); i++ { + go func(i int) { + defer wg.Done() + GenerateTestHelper(ctx, t, &http.Client{}, req[i], []string{resp[i]}) + }(i) + } + wg.Wait() +} + +// TODO - create a parallel test with 2 different models once we support concurrency diff --git a/integration/utils_test.go b/integration/utils_test.go new file mode 100644 index 00000000..c28ae66f --- /dev/null +++ b/integration/utils_test.go @@ -0,0 +1,190 @@ +//go:build integration + +package integration + +import ( + "bytes" + "context" + "encoding/json" + "fmt" + "io" + "log/slog" + "math/rand" + "net" + "net/http" + "os" + "path/filepath" + "runtime" + "strconv" + "strings" + "sync" + "testing" + "time" + + "github.com/jmorganca/ollama/api" + "github.com/jmorganca/ollama/app/lifecycle" + "github.com/stretchr/testify/assert" +) + +func FindPort() string { + port := 0 + if a, err := net.ResolveTCPAddr("tcp", "localhost:0"); err == nil { + var l *net.TCPListener + if l, err = net.ListenTCP("tcp", a); err == nil { + port = l.Addr().(*net.TCPAddr).Port + l.Close() + } + } + if port == 0 { + port = rand.Intn(65535-49152) + 49152 // get a random port in the ephemeral range + } + return strconv.Itoa(port) +} + +func GetTestEndpoint() (string, string) { + defaultPort := "11434" + ollamaHost := os.Getenv("OLLAMA_HOST") + + scheme, hostport, ok := strings.Cut(ollamaHost, "://") + if !ok { + scheme, hostport = "http", ollamaHost + } + + // trim trailing slashes + hostport = strings.TrimRight(hostport, "/") + + host, port, err := net.SplitHostPort(hostport) + if err != nil { + host, port = "127.0.0.1", defaultPort + if ip := net.ParseIP(strings.Trim(hostport, "[]")); ip != nil { + host = ip.String() + } else if hostport != "" { + host = hostport + } + } + + if os.Getenv("OLLAMA_TEST_EXISTING") == "" && port == defaultPort { + port = FindPort() + } + + url := fmt.Sprintf("%s:%s", host, port) + slog.Info("server connection", "url", url) + return scheme, url +} + +// TODO make fanicier, grab logs, etc. +var serverMutex sync.Mutex +var serverReady bool + +func StartServer(ctx context.Context, ollamaHost string) error { + // Make sure the server has been built + CLIName, err := filepath.Abs("../ollama") + if err != nil { + return err + } + + if runtime.GOOS == "windows" { + CLIName += ".exe" + } + _, err = os.Stat(CLIName) + if err != nil { + return fmt.Errorf("CLI missing, did you forget to build first? %w", err) + } + serverMutex.Lock() + defer serverMutex.Unlock() + if serverReady { + return nil + } + + if tmp := os.Getenv("OLLAMA_HOST"); tmp != ollamaHost { + slog.Info("setting env", "OLLAMA_HOST", ollamaHost) + os.Setenv("OLLAMA_HOST", ollamaHost) + } + + slog.Info("starting server", "url", ollamaHost) + done, err := lifecycle.SpawnServer(ctx, "../ollama") + if err != nil { + return fmt.Errorf("failed to start server: %w", err) + } + + go func() { + <-ctx.Done() + serverMutex.Lock() + defer serverMutex.Unlock() + exitCode := <-done + if exitCode > 0 { + slog.Warn("server failure", "exit", exitCode) + } + serverReady = false + }() + + // TODO wait only long enough for the server to be responsive... + time.Sleep(500 * time.Millisecond) + + serverReady = true + return nil +} + +func GenerateTestHelper(ctx context.Context, t *testing.T, client *http.Client, genReq api.GenerateRequest, anyResp []string) { + requestJSON, err := json.Marshal(genReq) + if err != nil { + t.Fatalf("Error serializing request: %v", err) + } + defer func() { + if t.Failed() && os.Getenv("OLLAMA_TEST_EXISTING") == "" { + // TODO + fp, err := os.Open(lifecycle.ServerLogFile) + if err != nil { + slog.Error("failed to open server log", "logfile", lifecycle.ServerLogFile, "error", err) + return + } + data, err := io.ReadAll(fp) + if err != nil { + slog.Error("failed to read server log", "logfile", lifecycle.ServerLogFile, "error", err) + return + } + slog.Warn("SERVER LOG FOLLOWS") + os.Stderr.Write(data) + slog.Warn("END OF SERVER") + } + err = os.Remove(lifecycle.ServerLogFile) + if err != nil && !os.IsNotExist(err) { + slog.Warn("failed to cleanup", "logfile", lifecycle.ServerLogFile, "error", err) + } + }() + scheme, testEndpoint := GetTestEndpoint() + + if os.Getenv("OLLAMA_TEST_EXISTING") == "" { + assert.NoError(t, StartServer(ctx, testEndpoint)) + } + + // Make the request and get the response + req, err := http.NewRequest("POST", scheme+"://"+testEndpoint+"/api/generate", bytes.NewReader(requestJSON)) + if err != nil { + t.Fatalf("Error creating request: %v", err) + } + + // Set the content type for the request + req.Header.Set("Content-Type", "application/json") + + // Make the request with the HTTP client + response, err := client.Do(req.WithContext(ctx)) + if err != nil { + t.Fatalf("Error making request: %v", err) + } + body, err := io.ReadAll(response.Body) + assert.NoError(t, err) + assert.Equal(t, response.StatusCode, 200, string(body)) + + // Verify the response is valid JSON + var payload api.GenerateResponse + err = json.Unmarshal(body, &payload) + if err != nil { + assert.NoError(t, err, body) + } + + // Verify the response contains the expected data + for _, resp := range anyResp { + assert.Contains(t, strings.ToLower(payload.Response), resp) + } +} diff --git a/scripts/setup_integration_tests.sh b/scripts/setup_integration_tests.sh deleted file mode 100755 index 851ce49d..00000000 --- a/scripts/setup_integration_tests.sh +++ /dev/null @@ -1,41 +0,0 @@ -#!/bin/bash - -# This script sets up integration tests which run the full stack to verify -# inference locally -# -# To run the relevant tests use -# go test -tags=integration ./server -set -e -set -o pipefail - -REPO=$(dirname $0)/../ -export OLLAMA_MODELS=${REPO}/test_data/models -REGISTRY_SCHEME=https -REGISTRY=registry.ollama.ai -TEST_MODELS=("library/orca-mini:latest" "library/llava:7b") -ACCEPT_HEADER="Accept: application/vnd.docker.distribution.manifest.v2+json" - -for model in ${TEST_MODELS[@]}; do - TEST_MODEL=$(echo ${model} | cut -f1 -d:) - TEST_MODEL_TAG=$(echo ${model} | cut -f2 -d:) - mkdir -p ${OLLAMA_MODELS}/manifests/${REGISTRY}/${TEST_MODEL}/ - mkdir -p ${OLLAMA_MODELS}/blobs/ - - echo "Pulling manifest for ${TEST_MODEL}:${TEST_MODEL_TAG}" - curl -s --header "${ACCEPT_HEADER}" \ - -o ${OLLAMA_MODELS}/manifests/${REGISTRY}/${TEST_MODEL}/${TEST_MODEL_TAG} \ - ${REGISTRY_SCHEME}://${REGISTRY}/v2/${TEST_MODEL}/manifests/${TEST_MODEL_TAG} - - CFG_HASH=$(cat ${OLLAMA_MODELS}/manifests/${REGISTRY}/${TEST_MODEL}/${TEST_MODEL_TAG} | jq -r ".config.digest") - echo "Pulling config blob ${CFG_HASH}" - curl -L -C - --header "${ACCEPT_HEADER}" \ - -o ${OLLAMA_MODELS}/blobs/${CFG_HASH} \ - ${REGISTRY_SCHEME}://${REGISTRY}/v2/${TEST_MODEL}/blobs/${CFG_HASH} - - for LAYER in $(cat ${OLLAMA_MODELS}/manifests/${REGISTRY}/${TEST_MODEL}/${TEST_MODEL_TAG} | jq -r ".layers[].digest"); do - echo "Pulling blob ${LAYER}" - curl -L -C - --header "${ACCEPT_HEADER}" \ - -o ${OLLAMA_MODELS}/blobs/${LAYER} \ - ${REGISTRY_SCHEME}://${REGISTRY}/v2/${TEST_MODEL}/blobs/${LAYER} - done -done diff --git a/server/llm_test.go b/server/llm_test.go deleted file mode 100644 index 7f63a64d..00000000 --- a/server/llm_test.go +++ /dev/null @@ -1,123 +0,0 @@ -//go:build integration - -package server - -import ( - "context" - "os" - "strings" - "sync" - "testing" - "time" - - "github.com/stretchr/testify/assert" - "github.com/stretchr/testify/require" - - "github.com/jmorganca/ollama/api" - "github.com/jmorganca/ollama/llm" -) - -// TODO - this would ideally be in the llm package, but that would require some refactoring of interfaces in the server -// package to avoid circular dependencies - -// WARNING - these tests will fail on mac if you don't manually copy ggml-metal.metal to this dir (./server) -// -// TODO - Fix this ^^ - -var ( - req = [2]api.GenerateRequest{ - { - Model: "orca-mini", - Prompt: "tell me a short story about agi?", - Options: map[string]interface{}{}, - }, { - Model: "orca-mini", - Prompt: "what is the origin of the us thanksgiving holiday?", - Options: map[string]interface{}{}, - }, - } - resp = [2]string{ - "once upon a time", - "united states thanksgiving", - } -) - -func TestIntegrationSimpleOrcaMini(t *testing.T) { - SkipIFNoTestData(t) - workDir, err := os.MkdirTemp("", "ollama") - require.NoError(t, err) - defer os.RemoveAll(workDir) - require.NoError(t, llm.Init(workDir)) - ctx, cancel := context.WithTimeout(context.Background(), time.Second*60) - defer cancel() - opts := api.DefaultOptions() - opts.Seed = 42 - opts.Temperature = 0.0 - model, llmRunner := PrepareModelForPrompts(t, req[0].Model, opts) - defer llmRunner.Close() - response := OneShotPromptResponse(t, ctx, req[0], model, llmRunner) - assert.Contains(t, strings.ToLower(response), resp[0]) -} - -// TODO -// The server always loads a new runner and closes the old one, which forces serial execution -// At present this test case fails with concurrency problems. Eventually we should try to -// get true concurrency working with n_parallel support in the backend -func TestIntegrationConcurrentPredictOrcaMini(t *testing.T) { - SkipIFNoTestData(t) - - t.Skip("concurrent prediction on single runner not currently supported") - - workDir, err := os.MkdirTemp("", "ollama") - require.NoError(t, err) - defer os.RemoveAll(workDir) - require.NoError(t, llm.Init(workDir)) - ctx, cancel := context.WithTimeout(context.Background(), time.Second*60) - defer cancel() - opts := api.DefaultOptions() - opts.Seed = 42 - opts.Temperature = 0.0 - var wg sync.WaitGroup - wg.Add(len(req)) - model, llmRunner := PrepareModelForPrompts(t, req[0].Model, opts) - defer llmRunner.Close() - for i := 0; i < len(req); i++ { - go func(i int) { - defer wg.Done() - response := OneShotPromptResponse(t, ctx, req[i], model, llmRunner) - t.Logf("Prompt: %s\nResponse: %s", req[0].Prompt, response) - assert.Contains(t, strings.ToLower(response), resp[i], "error in thread %d (%s)", i, req[i].Prompt) - }(i) - } - wg.Wait() -} - -func TestIntegrationConcurrentRunnersOrcaMini(t *testing.T) { - SkipIFNoTestData(t) - workDir, err := os.MkdirTemp("", "ollama") - require.NoError(t, err) - defer os.RemoveAll(workDir) - require.NoError(t, llm.Init(workDir)) - ctx, cancel := context.WithTimeout(context.Background(), time.Second*60) - defer cancel() - opts := api.DefaultOptions() - opts.Seed = 42 - opts.Temperature = 0.0 - var wg sync.WaitGroup - wg.Add(len(req)) - - t.Logf("Running %d concurrently", len(req)) - for i := 0; i < len(req); i++ { - go func(i int) { - defer wg.Done() - model, llmRunner := PrepareModelForPrompts(t, req[0].Model, opts) - defer llmRunner.Close() - response := OneShotPromptResponse(t, ctx, req[i], model, llmRunner) - t.Logf("Prompt: %s\nResponse: %s", req[0].Prompt, response) - assert.Contains(t, strings.ToLower(response), resp[i], "error in thread %d (%s)", i, req[i].Prompt) - }(i) - } - wg.Wait() -} - -// TODO - create a parallel test with 2 different models once we support concurrency diff --git a/server/llm_utils_test.go b/server/llm_utils_test.go deleted file mode 100644 index 92ee1039..00000000 --- a/server/llm_utils_test.go +++ /dev/null @@ -1,75 +0,0 @@ -//go:build integration - -package server - -import ( - "context" - "errors" - "os" - "path" - "runtime" - "testing" - - "github.com/jmorganca/ollama/api" - "github.com/jmorganca/ollama/llm" - "github.com/stretchr/testify/require" -) - -func SkipIFNoTestData(t *testing.T) { - modelDir := getModelDir() - if _, err := os.Stat(modelDir); errors.Is(err, os.ErrNotExist) { - t.Skipf("%s does not exist - skipping integration tests", modelDir) - } -} - -func getModelDir() string { - _, filename, _, _ := runtime.Caller(0) - return path.Dir(path.Dir(filename) + "/../test_data/models/.") -} - -func PrepareModelForPrompts(t *testing.T, modelName string, opts api.Options) (*Model, llm.LLM) { - modelDir := getModelDir() - os.Setenv("OLLAMA_MODELS", modelDir) - model, err := GetModel(modelName) - require.NoError(t, err, "GetModel ") - err = opts.FromMap(model.Options) - require.NoError(t, err, "opts from model ") - runner, err := llm.New("unused", model.ModelPath, model.AdapterPaths, model.ProjectorPaths, opts) - require.NoError(t, err, "llm.New failed") - return model, runner -} - -func OneShotPromptResponse(t *testing.T, ctx context.Context, req api.GenerateRequest, model *Model, runner llm.LLM) string { - prompt, err := model.PreResponsePrompt(PromptVars{ - System: req.System, - Prompt: req.Prompt, - First: len(req.Context) == 0, - }) - require.NoError(t, err, "prompt generation failed") - success := make(chan bool, 1) - response := "" - cb := func(r llm.PredictResult) { - - if !r.Done { - response += r.Content - } else { - success <- true - } - } - - predictReq := llm.PredictOpts{ - Prompt: prompt, - Format: req.Format, - Images: req.Images, - } - err = runner.Predict(ctx, predictReq, cb) - require.NoError(t, err, "predict call failed") - - select { - case <-ctx.Done(): - t.Errorf("failed to complete before timeout: \n%s", response) - return "" - case <-success: - return response - } -}