diff --git a/integration/basic_test.go b/integration/basic_test.go index 3cd5c354..ce933ffe 100644 --- a/integration/basic_test.go +++ b/integration/basic_test.go @@ -12,7 +12,7 @@ import ( ) func TestOrcaMiniBlueSky(t *testing.T) { - ctx, cancel := context.WithTimeout(context.Background(), 1*time.Minute) + ctx, cancel := context.WithTimeout(context.Background(), 2*time.Minute) defer cancel() // Set up the test data req := api.GenerateRequest{ diff --git a/integration/llm_image_test.go b/integration/llm_image_test.go index e7b6754c..0ac0e1e8 100644 --- a/integration/llm_image_test.go +++ b/integration/llm_image_test.go @@ -30,7 +30,7 @@ func TestIntegrationMultimodal(t *testing.T) { } resp := "the ollamas" - ctx, cancel := context.WithTimeout(context.Background(), time.Second*60) + ctx, cancel := context.WithTimeout(context.Background(), 3*time.Minute) defer cancel() GenerateTestHelper(ctx, t, &http.Client{}, req, []string{resp}) } diff --git a/integration/llm_test.go b/integration/llm_test.go index 805c573d..107b5573 100644 --- a/integration/llm_test.go +++ b/integration/llm_test.go @@ -40,16 +40,16 @@ var ( }, }, } - resp = [2]string{ - "scattering", - "united states thanksgiving", + resp = [2][]string{ + []string{"sunlight"}, + []string{"england", "english", "massachusetts", "pilgrims"}, } ) func TestIntegrationSimpleOrcaMini(t *testing.T) { - ctx, cancel := context.WithTimeout(context.Background(), time.Second*60) + ctx, cancel := context.WithTimeout(context.Background(), time.Second*120) defer cancel() - GenerateTestHelper(ctx, t, &http.Client{}, req[0], []string{resp[0]}) + GenerateTestHelper(ctx, t, &http.Client{}, req[0], resp[0]) } // TODO @@ -59,12 +59,12 @@ func TestIntegrationSimpleOrcaMini(t *testing.T) { func TestIntegrationConcurrentPredictOrcaMini(t *testing.T) { var wg sync.WaitGroup wg.Add(len(req)) - ctx, cancel := context.WithTimeout(context.Background(), time.Second*60) + ctx, cancel := context.WithTimeout(context.Background(), time.Second*120) defer cancel() for i := 0; i < len(req); i++ { go func(i int) { defer wg.Done() - GenerateTestHelper(ctx, t, &http.Client{}, req[i], []string{resp[i]}) + GenerateTestHelper(ctx, t, &http.Client{}, req[i], resp[i]) }(i) } wg.Wait() diff --git a/integration/utils_test.go b/integration/utils_test.go index c28ae66f..47184af8 100644 --- a/integration/utils_test.go +++ b/integration/utils_test.go @@ -125,6 +125,55 @@ func StartServer(ctx context.Context, ollamaHost string) error { return nil } +func PullIfMissing(ctx context.Context, client *http.Client, scheme, testEndpoint, modelName string) error { + slog.Debug("checking status of model", "model", modelName) + showReq := &api.ShowRequest{Name: modelName} + requestJSON, err := json.Marshal(showReq) + if err != nil { + return err + } + + req, err := http.NewRequest("POST", scheme+"://"+testEndpoint+"/api/show", bytes.NewReader(requestJSON)) + if err != nil { + return err + } + + // Make the request with the HTTP client + response, err := client.Do(req.WithContext(ctx)) + if err != nil { + return err + } + defer response.Body.Close() + if response.StatusCode == 200 { + slog.Info("model already present", "model", modelName) + return nil + } + slog.Info("model missing", "status", response.StatusCode) + + pullReq := &api.PullRequest{Name: modelName, Stream: &stream} + requestJSON, err = json.Marshal(pullReq) + if err != nil { + return err + } + + req, err = http.NewRequest("POST", scheme+"://"+testEndpoint+"/api/pull", bytes.NewReader(requestJSON)) + if err != nil { + return err + } + slog.Info("pulling", "model", modelName) + + response, err = client.Do(req.WithContext(ctx)) + if err != nil { + return err + } + defer response.Body.Close() + if response.StatusCode != 200 { + return fmt.Errorf("failed to pull model") // TODO more details perhaps + } + slog.Info("model pulled", "model", modelName) + return nil +} + func GenerateTestHelper(ctx context.Context, t *testing.T, client *http.Client, genReq api.GenerateRequest, anyResp []string) { requestJSON, err := json.Marshal(genReq) if err != nil { @@ -158,6 +207,11 @@ func GenerateTestHelper(ctx context.Context, t *testing.T, client *http.Client, assert.NoError(t, StartServer(ctx, testEndpoint)) } + err = PullIfMissing(ctx, client, scheme, testEndpoint, genReq.Model) + if err != nil { + t.Fatalf("Error pulling model: %v", err) + } + // Make the request and get the response req, err := http.NewRequest("POST", scheme+"://"+testEndpoint+"/api/generate", bytes.NewReader(requestJSON)) if err != nil { @@ -172,6 +226,7 @@ func GenerateTestHelper(ctx context.Context, t *testing.T, client *http.Client, if err != nil { t.Fatalf("Error making request: %v", err) } + defer response.Body.Close() body, err := io.ReadAll(response.Body) assert.NoError(t, err) assert.Equal(t, response.StatusCode, 200, string(body)) @@ -184,7 +239,12 @@ func GenerateTestHelper(ctx context.Context, t *testing.T, client *http.Client, } // Verify the response contains the expected data + atLeastOne := false for _, resp := range anyResp { - assert.Contains(t, strings.ToLower(payload.Response), resp) + if strings.Contains(strings.ToLower(payload.Response), resp) { + atLeastOne = true + break + } } + assert.True(t, atLeastOne, "none of %v found in %s", anyResp, payload.Response) }