Revamp go based integration tests
This uplevels the integration tests to run the server which can allow testing an existing server, or a remote server.
This commit is contained in:
parent
a5ba0fcf78
commit
949b6c01e0
8 changed files with 313 additions and 261 deletions
11
integration/README.md
Normal file
11
integration/README.md
Normal file
|
@ -0,0 +1,11 @@
|
||||||
|
# Integration Tests
|
||||||
|
|
||||||
|
This directory contains integration tests to exercise Ollama end-to-end to verify behavior
|
||||||
|
|
||||||
|
By default, these tests are disabled so `go test ./...` will exercise only unit tests. To run integration tests you must pass the integration tag. `go test -tags=integration ./...`
|
||||||
|
|
||||||
|
|
||||||
|
The integration tests have 2 modes of operating.
|
||||||
|
|
||||||
|
1. By default, they will start the server on a random port, run the tests, and then shutdown the server.
|
||||||
|
2. If `OLLAMA_TEST_EXISTING` is set to a non-empty string, the tests will run against an existing running server, which can be remote
|
28
integration/basic_test.go
Normal file
28
integration/basic_test.go
Normal file
|
@ -0,0 +1,28 @@
|
||||||
|
//go:build integration
|
||||||
|
|
||||||
|
package integration
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"net/http"
|
||||||
|
"testing"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"github.com/jmorganca/ollama/api"
|
||||||
|
)
|
||||||
|
|
||||||
|
func TestOrcaMiniBlueSky(t *testing.T) {
|
||||||
|
ctx, cancel := context.WithTimeout(context.Background(), 1*time.Minute)
|
||||||
|
defer cancel()
|
||||||
|
// Set up the test data
|
||||||
|
req := api.GenerateRequest{
|
||||||
|
Model: "orca-mini",
|
||||||
|
Prompt: "why is the sky blue?",
|
||||||
|
Stream: &stream,
|
||||||
|
Options: map[string]interface{}{
|
||||||
|
"temperature": 0,
|
||||||
|
"seed": 123,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
GenerateTestHelper(ctx, t, &http.Client{}, req, []string{"rayleigh"})
|
||||||
|
}
|
|
@ -1,49 +1,38 @@
|
||||||
//go:build integration
|
//go:build integration
|
||||||
|
|
||||||
package server
|
package integration
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"context"
|
"context"
|
||||||
"encoding/base64"
|
"encoding/base64"
|
||||||
"log"
|
"net/http"
|
||||||
"os"
|
|
||||||
"strings"
|
|
||||||
"testing"
|
"testing"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
"github.com/jmorganca/ollama/api"
|
"github.com/jmorganca/ollama/api"
|
||||||
"github.com/jmorganca/ollama/llm"
|
|
||||||
"github.com/stretchr/testify/assert"
|
|
||||||
"github.com/stretchr/testify/require"
|
"github.com/stretchr/testify/require"
|
||||||
)
|
)
|
||||||
|
|
||||||
func TestIntegrationMultimodal(t *testing.T) {
|
func TestIntegrationMultimodal(t *testing.T) {
|
||||||
SkipIFNoTestData(t)
|
|
||||||
image, err := base64.StdEncoding.DecodeString(imageEncoding)
|
image, err := base64.StdEncoding.DecodeString(imageEncoding)
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
req := api.GenerateRequest{
|
req := api.GenerateRequest{
|
||||||
Model: "llava:7b",
|
Model: "llava:7b",
|
||||||
Prompt: "what does the text in this image say?",
|
Prompt: "what does the text in this image say?",
|
||||||
Options: map[string]interface{}{},
|
Stream: &stream,
|
||||||
|
Options: map[string]interface{}{
|
||||||
|
"seed": 42,
|
||||||
|
"temperature": 0.0,
|
||||||
|
},
|
||||||
Images: []api.ImageData{
|
Images: []api.ImageData{
|
||||||
image,
|
image,
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
|
|
||||||
resp := "the ollamas"
|
resp := "the ollamas"
|
||||||
workDir, err := os.MkdirTemp("", "ollama")
|
|
||||||
require.NoError(t, err)
|
|
||||||
defer os.RemoveAll(workDir)
|
|
||||||
require.NoError(t, llm.Init(workDir))
|
|
||||||
ctx, cancel := context.WithTimeout(context.Background(), time.Second*60)
|
ctx, cancel := context.WithTimeout(context.Background(), time.Second*60)
|
||||||
defer cancel()
|
defer cancel()
|
||||||
opts := api.DefaultOptions()
|
GenerateTestHelper(ctx, t, &http.Client{}, req, []string{resp})
|
||||||
opts.Seed = 42
|
|
||||||
opts.Temperature = 0.0
|
|
||||||
model, llmRunner := PrepareModelForPrompts(t, req.Model, opts)
|
|
||||||
defer llmRunner.Close()
|
|
||||||
response := OneShotPromptResponse(t, ctx, req, model, llmRunner)
|
|
||||||
log.Print(response)
|
|
||||||
assert.Contains(t, strings.ToLower(response), resp)
|
|
||||||
}
|
}
|
||||||
|
|
||||||
const imageEncoding = `iVBORw0KGgoAAAANSUhEUgAAANIAAAB4CAYAAACHHqzKAAAAAXNSR0IArs4c6QAAAIRlWElmTU0AKgAAAAgABQESAAMAAAABAAEAAAEaAAUAAAABAAAASgEb
|
const imageEncoding = `iVBORw0KGgoAAAANSUhEUgAAANIAAAB4CAYAAACHHqzKAAAAAXNSR0IArs4c6QAAAIRlWElmTU0AKgAAAAgABQESAAMAAAABAAEAAAEaAAUAAAABAAAASgEb
|
73
integration/llm_test.go
Normal file
73
integration/llm_test.go
Normal file
|
@ -0,0 +1,73 @@
|
||||||
|
//go:build integration
|
||||||
|
|
||||||
|
package integration
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"net/http"
|
||||||
|
"sync"
|
||||||
|
"testing"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"github.com/jmorganca/ollama/api"
|
||||||
|
)
|
||||||
|
|
||||||
|
// TODO - this would ideally be in the llm package, but that would require some refactoring of interfaces in the server
|
||||||
|
// package to avoid circular dependencies
|
||||||
|
|
||||||
|
// WARNING - these tests will fail on mac if you don't manually copy ggml-metal.metal to this dir (./server)
|
||||||
|
//
|
||||||
|
// TODO - Fix this ^^
|
||||||
|
|
||||||
|
var (
|
||||||
|
stream = false
|
||||||
|
req = [2]api.GenerateRequest{
|
||||||
|
{
|
||||||
|
Model: "orca-mini",
|
||||||
|
Prompt: "why is the ocean blue?",
|
||||||
|
Stream: &stream,
|
||||||
|
Options: map[string]interface{}{
|
||||||
|
"seed": 42,
|
||||||
|
"temperature": 0.0,
|
||||||
|
},
|
||||||
|
}, {
|
||||||
|
Model: "orca-mini",
|
||||||
|
Prompt: "what is the origin of the us thanksgiving holiday?",
|
||||||
|
Stream: &stream,
|
||||||
|
Options: map[string]interface{}{
|
||||||
|
"seed": 42,
|
||||||
|
"temperature": 0.0,
|
||||||
|
},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
resp = [2]string{
|
||||||
|
"scattering",
|
||||||
|
"united states thanksgiving",
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
func TestIntegrationSimpleOrcaMini(t *testing.T) {
|
||||||
|
ctx, cancel := context.WithTimeout(context.Background(), time.Second*60)
|
||||||
|
defer cancel()
|
||||||
|
GenerateTestHelper(ctx, t, &http.Client{}, req[0], []string{resp[0]})
|
||||||
|
}
|
||||||
|
|
||||||
|
// TODO
|
||||||
|
// The server always loads a new runner and closes the old one, which forces serial execution
|
||||||
|
// At present this test case fails with concurrency problems. Eventually we should try to
|
||||||
|
// get true concurrency working with n_parallel support in the backend
|
||||||
|
func TestIntegrationConcurrentPredictOrcaMini(t *testing.T) {
|
||||||
|
var wg sync.WaitGroup
|
||||||
|
wg.Add(len(req))
|
||||||
|
ctx, cancel := context.WithTimeout(context.Background(), time.Second*60)
|
||||||
|
defer cancel()
|
||||||
|
for i := 0; i < len(req); i++ {
|
||||||
|
go func(i int) {
|
||||||
|
defer wg.Done()
|
||||||
|
GenerateTestHelper(ctx, t, &http.Client{}, req[i], []string{resp[i]})
|
||||||
|
}(i)
|
||||||
|
}
|
||||||
|
wg.Wait()
|
||||||
|
}
|
||||||
|
|
||||||
|
// TODO - create a parallel test with 2 different models once we support concurrency
|
190
integration/utils_test.go
Normal file
190
integration/utils_test.go
Normal file
|
@ -0,0 +1,190 @@
|
||||||
|
//go:build integration
|
||||||
|
|
||||||
|
package integration
|
||||||
|
|
||||||
|
import (
|
||||||
|
"bytes"
|
||||||
|
"context"
|
||||||
|
"encoding/json"
|
||||||
|
"fmt"
|
||||||
|
"io"
|
||||||
|
"log/slog"
|
||||||
|
"math/rand"
|
||||||
|
"net"
|
||||||
|
"net/http"
|
||||||
|
"os"
|
||||||
|
"path/filepath"
|
||||||
|
"runtime"
|
||||||
|
"strconv"
|
||||||
|
"strings"
|
||||||
|
"sync"
|
||||||
|
"testing"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"github.com/jmorganca/ollama/api"
|
||||||
|
"github.com/jmorganca/ollama/app/lifecycle"
|
||||||
|
"github.com/stretchr/testify/assert"
|
||||||
|
)
|
||||||
|
|
||||||
|
func FindPort() string {
|
||||||
|
port := 0
|
||||||
|
if a, err := net.ResolveTCPAddr("tcp", "localhost:0"); err == nil {
|
||||||
|
var l *net.TCPListener
|
||||||
|
if l, err = net.ListenTCP("tcp", a); err == nil {
|
||||||
|
port = l.Addr().(*net.TCPAddr).Port
|
||||||
|
l.Close()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if port == 0 {
|
||||||
|
port = rand.Intn(65535-49152) + 49152 // get a random port in the ephemeral range
|
||||||
|
}
|
||||||
|
return strconv.Itoa(port)
|
||||||
|
}
|
||||||
|
|
||||||
|
func GetTestEndpoint() (string, string) {
|
||||||
|
defaultPort := "11434"
|
||||||
|
ollamaHost := os.Getenv("OLLAMA_HOST")
|
||||||
|
|
||||||
|
scheme, hostport, ok := strings.Cut(ollamaHost, "://")
|
||||||
|
if !ok {
|
||||||
|
scheme, hostport = "http", ollamaHost
|
||||||
|
}
|
||||||
|
|
||||||
|
// trim trailing slashes
|
||||||
|
hostport = strings.TrimRight(hostport, "/")
|
||||||
|
|
||||||
|
host, port, err := net.SplitHostPort(hostport)
|
||||||
|
if err != nil {
|
||||||
|
host, port = "127.0.0.1", defaultPort
|
||||||
|
if ip := net.ParseIP(strings.Trim(hostport, "[]")); ip != nil {
|
||||||
|
host = ip.String()
|
||||||
|
} else if hostport != "" {
|
||||||
|
host = hostport
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if os.Getenv("OLLAMA_TEST_EXISTING") == "" && port == defaultPort {
|
||||||
|
port = FindPort()
|
||||||
|
}
|
||||||
|
|
||||||
|
url := fmt.Sprintf("%s:%s", host, port)
|
||||||
|
slog.Info("server connection", "url", url)
|
||||||
|
return scheme, url
|
||||||
|
}
|
||||||
|
|
||||||
|
// TODO make fanicier, grab logs, etc.
|
||||||
|
var serverMutex sync.Mutex
|
||||||
|
var serverReady bool
|
||||||
|
|
||||||
|
func StartServer(ctx context.Context, ollamaHost string) error {
|
||||||
|
// Make sure the server has been built
|
||||||
|
CLIName, err := filepath.Abs("../ollama")
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
if runtime.GOOS == "windows" {
|
||||||
|
CLIName += ".exe"
|
||||||
|
}
|
||||||
|
_, err = os.Stat(CLIName)
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("CLI missing, did you forget to build first? %w", err)
|
||||||
|
}
|
||||||
|
serverMutex.Lock()
|
||||||
|
defer serverMutex.Unlock()
|
||||||
|
if serverReady {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
if tmp := os.Getenv("OLLAMA_HOST"); tmp != ollamaHost {
|
||||||
|
slog.Info("setting env", "OLLAMA_HOST", ollamaHost)
|
||||||
|
os.Setenv("OLLAMA_HOST", ollamaHost)
|
||||||
|
}
|
||||||
|
|
||||||
|
slog.Info("starting server", "url", ollamaHost)
|
||||||
|
done, err := lifecycle.SpawnServer(ctx, "../ollama")
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("failed to start server: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
go func() {
|
||||||
|
<-ctx.Done()
|
||||||
|
serverMutex.Lock()
|
||||||
|
defer serverMutex.Unlock()
|
||||||
|
exitCode := <-done
|
||||||
|
if exitCode > 0 {
|
||||||
|
slog.Warn("server failure", "exit", exitCode)
|
||||||
|
}
|
||||||
|
serverReady = false
|
||||||
|
}()
|
||||||
|
|
||||||
|
// TODO wait only long enough for the server to be responsive...
|
||||||
|
time.Sleep(500 * time.Millisecond)
|
||||||
|
|
||||||
|
serverReady = true
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func GenerateTestHelper(ctx context.Context, t *testing.T, client *http.Client, genReq api.GenerateRequest, anyResp []string) {
|
||||||
|
requestJSON, err := json.Marshal(genReq)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("Error serializing request: %v", err)
|
||||||
|
}
|
||||||
|
defer func() {
|
||||||
|
if t.Failed() && os.Getenv("OLLAMA_TEST_EXISTING") == "" {
|
||||||
|
// TODO
|
||||||
|
fp, err := os.Open(lifecycle.ServerLogFile)
|
||||||
|
if err != nil {
|
||||||
|
slog.Error("failed to open server log", "logfile", lifecycle.ServerLogFile, "error", err)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
data, err := io.ReadAll(fp)
|
||||||
|
if err != nil {
|
||||||
|
slog.Error("failed to read server log", "logfile", lifecycle.ServerLogFile, "error", err)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
slog.Warn("SERVER LOG FOLLOWS")
|
||||||
|
os.Stderr.Write(data)
|
||||||
|
slog.Warn("END OF SERVER")
|
||||||
|
}
|
||||||
|
err = os.Remove(lifecycle.ServerLogFile)
|
||||||
|
if err != nil && !os.IsNotExist(err) {
|
||||||
|
slog.Warn("failed to cleanup", "logfile", lifecycle.ServerLogFile, "error", err)
|
||||||
|
}
|
||||||
|
}()
|
||||||
|
scheme, testEndpoint := GetTestEndpoint()
|
||||||
|
|
||||||
|
if os.Getenv("OLLAMA_TEST_EXISTING") == "" {
|
||||||
|
assert.NoError(t, StartServer(ctx, testEndpoint))
|
||||||
|
}
|
||||||
|
|
||||||
|
// Make the request and get the response
|
||||||
|
req, err := http.NewRequest("POST", scheme+"://"+testEndpoint+"/api/generate", bytes.NewReader(requestJSON))
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("Error creating request: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Set the content type for the request
|
||||||
|
req.Header.Set("Content-Type", "application/json")
|
||||||
|
|
||||||
|
// Make the request with the HTTP client
|
||||||
|
response, err := client.Do(req.WithContext(ctx))
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("Error making request: %v", err)
|
||||||
|
}
|
||||||
|
body, err := io.ReadAll(response.Body)
|
||||||
|
assert.NoError(t, err)
|
||||||
|
assert.Equal(t, response.StatusCode, 200, string(body))
|
||||||
|
|
||||||
|
// Verify the response is valid JSON
|
||||||
|
var payload api.GenerateResponse
|
||||||
|
err = json.Unmarshal(body, &payload)
|
||||||
|
if err != nil {
|
||||||
|
assert.NoError(t, err, body)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Verify the response contains the expected data
|
||||||
|
for _, resp := range anyResp {
|
||||||
|
assert.Contains(t, strings.ToLower(payload.Response), resp)
|
||||||
|
}
|
||||||
|
}
|
|
@ -1,41 +0,0 @@
|
||||||
#!/bin/bash
|
|
||||||
|
|
||||||
# This script sets up integration tests which run the full stack to verify
|
|
||||||
# inference locally
|
|
||||||
#
|
|
||||||
# To run the relevant tests use
|
|
||||||
# go test -tags=integration ./server
|
|
||||||
set -e
|
|
||||||
set -o pipefail
|
|
||||||
|
|
||||||
REPO=$(dirname $0)/../
|
|
||||||
export OLLAMA_MODELS=${REPO}/test_data/models
|
|
||||||
REGISTRY_SCHEME=https
|
|
||||||
REGISTRY=registry.ollama.ai
|
|
||||||
TEST_MODELS=("library/orca-mini:latest" "library/llava:7b")
|
|
||||||
ACCEPT_HEADER="Accept: application/vnd.docker.distribution.manifest.v2+json"
|
|
||||||
|
|
||||||
for model in ${TEST_MODELS[@]}; do
|
|
||||||
TEST_MODEL=$(echo ${model} | cut -f1 -d:)
|
|
||||||
TEST_MODEL_TAG=$(echo ${model} | cut -f2 -d:)
|
|
||||||
mkdir -p ${OLLAMA_MODELS}/manifests/${REGISTRY}/${TEST_MODEL}/
|
|
||||||
mkdir -p ${OLLAMA_MODELS}/blobs/
|
|
||||||
|
|
||||||
echo "Pulling manifest for ${TEST_MODEL}:${TEST_MODEL_TAG}"
|
|
||||||
curl -s --header "${ACCEPT_HEADER}" \
|
|
||||||
-o ${OLLAMA_MODELS}/manifests/${REGISTRY}/${TEST_MODEL}/${TEST_MODEL_TAG} \
|
|
||||||
${REGISTRY_SCHEME}://${REGISTRY}/v2/${TEST_MODEL}/manifests/${TEST_MODEL_TAG}
|
|
||||||
|
|
||||||
CFG_HASH=$(cat ${OLLAMA_MODELS}/manifests/${REGISTRY}/${TEST_MODEL}/${TEST_MODEL_TAG} | jq -r ".config.digest")
|
|
||||||
echo "Pulling config blob ${CFG_HASH}"
|
|
||||||
curl -L -C - --header "${ACCEPT_HEADER}" \
|
|
||||||
-o ${OLLAMA_MODELS}/blobs/${CFG_HASH} \
|
|
||||||
${REGISTRY_SCHEME}://${REGISTRY}/v2/${TEST_MODEL}/blobs/${CFG_HASH}
|
|
||||||
|
|
||||||
for LAYER in $(cat ${OLLAMA_MODELS}/manifests/${REGISTRY}/${TEST_MODEL}/${TEST_MODEL_TAG} | jq -r ".layers[].digest"); do
|
|
||||||
echo "Pulling blob ${LAYER}"
|
|
||||||
curl -L -C - --header "${ACCEPT_HEADER}" \
|
|
||||||
-o ${OLLAMA_MODELS}/blobs/${LAYER} \
|
|
||||||
${REGISTRY_SCHEME}://${REGISTRY}/v2/${TEST_MODEL}/blobs/${LAYER}
|
|
||||||
done
|
|
||||||
done
|
|
|
@ -1,123 +0,0 @@
|
||||||
//go:build integration
|
|
||||||
|
|
||||||
package server
|
|
||||||
|
|
||||||
import (
|
|
||||||
"context"
|
|
||||||
"os"
|
|
||||||
"strings"
|
|
||||||
"sync"
|
|
||||||
"testing"
|
|
||||||
"time"
|
|
||||||
|
|
||||||
"github.com/stretchr/testify/assert"
|
|
||||||
"github.com/stretchr/testify/require"
|
|
||||||
|
|
||||||
"github.com/jmorganca/ollama/api"
|
|
||||||
"github.com/jmorganca/ollama/llm"
|
|
||||||
)
|
|
||||||
|
|
||||||
// TODO - this would ideally be in the llm package, but that would require some refactoring of interfaces in the server
|
|
||||||
// package to avoid circular dependencies
|
|
||||||
|
|
||||||
// WARNING - these tests will fail on mac if you don't manually copy ggml-metal.metal to this dir (./server)
|
|
||||||
//
|
|
||||||
// TODO - Fix this ^^
|
|
||||||
|
|
||||||
var (
|
|
||||||
req = [2]api.GenerateRequest{
|
|
||||||
{
|
|
||||||
Model: "orca-mini",
|
|
||||||
Prompt: "tell me a short story about agi?",
|
|
||||||
Options: map[string]interface{}{},
|
|
||||||
}, {
|
|
||||||
Model: "orca-mini",
|
|
||||||
Prompt: "what is the origin of the us thanksgiving holiday?",
|
|
||||||
Options: map[string]interface{}{},
|
|
||||||
},
|
|
||||||
}
|
|
||||||
resp = [2]string{
|
|
||||||
"once upon a time",
|
|
||||||
"united states thanksgiving",
|
|
||||||
}
|
|
||||||
)
|
|
||||||
|
|
||||||
func TestIntegrationSimpleOrcaMini(t *testing.T) {
|
|
||||||
SkipIFNoTestData(t)
|
|
||||||
workDir, err := os.MkdirTemp("", "ollama")
|
|
||||||
require.NoError(t, err)
|
|
||||||
defer os.RemoveAll(workDir)
|
|
||||||
require.NoError(t, llm.Init(workDir))
|
|
||||||
ctx, cancel := context.WithTimeout(context.Background(), time.Second*60)
|
|
||||||
defer cancel()
|
|
||||||
opts := api.DefaultOptions()
|
|
||||||
opts.Seed = 42
|
|
||||||
opts.Temperature = 0.0
|
|
||||||
model, llmRunner := PrepareModelForPrompts(t, req[0].Model, opts)
|
|
||||||
defer llmRunner.Close()
|
|
||||||
response := OneShotPromptResponse(t, ctx, req[0], model, llmRunner)
|
|
||||||
assert.Contains(t, strings.ToLower(response), resp[0])
|
|
||||||
}
|
|
||||||
|
|
||||||
// TODO
|
|
||||||
// The server always loads a new runner and closes the old one, which forces serial execution
|
|
||||||
// At present this test case fails with concurrency problems. Eventually we should try to
|
|
||||||
// get true concurrency working with n_parallel support in the backend
|
|
||||||
func TestIntegrationConcurrentPredictOrcaMini(t *testing.T) {
|
|
||||||
SkipIFNoTestData(t)
|
|
||||||
|
|
||||||
t.Skip("concurrent prediction on single runner not currently supported")
|
|
||||||
|
|
||||||
workDir, err := os.MkdirTemp("", "ollama")
|
|
||||||
require.NoError(t, err)
|
|
||||||
defer os.RemoveAll(workDir)
|
|
||||||
require.NoError(t, llm.Init(workDir))
|
|
||||||
ctx, cancel := context.WithTimeout(context.Background(), time.Second*60)
|
|
||||||
defer cancel()
|
|
||||||
opts := api.DefaultOptions()
|
|
||||||
opts.Seed = 42
|
|
||||||
opts.Temperature = 0.0
|
|
||||||
var wg sync.WaitGroup
|
|
||||||
wg.Add(len(req))
|
|
||||||
model, llmRunner := PrepareModelForPrompts(t, req[0].Model, opts)
|
|
||||||
defer llmRunner.Close()
|
|
||||||
for i := 0; i < len(req); i++ {
|
|
||||||
go func(i int) {
|
|
||||||
defer wg.Done()
|
|
||||||
response := OneShotPromptResponse(t, ctx, req[i], model, llmRunner)
|
|
||||||
t.Logf("Prompt: %s\nResponse: %s", req[0].Prompt, response)
|
|
||||||
assert.Contains(t, strings.ToLower(response), resp[i], "error in thread %d (%s)", i, req[i].Prompt)
|
|
||||||
}(i)
|
|
||||||
}
|
|
||||||
wg.Wait()
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestIntegrationConcurrentRunnersOrcaMini(t *testing.T) {
|
|
||||||
SkipIFNoTestData(t)
|
|
||||||
workDir, err := os.MkdirTemp("", "ollama")
|
|
||||||
require.NoError(t, err)
|
|
||||||
defer os.RemoveAll(workDir)
|
|
||||||
require.NoError(t, llm.Init(workDir))
|
|
||||||
ctx, cancel := context.WithTimeout(context.Background(), time.Second*60)
|
|
||||||
defer cancel()
|
|
||||||
opts := api.DefaultOptions()
|
|
||||||
opts.Seed = 42
|
|
||||||
opts.Temperature = 0.0
|
|
||||||
var wg sync.WaitGroup
|
|
||||||
wg.Add(len(req))
|
|
||||||
|
|
||||||
t.Logf("Running %d concurrently", len(req))
|
|
||||||
for i := 0; i < len(req); i++ {
|
|
||||||
go func(i int) {
|
|
||||||
defer wg.Done()
|
|
||||||
model, llmRunner := PrepareModelForPrompts(t, req[0].Model, opts)
|
|
||||||
defer llmRunner.Close()
|
|
||||||
response := OneShotPromptResponse(t, ctx, req[i], model, llmRunner)
|
|
||||||
t.Logf("Prompt: %s\nResponse: %s", req[0].Prompt, response)
|
|
||||||
assert.Contains(t, strings.ToLower(response), resp[i], "error in thread %d (%s)", i, req[i].Prompt)
|
|
||||||
}(i)
|
|
||||||
}
|
|
||||||
wg.Wait()
|
|
||||||
}
|
|
||||||
|
|
||||||
// TODO - create a parallel test with 2 different models once we support concurrency
|
|
|
@ -1,75 +0,0 @@
|
||||||
//go:build integration
|
|
||||||
|
|
||||||
package server
|
|
||||||
|
|
||||||
import (
|
|
||||||
"context"
|
|
||||||
"errors"
|
|
||||||
"os"
|
|
||||||
"path"
|
|
||||||
"runtime"
|
|
||||||
"testing"
|
|
||||||
|
|
||||||
"github.com/jmorganca/ollama/api"
|
|
||||||
"github.com/jmorganca/ollama/llm"
|
|
||||||
"github.com/stretchr/testify/require"
|
|
||||||
)
|
|
||||||
|
|
||||||
func SkipIFNoTestData(t *testing.T) {
|
|
||||||
modelDir := getModelDir()
|
|
||||||
if _, err := os.Stat(modelDir); errors.Is(err, os.ErrNotExist) {
|
|
||||||
t.Skipf("%s does not exist - skipping integration tests", modelDir)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func getModelDir() string {
|
|
||||||
_, filename, _, _ := runtime.Caller(0)
|
|
||||||
return path.Dir(path.Dir(filename) + "/../test_data/models/.")
|
|
||||||
}
|
|
||||||
|
|
||||||
func PrepareModelForPrompts(t *testing.T, modelName string, opts api.Options) (*Model, llm.LLM) {
|
|
||||||
modelDir := getModelDir()
|
|
||||||
os.Setenv("OLLAMA_MODELS", modelDir)
|
|
||||||
model, err := GetModel(modelName)
|
|
||||||
require.NoError(t, err, "GetModel ")
|
|
||||||
err = opts.FromMap(model.Options)
|
|
||||||
require.NoError(t, err, "opts from model ")
|
|
||||||
runner, err := llm.New("unused", model.ModelPath, model.AdapterPaths, model.ProjectorPaths, opts)
|
|
||||||
require.NoError(t, err, "llm.New failed")
|
|
||||||
return model, runner
|
|
||||||
}
|
|
||||||
|
|
||||||
func OneShotPromptResponse(t *testing.T, ctx context.Context, req api.GenerateRequest, model *Model, runner llm.LLM) string {
|
|
||||||
prompt, err := model.PreResponsePrompt(PromptVars{
|
|
||||||
System: req.System,
|
|
||||||
Prompt: req.Prompt,
|
|
||||||
First: len(req.Context) == 0,
|
|
||||||
})
|
|
||||||
require.NoError(t, err, "prompt generation failed")
|
|
||||||
success := make(chan bool, 1)
|
|
||||||
response := ""
|
|
||||||
cb := func(r llm.PredictResult) {
|
|
||||||
|
|
||||||
if !r.Done {
|
|
||||||
response += r.Content
|
|
||||||
} else {
|
|
||||||
success <- true
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
predictReq := llm.PredictOpts{
|
|
||||||
Prompt: prompt,
|
|
||||||
Format: req.Format,
|
|
||||||
Images: req.Images,
|
|
||||||
}
|
|
||||||
err = runner.Predict(ctx, predictReq, cb)
|
|
||||||
require.NoError(t, err, "predict call failed")
|
|
||||||
|
|
||||||
select {
|
|
||||||
case <-ctx.Done():
|
|
||||||
t.Errorf("failed to complete before timeout: \n%s", response)
|
|
||||||
return ""
|
|
||||||
case <-success:
|
|
||||||
return response
|
|
||||||
}
|
|
||||||
}
|
|
Loading…
Reference in a new issue