Integration tests conditionally pull

If images aren't present, pull them.
Also fixes the expected responses
This commit is contained in:
Daniel Hiltgen 2024-03-24 16:22:38 -07:00
parent acfa2b9422
commit 7b6cbc10ec
4 changed files with 70 additions and 10 deletions

View file

@ -12,7 +12,7 @@ import (
) )
func TestOrcaMiniBlueSky(t *testing.T) { func TestOrcaMiniBlueSky(t *testing.T) {
ctx, cancel := context.WithTimeout(context.Background(), 1*time.Minute) ctx, cancel := context.WithTimeout(context.Background(), 2*time.Minute)
defer cancel() defer cancel()
// Set up the test data // Set up the test data
req := api.GenerateRequest{ req := api.GenerateRequest{

View file

@ -30,7 +30,7 @@ func TestIntegrationMultimodal(t *testing.T) {
} }
resp := "the ollamas" resp := "the ollamas"
ctx, cancel := context.WithTimeout(context.Background(), time.Second*60) ctx, cancel := context.WithTimeout(context.Background(), 3*time.Minute)
defer cancel() defer cancel()
GenerateTestHelper(ctx, t, &http.Client{}, req, []string{resp}) GenerateTestHelper(ctx, t, &http.Client{}, req, []string{resp})
} }

View file

@ -40,16 +40,16 @@ var (
}, },
}, },
} }
resp = [2]string{ resp = [2][]string{
"scattering", []string{"sunlight"},
"united states thanksgiving", []string{"england", "english", "massachusetts", "pilgrims"},
} }
) )
func TestIntegrationSimpleOrcaMini(t *testing.T) { func TestIntegrationSimpleOrcaMini(t *testing.T) {
ctx, cancel := context.WithTimeout(context.Background(), time.Second*60) ctx, cancel := context.WithTimeout(context.Background(), time.Second*120)
defer cancel() defer cancel()
GenerateTestHelper(ctx, t, &http.Client{}, req[0], []string{resp[0]}) GenerateTestHelper(ctx, t, &http.Client{}, req[0], resp[0])
} }
// TODO // TODO
@ -59,12 +59,12 @@ func TestIntegrationSimpleOrcaMini(t *testing.T) {
func TestIntegrationConcurrentPredictOrcaMini(t *testing.T) { func TestIntegrationConcurrentPredictOrcaMini(t *testing.T) {
var wg sync.WaitGroup var wg sync.WaitGroup
wg.Add(len(req)) wg.Add(len(req))
ctx, cancel := context.WithTimeout(context.Background(), time.Second*60) ctx, cancel := context.WithTimeout(context.Background(), time.Second*120)
defer cancel() defer cancel()
for i := 0; i < len(req); i++ { for i := 0; i < len(req); i++ {
go func(i int) { go func(i int) {
defer wg.Done() defer wg.Done()
GenerateTestHelper(ctx, t, &http.Client{}, req[i], []string{resp[i]}) GenerateTestHelper(ctx, t, &http.Client{}, req[i], resp[i])
}(i) }(i)
} }
wg.Wait() wg.Wait()

View file

@ -125,6 +125,55 @@ func StartServer(ctx context.Context, ollamaHost string) error {
return nil return nil
} }
func PullIfMissing(ctx context.Context, client *http.Client, scheme, testEndpoint, modelName string) error {
slog.Debug("checking status of model", "model", modelName)
showReq := &api.ShowRequest{Name: modelName}
requestJSON, err := json.Marshal(showReq)
if err != nil {
return err
}
req, err := http.NewRequest("POST", scheme+"://"+testEndpoint+"/api/show", bytes.NewReader(requestJSON))
if err != nil {
return err
}
// Make the request with the HTTP client
response, err := client.Do(req.WithContext(ctx))
if err != nil {
return err
}
defer response.Body.Close()
if response.StatusCode == 200 {
slog.Info("model already present", "model", modelName)
return nil
}
slog.Info("model missing", "status", response.StatusCode)
pullReq := &api.PullRequest{Name: modelName, Stream: &stream}
requestJSON, err = json.Marshal(pullReq)
if err != nil {
return err
}
req, err = http.NewRequest("POST", scheme+"://"+testEndpoint+"/api/pull", bytes.NewReader(requestJSON))
if err != nil {
return err
}
slog.Info("pulling", "model", modelName)
response, err = client.Do(req.WithContext(ctx))
if err != nil {
return err
}
defer response.Body.Close()
if response.StatusCode != 200 {
return fmt.Errorf("failed to pull model") // TODO more details perhaps
}
slog.Info("model pulled", "model", modelName)
return nil
}
func GenerateTestHelper(ctx context.Context, t *testing.T, client *http.Client, genReq api.GenerateRequest, anyResp []string) { func GenerateTestHelper(ctx context.Context, t *testing.T, client *http.Client, genReq api.GenerateRequest, anyResp []string) {
requestJSON, err := json.Marshal(genReq) requestJSON, err := json.Marshal(genReq)
if err != nil { if err != nil {
@ -158,6 +207,11 @@ func GenerateTestHelper(ctx context.Context, t *testing.T, client *http.Client,
assert.NoError(t, StartServer(ctx, testEndpoint)) assert.NoError(t, StartServer(ctx, testEndpoint))
} }
err = PullIfMissing(ctx, client, scheme, testEndpoint, genReq.Model)
if err != nil {
t.Fatalf("Error pulling model: %v", err)
}
// Make the request and get the response // Make the request and get the response
req, err := http.NewRequest("POST", scheme+"://"+testEndpoint+"/api/generate", bytes.NewReader(requestJSON)) req, err := http.NewRequest("POST", scheme+"://"+testEndpoint+"/api/generate", bytes.NewReader(requestJSON))
if err != nil { if err != nil {
@ -172,6 +226,7 @@ func GenerateTestHelper(ctx context.Context, t *testing.T, client *http.Client,
if err != nil { if err != nil {
t.Fatalf("Error making request: %v", err) t.Fatalf("Error making request: %v", err)
} }
defer response.Body.Close()
body, err := io.ReadAll(response.Body) body, err := io.ReadAll(response.Body)
assert.NoError(t, err) assert.NoError(t, err)
assert.Equal(t, response.StatusCode, 200, string(body)) assert.Equal(t, response.StatusCode, 200, string(body))
@ -184,7 +239,12 @@ func GenerateTestHelper(ctx context.Context, t *testing.T, client *http.Client,
} }
// Verify the response contains the expected data // Verify the response contains the expected data
atLeastOne := false
for _, resp := range anyResp { for _, resp := range anyResp {
assert.Contains(t, strings.ToLower(payload.Response), resp) if strings.Contains(strings.ToLower(payload.Response), resp) {
atLeastOne = true
break
} }
}
assert.True(t, atLeastOne, "none of %v found in %s", anyResp, payload.Response)
} }