Merge pull request #3309 from dhiltgen/integration_testing
Revamp go based integration tests
This commit is contained in:
commit
1784113ef5
8 changed files with 313 additions and 261 deletions
11
integration/README.md
Normal file
11
integration/README.md
Normal file
|
@ -0,0 +1,11 @@
|
|||
# Integration Tests
|
||||
|
||||
This directory contains integration tests to exercise Ollama end-to-end to verify behavior
|
||||
|
||||
By default, these tests are disabled so `go test ./...` will exercise only unit tests. To run integration tests you must pass the integration tag. `go test -tags=integration ./...`
|
||||
|
||||
|
||||
The integration tests have 2 modes of operating.
|
||||
|
||||
1. By default, they will start the server on a random port, run the tests, and then shutdown the server.
|
||||
2. If `OLLAMA_TEST_EXISTING` is set to a non-empty string, the tests will run against an existing running server, which can be remote
|
28
integration/basic_test.go
Normal file
28
integration/basic_test.go
Normal file
|
@ -0,0 +1,28 @@
|
|||
//go:build integration
|
||||
|
||||
package integration
|
||||
|
||||
import (
|
||||
"context"
|
||||
"net/http"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/jmorganca/ollama/api"
|
||||
)
|
||||
|
||||
func TestOrcaMiniBlueSky(t *testing.T) {
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 1*time.Minute)
|
||||
defer cancel()
|
||||
// Set up the test data
|
||||
req := api.GenerateRequest{
|
||||
Model: "orca-mini",
|
||||
Prompt: "why is the sky blue?",
|
||||
Stream: &stream,
|
||||
Options: map[string]interface{}{
|
||||
"temperature": 0,
|
||||
"seed": 123,
|
||||
},
|
||||
}
|
||||
GenerateTestHelper(ctx, t, &http.Client{}, req, []string{"rayleigh"})
|
||||
}
|
|
@ -1,49 +1,38 @@
|
|||
//go:build integration
|
||||
|
||||
package server
|
||||
package integration
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/base64"
|
||||
"log"
|
||||
"os"
|
||||
"strings"
|
||||
"net/http"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/jmorganca/ollama/api"
|
||||
"github.com/jmorganca/ollama/llm"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
func TestIntegrationMultimodal(t *testing.T) {
|
||||
SkipIFNoTestData(t)
|
||||
image, err := base64.StdEncoding.DecodeString(imageEncoding)
|
||||
require.NoError(t, err)
|
||||
req := api.GenerateRequest{
|
||||
Model: "llava:7b",
|
||||
Prompt: "what does the text in this image say?",
|
||||
Options: map[string]interface{}{},
|
||||
Stream: &stream,
|
||||
Options: map[string]interface{}{
|
||||
"seed": 42,
|
||||
"temperature": 0.0,
|
||||
},
|
||||
Images: []api.ImageData{
|
||||
image,
|
||||
},
|
||||
}
|
||||
|
||||
resp := "the ollamas"
|
||||
workDir, err := os.MkdirTemp("", "ollama")
|
||||
require.NoError(t, err)
|
||||
defer os.RemoveAll(workDir)
|
||||
require.NoError(t, llm.Init(workDir))
|
||||
ctx, cancel := context.WithTimeout(context.Background(), time.Second*60)
|
||||
defer cancel()
|
||||
opts := api.DefaultOptions()
|
||||
opts.Seed = 42
|
||||
opts.Temperature = 0.0
|
||||
model, llmRunner := PrepareModelForPrompts(t, req.Model, opts)
|
||||
defer llmRunner.Close()
|
||||
response := OneShotPromptResponse(t, ctx, req, model, llmRunner)
|
||||
log.Print(response)
|
||||
assert.Contains(t, strings.ToLower(response), resp)
|
||||
GenerateTestHelper(ctx, t, &http.Client{}, req, []string{resp})
|
||||
}
|
||||
|
||||
const imageEncoding = `iVBORw0KGgoAAAANSUhEUgAAANIAAAB4CAYAAACHHqzKAAAAAXNSR0IArs4c6QAAAIRlWElmTU0AKgAAAAgABQESAAMAAAABAAEAAAEaAAUAAAABAAAASgEb
|
73
integration/llm_test.go
Normal file
73
integration/llm_test.go
Normal file
|
@ -0,0 +1,73 @@
|
|||
//go:build integration
|
||||
|
||||
package integration
|
||||
|
||||
import (
|
||||
"context"
|
||||
"net/http"
|
||||
"sync"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/jmorganca/ollama/api"
|
||||
)
|
||||
|
||||
// TODO - this would ideally be in the llm package, but that would require some refactoring of interfaces in the server
|
||||
// package to avoid circular dependencies
|
||||
|
||||
// WARNING - these tests will fail on mac if you don't manually copy ggml-metal.metal to this dir (./server)
|
||||
//
|
||||
// TODO - Fix this ^^
|
||||
|
||||
var (
|
||||
stream = false
|
||||
req = [2]api.GenerateRequest{
|
||||
{
|
||||
Model: "orca-mini",
|
||||
Prompt: "why is the ocean blue?",
|
||||
Stream: &stream,
|
||||
Options: map[string]interface{}{
|
||||
"seed": 42,
|
||||
"temperature": 0.0,
|
||||
},
|
||||
}, {
|
||||
Model: "orca-mini",
|
||||
Prompt: "what is the origin of the us thanksgiving holiday?",
|
||||
Stream: &stream,
|
||||
Options: map[string]interface{}{
|
||||
"seed": 42,
|
||||
"temperature": 0.0,
|
||||
},
|
||||
},
|
||||
}
|
||||
resp = [2]string{
|
||||
"scattering",
|
||||
"united states thanksgiving",
|
||||
}
|
||||
)
|
||||
|
||||
func TestIntegrationSimpleOrcaMini(t *testing.T) {
|
||||
ctx, cancel := context.WithTimeout(context.Background(), time.Second*60)
|
||||
defer cancel()
|
||||
GenerateTestHelper(ctx, t, &http.Client{}, req[0], []string{resp[0]})
|
||||
}
|
||||
|
||||
// TODO
|
||||
// The server always loads a new runner and closes the old one, which forces serial execution
|
||||
// At present this test case fails with concurrency problems. Eventually we should try to
|
||||
// get true concurrency working with n_parallel support in the backend
|
||||
func TestIntegrationConcurrentPredictOrcaMini(t *testing.T) {
|
||||
var wg sync.WaitGroup
|
||||
wg.Add(len(req))
|
||||
ctx, cancel := context.WithTimeout(context.Background(), time.Second*60)
|
||||
defer cancel()
|
||||
for i := 0; i < len(req); i++ {
|
||||
go func(i int) {
|
||||
defer wg.Done()
|
||||
GenerateTestHelper(ctx, t, &http.Client{}, req[i], []string{resp[i]})
|
||||
}(i)
|
||||
}
|
||||
wg.Wait()
|
||||
}
|
||||
|
||||
// TODO - create a parallel test with 2 different models once we support concurrency
|
190
integration/utils_test.go
Normal file
190
integration/utils_test.go
Normal file
|
@ -0,0 +1,190 @@
|
|||
//go:build integration
|
||||
|
||||
package integration
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"context"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"io"
|
||||
"log/slog"
|
||||
"math/rand"
|
||||
"net"
|
||||
"net/http"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"runtime"
|
||||
"strconv"
|
||||
"strings"
|
||||
"sync"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/jmorganca/ollama/api"
|
||||
"github.com/jmorganca/ollama/app/lifecycle"
|
||||
"github.com/stretchr/testify/assert"
|
||||
)
|
||||
|
||||
func FindPort() string {
|
||||
port := 0
|
||||
if a, err := net.ResolveTCPAddr("tcp", "localhost:0"); err == nil {
|
||||
var l *net.TCPListener
|
||||
if l, err = net.ListenTCP("tcp", a); err == nil {
|
||||
port = l.Addr().(*net.TCPAddr).Port
|
||||
l.Close()
|
||||
}
|
||||
}
|
||||
if port == 0 {
|
||||
port = rand.Intn(65535-49152) + 49152 // get a random port in the ephemeral range
|
||||
}
|
||||
return strconv.Itoa(port)
|
||||
}
|
||||
|
||||
func GetTestEndpoint() (string, string) {
|
||||
defaultPort := "11434"
|
||||
ollamaHost := os.Getenv("OLLAMA_HOST")
|
||||
|
||||
scheme, hostport, ok := strings.Cut(ollamaHost, "://")
|
||||
if !ok {
|
||||
scheme, hostport = "http", ollamaHost
|
||||
}
|
||||
|
||||
// trim trailing slashes
|
||||
hostport = strings.TrimRight(hostport, "/")
|
||||
|
||||
host, port, err := net.SplitHostPort(hostport)
|
||||
if err != nil {
|
||||
host, port = "127.0.0.1", defaultPort
|
||||
if ip := net.ParseIP(strings.Trim(hostport, "[]")); ip != nil {
|
||||
host = ip.String()
|
||||
} else if hostport != "" {
|
||||
host = hostport
|
||||
}
|
||||
}
|
||||
|
||||
if os.Getenv("OLLAMA_TEST_EXISTING") == "" && port == defaultPort {
|
||||
port = FindPort()
|
||||
}
|
||||
|
||||
url := fmt.Sprintf("%s:%s", host, port)
|
||||
slog.Info("server connection", "url", url)
|
||||
return scheme, url
|
||||
}
|
||||
|
||||
// TODO make fanicier, grab logs, etc.
|
||||
var serverMutex sync.Mutex
|
||||
var serverReady bool
|
||||
|
||||
func StartServer(ctx context.Context, ollamaHost string) error {
|
||||
// Make sure the server has been built
|
||||
CLIName, err := filepath.Abs("../ollama")
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if runtime.GOOS == "windows" {
|
||||
CLIName += ".exe"
|
||||
}
|
||||
_, err = os.Stat(CLIName)
|
||||
if err != nil {
|
||||
return fmt.Errorf("CLI missing, did you forget to build first? %w", err)
|
||||
}
|
||||
serverMutex.Lock()
|
||||
defer serverMutex.Unlock()
|
||||
if serverReady {
|
||||
return nil
|
||||
}
|
||||
|
||||
if tmp := os.Getenv("OLLAMA_HOST"); tmp != ollamaHost {
|
||||
slog.Info("setting env", "OLLAMA_HOST", ollamaHost)
|
||||
os.Setenv("OLLAMA_HOST", ollamaHost)
|
||||
}
|
||||
|
||||
slog.Info("starting server", "url", ollamaHost)
|
||||
done, err := lifecycle.SpawnServer(ctx, "../ollama")
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to start server: %w", err)
|
||||
}
|
||||
|
||||
go func() {
|
||||
<-ctx.Done()
|
||||
serverMutex.Lock()
|
||||
defer serverMutex.Unlock()
|
||||
exitCode := <-done
|
||||
if exitCode > 0 {
|
||||
slog.Warn("server failure", "exit", exitCode)
|
||||
}
|
||||
serverReady = false
|
||||
}()
|
||||
|
||||
// TODO wait only long enough for the server to be responsive...
|
||||
time.Sleep(500 * time.Millisecond)
|
||||
|
||||
serverReady = true
|
||||
return nil
|
||||
}
|
||||
|
||||
func GenerateTestHelper(ctx context.Context, t *testing.T, client *http.Client, genReq api.GenerateRequest, anyResp []string) {
|
||||
requestJSON, err := json.Marshal(genReq)
|
||||
if err != nil {
|
||||
t.Fatalf("Error serializing request: %v", err)
|
||||
}
|
||||
defer func() {
|
||||
if t.Failed() && os.Getenv("OLLAMA_TEST_EXISTING") == "" {
|
||||
// TODO
|
||||
fp, err := os.Open(lifecycle.ServerLogFile)
|
||||
if err != nil {
|
||||
slog.Error("failed to open server log", "logfile", lifecycle.ServerLogFile, "error", err)
|
||||
return
|
||||
}
|
||||
data, err := io.ReadAll(fp)
|
||||
if err != nil {
|
||||
slog.Error("failed to read server log", "logfile", lifecycle.ServerLogFile, "error", err)
|
||||
return
|
||||
}
|
||||
slog.Warn("SERVER LOG FOLLOWS")
|
||||
os.Stderr.Write(data)
|
||||
slog.Warn("END OF SERVER")
|
||||
}
|
||||
err = os.Remove(lifecycle.ServerLogFile)
|
||||
if err != nil && !os.IsNotExist(err) {
|
||||
slog.Warn("failed to cleanup", "logfile", lifecycle.ServerLogFile, "error", err)
|
||||
}
|
||||
}()
|
||||
scheme, testEndpoint := GetTestEndpoint()
|
||||
|
||||
if os.Getenv("OLLAMA_TEST_EXISTING") == "" {
|
||||
assert.NoError(t, StartServer(ctx, testEndpoint))
|
||||
}
|
||||
|
||||
// Make the request and get the response
|
||||
req, err := http.NewRequest("POST", scheme+"://"+testEndpoint+"/api/generate", bytes.NewReader(requestJSON))
|
||||
if err != nil {
|
||||
t.Fatalf("Error creating request: %v", err)
|
||||
}
|
||||
|
||||
// Set the content type for the request
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
|
||||
// Make the request with the HTTP client
|
||||
response, err := client.Do(req.WithContext(ctx))
|
||||
if err != nil {
|
||||
t.Fatalf("Error making request: %v", err)
|
||||
}
|
||||
body, err := io.ReadAll(response.Body)
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, response.StatusCode, 200, string(body))
|
||||
|
||||
// Verify the response is valid JSON
|
||||
var payload api.GenerateResponse
|
||||
err = json.Unmarshal(body, &payload)
|
||||
if err != nil {
|
||||
assert.NoError(t, err, body)
|
||||
}
|
||||
|
||||
// Verify the response contains the expected data
|
||||
for _, resp := range anyResp {
|
||||
assert.Contains(t, strings.ToLower(payload.Response), resp)
|
||||
}
|
||||
}
|
|
@ -1,41 +0,0 @@
|
|||
#!/bin/bash
|
||||
|
||||
# This script sets up integration tests which run the full stack to verify
|
||||
# inference locally
|
||||
#
|
||||
# To run the relevant tests use
|
||||
# go test -tags=integration ./server
|
||||
set -e
|
||||
set -o pipefail
|
||||
|
||||
REPO=$(dirname $0)/../
|
||||
export OLLAMA_MODELS=${REPO}/test_data/models
|
||||
REGISTRY_SCHEME=https
|
||||
REGISTRY=registry.ollama.ai
|
||||
TEST_MODELS=("library/orca-mini:latest" "library/llava:7b")
|
||||
ACCEPT_HEADER="Accept: application/vnd.docker.distribution.manifest.v2+json"
|
||||
|
||||
for model in ${TEST_MODELS[@]}; do
|
||||
TEST_MODEL=$(echo ${model} | cut -f1 -d:)
|
||||
TEST_MODEL_TAG=$(echo ${model} | cut -f2 -d:)
|
||||
mkdir -p ${OLLAMA_MODELS}/manifests/${REGISTRY}/${TEST_MODEL}/
|
||||
mkdir -p ${OLLAMA_MODELS}/blobs/
|
||||
|
||||
echo "Pulling manifest for ${TEST_MODEL}:${TEST_MODEL_TAG}"
|
||||
curl -s --header "${ACCEPT_HEADER}" \
|
||||
-o ${OLLAMA_MODELS}/manifests/${REGISTRY}/${TEST_MODEL}/${TEST_MODEL_TAG} \
|
||||
${REGISTRY_SCHEME}://${REGISTRY}/v2/${TEST_MODEL}/manifests/${TEST_MODEL_TAG}
|
||||
|
||||
CFG_HASH=$(cat ${OLLAMA_MODELS}/manifests/${REGISTRY}/${TEST_MODEL}/${TEST_MODEL_TAG} | jq -r ".config.digest")
|
||||
echo "Pulling config blob ${CFG_HASH}"
|
||||
curl -L -C - --header "${ACCEPT_HEADER}" \
|
||||
-o ${OLLAMA_MODELS}/blobs/${CFG_HASH} \
|
||||
${REGISTRY_SCHEME}://${REGISTRY}/v2/${TEST_MODEL}/blobs/${CFG_HASH}
|
||||
|
||||
for LAYER in $(cat ${OLLAMA_MODELS}/manifests/${REGISTRY}/${TEST_MODEL}/${TEST_MODEL_TAG} | jq -r ".layers[].digest"); do
|
||||
echo "Pulling blob ${LAYER}"
|
||||
curl -L -C - --header "${ACCEPT_HEADER}" \
|
||||
-o ${OLLAMA_MODELS}/blobs/${LAYER} \
|
||||
${REGISTRY_SCHEME}://${REGISTRY}/v2/${TEST_MODEL}/blobs/${LAYER}
|
||||
done
|
||||
done
|
|
@ -1,123 +0,0 @@
|
|||
//go:build integration
|
||||
|
||||
package server
|
||||
|
||||
import (
|
||||
"context"
|
||||
"os"
|
||||
"strings"
|
||||
"sync"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
|
||||
"github.com/jmorganca/ollama/api"
|
||||
"github.com/jmorganca/ollama/llm"
|
||||
)
|
||||
|
||||
// TODO - this would ideally be in the llm package, but that would require some refactoring of interfaces in the server
|
||||
// package to avoid circular dependencies
|
||||
|
||||
// WARNING - these tests will fail on mac if you don't manually copy ggml-metal.metal to this dir (./server)
|
||||
//
|
||||
// TODO - Fix this ^^
|
||||
|
||||
var (
|
||||
req = [2]api.GenerateRequest{
|
||||
{
|
||||
Model: "orca-mini",
|
||||
Prompt: "tell me a short story about agi?",
|
||||
Options: map[string]interface{}{},
|
||||
}, {
|
||||
Model: "orca-mini",
|
||||
Prompt: "what is the origin of the us thanksgiving holiday?",
|
||||
Options: map[string]interface{}{},
|
||||
},
|
||||
}
|
||||
resp = [2]string{
|
||||
"once upon a time",
|
||||
"united states thanksgiving",
|
||||
}
|
||||
)
|
||||
|
||||
func TestIntegrationSimpleOrcaMini(t *testing.T) {
|
||||
SkipIFNoTestData(t)
|
||||
workDir, err := os.MkdirTemp("", "ollama")
|
||||
require.NoError(t, err)
|
||||
defer os.RemoveAll(workDir)
|
||||
require.NoError(t, llm.Init(workDir))
|
||||
ctx, cancel := context.WithTimeout(context.Background(), time.Second*60)
|
||||
defer cancel()
|
||||
opts := api.DefaultOptions()
|
||||
opts.Seed = 42
|
||||
opts.Temperature = 0.0
|
||||
model, llmRunner := PrepareModelForPrompts(t, req[0].Model, opts)
|
||||
defer llmRunner.Close()
|
||||
response := OneShotPromptResponse(t, ctx, req[0], model, llmRunner)
|
||||
assert.Contains(t, strings.ToLower(response), resp[0])
|
||||
}
|
||||
|
||||
// TODO
|
||||
// The server always loads a new runner and closes the old one, which forces serial execution
|
||||
// At present this test case fails with concurrency problems. Eventually we should try to
|
||||
// get true concurrency working with n_parallel support in the backend
|
||||
func TestIntegrationConcurrentPredictOrcaMini(t *testing.T) {
|
||||
SkipIFNoTestData(t)
|
||||
|
||||
t.Skip("concurrent prediction on single runner not currently supported")
|
||||
|
||||
workDir, err := os.MkdirTemp("", "ollama")
|
||||
require.NoError(t, err)
|
||||
defer os.RemoveAll(workDir)
|
||||
require.NoError(t, llm.Init(workDir))
|
||||
ctx, cancel := context.WithTimeout(context.Background(), time.Second*60)
|
||||
defer cancel()
|
||||
opts := api.DefaultOptions()
|
||||
opts.Seed = 42
|
||||
opts.Temperature = 0.0
|
||||
var wg sync.WaitGroup
|
||||
wg.Add(len(req))
|
||||
model, llmRunner := PrepareModelForPrompts(t, req[0].Model, opts)
|
||||
defer llmRunner.Close()
|
||||
for i := 0; i < len(req); i++ {
|
||||
go func(i int) {
|
||||
defer wg.Done()
|
||||
response := OneShotPromptResponse(t, ctx, req[i], model, llmRunner)
|
||||
t.Logf("Prompt: %s\nResponse: %s", req[0].Prompt, response)
|
||||
assert.Contains(t, strings.ToLower(response), resp[i], "error in thread %d (%s)", i, req[i].Prompt)
|
||||
}(i)
|
||||
}
|
||||
wg.Wait()
|
||||
}
|
||||
|
||||
func TestIntegrationConcurrentRunnersOrcaMini(t *testing.T) {
|
||||
SkipIFNoTestData(t)
|
||||
workDir, err := os.MkdirTemp("", "ollama")
|
||||
require.NoError(t, err)
|
||||
defer os.RemoveAll(workDir)
|
||||
require.NoError(t, llm.Init(workDir))
|
||||
ctx, cancel := context.WithTimeout(context.Background(), time.Second*60)
|
||||
defer cancel()
|
||||
opts := api.DefaultOptions()
|
||||
opts.Seed = 42
|
||||
opts.Temperature = 0.0
|
||||
var wg sync.WaitGroup
|
||||
wg.Add(len(req))
|
||||
|
||||
t.Logf("Running %d concurrently", len(req))
|
||||
for i := 0; i < len(req); i++ {
|
||||
go func(i int) {
|
||||
defer wg.Done()
|
||||
model, llmRunner := PrepareModelForPrompts(t, req[0].Model, opts)
|
||||
defer llmRunner.Close()
|
||||
response := OneShotPromptResponse(t, ctx, req[i], model, llmRunner)
|
||||
t.Logf("Prompt: %s\nResponse: %s", req[0].Prompt, response)
|
||||
assert.Contains(t, strings.ToLower(response), resp[i], "error in thread %d (%s)", i, req[i].Prompt)
|
||||
}(i)
|
||||
}
|
||||
wg.Wait()
|
||||
}
|
||||
|
||||
// TODO - create a parallel test with 2 different models once we support concurrency
|
|
@ -1,75 +0,0 @@
|
|||
//go:build integration
|
||||
|
||||
package server
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"os"
|
||||
"path"
|
||||
"runtime"
|
||||
"testing"
|
||||
|
||||
"github.com/jmorganca/ollama/api"
|
||||
"github.com/jmorganca/ollama/llm"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
func SkipIFNoTestData(t *testing.T) {
|
||||
modelDir := getModelDir()
|
||||
if _, err := os.Stat(modelDir); errors.Is(err, os.ErrNotExist) {
|
||||
t.Skipf("%s does not exist - skipping integration tests", modelDir)
|
||||
}
|
||||
}
|
||||
|
||||
func getModelDir() string {
|
||||
_, filename, _, _ := runtime.Caller(0)
|
||||
return path.Dir(path.Dir(filename) + "/../test_data/models/.")
|
||||
}
|
||||
|
||||
func PrepareModelForPrompts(t *testing.T, modelName string, opts api.Options) (*Model, llm.LLM) {
|
||||
modelDir := getModelDir()
|
||||
os.Setenv("OLLAMA_MODELS", modelDir)
|
||||
model, err := GetModel(modelName)
|
||||
require.NoError(t, err, "GetModel ")
|
||||
err = opts.FromMap(model.Options)
|
||||
require.NoError(t, err, "opts from model ")
|
||||
runner, err := llm.New("unused", model.ModelPath, model.AdapterPaths, model.ProjectorPaths, opts)
|
||||
require.NoError(t, err, "llm.New failed")
|
||||
return model, runner
|
||||
}
|
||||
|
||||
func OneShotPromptResponse(t *testing.T, ctx context.Context, req api.GenerateRequest, model *Model, runner llm.LLM) string {
|
||||
prompt, err := model.PreResponsePrompt(PromptVars{
|
||||
System: req.System,
|
||||
Prompt: req.Prompt,
|
||||
First: len(req.Context) == 0,
|
||||
})
|
||||
require.NoError(t, err, "prompt generation failed")
|
||||
success := make(chan bool, 1)
|
||||
response := ""
|
||||
cb := func(r llm.PredictResult) {
|
||||
|
||||
if !r.Done {
|
||||
response += r.Content
|
||||
} else {
|
||||
success <- true
|
||||
}
|
||||
}
|
||||
|
||||
predictReq := llm.PredictOpts{
|
||||
Prompt: prompt,
|
||||
Format: req.Format,
|
||||
Images: req.Images,
|
||||
}
|
||||
err = runner.Predict(ctx, predictReq, cb)
|
||||
require.NoError(t, err, "predict call failed")
|
||||
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
t.Errorf("failed to complete before timeout: \n%s", response)
|
||||
return ""
|
||||
case <-success:
|
||||
return response
|
||||
}
|
||||
}
|
Loading…
Reference in a new issue