add chat and generate tests with mock runner
This commit is contained in:
parent
64039df6d7
commit
4a565cbf94
6 changed files with 679 additions and 14 deletions
|
@ -537,6 +537,7 @@ var ggufKVOrder = map[string][]string{
|
||||||
"tokenizer.ggml.add_bos_token",
|
"tokenizer.ggml.add_bos_token",
|
||||||
"tokenizer.ggml.add_eos_token",
|
"tokenizer.ggml.add_eos_token",
|
||||||
"tokenizer.chat_template",
|
"tokenizer.chat_template",
|
||||||
|
"bert.pooling_type",
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -3,7 +3,6 @@ package server
|
||||||
import (
|
import (
|
||||||
"bytes"
|
"bytes"
|
||||||
"context"
|
"context"
|
||||||
"strings"
|
|
||||||
"testing"
|
"testing"
|
||||||
|
|
||||||
"github.com/google/go-cmp/cmp"
|
"github.com/google/go-cmp/cmp"
|
||||||
|
@ -11,14 +10,6 @@ import (
|
||||||
"github.com/ollama/ollama/template"
|
"github.com/ollama/ollama/template"
|
||||||
)
|
)
|
||||||
|
|
||||||
func tokenize(_ context.Context, s string) (tokens []int, err error) {
|
|
||||||
for range strings.Fields(s) {
|
|
||||||
tokens = append(tokens, len(tokens))
|
|
||||||
}
|
|
||||||
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestChatPrompt(t *testing.T) {
|
func TestChatPrompt(t *testing.T) {
|
||||||
type expect struct {
|
type expect struct {
|
||||||
prompt string
|
prompt string
|
||||||
|
@ -192,15 +183,11 @@ func TestChatPrompt(t *testing.T) {
|
||||||
t.Run(tt.name, func(t *testing.T) {
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
model := Model{Template: tmpl, ProjectorPaths: []string{"vision"}}
|
model := Model{Template: tmpl, ProjectorPaths: []string{"vision"}}
|
||||||
opts := api.Options{Runner: api.Runner{NumCtx: tt.limit}}
|
opts := api.Options{Runner: api.Runner{NumCtx: tt.limit}}
|
||||||
prompt, images, err := chatPrompt(context.TODO(), &model, tokenize, &opts, tt.msgs, nil)
|
prompt, images, err := chatPrompt(context.TODO(), &model, mockRunner{}.Tokenize, &opts, tt.msgs, nil)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatal(err)
|
t.Fatal(err)
|
||||||
}
|
}
|
||||||
|
|
||||||
if tt.prompt != prompt {
|
|
||||||
t.Errorf("expected %q, got %q", tt.prompt, prompt)
|
|
||||||
}
|
|
||||||
|
|
||||||
if diff := cmp.Diff(prompt, tt.prompt); diff != "" {
|
if diff := cmp.Diff(prompt, tt.prompt); diff != "" {
|
||||||
t.Errorf("mismatch (-got +want):\n%s", diff)
|
t.Errorf("mismatch (-got +want):\n%s", diff)
|
||||||
}
|
}
|
||||||
|
|
|
@ -85,6 +85,8 @@ func checkFileExists(t *testing.T, p string, expect []string) {
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestCreateFromBin(t *testing.T) {
|
func TestCreateFromBin(t *testing.T) {
|
||||||
|
gin.SetMode(gin.TestMode)
|
||||||
|
|
||||||
p := t.TempDir()
|
p := t.TempDir()
|
||||||
t.Setenv("OLLAMA_MODELS", p)
|
t.Setenv("OLLAMA_MODELS", p)
|
||||||
envconfig.LoadConfig()
|
envconfig.LoadConfig()
|
||||||
|
@ -111,6 +113,8 @@ func TestCreateFromBin(t *testing.T) {
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestCreateFromModel(t *testing.T) {
|
func TestCreateFromModel(t *testing.T) {
|
||||||
|
gin.SetMode(gin.TestMode)
|
||||||
|
|
||||||
p := t.TempDir()
|
p := t.TempDir()
|
||||||
t.Setenv("OLLAMA_MODELS", p)
|
t.Setenv("OLLAMA_MODELS", p)
|
||||||
envconfig.LoadConfig()
|
envconfig.LoadConfig()
|
||||||
|
@ -152,6 +156,8 @@ func TestCreateFromModel(t *testing.T) {
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestCreateRemovesLayers(t *testing.T) {
|
func TestCreateRemovesLayers(t *testing.T) {
|
||||||
|
gin.SetMode(gin.TestMode)
|
||||||
|
|
||||||
p := t.TempDir()
|
p := t.TempDir()
|
||||||
t.Setenv("OLLAMA_MODELS", p)
|
t.Setenv("OLLAMA_MODELS", p)
|
||||||
envconfig.LoadConfig()
|
envconfig.LoadConfig()
|
||||||
|
@ -199,6 +205,8 @@ func TestCreateRemovesLayers(t *testing.T) {
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestCreateUnsetsSystem(t *testing.T) {
|
func TestCreateUnsetsSystem(t *testing.T) {
|
||||||
|
gin.SetMode(gin.TestMode)
|
||||||
|
|
||||||
p := t.TempDir()
|
p := t.TempDir()
|
||||||
t.Setenv("OLLAMA_MODELS", p)
|
t.Setenv("OLLAMA_MODELS", p)
|
||||||
envconfig.LoadConfig()
|
envconfig.LoadConfig()
|
||||||
|
@ -255,6 +263,8 @@ func TestCreateUnsetsSystem(t *testing.T) {
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestCreateMergeParameters(t *testing.T) {
|
func TestCreateMergeParameters(t *testing.T) {
|
||||||
|
gin.SetMode(gin.TestMode)
|
||||||
|
|
||||||
p := t.TempDir()
|
p := t.TempDir()
|
||||||
t.Setenv("OLLAMA_MODELS", p)
|
t.Setenv("OLLAMA_MODELS", p)
|
||||||
envconfig.LoadConfig()
|
envconfig.LoadConfig()
|
||||||
|
@ -358,6 +368,8 @@ func TestCreateMergeParameters(t *testing.T) {
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestCreateReplacesMessages(t *testing.T) {
|
func TestCreateReplacesMessages(t *testing.T) {
|
||||||
|
gin.SetMode(gin.TestMode)
|
||||||
|
|
||||||
p := t.TempDir()
|
p := t.TempDir()
|
||||||
t.Setenv("OLLAMA_MODELS", p)
|
t.Setenv("OLLAMA_MODELS", p)
|
||||||
envconfig.LoadConfig()
|
envconfig.LoadConfig()
|
||||||
|
@ -434,6 +446,8 @@ func TestCreateReplacesMessages(t *testing.T) {
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestCreateTemplateSystem(t *testing.T) {
|
func TestCreateTemplateSystem(t *testing.T) {
|
||||||
|
gin.SetMode(gin.TestMode)
|
||||||
|
|
||||||
p := t.TempDir()
|
p := t.TempDir()
|
||||||
t.Setenv("OLLAMA_MODELS", p)
|
t.Setenv("OLLAMA_MODELS", p)
|
||||||
envconfig.LoadConfig()
|
envconfig.LoadConfig()
|
||||||
|
@ -480,6 +494,8 @@ func TestCreateTemplateSystem(t *testing.T) {
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestCreateLicenses(t *testing.T) {
|
func TestCreateLicenses(t *testing.T) {
|
||||||
|
gin.SetMode(gin.TestMode)
|
||||||
|
|
||||||
p := t.TempDir()
|
p := t.TempDir()
|
||||||
t.Setenv("OLLAMA_MODELS", p)
|
t.Setenv("OLLAMA_MODELS", p)
|
||||||
envconfig.LoadConfig()
|
envconfig.LoadConfig()
|
||||||
|
@ -526,6 +542,8 @@ func TestCreateLicenses(t *testing.T) {
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestCreateDetectTemplate(t *testing.T) {
|
func TestCreateDetectTemplate(t *testing.T) {
|
||||||
|
gin.SetMode(gin.TestMode)
|
||||||
|
|
||||||
p := t.TempDir()
|
p := t.TempDir()
|
||||||
t.Setenv("OLLAMA_MODELS", p)
|
t.Setenv("OLLAMA_MODELS", p)
|
||||||
envconfig.LoadConfig()
|
envconfig.LoadConfig()
|
||||||
|
|
|
@ -8,12 +8,15 @@ import (
|
||||||
"path/filepath"
|
"path/filepath"
|
||||||
"testing"
|
"testing"
|
||||||
|
|
||||||
|
"github.com/gin-gonic/gin"
|
||||||
"github.com/ollama/ollama/api"
|
"github.com/ollama/ollama/api"
|
||||||
"github.com/ollama/ollama/envconfig"
|
"github.com/ollama/ollama/envconfig"
|
||||||
"github.com/ollama/ollama/types/model"
|
"github.com/ollama/ollama/types/model"
|
||||||
)
|
)
|
||||||
|
|
||||||
func TestDelete(t *testing.T) {
|
func TestDelete(t *testing.T) {
|
||||||
|
gin.SetMode(gin.TestMode)
|
||||||
|
|
||||||
p := t.TempDir()
|
p := t.TempDir()
|
||||||
t.Setenv("OLLAMA_MODELS", p)
|
t.Setenv("OLLAMA_MODELS", p)
|
||||||
envconfig.LoadConfig()
|
envconfig.LoadConfig()
|
||||||
|
@ -77,6 +80,8 @@ func TestDelete(t *testing.T) {
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestDeleteDuplicateLayers(t *testing.T) {
|
func TestDeleteDuplicateLayers(t *testing.T) {
|
||||||
|
gin.SetMode(gin.TestMode)
|
||||||
|
|
||||||
p := t.TempDir()
|
p := t.TempDir()
|
||||||
t.Setenv("OLLAMA_MODELS", p)
|
t.Setenv("OLLAMA_MODELS", p)
|
||||||
var s Server
|
var s Server
|
||||||
|
|
651
server/routes_generate_test.go
Normal file
651
server/routes_generate_test.go
Normal file
|
@ -0,0 +1,651 @@
|
||||||
|
package server
|
||||||
|
|
||||||
|
import (
|
||||||
|
"bytes"
|
||||||
|
"context"
|
||||||
|
"encoding/json"
|
||||||
|
"fmt"
|
||||||
|
"io"
|
||||||
|
"net/http"
|
||||||
|
"strings"
|
||||||
|
"testing"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"github.com/gin-gonic/gin"
|
||||||
|
"github.com/google/go-cmp/cmp"
|
||||||
|
|
||||||
|
"github.com/ollama/ollama/api"
|
||||||
|
"github.com/ollama/ollama/gpu"
|
||||||
|
"github.com/ollama/ollama/llm"
|
||||||
|
)
|
||||||
|
|
||||||
|
type mockRunner struct {
|
||||||
|
llm.LlamaServer
|
||||||
|
|
||||||
|
// CompletionRequest is only valid until the next call to Completion
|
||||||
|
llm.CompletionRequest
|
||||||
|
llm.CompletionResponse
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m *mockRunner) Completion(_ context.Context, r llm.CompletionRequest, fn func(r llm.CompletionResponse)) error {
|
||||||
|
m.CompletionRequest = r
|
||||||
|
fn(m.CompletionResponse)
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (mockRunner) Tokenize(_ context.Context, s string) (tokens []int, err error) {
|
||||||
|
for range strings.Fields(s) {
|
||||||
|
tokens = append(tokens, len(tokens))
|
||||||
|
}
|
||||||
|
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
func newMockServer(mock *mockRunner) func(gpu.GpuInfoList, string, *llm.GGML, []string, []string, api.Options, int) (llm.LlamaServer, error) {
|
||||||
|
return func(gpus gpu.GpuInfoList, model string, ggml *llm.GGML, projectors, system []string, opts api.Options, numParallel int) (llm.LlamaServer, error) {
|
||||||
|
return mock, nil
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestGenerateChat(t *testing.T) {
|
||||||
|
gin.SetMode(gin.TestMode)
|
||||||
|
|
||||||
|
mock := mockRunner{
|
||||||
|
CompletionResponse: llm.CompletionResponse{
|
||||||
|
Done: true,
|
||||||
|
DoneReason: "stop",
|
||||||
|
PromptEvalCount: 1,
|
||||||
|
PromptEvalDuration: 1,
|
||||||
|
EvalCount: 1,
|
||||||
|
EvalDuration: 1,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
s := Server{
|
||||||
|
sched: &Scheduler{
|
||||||
|
pendingReqCh: make(chan *LlmRequest, 1),
|
||||||
|
finishedReqCh: make(chan *LlmRequest, 1),
|
||||||
|
expiredCh: make(chan *runnerRef, 1),
|
||||||
|
unloadedCh: make(chan any, 1),
|
||||||
|
loaded: make(map[string]*runnerRef),
|
||||||
|
newServerFn: newMockServer(&mock),
|
||||||
|
getGpuFn: gpu.GetGPUInfo,
|
||||||
|
getCpuFn: gpu.GetCPUInfo,
|
||||||
|
reschedDelay: 250 * time.Millisecond,
|
||||||
|
loadFn: func(req *LlmRequest, ggml *llm.GGML, gpus gpu.GpuInfoList, numParallel int) {
|
||||||
|
req.successCh <- &runnerRef{
|
||||||
|
llama: &mock,
|
||||||
|
}
|
||||||
|
},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
go s.sched.Run(context.TODO())
|
||||||
|
|
||||||
|
w := createRequest(t, s.CreateModelHandler, api.CreateRequest{
|
||||||
|
Name: "test",
|
||||||
|
Modelfile: fmt.Sprintf(`FROM %s
|
||||||
|
TEMPLATE """
|
||||||
|
{{- if .System }}System: {{ .System }} {{ end }}
|
||||||
|
{{- if .Prompt }}User: {{ .Prompt }} {{ end }}
|
||||||
|
{{- if .Response }}Assistant: {{ .Response }} {{ end }}"""
|
||||||
|
`, createBinFile(t, llm.KV{
|
||||||
|
"general.architecture": "llama",
|
||||||
|
"llama.block_count": uint32(1),
|
||||||
|
"llama.context_length": uint32(8192),
|
||||||
|
"llama.embedding_length": uint32(4096),
|
||||||
|
"llama.attention.head_count": uint32(32),
|
||||||
|
"llama.attention.head_count_kv": uint32(8),
|
||||||
|
"tokenizer.ggml.tokens": []string{""},
|
||||||
|
"tokenizer.ggml.scores": []float32{0},
|
||||||
|
"tokenizer.ggml.token_type": []int32{0},
|
||||||
|
}, []llm.Tensor{
|
||||||
|
{Name: "token_embd.weight", Shape: []uint64{1}, WriterTo: bytes.NewReader(make([]byte, 4))},
|
||||||
|
{Name: "blk.0.attn_norm.weight", Shape: []uint64{1}, WriterTo: bytes.NewReader(make([]byte, 4))},
|
||||||
|
{Name: "blk.0.ffn_down.weight", Shape: []uint64{1}, WriterTo: bytes.NewReader(make([]byte, 4))},
|
||||||
|
{Name: "blk.0.ffn_gate.weight", Shape: []uint64{1}, WriterTo: bytes.NewReader(make([]byte, 4))},
|
||||||
|
{Name: "blk.0.ffn_up.weight", Shape: []uint64{1}, WriterTo: bytes.NewReader(make([]byte, 4))},
|
||||||
|
{Name: "blk.0.ffn_norm.weight", Shape: []uint64{1}, WriterTo: bytes.NewReader(make([]byte, 4))},
|
||||||
|
{Name: "blk.0.attn_k.weight", Shape: []uint64{1}, WriterTo: bytes.NewReader(make([]byte, 4))},
|
||||||
|
{Name: "blk.0.attn_output.weight", Shape: []uint64{1}, WriterTo: bytes.NewReader(make([]byte, 4))},
|
||||||
|
{Name: "blk.0.attn_q.weight", Shape: []uint64{1}, WriterTo: bytes.NewReader(make([]byte, 4))},
|
||||||
|
{Name: "blk.0.attn_v.weight", Shape: []uint64{1}, WriterTo: bytes.NewReader(make([]byte, 4))},
|
||||||
|
{Name: "output.weight", Shape: []uint64{1}, WriterTo: bytes.NewReader(make([]byte, 4))},
|
||||||
|
})),
|
||||||
|
Stream: &stream,
|
||||||
|
})
|
||||||
|
|
||||||
|
if w.Code != http.StatusOK {
|
||||||
|
t.Fatalf("expected status 200, got %d", w.Code)
|
||||||
|
}
|
||||||
|
|
||||||
|
t.Run("missing body", func(t *testing.T) {
|
||||||
|
w := createRequest(t, s.ChatHandler, nil)
|
||||||
|
if w.Code != http.StatusBadRequest {
|
||||||
|
t.Errorf("expected status 400, got %d", w.Code)
|
||||||
|
}
|
||||||
|
|
||||||
|
if diff := cmp.Diff(w.Body.String(), `{"error":"model is required"}`); diff != "" {
|
||||||
|
t.Errorf("mismatch (-got +want):\n%s", diff)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("missing model", func(t *testing.T) {
|
||||||
|
w := createRequest(t, s.ChatHandler, api.ChatRequest{})
|
||||||
|
if w.Code != http.StatusBadRequest {
|
||||||
|
t.Errorf("expected status 400, got %d", w.Code)
|
||||||
|
}
|
||||||
|
|
||||||
|
if diff := cmp.Diff(w.Body.String(), `{"error":"model is required"}`); diff != "" {
|
||||||
|
t.Errorf("mismatch (-got +want):\n%s", diff)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("missing capabilities", func(t *testing.T) {
|
||||||
|
w := createRequest(t, s.CreateModelHandler, api.CreateRequest{
|
||||||
|
Name: "bert",
|
||||||
|
Modelfile: fmt.Sprintf("FROM %s", createBinFile(t, llm.KV{
|
||||||
|
"general.architecture": "bert",
|
||||||
|
"bert.pooling_type": uint32(0),
|
||||||
|
}, []llm.Tensor{})),
|
||||||
|
Stream: &stream,
|
||||||
|
})
|
||||||
|
|
||||||
|
if w.Code != http.StatusOK {
|
||||||
|
t.Fatalf("expected status 200, got %d", w.Code)
|
||||||
|
}
|
||||||
|
|
||||||
|
w = createRequest(t, s.ChatHandler, api.ChatRequest{
|
||||||
|
Model: "bert",
|
||||||
|
})
|
||||||
|
|
||||||
|
if w.Code != http.StatusBadRequest {
|
||||||
|
t.Errorf("expected status 400, got %d", w.Code)
|
||||||
|
}
|
||||||
|
|
||||||
|
if diff := cmp.Diff(w.Body.String(), `{"error":"\"bert\" does not support chat"}`); diff != "" {
|
||||||
|
t.Errorf("mismatch (-got +want):\n%s", diff)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("load model", func(t *testing.T) {
|
||||||
|
w := createRequest(t, s.ChatHandler, api.ChatRequest{
|
||||||
|
Model: "test",
|
||||||
|
})
|
||||||
|
|
||||||
|
if w.Code != http.StatusOK {
|
||||||
|
t.Errorf("expected status 200, got %d", w.Code)
|
||||||
|
}
|
||||||
|
|
||||||
|
var actual api.ChatResponse
|
||||||
|
if err := json.NewDecoder(w.Body).Decode(&actual); err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if actual.Model != "test" {
|
||||||
|
t.Errorf("expected model test, got %s", actual.Model)
|
||||||
|
}
|
||||||
|
|
||||||
|
if !actual.Done {
|
||||||
|
t.Errorf("expected done true, got false")
|
||||||
|
}
|
||||||
|
|
||||||
|
if actual.DoneReason != "load" {
|
||||||
|
t.Errorf("expected done reason load, got %s", actual.DoneReason)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
|
||||||
|
checkChatResponse := func(t *testing.T, body io.Reader, model, content string) {
|
||||||
|
t.Helper()
|
||||||
|
|
||||||
|
var actual api.ChatResponse
|
||||||
|
if err := json.NewDecoder(body).Decode(&actual); err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if actual.Model != model {
|
||||||
|
t.Errorf("expected model test, got %s", actual.Model)
|
||||||
|
}
|
||||||
|
|
||||||
|
if !actual.Done {
|
||||||
|
t.Errorf("expected done false, got true")
|
||||||
|
}
|
||||||
|
|
||||||
|
if actual.DoneReason != "stop" {
|
||||||
|
t.Errorf("expected done reason stop, got %s", actual.DoneReason)
|
||||||
|
}
|
||||||
|
|
||||||
|
if diff := cmp.Diff(actual.Message, api.Message{
|
||||||
|
Role: "assistant",
|
||||||
|
Content: content,
|
||||||
|
}); diff != "" {
|
||||||
|
t.Errorf("mismatch (-got +want):\n%s", diff)
|
||||||
|
}
|
||||||
|
|
||||||
|
if actual.PromptEvalCount == 0 {
|
||||||
|
t.Errorf("expected prompt eval count > 0, got 0")
|
||||||
|
}
|
||||||
|
|
||||||
|
if actual.PromptEvalDuration == 0 {
|
||||||
|
t.Errorf("expected prompt eval duration > 0, got 0")
|
||||||
|
}
|
||||||
|
|
||||||
|
if actual.EvalCount == 0 {
|
||||||
|
t.Errorf("expected eval count > 0, got 0")
|
||||||
|
}
|
||||||
|
|
||||||
|
if actual.EvalDuration == 0 {
|
||||||
|
t.Errorf("expected eval duration > 0, got 0")
|
||||||
|
}
|
||||||
|
|
||||||
|
if actual.LoadDuration == 0 {
|
||||||
|
t.Errorf("expected load duration > 0, got 0")
|
||||||
|
}
|
||||||
|
|
||||||
|
if actual.TotalDuration == 0 {
|
||||||
|
t.Errorf("expected load duration > 0, got 0")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
mock.CompletionResponse.Content = "Hi!"
|
||||||
|
t.Run("messages", func(t *testing.T) {
|
||||||
|
w := createRequest(t, s.ChatHandler, api.ChatRequest{
|
||||||
|
Model: "test",
|
||||||
|
Messages: []api.Message{
|
||||||
|
{Role: "user", Content: "Hello!"},
|
||||||
|
},
|
||||||
|
Stream: &stream,
|
||||||
|
})
|
||||||
|
|
||||||
|
if w.Code != http.StatusOK {
|
||||||
|
t.Errorf("expected status 200, got %d", w.Code)
|
||||||
|
}
|
||||||
|
|
||||||
|
if diff := cmp.Diff(mock.CompletionRequest.Prompt, "User: Hello! "); diff != "" {
|
||||||
|
t.Errorf("mismatch (-got +want):\n%s", diff)
|
||||||
|
}
|
||||||
|
|
||||||
|
checkChatResponse(t, w.Body, "test", "Hi!")
|
||||||
|
})
|
||||||
|
|
||||||
|
w = createRequest(t, s.CreateModelHandler, api.CreateRequest{
|
||||||
|
Model: "test-system",
|
||||||
|
Modelfile: "FROM test\nSYSTEM You are a helpful assistant.",
|
||||||
|
})
|
||||||
|
|
||||||
|
if w.Code != http.StatusOK {
|
||||||
|
t.Fatalf("expected status 200, got %d", w.Code)
|
||||||
|
}
|
||||||
|
|
||||||
|
t.Run("messages with model system", func(t *testing.T) {
|
||||||
|
w := createRequest(t, s.ChatHandler, api.ChatRequest{
|
||||||
|
Model: "test-system",
|
||||||
|
Messages: []api.Message{
|
||||||
|
{Role: "user", Content: "Hello!"},
|
||||||
|
},
|
||||||
|
Stream: &stream,
|
||||||
|
})
|
||||||
|
|
||||||
|
if w.Code != http.StatusOK {
|
||||||
|
t.Errorf("expected status 200, got %d", w.Code)
|
||||||
|
}
|
||||||
|
|
||||||
|
if diff := cmp.Diff(mock.CompletionRequest.Prompt, "System: You are a helpful assistant. User: Hello! "); diff != "" {
|
||||||
|
t.Errorf("mismatch (-got +want):\n%s", diff)
|
||||||
|
}
|
||||||
|
|
||||||
|
checkChatResponse(t, w.Body, "test-system", "Hi!")
|
||||||
|
})
|
||||||
|
|
||||||
|
mock.CompletionResponse.Content = "Abra kadabra!"
|
||||||
|
t.Run("messages with system", func(t *testing.T) {
|
||||||
|
w := createRequest(t, s.ChatHandler, api.ChatRequest{
|
||||||
|
Model: "test-system",
|
||||||
|
Messages: []api.Message{
|
||||||
|
{Role: "system", Content: "You can perform magic tricks."},
|
||||||
|
{Role: "user", Content: "Hello!"},
|
||||||
|
},
|
||||||
|
Stream: &stream,
|
||||||
|
})
|
||||||
|
|
||||||
|
if w.Code != http.StatusOK {
|
||||||
|
t.Errorf("expected status 200, got %d", w.Code)
|
||||||
|
}
|
||||||
|
|
||||||
|
if diff := cmp.Diff(mock.CompletionRequest.Prompt, "System: You can perform magic tricks. User: Hello! "); diff != "" {
|
||||||
|
t.Errorf("mismatch (-got +want):\n%s", diff)
|
||||||
|
}
|
||||||
|
|
||||||
|
checkChatResponse(t, w.Body, "test-system", "Abra kadabra!")
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("messages with interleaved system", func(t *testing.T) {
|
||||||
|
w := createRequest(t, s.ChatHandler, api.ChatRequest{
|
||||||
|
Model: "test-system",
|
||||||
|
Messages: []api.Message{
|
||||||
|
{Role: "user", Content: "Hello!"},
|
||||||
|
{Role: "assistant", Content: "I can help you with that."},
|
||||||
|
{Role: "system", Content: "You can perform magic tricks."},
|
||||||
|
{Role: "user", Content: "Help me write tests."},
|
||||||
|
},
|
||||||
|
Stream: &stream,
|
||||||
|
})
|
||||||
|
|
||||||
|
if w.Code != http.StatusOK {
|
||||||
|
t.Errorf("expected status 200, got %d", w.Code)
|
||||||
|
}
|
||||||
|
|
||||||
|
if diff := cmp.Diff(mock.CompletionRequest.Prompt, "System: You are a helpful assistant. User: Hello! Assistant: I can help you with that. System: You can perform magic tricks. User: Help me write tests. "); diff != "" {
|
||||||
|
t.Errorf("mismatch (-got +want):\n%s", diff)
|
||||||
|
}
|
||||||
|
|
||||||
|
checkChatResponse(t, w.Body, "test-system", "Abra kadabra!")
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestGenerate(t *testing.T) {
|
||||||
|
gin.SetMode(gin.TestMode)
|
||||||
|
|
||||||
|
mock := mockRunner{
|
||||||
|
CompletionResponse: llm.CompletionResponse{
|
||||||
|
Done: true,
|
||||||
|
DoneReason: "stop",
|
||||||
|
PromptEvalCount: 1,
|
||||||
|
PromptEvalDuration: 1,
|
||||||
|
EvalCount: 1,
|
||||||
|
EvalDuration: 1,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
s := Server{
|
||||||
|
sched: &Scheduler{
|
||||||
|
pendingReqCh: make(chan *LlmRequest, 1),
|
||||||
|
finishedReqCh: make(chan *LlmRequest, 1),
|
||||||
|
expiredCh: make(chan *runnerRef, 1),
|
||||||
|
unloadedCh: make(chan any, 1),
|
||||||
|
loaded: make(map[string]*runnerRef),
|
||||||
|
newServerFn: newMockServer(&mock),
|
||||||
|
getGpuFn: gpu.GetGPUInfo,
|
||||||
|
getCpuFn: gpu.GetCPUInfo,
|
||||||
|
reschedDelay: 250 * time.Millisecond,
|
||||||
|
loadFn: func(req *LlmRequest, ggml *llm.GGML, gpus gpu.GpuInfoList, numParallel int) {
|
||||||
|
req.successCh <- &runnerRef{
|
||||||
|
llama: &mock,
|
||||||
|
}
|
||||||
|
},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
go s.sched.Run(context.TODO())
|
||||||
|
|
||||||
|
w := createRequest(t, s.CreateModelHandler, api.CreateRequest{
|
||||||
|
Name: "test",
|
||||||
|
Modelfile: fmt.Sprintf(`FROM %s
|
||||||
|
TEMPLATE """
|
||||||
|
{{- if .System }}System: {{ .System }} {{ end }}
|
||||||
|
{{- if .Prompt }}User: {{ .Prompt }} {{ end }}
|
||||||
|
{{- if .Response }}Assistant: {{ .Response }} {{ end }}"""
|
||||||
|
`, createBinFile(t, llm.KV{
|
||||||
|
"general.architecture": "llama",
|
||||||
|
"llama.block_count": uint32(1),
|
||||||
|
"llama.context_length": uint32(8192),
|
||||||
|
"llama.embedding_length": uint32(4096),
|
||||||
|
"llama.attention.head_count": uint32(32),
|
||||||
|
"llama.attention.head_count_kv": uint32(8),
|
||||||
|
"tokenizer.ggml.tokens": []string{""},
|
||||||
|
"tokenizer.ggml.scores": []float32{0},
|
||||||
|
"tokenizer.ggml.token_type": []int32{0},
|
||||||
|
}, []llm.Tensor{
|
||||||
|
{Name: "token_embd.weight", Shape: []uint64{1}, WriterTo: bytes.NewReader(make([]byte, 4))},
|
||||||
|
{Name: "blk.0.attn_norm.weight", Shape: []uint64{1}, WriterTo: bytes.NewReader(make([]byte, 4))},
|
||||||
|
{Name: "blk.0.ffn_down.weight", Shape: []uint64{1}, WriterTo: bytes.NewReader(make([]byte, 4))},
|
||||||
|
{Name: "blk.0.ffn_gate.weight", Shape: []uint64{1}, WriterTo: bytes.NewReader(make([]byte, 4))},
|
||||||
|
{Name: "blk.0.ffn_up.weight", Shape: []uint64{1}, WriterTo: bytes.NewReader(make([]byte, 4))},
|
||||||
|
{Name: "blk.0.ffn_norm.weight", Shape: []uint64{1}, WriterTo: bytes.NewReader(make([]byte, 4))},
|
||||||
|
{Name: "blk.0.attn_k.weight", Shape: []uint64{1}, WriterTo: bytes.NewReader(make([]byte, 4))},
|
||||||
|
{Name: "blk.0.attn_output.weight", Shape: []uint64{1}, WriterTo: bytes.NewReader(make([]byte, 4))},
|
||||||
|
{Name: "blk.0.attn_q.weight", Shape: []uint64{1}, WriterTo: bytes.NewReader(make([]byte, 4))},
|
||||||
|
{Name: "blk.0.attn_v.weight", Shape: []uint64{1}, WriterTo: bytes.NewReader(make([]byte, 4))},
|
||||||
|
{Name: "output.weight", Shape: []uint64{1}, WriterTo: bytes.NewReader(make([]byte, 4))},
|
||||||
|
})),
|
||||||
|
Stream: &stream,
|
||||||
|
})
|
||||||
|
|
||||||
|
if w.Code != http.StatusOK {
|
||||||
|
t.Fatalf("expected status 200, got %d", w.Code)
|
||||||
|
}
|
||||||
|
|
||||||
|
t.Run("missing body", func(t *testing.T) {
|
||||||
|
w := createRequest(t, s.GenerateHandler, nil)
|
||||||
|
if w.Code != http.StatusBadRequest {
|
||||||
|
t.Errorf("expected status 400, got %d", w.Code)
|
||||||
|
}
|
||||||
|
|
||||||
|
if diff := cmp.Diff(w.Body.String(), `{"error":"model is required"}`); diff != "" {
|
||||||
|
t.Errorf("mismatch (-got +want):\n%s", diff)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("missing model", func(t *testing.T) {
|
||||||
|
w := createRequest(t, s.GenerateHandler, api.GenerateRequest{})
|
||||||
|
if w.Code != http.StatusBadRequest {
|
||||||
|
t.Errorf("expected status 400, got %d", w.Code)
|
||||||
|
}
|
||||||
|
|
||||||
|
if diff := cmp.Diff(w.Body.String(), `{"error":"model is required"}`); diff != "" {
|
||||||
|
t.Errorf("mismatch (-got +want):\n%s", diff)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("missing capabilities", func(t *testing.T) {
|
||||||
|
w := createRequest(t, s.CreateModelHandler, api.CreateRequest{
|
||||||
|
Name: "bert",
|
||||||
|
Modelfile: fmt.Sprintf("FROM %s", createBinFile(t, llm.KV{
|
||||||
|
"general.architecture": "bert",
|
||||||
|
"bert.pooling_type": uint32(0),
|
||||||
|
}, []llm.Tensor{})),
|
||||||
|
Stream: &stream,
|
||||||
|
})
|
||||||
|
|
||||||
|
if w.Code != http.StatusOK {
|
||||||
|
t.Fatalf("expected status 200, got %d", w.Code)
|
||||||
|
}
|
||||||
|
|
||||||
|
w = createRequest(t, s.GenerateHandler, api.GenerateRequest{
|
||||||
|
Model: "bert",
|
||||||
|
})
|
||||||
|
|
||||||
|
if w.Code != http.StatusBadRequest {
|
||||||
|
t.Errorf("expected status 400, got %d", w.Code)
|
||||||
|
}
|
||||||
|
|
||||||
|
if diff := cmp.Diff(w.Body.String(), `{"error":"\"bert\" does not support generate"}`); diff != "" {
|
||||||
|
t.Errorf("mismatch (-got +want):\n%s", diff)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("load model", func(t *testing.T) {
|
||||||
|
w := createRequest(t, s.GenerateHandler, api.GenerateRequest{
|
||||||
|
Model: "test",
|
||||||
|
})
|
||||||
|
|
||||||
|
if w.Code != http.StatusOK {
|
||||||
|
t.Errorf("expected status 200, got %d", w.Code)
|
||||||
|
}
|
||||||
|
|
||||||
|
var actual api.GenerateResponse
|
||||||
|
if err := json.NewDecoder(w.Body).Decode(&actual); err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if actual.Model != "test" {
|
||||||
|
t.Errorf("expected model test, got %s", actual.Model)
|
||||||
|
}
|
||||||
|
|
||||||
|
if !actual.Done {
|
||||||
|
t.Errorf("expected done true, got false")
|
||||||
|
}
|
||||||
|
|
||||||
|
if actual.DoneReason != "load" {
|
||||||
|
t.Errorf("expected done reason load, got %s", actual.DoneReason)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
|
||||||
|
checkGenerateResponse := func(t *testing.T, body io.Reader, model, content string) {
|
||||||
|
t.Helper()
|
||||||
|
|
||||||
|
var actual api.GenerateResponse
|
||||||
|
if err := json.NewDecoder(body).Decode(&actual); err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if actual.Model != model {
|
||||||
|
t.Errorf("expected model test, got %s", actual.Model)
|
||||||
|
}
|
||||||
|
|
||||||
|
if !actual.Done {
|
||||||
|
t.Errorf("expected done false, got true")
|
||||||
|
}
|
||||||
|
|
||||||
|
if actual.DoneReason != "stop" {
|
||||||
|
t.Errorf("expected done reason stop, got %s", actual.DoneReason)
|
||||||
|
}
|
||||||
|
|
||||||
|
if actual.Response != content {
|
||||||
|
t.Errorf("expected response %s, got %s", content, actual.Response)
|
||||||
|
}
|
||||||
|
|
||||||
|
if actual.Context == nil {
|
||||||
|
t.Errorf("expected context not nil")
|
||||||
|
}
|
||||||
|
|
||||||
|
if actual.PromptEvalCount == 0 {
|
||||||
|
t.Errorf("expected prompt eval count > 0, got 0")
|
||||||
|
}
|
||||||
|
|
||||||
|
if actual.PromptEvalDuration == 0 {
|
||||||
|
t.Errorf("expected prompt eval duration > 0, got 0")
|
||||||
|
}
|
||||||
|
|
||||||
|
if actual.EvalCount == 0 {
|
||||||
|
t.Errorf("expected eval count > 0, got 0")
|
||||||
|
}
|
||||||
|
|
||||||
|
if actual.EvalDuration == 0 {
|
||||||
|
t.Errorf("expected eval duration > 0, got 0")
|
||||||
|
}
|
||||||
|
|
||||||
|
if actual.LoadDuration == 0 {
|
||||||
|
t.Errorf("expected load duration > 0, got 0")
|
||||||
|
}
|
||||||
|
|
||||||
|
if actual.TotalDuration == 0 {
|
||||||
|
t.Errorf("expected load duration > 0, got 0")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
mock.CompletionResponse.Content = "Hi!"
|
||||||
|
t.Run("prompt", func(t *testing.T) {
|
||||||
|
w := createRequest(t, s.GenerateHandler, api.GenerateRequest{
|
||||||
|
Model: "test",
|
||||||
|
Prompt: "Hello!",
|
||||||
|
Stream: &stream,
|
||||||
|
})
|
||||||
|
|
||||||
|
if w.Code != http.StatusOK {
|
||||||
|
t.Errorf("expected status 200, got %d", w.Code)
|
||||||
|
}
|
||||||
|
|
||||||
|
if diff := cmp.Diff(mock.CompletionRequest.Prompt, "User: Hello! "); diff != "" {
|
||||||
|
t.Errorf("mismatch (-got +want):\n%s", diff)
|
||||||
|
}
|
||||||
|
|
||||||
|
checkGenerateResponse(t, w.Body, "test", "Hi!")
|
||||||
|
})
|
||||||
|
|
||||||
|
w = createRequest(t, s.CreateModelHandler, api.CreateRequest{
|
||||||
|
Model: "test-system",
|
||||||
|
Modelfile: "FROM test\nSYSTEM You are a helpful assistant.",
|
||||||
|
})
|
||||||
|
|
||||||
|
if w.Code != http.StatusOK {
|
||||||
|
t.Fatalf("expected status 200, got %d", w.Code)
|
||||||
|
}
|
||||||
|
|
||||||
|
t.Run("prompt with model system", func(t *testing.T) {
|
||||||
|
w := createRequest(t, s.GenerateHandler, api.GenerateRequest{
|
||||||
|
Model: "test-system",
|
||||||
|
Prompt: "Hello!",
|
||||||
|
Stream: &stream,
|
||||||
|
})
|
||||||
|
|
||||||
|
if w.Code != http.StatusOK {
|
||||||
|
t.Errorf("expected status 200, got %d", w.Code)
|
||||||
|
}
|
||||||
|
|
||||||
|
if diff := cmp.Diff(mock.CompletionRequest.Prompt, "System: You are a helpful assistant. User: Hello! "); diff != "" {
|
||||||
|
t.Errorf("mismatch (-got +want):\n%s", diff)
|
||||||
|
}
|
||||||
|
|
||||||
|
checkGenerateResponse(t, w.Body, "test-system", "Hi!")
|
||||||
|
})
|
||||||
|
|
||||||
|
mock.CompletionResponse.Content = "Abra kadabra!"
|
||||||
|
t.Run("prompt with system", func(t *testing.T) {
|
||||||
|
w := createRequest(t, s.GenerateHandler, api.GenerateRequest{
|
||||||
|
Model: "test-system",
|
||||||
|
Prompt: "Hello!",
|
||||||
|
System: "You can perform magic tricks.",
|
||||||
|
Stream: &stream,
|
||||||
|
})
|
||||||
|
|
||||||
|
if w.Code != http.StatusOK {
|
||||||
|
t.Errorf("expected status 200, got %d", w.Code)
|
||||||
|
}
|
||||||
|
|
||||||
|
if diff := cmp.Diff(mock.CompletionRequest.Prompt, "System: You can perform magic tricks. User: Hello! "); diff != "" {
|
||||||
|
t.Errorf("mismatch (-got +want):\n%s", diff)
|
||||||
|
}
|
||||||
|
|
||||||
|
checkGenerateResponse(t, w.Body, "test-system", "Abra kadabra!")
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("prompt with template", func(t *testing.T) {
|
||||||
|
w := createRequest(t, s.GenerateHandler, api.GenerateRequest{
|
||||||
|
Model: "test-system",
|
||||||
|
Prompt: "Help me write tests.",
|
||||||
|
System: "You can perform magic tricks.",
|
||||||
|
Template: `{{- if .System }}{{ .System }} {{ end }}
|
||||||
|
{{- if .Prompt }}### USER {{ .Prompt }} {{ end }}
|
||||||
|
{{- if .Response }}### ASSISTANT {{ .Response }} {{ end }}`,
|
||||||
|
Stream: &stream,
|
||||||
|
})
|
||||||
|
|
||||||
|
if w.Code != http.StatusOK {
|
||||||
|
t.Errorf("expected status 200, got %d", w.Code)
|
||||||
|
}
|
||||||
|
|
||||||
|
if diff := cmp.Diff(mock.CompletionRequest.Prompt, "You can perform magic tricks. ### USER Help me write tests. "); diff != "" {
|
||||||
|
t.Errorf("mismatch (-got +want):\n%s", diff)
|
||||||
|
}
|
||||||
|
|
||||||
|
checkGenerateResponse(t, w.Body, "test-system", "Abra kadabra!")
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("raw", func(t *testing.T) {
|
||||||
|
w := createRequest(t, s.GenerateHandler, api.GenerateRequest{
|
||||||
|
Model: "test-system",
|
||||||
|
Prompt: "Help me write tests.",
|
||||||
|
Raw: true,
|
||||||
|
Stream: &stream,
|
||||||
|
})
|
||||||
|
|
||||||
|
if w.Code != http.StatusOK {
|
||||||
|
t.Errorf("expected status 200, got %d", w.Code)
|
||||||
|
}
|
||||||
|
|
||||||
|
if diff := cmp.Diff(mock.CompletionRequest.Prompt, "Help me write tests."); diff != "" {
|
||||||
|
t.Errorf("mismatch (-got +want):\n%s", diff)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
|
@ -7,11 +7,14 @@ import (
|
||||||
"slices"
|
"slices"
|
||||||
"testing"
|
"testing"
|
||||||
|
|
||||||
|
"github.com/gin-gonic/gin"
|
||||||
"github.com/ollama/ollama/api"
|
"github.com/ollama/ollama/api"
|
||||||
"github.com/ollama/ollama/envconfig"
|
"github.com/ollama/ollama/envconfig"
|
||||||
)
|
)
|
||||||
|
|
||||||
func TestList(t *testing.T) {
|
func TestList(t *testing.T) {
|
||||||
|
gin.SetMode(gin.TestMode)
|
||||||
|
|
||||||
t.Setenv("OLLAMA_MODELS", t.TempDir())
|
t.Setenv("OLLAMA_MODELS", t.TempDir())
|
||||||
envconfig.LoadConfig()
|
envconfig.LoadConfig()
|
||||||
|
|
||||||
|
|
Loading…
Reference in a new issue