This commit is contained in:
Michael Yang 2024-07-03 19:43:17 -07:00
parent 0f1910129f
commit 1954ec5917
8 changed files with 37 additions and 59 deletions

View file

@ -2,8 +2,6 @@ package api
import ( import (
"testing" "testing"
"github.com/ollama/ollama/envconfig"
) )
func TestClientFromEnvironment(t *testing.T) { func TestClientFromEnvironment(t *testing.T) {
@ -33,7 +31,6 @@ func TestClientFromEnvironment(t *testing.T) {
for k, v := range testCases { for k, v := range testCases {
t.Run(k, func(t *testing.T) { t.Run(k, func(t *testing.T) {
t.Setenv("OLLAMA_HOST", v.value) t.Setenv("OLLAMA_HOST", v.value)
envconfig.LoadConfig()
client, err := ClientFromEnvironment() client, err := ClientFromEnvironment()
if err != v.err { if err != v.err {

View file

@ -5,14 +5,16 @@ package integration
import ( import (
"context" "context"
"log/slog" "log/slog"
"os"
"strconv" "strconv"
"sync" "sync"
"testing" "testing"
"time" "time"
"github.com/ollama/ollama/api"
"github.com/stretchr/testify/require" "github.com/stretchr/testify/require"
"github.com/ollama/ollama/api"
"github.com/ollama/ollama/envconfig"
"github.com/ollama/ollama/format"
) )
func TestMultiModelConcurrency(t *testing.T) { func TestMultiModelConcurrency(t *testing.T) {
@ -106,13 +108,16 @@ func TestIntegrationConcurrentPredictOrcaMini(t *testing.T) {
// Stress the system if we know how much VRAM it has, and attempt to load more models than will fit // Stress the system if we know how much VRAM it has, and attempt to load more models than will fit
func TestMultiModelStress(t *testing.T) { func TestMultiModelStress(t *testing.T) {
vram := os.Getenv("OLLAMA_MAX_VRAM") // TODO - discover actual VRAM s := os.Getenv("OLLAMA_MAX_VRAM") // TODO - discover actual VRAM
if vram == "" { if s == "" {
t.Skip("OLLAMA_MAX_VRAM not specified, can't pick the right models for the stress test") t.Skip("OLLAMA_MAX_VRAM not specified, can't pick the right models for the stress test")
} }
max, err := strconv.ParseUint(vram, 10, 64)
require.NoError(t, err) maxVram, err := strconv.ParseUint(s, 10, 64)
const MB = uint64(1024 * 1024) if err != nil {
t.Fatal(err)
}
type model struct { type model struct {
name string name string
size uint64 // Approximate amount of VRAM they typically use when fully loaded in VRAM size uint64 // Approximate amount of VRAM they typically use when fully loaded in VRAM
@ -121,83 +126,82 @@ func TestMultiModelStress(t *testing.T) {
smallModels := []model{ smallModels := []model{
{ {
name: "orca-mini", name: "orca-mini",
size: 2992 * MB, size: 2992 * format.MebiByte,
}, },
{ {
name: "phi", name: "phi",
size: 2616 * MB, size: 2616 * format.MebiByte,
}, },
{ {
name: "gemma:2b", name: "gemma:2b",
size: 2364 * MB, size: 2364 * format.MebiByte,
}, },
{ {
name: "stable-code:3b", name: "stable-code:3b",
size: 2608 * MB, size: 2608 * format.MebiByte,
}, },
{ {
name: "starcoder2:3b", name: "starcoder2:3b",
size: 2166 * MB, size: 2166 * format.MebiByte,
}, },
} }
mediumModels := []model{ mediumModels := []model{
{ {
name: "llama2", name: "llama2",
size: 5118 * MB, size: 5118 * format.MebiByte,
}, },
{ {
name: "mistral", name: "mistral",
size: 4620 * MB, size: 4620 * format.MebiByte,
}, },
{ {
name: "orca-mini:7b", name: "orca-mini:7b",
size: 5118 * MB, size: 5118 * format.MebiByte,
}, },
{ {
name: "dolphin-mistral", name: "dolphin-mistral",
size: 4620 * MB, size: 4620 * format.MebiByte,
}, },
{ {
name: "gemma:7b", name: "gemma:7b",
size: 5000 * MB, size: 5000 * format.MebiByte,
},
{
name: "codellama:7b",
size: 5118 * format.MebiByte,
}, },
// TODO - uncomment this once #3565 is merged and this is rebased on it
// {
// name: "codellama:7b",
// size: 5118 * MB,
// },
} }
// These seem to be too slow to be useful... // These seem to be too slow to be useful...
// largeModels := []model{ // largeModels := []model{
// { // {
// name: "llama2:13b", // name: "llama2:13b",
// size: 7400 * MB, // size: 7400 * format.MebiByte,
// }, // },
// { // {
// name: "codellama:13b", // name: "codellama:13b",
// size: 7400 * MB, // size: 7400 * format.MebiByte,
// }, // },
// { // {
// name: "orca-mini:13b", // name: "orca-mini:13b",
// size: 7400 * MB, // size: 7400 * format.MebiByte,
// }, // },
// { // {
// name: "gemma:7b", // name: "gemma:7b",
// size: 5000 * MB, // size: 5000 * format.MebiByte,
// }, // },
// { // {
// name: "starcoder2:15b", // name: "starcoder2:15b",
// size: 9100 * MB, // size: 9100 * format.MebiByte,
// }, // },
// } // }
var chosenModels []model var chosenModels []model
switch { switch {
case max < 10000*MB: case maxVram < 10000*format.MebiByte:
slog.Info("selecting small models") slog.Info("selecting small models")
chosenModels = smallModels chosenModels = smallModels
// case max < 30000*MB: // case maxVram < 30000*format.MebiByte:
default: default:
slog.Info("selecting medium models") slog.Info("selecting medium models")
chosenModels = mediumModels chosenModels = mediumModels
@ -226,15 +230,15 @@ func TestMultiModelStress(t *testing.T) {
} }
var wg sync.WaitGroup var wg sync.WaitGroup
consumed := uint64(256 * MB) // Assume some baseline usage consumed := uint64(256 * format.MebiByte) // Assume some baseline usage
for i := 0; i < len(req); i++ { for i := 0; i < len(req); i++ {
// Always get at least 2 models, but dont' overshoot VRAM too much or we'll take too long // Always get at least 2 models, but dont' overshoot VRAM too much or we'll take too long
if i > 1 && consumed > max { if i > 1 && consumed > vram {
slog.Info("achieved target vram exhaustion", "count", i, "vramMB", max/1024/1024, "modelsMB", consumed/1024/1024) slog.Info("achieved target vram exhaustion", "count", i, "vram", format.HumanBytes2(vram), "models", format.HumanBytes2(consumed))
break break
} }
consumed += chosenModels[i].size consumed += chosenModels[i].size
slog.Info("target vram", "count", i, "vramMB", max/1024/1024, "modelsMB", consumed/1024/1024) slog.Info("target vram", "count", i, "vram", format.HumanBytes2(vram), "models", format.HumanBytes2(consumed))
wg.Add(1) wg.Add(1)
go func(i int) { go func(i int) {

View file

@ -7,7 +7,6 @@ import (
"slices" "slices"
"testing" "testing"
"github.com/ollama/ollama/envconfig"
"github.com/ollama/ollama/types/model" "github.com/ollama/ollama/types/model"
) )
@ -108,7 +107,6 @@ func TestManifests(t *testing.T) {
t.Run(n, func(t *testing.T) { t.Run(n, func(t *testing.T) {
d := t.TempDir() d := t.TempDir()
t.Setenv("OLLAMA_MODELS", d) t.Setenv("OLLAMA_MODELS", d)
envconfig.LoadConfig()
for _, p := range wants.ps { for _, p := range wants.ps {
createManifest(t, d, p) createManifest(t, d, p)

View file

@ -7,8 +7,6 @@ import (
"github.com/stretchr/testify/assert" "github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require" "github.com/stretchr/testify/require"
"github.com/ollama/ollama/envconfig"
) )
func TestGetBlobsPath(t *testing.T) { func TestGetBlobsPath(t *testing.T) {
@ -63,7 +61,6 @@ func TestGetBlobsPath(t *testing.T) {
for _, tc := range tests { for _, tc := range tests {
t.Run(tc.name, func(t *testing.T) { t.Run(tc.name, func(t *testing.T) {
t.Setenv("OLLAMA_MODELS", dir) t.Setenv("OLLAMA_MODELS", dir)
envconfig.LoadConfig()
got, err := GetBlobsPath(tc.digest) got, err := GetBlobsPath(tc.digest)

View file

@ -15,7 +15,6 @@ import (
"github.com/gin-gonic/gin" "github.com/gin-gonic/gin"
"github.com/ollama/ollama/api" "github.com/ollama/ollama/api"
"github.com/ollama/ollama/envconfig"
"github.com/ollama/ollama/llm" "github.com/ollama/ollama/llm"
) )
@ -89,7 +88,6 @@ func TestCreateFromBin(t *testing.T) {
p := t.TempDir() p := t.TempDir()
t.Setenv("OLLAMA_MODELS", p) t.Setenv("OLLAMA_MODELS", p)
envconfig.LoadConfig()
var s Server var s Server
w := createRequest(t, s.CreateModelHandler, api.CreateRequest{ w := createRequest(t, s.CreateModelHandler, api.CreateRequest{
@ -117,7 +115,6 @@ func TestCreateFromModel(t *testing.T) {
p := t.TempDir() p := t.TempDir()
t.Setenv("OLLAMA_MODELS", p) t.Setenv("OLLAMA_MODELS", p)
envconfig.LoadConfig()
var s Server var s Server
w := createRequest(t, s.CreateModelHandler, api.CreateRequest{ w := createRequest(t, s.CreateModelHandler, api.CreateRequest{
@ -160,7 +157,6 @@ func TestCreateRemovesLayers(t *testing.T) {
p := t.TempDir() p := t.TempDir()
t.Setenv("OLLAMA_MODELS", p) t.Setenv("OLLAMA_MODELS", p)
envconfig.LoadConfig()
var s Server var s Server
w := createRequest(t, s.CreateModelHandler, api.CreateRequest{ w := createRequest(t, s.CreateModelHandler, api.CreateRequest{
@ -209,7 +205,6 @@ func TestCreateUnsetsSystem(t *testing.T) {
p := t.TempDir() p := t.TempDir()
t.Setenv("OLLAMA_MODELS", p) t.Setenv("OLLAMA_MODELS", p)
envconfig.LoadConfig()
var s Server var s Server
w := createRequest(t, s.CreateModelHandler, api.CreateRequest{ w := createRequest(t, s.CreateModelHandler, api.CreateRequest{
@ -267,7 +262,6 @@ func TestCreateMergeParameters(t *testing.T) {
p := t.TempDir() p := t.TempDir()
t.Setenv("OLLAMA_MODELS", p) t.Setenv("OLLAMA_MODELS", p)
envconfig.LoadConfig()
var s Server var s Server
w := createRequest(t, s.CreateModelHandler, api.CreateRequest{ w := createRequest(t, s.CreateModelHandler, api.CreateRequest{
@ -372,7 +366,6 @@ func TestCreateReplacesMessages(t *testing.T) {
p := t.TempDir() p := t.TempDir()
t.Setenv("OLLAMA_MODELS", p) t.Setenv("OLLAMA_MODELS", p)
envconfig.LoadConfig()
var s Server var s Server
w := createRequest(t, s.CreateModelHandler, api.CreateRequest{ w := createRequest(t, s.CreateModelHandler, api.CreateRequest{
@ -450,7 +443,6 @@ func TestCreateTemplateSystem(t *testing.T) {
p := t.TempDir() p := t.TempDir()
t.Setenv("OLLAMA_MODELS", p) t.Setenv("OLLAMA_MODELS", p)
envconfig.LoadConfig()
var s Server var s Server
w := createRequest(t, s.CreateModelHandler, api.CreateRequest{ w := createRequest(t, s.CreateModelHandler, api.CreateRequest{
@ -534,7 +526,6 @@ func TestCreateLicenses(t *testing.T) {
p := t.TempDir() p := t.TempDir()
t.Setenv("OLLAMA_MODELS", p) t.Setenv("OLLAMA_MODELS", p)
envconfig.LoadConfig()
var s Server var s Server
w := createRequest(t, s.CreateModelHandler, api.CreateRequest{ w := createRequest(t, s.CreateModelHandler, api.CreateRequest{
@ -582,7 +573,6 @@ func TestCreateDetectTemplate(t *testing.T) {
p := t.TempDir() p := t.TempDir()
t.Setenv("OLLAMA_MODELS", p) t.Setenv("OLLAMA_MODELS", p)
envconfig.LoadConfig()
var s Server var s Server
t.Run("matched", func(t *testing.T) { t.Run("matched", func(t *testing.T) {

View file

@ -10,7 +10,6 @@ import (
"github.com/gin-gonic/gin" "github.com/gin-gonic/gin"
"github.com/ollama/ollama/api" "github.com/ollama/ollama/api"
"github.com/ollama/ollama/envconfig"
"github.com/ollama/ollama/types/model" "github.com/ollama/ollama/types/model"
) )
@ -19,7 +18,6 @@ func TestDelete(t *testing.T) {
p := t.TempDir() p := t.TempDir()
t.Setenv("OLLAMA_MODELS", p) t.Setenv("OLLAMA_MODELS", p)
envconfig.LoadConfig()
var s Server var s Server

View file

@ -9,14 +9,12 @@ import (
"github.com/gin-gonic/gin" "github.com/gin-gonic/gin"
"github.com/ollama/ollama/api" "github.com/ollama/ollama/api"
"github.com/ollama/ollama/envconfig"
) )
func TestList(t *testing.T) { func TestList(t *testing.T) {
gin.SetMode(gin.TestMode) gin.SetMode(gin.TestMode)
t.Setenv("OLLAMA_MODELS", t.TempDir()) t.Setenv("OLLAMA_MODELS", t.TempDir())
envconfig.LoadConfig()
expectNames := []string{ expectNames := []string{
"mistral:7b-instruct-q4_0", "mistral:7b-instruct-q4_0",

View file

@ -19,7 +19,6 @@ import (
"github.com/stretchr/testify/require" "github.com/stretchr/testify/require"
"github.com/ollama/ollama/api" "github.com/ollama/ollama/api"
"github.com/ollama/ollama/envconfig"
"github.com/ollama/ollama/llm" "github.com/ollama/ollama/llm"
"github.com/ollama/ollama/openai" "github.com/ollama/ollama/openai"
"github.com/ollama/ollama/parser" "github.com/ollama/ollama/parser"
@ -347,7 +346,6 @@ func Test_Routes(t *testing.T) {
} }
t.Setenv("OLLAMA_MODELS", t.TempDir()) t.Setenv("OLLAMA_MODELS", t.TempDir())
envconfig.LoadConfig()
s := &Server{} s := &Server{}
router := s.GenerateRoutes() router := s.GenerateRoutes()
@ -378,7 +376,6 @@ func Test_Routes(t *testing.T) {
func TestCase(t *testing.T) { func TestCase(t *testing.T) {
t.Setenv("OLLAMA_MODELS", t.TempDir()) t.Setenv("OLLAMA_MODELS", t.TempDir())
envconfig.LoadConfig()
cases := []string{ cases := []string{
"mistral", "mistral",
@ -458,7 +455,6 @@ func TestCase(t *testing.T) {
func TestShow(t *testing.T) { func TestShow(t *testing.T) {
t.Setenv("OLLAMA_MODELS", t.TempDir()) t.Setenv("OLLAMA_MODELS", t.TempDir())
envconfig.LoadConfig()
var s Server var s Server