ollama/integration/concurrency_test.go

271 lines
6.2 KiB
Go
Raw Normal View History

//go:build integration
package integration
import (
"context"
"log/slog"
2024-08-05 16:34:54 -07:00
"os"
"strconv"
"sync"
"testing"
"time"
"github.com/stretchr/testify/require"
2024-07-03 19:43:17 -07:00
"github.com/ollama/ollama/api"
"github.com/ollama/ollama/format"
)
func TestMultiModelConcurrency(t *testing.T) {
var (
req = [2]api.GenerateRequest{
{
Model: "orca-mini",
Prompt: "why is the ocean blue?",
Stream: &stream,
KeepAlive: &api.Duration{Duration: 10 * time.Second},
Options: map[string]interface{}{
"seed": 42,
"temperature": 0.0,
},
}, {
Model: "tinydolphin",
Prompt: "what is the origin of the us thanksgiving holiday?",
Stream: &stream,
KeepAlive: &api.Duration{Duration: 10 * time.Second},
Options: map[string]interface{}{
"seed": 42,
"temperature": 0.0,
},
},
}
resp = [2][]string{
2024-08-05 16:34:54 -07:00
{"sunlight"},
{"england", "english", "massachusetts", "pilgrims", "british"},
}
)
var wg sync.WaitGroup
wg.Add(len(req))
ctx, cancel := context.WithTimeout(context.Background(), time.Second*240)
defer cancel()
client, _, cleanup := InitServerConnection(ctx, t)
defer cleanup()
for i := 0; i < len(req); i++ {
require.NoError(t, PullIfMissing(ctx, client, req[i].Model))
}
for i := 0; i < len(req); i++ {
go func(i int) {
defer wg.Done()
DoGenerate(ctx, t, client, req[i], resp[i], 60*time.Second, 10*time.Second)
}(i)
}
wg.Wait()
}
func TestIntegrationConcurrentPredictOrcaMini(t *testing.T) {
req, resp := GenerateRequests()
reqLimit := len(req)
iterLimit := 5
2024-08-05 16:34:54 -07:00
if s := os.Getenv("OLLAMA_MAX_VRAM"); s != "" {
maxVram, err := strconv.ParseUint(s, 10, 64)
require.NoError(t, err)
// Don't hammer on small VRAM cards...
2024-08-05 16:34:54 -07:00
if maxVram < 4*format.GibiByte {
reqLimit = min(reqLimit, 2)
iterLimit = 2
}
}
ctx, cancel := context.WithTimeout(context.Background(), 9*time.Minute)
defer cancel()
client, _, cleanup := InitServerConnection(ctx, t)
defer cleanup()
// Get the server running (if applicable) warm the model up with a single initial request
DoGenerate(ctx, t, client, req[0], resp[0], 60*time.Second, 10*time.Second)
var wg sync.WaitGroup
wg.Add(reqLimit)
for i := 0; i < reqLimit; i++ {
go func(i int) {
defer wg.Done()
for j := 0; j < iterLimit; j++ {
slog.Info("Starting", "req", i, "iter", j)
// On slower GPUs it can take a while to process the concurrent requests
// so we allow a much longer initial timeout
DoGenerate(ctx, t, client, req[i], resp[i], 120*time.Second, 20*time.Second)
}
}(i)
}
wg.Wait()
}
// Stress the system if we know how much VRAM it has, and attempt to load more models than will fit
func TestMultiModelStress(t *testing.T) {
2024-07-03 19:43:17 -07:00
s := os.Getenv("OLLAMA_MAX_VRAM") // TODO - discover actual VRAM
if s == "" {
t.Skip("OLLAMA_MAX_VRAM not specified, can't pick the right models for the stress test")
}
2024-07-03 19:43:17 -07:00
maxVram, err := strconv.ParseUint(s, 10, 64)
if err != nil {
t.Fatal(err)
}
type model struct {
name string
size uint64 // Approximate amount of VRAM they typically use when fully loaded in VRAM
}
smallModels := []model{
{
name: "orca-mini",
2024-07-03 19:43:17 -07:00
size: 2992 * format.MebiByte,
},
{
name: "phi",
2024-07-03 19:43:17 -07:00
size: 2616 * format.MebiByte,
},
{
name: "gemma:2b",
2024-07-03 19:43:17 -07:00
size: 2364 * format.MebiByte,
},
{
name: "stable-code:3b",
2024-07-03 19:43:17 -07:00
size: 2608 * format.MebiByte,
},
{
name: "starcoder2:3b",
2024-07-03 19:43:17 -07:00
size: 2166 * format.MebiByte,
},
}
mediumModels := []model{
{
name: "llama2",
2024-07-03 19:43:17 -07:00
size: 5118 * format.MebiByte,
},
{
name: "mistral",
2024-07-03 19:43:17 -07:00
size: 4620 * format.MebiByte,
},
{
name: "orca-mini:7b",
2024-07-03 19:43:17 -07:00
size: 5118 * format.MebiByte,
},
{
name: "dolphin-mistral",
2024-07-03 19:43:17 -07:00
size: 4620 * format.MebiByte,
},
{
name: "gemma:7b",
2024-07-03 19:43:17 -07:00
size: 5000 * format.MebiByte,
},
{
name: "codellama:7b",
size: 5118 * format.MebiByte,
},
}
// These seem to be too slow to be useful...
// largeModels := []model{
// {
// name: "llama2:13b",
2024-07-03 19:43:17 -07:00
// size: 7400 * format.MebiByte,
// },
// {
// name: "codellama:13b",
2024-07-03 19:43:17 -07:00
// size: 7400 * format.MebiByte,
// },
// {
// name: "orca-mini:13b",
2024-07-03 19:43:17 -07:00
// size: 7400 * format.MebiByte,
// },
// {
// name: "gemma:7b",
2024-07-03 19:43:17 -07:00
// size: 5000 * format.MebiByte,
// },
// {
// name: "starcoder2:15b",
2024-07-03 19:43:17 -07:00
// size: 9100 * format.MebiByte,
// },
// }
var chosenModels []model
switch {
2024-07-03 19:43:17 -07:00
case maxVram < 10000*format.MebiByte:
slog.Info("selecting small models")
chosenModels = smallModels
2024-07-03 19:43:17 -07:00
// case maxVram < 30000*format.MebiByte:
default:
slog.Info("selecting medium models")
chosenModels = mediumModels
// default:
// slog.Info("selecting large models")
// chosenModels = largModels
}
req, resp := GenerateRequests()
for i := range req {
if i > len(chosenModels) {
break
}
req[i].Model = chosenModels[i].name
}
ctx, cancel := context.WithTimeout(context.Background(), 15*time.Minute) // TODO baseline -- 10m too short
defer cancel()
client, _, cleanup := InitServerConnection(ctx, t)
defer cleanup()
// Make sure all the models are pulled before we get started
for _, r := range req {
require.NoError(t, PullIfMissing(ctx, client, r.Model))
}
var wg sync.WaitGroup
2024-07-03 19:43:17 -07:00
consumed := uint64(256 * format.MebiByte) // Assume some baseline usage
for i := 0; i < len(req); i++ {
// Always get at least 2 models, but dont' overshoot VRAM too much or we'll take too long
2024-08-05 16:34:54 -07:00
if i > 1 && consumed > maxVram {
slog.Info("achieved target vram exhaustion", "count", i, "vram", format.HumanBytes2(maxVram), "models", format.HumanBytes2(consumed))
break
}
consumed += chosenModels[i].size
2024-08-05 16:34:54 -07:00
slog.Info("target vram", "count", i, "vram", format.HumanBytes2(maxVram), "models", format.HumanBytes2(consumed))
wg.Add(1)
go func(i int) {
defer wg.Done()
for j := 0; j < 3; j++ {
slog.Info("Starting", "req", i, "iter", j, "model", req[i].Model)
2024-05-10 14:13:26 -07:00
DoGenerate(ctx, t, client, req[i], resp[i], 120*time.Second, 5*time.Second)
}
}(i)
}
go func() {
for {
time.Sleep(2 * time.Second)
select {
case <-ctx.Done():
return
default:
models, err := client.ListRunning(ctx)
if err != nil {
slog.Warn("failed to list running models", "error", err)
continue
}
for _, m := range models.Models {
slog.Info("loaded model snapshot", "model", m)
}
}
}
}()
wg.Wait()
}