From f4408219e92a7b22107a68d0b3f5eb545c06aed9 Mon Sep 17 00:00:00 2001 From: Daniel Hiltgen Date: Fri, 5 Jul 2024 15:30:06 -0700 Subject: [PATCH 01/29] Refine scheduler unit tests for reliability This breaks up some of the test scenarios to create a more reliable set of tests, as well as adding a little more coverage. --- server/sched_test.go | 327 ++++++++++++++++++++++++++----------------- 1 file changed, 196 insertions(+), 131 deletions(-) diff --git a/server/sched_test.go b/server/sched_test.go index 3fbd188a..c16b407d 100644 --- a/server/sched_test.go +++ b/server/sched_test.go @@ -7,6 +7,7 @@ import ( "fmt" "log/slog" "os" + "runtime" "testing" "time" @@ -94,7 +95,7 @@ func TestLoad(t *testing.T) { require.Len(t, s.expiredCh, 1) } -type bundle struct { +type reqBundle struct { ctx context.Context //nolint:containedctx ctxDone func() srv *mockLlm @@ -102,13 +103,13 @@ type bundle struct { ggml *llm.GGML } -func (scenario *bundle) newServer(gpus gpu.GpuInfoList, model string, ggml *llm.GGML, adapters []string, projectors []string, opts api.Options, numParallel int) (llm.LlamaServer, error) { +func (scenario *reqBundle) newServer(gpus gpu.GpuInfoList, model string, ggml *llm.GGML, adapters []string, projectors []string, opts api.Options, numParallel int) (llm.LlamaServer, error) { return scenario.srv, nil } -func newScenario(t *testing.T, ctx context.Context, modelName string, estimatedVRAM uint64) *bundle { - scenario := &bundle{} - scenario.ctx, scenario.ctxDone = context.WithCancel(ctx) +func newScenarioRequest(t *testing.T, ctx context.Context, modelName string, estimatedVRAM uint64, duration *api.Duration) *reqBundle { + b := &reqBundle{} + b.ctx, b.ctxDone = context.WithCancel(ctx) t.Helper() f, err := os.CreateTemp(t.TempDir(), modelName) @@ -135,124 +136,154 @@ func newScenario(t *testing.T, ctx context.Context, modelName string, estimatedV fname := f.Name() model := &Model{Name: modelName, ModelPath: fname} - scenario.ggml, err = llm.LoadModel(model.ModelPath, 0) + b.ggml, err = llm.LoadModel(model.ModelPath, 0) require.NoError(t, err) - scenario.req = &LlmRequest{ - ctx: scenario.ctx, + if duration == nil { + duration = &api.Duration{Duration: 5 * time.Millisecond} + } + b.req = &LlmRequest{ + ctx: b.ctx, model: model, opts: api.DefaultOptions(), - sessionDuration: &api.Duration{Duration: 5 * time.Millisecond}, + sessionDuration: duration, successCh: make(chan *runnerRef, 1), errCh: make(chan error, 1), } - scenario.srv = &mockLlm{estimatedVRAM: estimatedVRAM, estimatedVRAMByGPU: map[string]uint64{"": estimatedVRAM}} - return scenario + b.srv = &mockLlm{estimatedVRAM: estimatedVRAM, estimatedVRAMByGPU: map[string]uint64{"": estimatedVRAM}} + return b } -func TestRequests(t *testing.T) { - ctx, done := context.WithTimeout(context.Background(), 10*time.Second) +func getGpuFn() gpu.GpuInfoList { + g := gpu.GpuInfo{Library: "metal"} + g.TotalMemory = 24 * format.GigaByte + g.FreeMemory = 12 * format.GigaByte + return []gpu.GpuInfo{g} +} + +func getCpuFn() gpu.GpuInfoList { + g := gpu.GpuInfo{Library: "cpu"} + g.TotalMemory = 32 * format.GigaByte + g.FreeMemory = 26 * format.GigaByte + return []gpu.GpuInfo{g} +} + +func TestRequestsSameModelSameRequest(t *testing.T) { + ctx, done := context.WithTimeout(context.Background(), 500*time.Millisecond) defer done() - - // Same model, same request - scenario1a := newScenario(t, ctx, "ollama-model-1", 10) - scenario1a.req.sessionDuration = &api.Duration{Duration: 5 * time.Millisecond} - scenario1b := newScenario(t, ctx, "ollama-model-1", 11) - scenario1b.req.model = scenario1a.req.model - scenario1b.ggml = scenario1a.ggml - scenario1b.req.sessionDuration = &api.Duration{Duration: 0} - - // simple reload of same model - scenario2a := newScenario(t, ctx, "ollama-model-1", 20) - tmpModel := *scenario1a.req.model - scenario2a.req.model = &tmpModel - scenario2a.ggml = scenario1a.ggml - scenario2a.req.sessionDuration = &api.Duration{Duration: 5 * time.Millisecond} - - // Multiple loaded models - scenario3a := newScenario(t, ctx, "ollama-model-3a", 1*format.GigaByte) - scenario3b := newScenario(t, ctx, "ollama-model-3b", 24*format.GigaByte) - scenario3c := newScenario(t, ctx, "ollama-model-4a", 30) - scenario3c.req.opts.NumGPU = 0 // CPU load, will be allowed - scenario3d := newScenario(t, ctx, "ollama-model-3c", 30) // Needs prior unloaded - s := InitScheduler(ctx) - s.getGpuFn = func() gpu.GpuInfoList { - g := gpu.GpuInfo{Library: "metal"} - g.TotalMemory = 24 * format.GigaByte - g.FreeMemory = 12 * format.GigaByte - return []gpu.GpuInfo{g} - } - s.getCpuFn = func() gpu.GpuInfoList { - g := gpu.GpuInfo{Library: "cpu"} - g.TotalMemory = 32 * format.GigaByte - g.FreeMemory = 26 * format.GigaByte - return []gpu.GpuInfo{g} - } - s.newServerFn = scenario1a.newServer - slog.Info("scenario1a") - s.pendingReqCh <- scenario1a.req + s.getGpuFn = getGpuFn + s.getCpuFn = getCpuFn + a := newScenarioRequest(t, ctx, "ollama-model-1", 10, &api.Duration{Duration: 5 * time.Millisecond}) + b := newScenarioRequest(t, ctx, "ollama-model-1", 11, &api.Duration{Duration: 0}) + b.req.model = a.req.model + b.ggml = a.ggml + + s.newServerFn = a.newServer + slog.Info("a") + s.pendingReqCh <- a.req require.Len(t, s.pendingReqCh, 1) s.Run(ctx) select { - case resp := <-scenario1a.req.successCh: - require.Equal(t, resp.llama, scenario1a.srv) + case resp := <-a.req.successCh: + require.Equal(t, resp.llama, a.srv) require.Empty(t, s.pendingReqCh) - require.Empty(t, scenario1a.req.errCh) - case err := <-scenario1a.req.errCh: + require.Empty(t, a.req.errCh) + case err := <-a.req.errCh: t.Fatal(err.Error()) case <-ctx.Done(): t.Fatal("timeout") } // Same runner as first request due to not needing a reload - s.newServerFn = scenario1b.newServer - slog.Info("scenario1b") - s.pendingReqCh <- scenario1b.req + s.newServerFn = b.newServer + slog.Info("b") + s.pendingReqCh <- b.req select { - case resp := <-scenario1b.req.successCh: - require.Equal(t, resp.llama, scenario1a.srv) + case resp := <-b.req.successCh: + require.Equal(t, resp.llama, a.srv) require.Empty(t, s.pendingReqCh) - require.Empty(t, scenario1b.req.errCh) - case err := <-scenario1b.req.errCh: + require.Empty(t, b.req.errCh) + case err := <-b.req.errCh: + t.Fatal(err.Error()) + case <-ctx.Done(): + t.Fatal("timeout") + } +} + +func TestRequestsSimpleReloadSameModel(t *testing.T) { + ctx, done := context.WithTimeout(context.Background(), 500*time.Millisecond) + defer done() + s := InitScheduler(ctx) + s.getGpuFn = getGpuFn + s.getCpuFn = getCpuFn + a := newScenarioRequest(t, ctx, "ollama-model-1", 10, &api.Duration{Duration: 5 * time.Millisecond}) + b := newScenarioRequest(t, ctx, "ollama-model-1", 20, &api.Duration{Duration: 5 * time.Millisecond}) + tmpModel := *a.req.model + b.req.model = &tmpModel + b.ggml = a.ggml + + s.newServerFn = a.newServer + slog.Info("a") + s.pendingReqCh <- a.req + require.Len(t, s.pendingReqCh, 1) + s.Run(ctx) + select { + case resp := <-a.req.successCh: + require.Equal(t, resp.llama, a.srv) + require.Empty(t, s.pendingReqCh) + require.Empty(t, a.req.errCh) + case err := <-a.req.errCh: t.Fatal(err.Error()) case <-ctx.Done(): t.Fatal("timeout") } // Trigger a reload - s.newServerFn = scenario2a.newServer - scenario2a.req.model.AdapterPaths = []string{"new"} - slog.Info("scenario2a") - s.pendingReqCh <- scenario2a.req + s.newServerFn = b.newServer + b.req.model.AdapterPaths = []string{"new"} + slog.Info("b") + s.pendingReqCh <- b.req // finish first two requests, so model can reload time.Sleep(1 * time.Millisecond) - scenario1a.ctxDone() - scenario1b.ctxDone() + a.ctxDone() select { - case resp := <-scenario2a.req.successCh: - require.Equal(t, resp.llama, scenario2a.srv) + case resp := <-b.req.successCh: + require.Equal(t, resp.llama, b.srv) require.Empty(t, s.pendingReqCh) - require.Empty(t, scenario2a.req.errCh) - case err := <-scenario2a.req.errCh: + require.Empty(t, b.req.errCh) + case err := <-b.req.errCh: t.Fatal(err.Error()) case <-ctx.Done(): t.Fatal("timeout") } +} + +func TestRequestsMultipleLoadedModels(t *testing.T) { + ctx, done := context.WithTimeout(context.Background(), 500*time.Millisecond) + defer done() + s := InitScheduler(ctx) + s.getGpuFn = getGpuFn + s.getCpuFn = getCpuFn + + // Multiple loaded models + a := newScenarioRequest(t, ctx, "ollama-model-3a", 1*format.GigaByte, nil) + b := newScenarioRequest(t, ctx, "ollama-model-3b", 24*format.GigaByte, nil) + c := newScenarioRequest(t, ctx, "ollama-model-4a", 30, nil) + c.req.opts.NumGPU = 0 // CPU load, will be allowed + d := newScenarioRequest(t, ctx, "ollama-model-3c", 30, nil) // Needs prior unloaded envconfig.MaxRunners = 1 - s.newServerFn = scenario3a.newServer - slog.Info("scenario3a") - s.pendingReqCh <- scenario3a.req - // finish prior request, so new model can load - time.Sleep(1 * time.Millisecond) - scenario2a.ctxDone() + s.newServerFn = a.newServer + slog.Info("a") + s.pendingReqCh <- a.req + s.Run(ctx) select { - case resp := <-scenario3a.req.successCh: - require.Equal(t, resp.llama, scenario3a.srv) + case resp := <-a.req.successCh: + require.Equal(t, resp.llama, a.srv) require.Empty(t, s.pendingReqCh) - require.Empty(t, scenario3a.req.errCh) - case err := <-scenario3a.req.errCh: + require.Empty(t, a.req.errCh) + case err := <-a.req.errCh: t.Fatal(err.Error()) case <-ctx.Done(): t.Fatal("timeout") @@ -262,15 +293,15 @@ func TestRequests(t *testing.T) { s.loadedMu.Unlock() envconfig.MaxRunners = 0 - s.newServerFn = scenario3b.newServer - slog.Info("scenario3b") - s.pendingReqCh <- scenario3b.req + s.newServerFn = b.newServer + slog.Info("b") + s.pendingReqCh <- b.req select { - case resp := <-scenario3b.req.successCh: - require.Equal(t, resp.llama, scenario3b.srv) + case resp := <-b.req.successCh: + require.Equal(t, resp.llama, b.srv) require.Empty(t, s.pendingReqCh) - require.Empty(t, scenario3b.req.errCh) - case err := <-scenario3b.req.errCh: + require.Empty(t, b.req.errCh) + case err := <-b.req.errCh: t.Fatal(err.Error()) case <-ctx.Done(): t.Fatal("timeout") @@ -280,15 +311,15 @@ func TestRequests(t *testing.T) { s.loadedMu.Unlock() // This is a CPU load with NumGPU = 0 so it should load - s.newServerFn = scenario3c.newServer - slog.Info("scenario3c") - s.pendingReqCh <- scenario3c.req + s.newServerFn = c.newServer + slog.Info("c") + s.pendingReqCh <- c.req select { - case resp := <-scenario3c.req.successCh: - require.Equal(t, resp.llama, scenario3c.srv) + case resp := <-c.req.successCh: + require.Equal(t, resp.llama, c.srv) require.Empty(t, s.pendingReqCh) - require.Empty(t, scenario3c.req.errCh) - case err := <-scenario3c.req.errCh: + require.Empty(t, c.req.errCh) + case err := <-c.req.errCh: t.Fatal(err.Error()) case <-ctx.Done(): t.Fatal("timeout") @@ -298,25 +329,25 @@ func TestRequests(t *testing.T) { s.loadedMu.Unlock() // Try to load a model that wont fit - s.newServerFn = scenario3d.newServer - slog.Info("scenario3d") + s.newServerFn = d.newServer + slog.Info("d") s.loadedMu.Lock() require.Len(t, s.loaded, 3) s.loadedMu.Unlock() - scenario3a.ctxDone() // Won't help since this one isn't big enough to make room + a.ctxDone() // Won't help since this one isn't big enough to make room time.Sleep(2 * time.Millisecond) - s.pendingReqCh <- scenario3d.req + s.pendingReqCh <- d.req // finish prior request, so new model can load time.Sleep(6 * time.Millisecond) s.loadedMu.Lock() require.Len(t, s.loaded, 2) s.loadedMu.Unlock() - scenario3b.ctxDone() + b.ctxDone() select { - case resp := <-scenario3d.req.successCh: - require.Equal(t, resp.llama, scenario3d.srv) + case resp := <-d.req.successCh: + require.Equal(t, resp.llama, d.srv) require.Empty(t, s.pendingReqCh) - require.Empty(t, scenario3d.req.errCh) + require.Empty(t, d.req.errCh) case <-ctx.Done(): t.Fatal("timeout") } @@ -325,30 +356,59 @@ func TestRequests(t *testing.T) { s.loadedMu.Unlock() } +func TestRequestsModelTooBigForSystem(t *testing.T) { + ctx, done := context.WithTimeout(context.Background(), 500*time.Millisecond) + defer done() + s := InitScheduler(ctx) + s.getGpuFn = func() gpu.GpuInfoList { + g := gpu.GpuInfo{Library: "metal"} + g.TotalMemory = 4 * format.MebiByte + g.FreeMemory = 3 * format.MebiByte + return []gpu.GpuInfo{g} + } + + s.getCpuFn = func() gpu.GpuInfoList { + g := gpu.GpuInfo{Library: "cpu"} + g.TotalMemory = 4 * format.MebiByte + g.FreeMemory = 2 * format.MebiByte + return []gpu.GpuInfo{g} + } + a := newScenarioRequest(t, ctx, "ollama-model-1", 10, &api.Duration{Duration: 5 * time.Millisecond}) + + s.newServerFn = a.newServer + slog.Info("a") + s.pendingReqCh <- a.req + require.Len(t, s.pendingReqCh, 1) + s.Run(ctx) + select { + case <-a.req.successCh: + if runtime.GOOS == "linux" { + t.Fatal("request should have been rejected with out of space") + } + // else - Darwin and Windows don't reject right now + case err := <-a.req.errCh: + require.Contains(t, err.Error(), "too large") + case <-ctx.Done(): + t.Fatal("timeout") + } +} func TestGetRunner(t *testing.T) { ctx, done := context.WithTimeout(context.Background(), 100*time.Millisecond) defer done() - scenario1a := newScenario(t, ctx, "ollama-model-1a", 10) - scenario1a.req.sessionDuration = &api.Duration{Duration: 0} - scenario1b := newScenario(t, ctx, "ollama-model-1b", 10) - scenario1b.req.sessionDuration = &api.Duration{Duration: 0} - scenario1c := newScenario(t, ctx, "ollama-model-1c", 10) - scenario1c.req.sessionDuration = &api.Duration{Duration: 0} + a := newScenarioRequest(t, ctx, "ollama-model-1a", 10, &api.Duration{Duration: 2 * time.Millisecond}) + b := newScenarioRequest(t, ctx, "ollama-model-1b", 10, &api.Duration{Duration: 2 * time.Millisecond}) + c := newScenarioRequest(t, ctx, "ollama-model-1c", 10, &api.Duration{Duration: 2 * time.Millisecond}) envconfig.MaxQueuedRequests = 1 s := InitScheduler(ctx) - s.getGpuFn = func() gpu.GpuInfoList { - g := gpu.GpuInfo{Library: "metal"} - g.TotalMemory = 24 * format.GigaByte - g.FreeMemory = 12 * format.GigaByte - return []gpu.GpuInfo{g} - } - s.newServerFn = scenario1a.newServer - slog.Info("scenario1a") - successCh1a, errCh1a := s.GetRunner(scenario1a.ctx, scenario1a.req.model, scenario1a.req.opts, scenario1a.req.sessionDuration) + s.getGpuFn = getGpuFn + s.getCpuFn = getCpuFn + s.newServerFn = a.newServer + slog.Info("a") + successCh1a, errCh1a := s.GetRunner(a.ctx, a.req.model, a.req.opts, a.req.sessionDuration) require.Len(t, s.pendingReqCh, 1) - slog.Info("scenario1b") - successCh1b, errCh1b := s.GetRunner(scenario1b.ctx, scenario1b.req.model, scenario1b.req.opts, scenario1b.req.sessionDuration) + slog.Info("b") + successCh1b, errCh1b := s.GetRunner(b.ctx, b.req.model, b.req.opts, b.req.sessionDuration) require.Len(t, s.pendingReqCh, 1) require.Empty(t, successCh1b) require.Len(t, errCh1b, 1) @@ -357,22 +417,24 @@ func TestGetRunner(t *testing.T) { s.Run(ctx) select { case resp := <-successCh1a: - require.Equal(t, resp.llama, scenario1a.srv) + require.Equal(t, resp.llama, a.srv) require.Empty(t, s.pendingReqCh) require.Empty(t, errCh1a) + case err := <-errCh1a: + t.Fatal(err.Error()) case <-ctx.Done(): t.Fatal("timeout") } - scenario1a.ctxDone() + a.ctxDone() // Set "a" model to idle so it can unload s.loadedMu.Lock() require.Len(t, s.loaded, 1) s.loadedMu.Unlock() - scenario1c.req.model.ModelPath = "bad path" - slog.Info("scenario1c") - successCh1c, errCh1c := s.GetRunner(scenario1c.ctx, scenario1c.req.model, scenario1c.req.opts, scenario1c.req.sessionDuration) + c.req.model.ModelPath = "bad path" + slog.Info("c") + successCh1c, errCh1c := s.GetRunner(c.ctx, c.req.model, c.req.opts, c.req.sessionDuration) // Starts in pending channel, then should be quickly processsed to return an error - time.Sleep(5 * time.Millisecond) + time.Sleep(20 * time.Millisecond) // Long enough for the "a" model to expire and unload require.Empty(t, successCh1c) s.loadedMu.Lock() require.Empty(t, s.loaded) @@ -380,7 +442,7 @@ func TestGetRunner(t *testing.T) { require.Len(t, errCh1c, 1) err = <-errCh1c require.Contains(t, err.Error(), "bad path") - scenario1b.ctxDone() + b.ctxDone() } // TODO - add one scenario that triggers the bogus finished event with positive ref count @@ -389,7 +451,7 @@ func TestPrematureExpired(t *testing.T) { defer done() // Same model, same request - scenario1a := newScenario(t, ctx, "ollama-model-1a", 10) + scenario1a := newScenarioRequest(t, ctx, "ollama-model-1a", 10, nil) s := InitScheduler(ctx) s.getGpuFn = func() gpu.GpuInfoList { g := gpu.GpuInfo{Library: "metal"} @@ -411,6 +473,8 @@ func TestPrematureExpired(t *testing.T) { s.loadedMu.Unlock() slog.Info("sending premature expired event now") s.expiredCh <- resp // Shouldn't happen in real life, but make sure its safe + case err := <-errCh1a: + t.Fatal(err.Error()) case <-ctx.Done(): t.Fatal("timeout") } @@ -446,6 +510,8 @@ func TestUseLoadedRunner(t *testing.T) { select { case success := <-req.successCh: require.Equal(t, r1, success) + case err := <-req.errCh: + t.Fatal(err.Error()) case <-ctx.Done(): t.Fatal("timeout") } @@ -625,8 +691,7 @@ func TestAlreadyCanceled(t *testing.T) { defer done() dctx, done2 := context.WithCancel(ctx) done2() - scenario1a := newScenario(t, dctx, "ollama-model-1", 10) - scenario1a.req.sessionDuration = &api.Duration{Duration: 0} + scenario1a := newScenarioRequest(t, dctx, "ollama-model-1", 10, &api.Duration{Duration: 0}) s := InitScheduler(ctx) slog.Info("scenario1a") s.pendingReqCh <- scenario1a.req From 73e2c8f68fe075ea159a20bbf778c0cf801316ad Mon Sep 17 00:00:00 2001 From: Daniel Hiltgen Date: Tue, 9 Jul 2024 15:28:25 -0700 Subject: [PATCH 02/29] Fix context exhaustion integration test for small gpus On the smaller GPUs, the initial model load of llama2 took over 30s (the default timeout for the DoGenerate helper) --- integration/context_test.go | 9 +++++++-- 1 file changed, 7 insertions(+), 2 deletions(-) diff --git a/integration/context_test.go b/integration/context_test.go index 46fac5ea..f1342e16 100644 --- a/integration/context_test.go +++ b/integration/context_test.go @@ -12,7 +12,7 @@ import ( func TestContextExhaustion(t *testing.T) { // Longer needed for small footprint GPUs - ctx, cancel := context.WithTimeout(context.Background(), 6*time.Minute) + ctx, cancel := context.WithTimeout(context.Background(), 5*time.Minute) defer cancel() // Set up the test data req := api.GenerateRequest{ @@ -25,5 +25,10 @@ func TestContextExhaustion(t *testing.T) { "num_ctx": 128, }, } - GenerateTestHelper(ctx, t, req, []string{"once", "upon", "lived"}) + client, _, cleanup := InitServerConnection(ctx, t) + defer cleanup() + if err := PullIfMissing(ctx, client, req.Model); err != nil { + t.Fatalf("PullIfMissing failed: %v", err) + } + DoGenerate(ctx, t, client, req, []string{"once", "upon", "lived"}, 120*time.Second, 10*time.Second) } From 4a565cbf9417ab6ec4560d334d556b5e97e23be9 Mon Sep 17 00:00:00 2001 From: Michael Yang Date: Sat, 13 Jul 2024 17:46:24 -0700 Subject: [PATCH 03/29] add chat and generate tests with mock runner --- llm/gguf.go | 1 + server/prompt_test.go | 15 +- server/routes_create_test.go | 18 + server/routes_delete_test.go | 5 + server/routes_generate_test.go | 651 +++++++++++++++++++++++++++++++++ server/routes_list_test.go | 3 + 6 files changed, 679 insertions(+), 14 deletions(-) create mode 100644 server/routes_generate_test.go diff --git a/llm/gguf.go b/llm/gguf.go index 4d343a1b..a8427aed 100644 --- a/llm/gguf.go +++ b/llm/gguf.go @@ -537,6 +537,7 @@ var ggufKVOrder = map[string][]string{ "tokenizer.ggml.add_bos_token", "tokenizer.ggml.add_eos_token", "tokenizer.chat_template", + "bert.pooling_type", }, } diff --git a/server/prompt_test.go b/server/prompt_test.go index 9c4da068..02d23785 100644 --- a/server/prompt_test.go +++ b/server/prompt_test.go @@ -3,7 +3,6 @@ package server import ( "bytes" "context" - "strings" "testing" "github.com/google/go-cmp/cmp" @@ -11,14 +10,6 @@ import ( "github.com/ollama/ollama/template" ) -func tokenize(_ context.Context, s string) (tokens []int, err error) { - for range strings.Fields(s) { - tokens = append(tokens, len(tokens)) - } - - return -} - func TestChatPrompt(t *testing.T) { type expect struct { prompt string @@ -192,15 +183,11 @@ func TestChatPrompt(t *testing.T) { t.Run(tt.name, func(t *testing.T) { model := Model{Template: tmpl, ProjectorPaths: []string{"vision"}} opts := api.Options{Runner: api.Runner{NumCtx: tt.limit}} - prompt, images, err := chatPrompt(context.TODO(), &model, tokenize, &opts, tt.msgs, nil) + prompt, images, err := chatPrompt(context.TODO(), &model, mockRunner{}.Tokenize, &opts, tt.msgs, nil) if err != nil { t.Fatal(err) } - if tt.prompt != prompt { - t.Errorf("expected %q, got %q", tt.prompt, prompt) - } - if diff := cmp.Diff(prompt, tt.prompt); diff != "" { t.Errorf("mismatch (-got +want):\n%s", diff) } diff --git a/server/routes_create_test.go b/server/routes_create_test.go index 04174b92..cb548ebd 100644 --- a/server/routes_create_test.go +++ b/server/routes_create_test.go @@ -85,6 +85,8 @@ func checkFileExists(t *testing.T, p string, expect []string) { } func TestCreateFromBin(t *testing.T) { + gin.SetMode(gin.TestMode) + p := t.TempDir() t.Setenv("OLLAMA_MODELS", p) envconfig.LoadConfig() @@ -111,6 +113,8 @@ func TestCreateFromBin(t *testing.T) { } func TestCreateFromModel(t *testing.T) { + gin.SetMode(gin.TestMode) + p := t.TempDir() t.Setenv("OLLAMA_MODELS", p) envconfig.LoadConfig() @@ -152,6 +156,8 @@ func TestCreateFromModel(t *testing.T) { } func TestCreateRemovesLayers(t *testing.T) { + gin.SetMode(gin.TestMode) + p := t.TempDir() t.Setenv("OLLAMA_MODELS", p) envconfig.LoadConfig() @@ -199,6 +205,8 @@ func TestCreateRemovesLayers(t *testing.T) { } func TestCreateUnsetsSystem(t *testing.T) { + gin.SetMode(gin.TestMode) + p := t.TempDir() t.Setenv("OLLAMA_MODELS", p) envconfig.LoadConfig() @@ -255,6 +263,8 @@ func TestCreateUnsetsSystem(t *testing.T) { } func TestCreateMergeParameters(t *testing.T) { + gin.SetMode(gin.TestMode) + p := t.TempDir() t.Setenv("OLLAMA_MODELS", p) envconfig.LoadConfig() @@ -358,6 +368,8 @@ func TestCreateMergeParameters(t *testing.T) { } func TestCreateReplacesMessages(t *testing.T) { + gin.SetMode(gin.TestMode) + p := t.TempDir() t.Setenv("OLLAMA_MODELS", p) envconfig.LoadConfig() @@ -434,6 +446,8 @@ func TestCreateReplacesMessages(t *testing.T) { } func TestCreateTemplateSystem(t *testing.T) { + gin.SetMode(gin.TestMode) + p := t.TempDir() t.Setenv("OLLAMA_MODELS", p) envconfig.LoadConfig() @@ -480,6 +494,8 @@ func TestCreateTemplateSystem(t *testing.T) { } func TestCreateLicenses(t *testing.T) { + gin.SetMode(gin.TestMode) + p := t.TempDir() t.Setenv("OLLAMA_MODELS", p) envconfig.LoadConfig() @@ -526,6 +542,8 @@ func TestCreateLicenses(t *testing.T) { } func TestCreateDetectTemplate(t *testing.T) { + gin.SetMode(gin.TestMode) + p := t.TempDir() t.Setenv("OLLAMA_MODELS", p) envconfig.LoadConfig() diff --git a/server/routes_delete_test.go b/server/routes_delete_test.go index 00303bd1..33a97a73 100644 --- a/server/routes_delete_test.go +++ b/server/routes_delete_test.go @@ -8,12 +8,15 @@ import ( "path/filepath" "testing" + "github.com/gin-gonic/gin" "github.com/ollama/ollama/api" "github.com/ollama/ollama/envconfig" "github.com/ollama/ollama/types/model" ) func TestDelete(t *testing.T) { + gin.SetMode(gin.TestMode) + p := t.TempDir() t.Setenv("OLLAMA_MODELS", p) envconfig.LoadConfig() @@ -77,6 +80,8 @@ func TestDelete(t *testing.T) { } func TestDeleteDuplicateLayers(t *testing.T) { + gin.SetMode(gin.TestMode) + p := t.TempDir() t.Setenv("OLLAMA_MODELS", p) var s Server diff --git a/server/routes_generate_test.go b/server/routes_generate_test.go new file mode 100644 index 00000000..9d899328 --- /dev/null +++ b/server/routes_generate_test.go @@ -0,0 +1,651 @@ +package server + +import ( + "bytes" + "context" + "encoding/json" + "fmt" + "io" + "net/http" + "strings" + "testing" + "time" + + "github.com/gin-gonic/gin" + "github.com/google/go-cmp/cmp" + + "github.com/ollama/ollama/api" + "github.com/ollama/ollama/gpu" + "github.com/ollama/ollama/llm" +) + +type mockRunner struct { + llm.LlamaServer + + // CompletionRequest is only valid until the next call to Completion + llm.CompletionRequest + llm.CompletionResponse +} + +func (m *mockRunner) Completion(_ context.Context, r llm.CompletionRequest, fn func(r llm.CompletionResponse)) error { + m.CompletionRequest = r + fn(m.CompletionResponse) + return nil +} + +func (mockRunner) Tokenize(_ context.Context, s string) (tokens []int, err error) { + for range strings.Fields(s) { + tokens = append(tokens, len(tokens)) + } + + return +} + +func newMockServer(mock *mockRunner) func(gpu.GpuInfoList, string, *llm.GGML, []string, []string, api.Options, int) (llm.LlamaServer, error) { + return func(gpus gpu.GpuInfoList, model string, ggml *llm.GGML, projectors, system []string, opts api.Options, numParallel int) (llm.LlamaServer, error) { + return mock, nil + } +} + +func TestGenerateChat(t *testing.T) { + gin.SetMode(gin.TestMode) + + mock := mockRunner{ + CompletionResponse: llm.CompletionResponse{ + Done: true, + DoneReason: "stop", + PromptEvalCount: 1, + PromptEvalDuration: 1, + EvalCount: 1, + EvalDuration: 1, + }, + } + + s := Server{ + sched: &Scheduler{ + pendingReqCh: make(chan *LlmRequest, 1), + finishedReqCh: make(chan *LlmRequest, 1), + expiredCh: make(chan *runnerRef, 1), + unloadedCh: make(chan any, 1), + loaded: make(map[string]*runnerRef), + newServerFn: newMockServer(&mock), + getGpuFn: gpu.GetGPUInfo, + getCpuFn: gpu.GetCPUInfo, + reschedDelay: 250 * time.Millisecond, + loadFn: func(req *LlmRequest, ggml *llm.GGML, gpus gpu.GpuInfoList, numParallel int) { + req.successCh <- &runnerRef{ + llama: &mock, + } + }, + }, + } + + go s.sched.Run(context.TODO()) + + w := createRequest(t, s.CreateModelHandler, api.CreateRequest{ + Name: "test", + Modelfile: fmt.Sprintf(`FROM %s + TEMPLATE """ +{{- if .System }}System: {{ .System }} {{ end }} +{{- if .Prompt }}User: {{ .Prompt }} {{ end }} +{{- if .Response }}Assistant: {{ .Response }} {{ end }}""" +`, createBinFile(t, llm.KV{ + "general.architecture": "llama", + "llama.block_count": uint32(1), + "llama.context_length": uint32(8192), + "llama.embedding_length": uint32(4096), + "llama.attention.head_count": uint32(32), + "llama.attention.head_count_kv": uint32(8), + "tokenizer.ggml.tokens": []string{""}, + "tokenizer.ggml.scores": []float32{0}, + "tokenizer.ggml.token_type": []int32{0}, + }, []llm.Tensor{ + {Name: "token_embd.weight", Shape: []uint64{1}, WriterTo: bytes.NewReader(make([]byte, 4))}, + {Name: "blk.0.attn_norm.weight", Shape: []uint64{1}, WriterTo: bytes.NewReader(make([]byte, 4))}, + {Name: "blk.0.ffn_down.weight", Shape: []uint64{1}, WriterTo: bytes.NewReader(make([]byte, 4))}, + {Name: "blk.0.ffn_gate.weight", Shape: []uint64{1}, WriterTo: bytes.NewReader(make([]byte, 4))}, + {Name: "blk.0.ffn_up.weight", Shape: []uint64{1}, WriterTo: bytes.NewReader(make([]byte, 4))}, + {Name: "blk.0.ffn_norm.weight", Shape: []uint64{1}, WriterTo: bytes.NewReader(make([]byte, 4))}, + {Name: "blk.0.attn_k.weight", Shape: []uint64{1}, WriterTo: bytes.NewReader(make([]byte, 4))}, + {Name: "blk.0.attn_output.weight", Shape: []uint64{1}, WriterTo: bytes.NewReader(make([]byte, 4))}, + {Name: "blk.0.attn_q.weight", Shape: []uint64{1}, WriterTo: bytes.NewReader(make([]byte, 4))}, + {Name: "blk.0.attn_v.weight", Shape: []uint64{1}, WriterTo: bytes.NewReader(make([]byte, 4))}, + {Name: "output.weight", Shape: []uint64{1}, WriterTo: bytes.NewReader(make([]byte, 4))}, + })), + Stream: &stream, + }) + + if w.Code != http.StatusOK { + t.Fatalf("expected status 200, got %d", w.Code) + } + + t.Run("missing body", func(t *testing.T) { + w := createRequest(t, s.ChatHandler, nil) + if w.Code != http.StatusBadRequest { + t.Errorf("expected status 400, got %d", w.Code) + } + + if diff := cmp.Diff(w.Body.String(), `{"error":"model is required"}`); diff != "" { + t.Errorf("mismatch (-got +want):\n%s", diff) + } + }) + + t.Run("missing model", func(t *testing.T) { + w := createRequest(t, s.ChatHandler, api.ChatRequest{}) + if w.Code != http.StatusBadRequest { + t.Errorf("expected status 400, got %d", w.Code) + } + + if diff := cmp.Diff(w.Body.String(), `{"error":"model is required"}`); diff != "" { + t.Errorf("mismatch (-got +want):\n%s", diff) + } + }) + + t.Run("missing capabilities", func(t *testing.T) { + w := createRequest(t, s.CreateModelHandler, api.CreateRequest{ + Name: "bert", + Modelfile: fmt.Sprintf("FROM %s", createBinFile(t, llm.KV{ + "general.architecture": "bert", + "bert.pooling_type": uint32(0), + }, []llm.Tensor{})), + Stream: &stream, + }) + + if w.Code != http.StatusOK { + t.Fatalf("expected status 200, got %d", w.Code) + } + + w = createRequest(t, s.ChatHandler, api.ChatRequest{ + Model: "bert", + }) + + if w.Code != http.StatusBadRequest { + t.Errorf("expected status 400, got %d", w.Code) + } + + if diff := cmp.Diff(w.Body.String(), `{"error":"\"bert\" does not support chat"}`); diff != "" { + t.Errorf("mismatch (-got +want):\n%s", diff) + } + }) + + t.Run("load model", func(t *testing.T) { + w := createRequest(t, s.ChatHandler, api.ChatRequest{ + Model: "test", + }) + + if w.Code != http.StatusOK { + t.Errorf("expected status 200, got %d", w.Code) + } + + var actual api.ChatResponse + if err := json.NewDecoder(w.Body).Decode(&actual); err != nil { + t.Fatal(err) + } + + if actual.Model != "test" { + t.Errorf("expected model test, got %s", actual.Model) + } + + if !actual.Done { + t.Errorf("expected done true, got false") + } + + if actual.DoneReason != "load" { + t.Errorf("expected done reason load, got %s", actual.DoneReason) + } + }) + + checkChatResponse := func(t *testing.T, body io.Reader, model, content string) { + t.Helper() + + var actual api.ChatResponse + if err := json.NewDecoder(body).Decode(&actual); err != nil { + t.Fatal(err) + } + + if actual.Model != model { + t.Errorf("expected model test, got %s", actual.Model) + } + + if !actual.Done { + t.Errorf("expected done false, got true") + } + + if actual.DoneReason != "stop" { + t.Errorf("expected done reason stop, got %s", actual.DoneReason) + } + + if diff := cmp.Diff(actual.Message, api.Message{ + Role: "assistant", + Content: content, + }); diff != "" { + t.Errorf("mismatch (-got +want):\n%s", diff) + } + + if actual.PromptEvalCount == 0 { + t.Errorf("expected prompt eval count > 0, got 0") + } + + if actual.PromptEvalDuration == 0 { + t.Errorf("expected prompt eval duration > 0, got 0") + } + + if actual.EvalCount == 0 { + t.Errorf("expected eval count > 0, got 0") + } + + if actual.EvalDuration == 0 { + t.Errorf("expected eval duration > 0, got 0") + } + + if actual.LoadDuration == 0 { + t.Errorf("expected load duration > 0, got 0") + } + + if actual.TotalDuration == 0 { + t.Errorf("expected load duration > 0, got 0") + } + } + + mock.CompletionResponse.Content = "Hi!" + t.Run("messages", func(t *testing.T) { + w := createRequest(t, s.ChatHandler, api.ChatRequest{ + Model: "test", + Messages: []api.Message{ + {Role: "user", Content: "Hello!"}, + }, + Stream: &stream, + }) + + if w.Code != http.StatusOK { + t.Errorf("expected status 200, got %d", w.Code) + } + + if diff := cmp.Diff(mock.CompletionRequest.Prompt, "User: Hello! "); diff != "" { + t.Errorf("mismatch (-got +want):\n%s", diff) + } + + checkChatResponse(t, w.Body, "test", "Hi!") + }) + + w = createRequest(t, s.CreateModelHandler, api.CreateRequest{ + Model: "test-system", + Modelfile: "FROM test\nSYSTEM You are a helpful assistant.", + }) + + if w.Code != http.StatusOK { + t.Fatalf("expected status 200, got %d", w.Code) + } + + t.Run("messages with model system", func(t *testing.T) { + w := createRequest(t, s.ChatHandler, api.ChatRequest{ + Model: "test-system", + Messages: []api.Message{ + {Role: "user", Content: "Hello!"}, + }, + Stream: &stream, + }) + + if w.Code != http.StatusOK { + t.Errorf("expected status 200, got %d", w.Code) + } + + if diff := cmp.Diff(mock.CompletionRequest.Prompt, "System: You are a helpful assistant. User: Hello! "); diff != "" { + t.Errorf("mismatch (-got +want):\n%s", diff) + } + + checkChatResponse(t, w.Body, "test-system", "Hi!") + }) + + mock.CompletionResponse.Content = "Abra kadabra!" + t.Run("messages with system", func(t *testing.T) { + w := createRequest(t, s.ChatHandler, api.ChatRequest{ + Model: "test-system", + Messages: []api.Message{ + {Role: "system", Content: "You can perform magic tricks."}, + {Role: "user", Content: "Hello!"}, + }, + Stream: &stream, + }) + + if w.Code != http.StatusOK { + t.Errorf("expected status 200, got %d", w.Code) + } + + if diff := cmp.Diff(mock.CompletionRequest.Prompt, "System: You can perform magic tricks. User: Hello! "); diff != "" { + t.Errorf("mismatch (-got +want):\n%s", diff) + } + + checkChatResponse(t, w.Body, "test-system", "Abra kadabra!") + }) + + t.Run("messages with interleaved system", func(t *testing.T) { + w := createRequest(t, s.ChatHandler, api.ChatRequest{ + Model: "test-system", + Messages: []api.Message{ + {Role: "user", Content: "Hello!"}, + {Role: "assistant", Content: "I can help you with that."}, + {Role: "system", Content: "You can perform magic tricks."}, + {Role: "user", Content: "Help me write tests."}, + }, + Stream: &stream, + }) + + if w.Code != http.StatusOK { + t.Errorf("expected status 200, got %d", w.Code) + } + + if diff := cmp.Diff(mock.CompletionRequest.Prompt, "System: You are a helpful assistant. User: Hello! Assistant: I can help you with that. System: You can perform magic tricks. User: Help me write tests. "); diff != "" { + t.Errorf("mismatch (-got +want):\n%s", diff) + } + + checkChatResponse(t, w.Body, "test-system", "Abra kadabra!") + }) +} + +func TestGenerate(t *testing.T) { + gin.SetMode(gin.TestMode) + + mock := mockRunner{ + CompletionResponse: llm.CompletionResponse{ + Done: true, + DoneReason: "stop", + PromptEvalCount: 1, + PromptEvalDuration: 1, + EvalCount: 1, + EvalDuration: 1, + }, + } + + s := Server{ + sched: &Scheduler{ + pendingReqCh: make(chan *LlmRequest, 1), + finishedReqCh: make(chan *LlmRequest, 1), + expiredCh: make(chan *runnerRef, 1), + unloadedCh: make(chan any, 1), + loaded: make(map[string]*runnerRef), + newServerFn: newMockServer(&mock), + getGpuFn: gpu.GetGPUInfo, + getCpuFn: gpu.GetCPUInfo, + reschedDelay: 250 * time.Millisecond, + loadFn: func(req *LlmRequest, ggml *llm.GGML, gpus gpu.GpuInfoList, numParallel int) { + req.successCh <- &runnerRef{ + llama: &mock, + } + }, + }, + } + + go s.sched.Run(context.TODO()) + + w := createRequest(t, s.CreateModelHandler, api.CreateRequest{ + Name: "test", + Modelfile: fmt.Sprintf(`FROM %s + TEMPLATE """ +{{- if .System }}System: {{ .System }} {{ end }} +{{- if .Prompt }}User: {{ .Prompt }} {{ end }} +{{- if .Response }}Assistant: {{ .Response }} {{ end }}""" +`, createBinFile(t, llm.KV{ + "general.architecture": "llama", + "llama.block_count": uint32(1), + "llama.context_length": uint32(8192), + "llama.embedding_length": uint32(4096), + "llama.attention.head_count": uint32(32), + "llama.attention.head_count_kv": uint32(8), + "tokenizer.ggml.tokens": []string{""}, + "tokenizer.ggml.scores": []float32{0}, + "tokenizer.ggml.token_type": []int32{0}, + }, []llm.Tensor{ + {Name: "token_embd.weight", Shape: []uint64{1}, WriterTo: bytes.NewReader(make([]byte, 4))}, + {Name: "blk.0.attn_norm.weight", Shape: []uint64{1}, WriterTo: bytes.NewReader(make([]byte, 4))}, + {Name: "blk.0.ffn_down.weight", Shape: []uint64{1}, WriterTo: bytes.NewReader(make([]byte, 4))}, + {Name: "blk.0.ffn_gate.weight", Shape: []uint64{1}, WriterTo: bytes.NewReader(make([]byte, 4))}, + {Name: "blk.0.ffn_up.weight", Shape: []uint64{1}, WriterTo: bytes.NewReader(make([]byte, 4))}, + {Name: "blk.0.ffn_norm.weight", Shape: []uint64{1}, WriterTo: bytes.NewReader(make([]byte, 4))}, + {Name: "blk.0.attn_k.weight", Shape: []uint64{1}, WriterTo: bytes.NewReader(make([]byte, 4))}, + {Name: "blk.0.attn_output.weight", Shape: []uint64{1}, WriterTo: bytes.NewReader(make([]byte, 4))}, + {Name: "blk.0.attn_q.weight", Shape: []uint64{1}, WriterTo: bytes.NewReader(make([]byte, 4))}, + {Name: "blk.0.attn_v.weight", Shape: []uint64{1}, WriterTo: bytes.NewReader(make([]byte, 4))}, + {Name: "output.weight", Shape: []uint64{1}, WriterTo: bytes.NewReader(make([]byte, 4))}, + })), + Stream: &stream, + }) + + if w.Code != http.StatusOK { + t.Fatalf("expected status 200, got %d", w.Code) + } + + t.Run("missing body", func(t *testing.T) { + w := createRequest(t, s.GenerateHandler, nil) + if w.Code != http.StatusBadRequest { + t.Errorf("expected status 400, got %d", w.Code) + } + + if diff := cmp.Diff(w.Body.String(), `{"error":"model is required"}`); diff != "" { + t.Errorf("mismatch (-got +want):\n%s", diff) + } + }) + + t.Run("missing model", func(t *testing.T) { + w := createRequest(t, s.GenerateHandler, api.GenerateRequest{}) + if w.Code != http.StatusBadRequest { + t.Errorf("expected status 400, got %d", w.Code) + } + + if diff := cmp.Diff(w.Body.String(), `{"error":"model is required"}`); diff != "" { + t.Errorf("mismatch (-got +want):\n%s", diff) + } + }) + + t.Run("missing capabilities", func(t *testing.T) { + w := createRequest(t, s.CreateModelHandler, api.CreateRequest{ + Name: "bert", + Modelfile: fmt.Sprintf("FROM %s", createBinFile(t, llm.KV{ + "general.architecture": "bert", + "bert.pooling_type": uint32(0), + }, []llm.Tensor{})), + Stream: &stream, + }) + + if w.Code != http.StatusOK { + t.Fatalf("expected status 200, got %d", w.Code) + } + + w = createRequest(t, s.GenerateHandler, api.GenerateRequest{ + Model: "bert", + }) + + if w.Code != http.StatusBadRequest { + t.Errorf("expected status 400, got %d", w.Code) + } + + if diff := cmp.Diff(w.Body.String(), `{"error":"\"bert\" does not support generate"}`); diff != "" { + t.Errorf("mismatch (-got +want):\n%s", diff) + } + }) + + t.Run("load model", func(t *testing.T) { + w := createRequest(t, s.GenerateHandler, api.GenerateRequest{ + Model: "test", + }) + + if w.Code != http.StatusOK { + t.Errorf("expected status 200, got %d", w.Code) + } + + var actual api.GenerateResponse + if err := json.NewDecoder(w.Body).Decode(&actual); err != nil { + t.Fatal(err) + } + + if actual.Model != "test" { + t.Errorf("expected model test, got %s", actual.Model) + } + + if !actual.Done { + t.Errorf("expected done true, got false") + } + + if actual.DoneReason != "load" { + t.Errorf("expected done reason load, got %s", actual.DoneReason) + } + }) + + checkGenerateResponse := func(t *testing.T, body io.Reader, model, content string) { + t.Helper() + + var actual api.GenerateResponse + if err := json.NewDecoder(body).Decode(&actual); err != nil { + t.Fatal(err) + } + + if actual.Model != model { + t.Errorf("expected model test, got %s", actual.Model) + } + + if !actual.Done { + t.Errorf("expected done false, got true") + } + + if actual.DoneReason != "stop" { + t.Errorf("expected done reason stop, got %s", actual.DoneReason) + } + + if actual.Response != content { + t.Errorf("expected response %s, got %s", content, actual.Response) + } + + if actual.Context == nil { + t.Errorf("expected context not nil") + } + + if actual.PromptEvalCount == 0 { + t.Errorf("expected prompt eval count > 0, got 0") + } + + if actual.PromptEvalDuration == 0 { + t.Errorf("expected prompt eval duration > 0, got 0") + } + + if actual.EvalCount == 0 { + t.Errorf("expected eval count > 0, got 0") + } + + if actual.EvalDuration == 0 { + t.Errorf("expected eval duration > 0, got 0") + } + + if actual.LoadDuration == 0 { + t.Errorf("expected load duration > 0, got 0") + } + + if actual.TotalDuration == 0 { + t.Errorf("expected load duration > 0, got 0") + } + } + + mock.CompletionResponse.Content = "Hi!" + t.Run("prompt", func(t *testing.T) { + w := createRequest(t, s.GenerateHandler, api.GenerateRequest{ + Model: "test", + Prompt: "Hello!", + Stream: &stream, + }) + + if w.Code != http.StatusOK { + t.Errorf("expected status 200, got %d", w.Code) + } + + if diff := cmp.Diff(mock.CompletionRequest.Prompt, "User: Hello! "); diff != "" { + t.Errorf("mismatch (-got +want):\n%s", diff) + } + + checkGenerateResponse(t, w.Body, "test", "Hi!") + }) + + w = createRequest(t, s.CreateModelHandler, api.CreateRequest{ + Model: "test-system", + Modelfile: "FROM test\nSYSTEM You are a helpful assistant.", + }) + + if w.Code != http.StatusOK { + t.Fatalf("expected status 200, got %d", w.Code) + } + + t.Run("prompt with model system", func(t *testing.T) { + w := createRequest(t, s.GenerateHandler, api.GenerateRequest{ + Model: "test-system", + Prompt: "Hello!", + Stream: &stream, + }) + + if w.Code != http.StatusOK { + t.Errorf("expected status 200, got %d", w.Code) + } + + if diff := cmp.Diff(mock.CompletionRequest.Prompt, "System: You are a helpful assistant. User: Hello! "); diff != "" { + t.Errorf("mismatch (-got +want):\n%s", diff) + } + + checkGenerateResponse(t, w.Body, "test-system", "Hi!") + }) + + mock.CompletionResponse.Content = "Abra kadabra!" + t.Run("prompt with system", func(t *testing.T) { + w := createRequest(t, s.GenerateHandler, api.GenerateRequest{ + Model: "test-system", + Prompt: "Hello!", + System: "You can perform magic tricks.", + Stream: &stream, + }) + + if w.Code != http.StatusOK { + t.Errorf("expected status 200, got %d", w.Code) + } + + if diff := cmp.Diff(mock.CompletionRequest.Prompt, "System: You can perform magic tricks. User: Hello! "); diff != "" { + t.Errorf("mismatch (-got +want):\n%s", diff) + } + + checkGenerateResponse(t, w.Body, "test-system", "Abra kadabra!") + }) + + t.Run("prompt with template", func(t *testing.T) { + w := createRequest(t, s.GenerateHandler, api.GenerateRequest{ + Model: "test-system", + Prompt: "Help me write tests.", + System: "You can perform magic tricks.", + Template: `{{- if .System }}{{ .System }} {{ end }} +{{- if .Prompt }}### USER {{ .Prompt }} {{ end }} +{{- if .Response }}### ASSISTANT {{ .Response }} {{ end }}`, + Stream: &stream, + }) + + if w.Code != http.StatusOK { + t.Errorf("expected status 200, got %d", w.Code) + } + + if diff := cmp.Diff(mock.CompletionRequest.Prompt, "You can perform magic tricks. ### USER Help me write tests. "); diff != "" { + t.Errorf("mismatch (-got +want):\n%s", diff) + } + + checkGenerateResponse(t, w.Body, "test-system", "Abra kadabra!") + }) + + t.Run("raw", func(t *testing.T) { + w := createRequest(t, s.GenerateHandler, api.GenerateRequest{ + Model: "test-system", + Prompt: "Help me write tests.", + Raw: true, + Stream: &stream, + }) + + if w.Code != http.StatusOK { + t.Errorf("expected status 200, got %d", w.Code) + } + + if diff := cmp.Diff(mock.CompletionRequest.Prompt, "Help me write tests."); diff != "" { + t.Errorf("mismatch (-got +want):\n%s", diff) + } + }) +} diff --git a/server/routes_list_test.go b/server/routes_list_test.go index d04be9d6..c2d9c113 100644 --- a/server/routes_list_test.go +++ b/server/routes_list_test.go @@ -7,11 +7,14 @@ import ( "slices" "testing" + "github.com/gin-gonic/gin" "github.com/ollama/ollama/api" "github.com/ollama/ollama/envconfig" ) func TestList(t *testing.T) { + gin.SetMode(gin.TestMode) + t.Setenv("OLLAMA_MODELS", t.TempDir()) envconfig.LoadConfig() From 4cb5d7decc08f4c7c136f81b2b5f1c74ebffbc62 Mon Sep 17 00:00:00 2001 From: Jeffrey Morgan Date: Tue, 16 Jul 2024 11:09:00 -0700 Subject: [PATCH 04/29] server: omit model system prompt if empty (#5717) --- server/routes.go | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/server/routes.go b/server/routes.go index 9712d895..d0cbe6cc 100644 --- a/server/routes.go +++ b/server/routes.go @@ -1310,7 +1310,7 @@ func (s *Server) ChatHandler(c *gin.Context) { return } - if req.Messages[0].Role != "system" { + if req.Messages[0].Role != "system" && m.System != "" { req.Messages = append([]api.Message{{Role: "system", Content: m.System}}, req.Messages...) } From 5afbb60fc452965a4a53f1e46816ea41298269c6 Mon Sep 17 00:00:00 2001 From: Michael Yang Date: Tue, 16 Jul 2024 09:38:46 -0700 Subject: [PATCH 05/29] fix unmarshal type errors --- server/model.go | 48 +++++++++++++++++--------------------------- server/model_test.go | 33 +++++++++++++++++++----------- 2 files changed, 39 insertions(+), 42 deletions(-) diff --git a/server/model.go b/server/model.go index be318db9..9e22d63a 100644 --- a/server/model.go +++ b/server/model.go @@ -327,7 +327,8 @@ func (m *Model) parseToolCalls(s string) ([]api.ToolCall, bool) { var kv map[string]string // execute the subtree with placeholders to identify the keys - if err := json.Unmarshal(b.Bytes(), &kv); err != nil { + // trim any commands that might exist in the template + if err := json.Unmarshal(bytes.TrimSuffix(b.Bytes(), []byte(",")), &kv); err != nil { return nil, false } @@ -342,35 +343,26 @@ func (m *Model) parseToolCalls(s string) ([]api.ToolCall, bool) { } } - var sm []map[string]any - decoder := json.NewDecoder(strings.NewReader(s)) - for { - // incrementally decode the JSON into a list of JSON objects - // skipping over any invalid tokens - if err := decoder.Decode(&sm); err != nil { - if errors.Is(err, io.EOF) { - break - } - - if errors.As(err, new(*json.SyntaxError)) { - r := decoder.Buffered() - if _, err := r.Read(make([]byte, decoder.InputOffset()+1)); err != nil { - break - } - - decoder = json.NewDecoder(r) - continue - } - + var objs []map[string]any + for offset := 0; offset < len(s); { + if err := json.NewDecoder(strings.NewReader(s[offset:])).Decode(&objs); errors.Is(err, io.EOF) { + break + } else if syntax := &(json.SyntaxError{}); errors.As(err, &syntax) { + // skip over any syntax errors + offset += int(syntax.Offset) + } else if unmarshalType := &(json.UnmarshalTypeError{}); errors.As(err, &unmarshalType) { + // skip over any unmarshalable types + offset += int(unmarshalType.Offset) + } else if err != nil { return nil, false + } else { + // break when an object is decoded + break } - - // break as soon as a valid object is decoded - break } var toolCalls []api.ToolCall - for _, kv := range sm { + for _, kv := range objs { call := api.ToolCall{ ID: uuid.New().String(), Type: "function", @@ -388,9 +380,5 @@ func (m *Model) parseToolCalls(s string) ([]api.ToolCall, bool) { toolCalls = append(toolCalls, call) } - if len(toolCalls) > 0 { - return toolCalls, true - } - - return nil, false + return toolCalls, len(toolCalls) > 0 } diff --git a/server/model_test.go b/server/model_test.go index 02578192..d39f2891 100644 --- a/server/model_test.go +++ b/server/model_test.go @@ -136,11 +136,16 @@ func TestExecuteWithTools(t *testing.T) { cases := []struct { model string output string + ok bool }{ - {"mistral", `[TOOL_CALLS] [{"name": "get_current_weather", "arguments": {"format":"fahrenheit","location":"San Francisco, CA"}},{"name": "get_current_weather", "arguments": {"format":"celsius","location":"Toronto, Canada"}}]`}, + {"mistral", `[TOOL_CALLS] [{"name": "get_current_weather", "arguments": {"format":"fahrenheit","location":"San Francisco, CA"}},{"name": "get_current_weather", "arguments": {"format":"celsius","location":"Toronto, Canada"}}]`, true}, {"mistral", `[TOOL_CALLS] [{"name": "get_current_weather", "arguments": {"format":"fahrenheit","location":"San Francisco, CA"}},{"name": "get_current_weather", "arguments": {"format":"celsius","location":"Toronto, Canada"}}] -The temperature in San Francisco, CA is 70°F and in Toronto, Canada is 20°C.`}, +The temperature in San Francisco, CA is 70°F and in Toronto, Canada is 20°C.`, true}, + {"mistral", `I'm not aware of that information. However, I can suggest searching for the weather using the "get_current_weather" function: + + [{"name": "get_current_weather", "arguments": {"format":"fahrenheit","location":"San Francisco, CA"}},{"name": "get_current_weather", "arguments": {"format":"celsius","location":"Toronto, Canada"}}]`, true}, + {"mistral", " The weather in San Francisco, CA is 70°F and in Toronto, Canada is 20°C.", false}, {"command-r-plus", "Action: ```json" + ` [ { @@ -158,8 +163,10 @@ The temperature in San Francisco, CA is 70°F and in Toronto, Canada is 20°C.`} } } ] -` + "```"}, - {"firefunction", ` functools[{"name": "get_current_weather", "arguments": {"format":"fahrenheit","location":"San Francisco, CA"}},{"name": "get_current_weather", "arguments": {"format":"celsius","location":"Toronto, Canada"}}]`}, +` + "```", true}, + {"command-r-plus", " The weather in San Francisco, CA is 70°F and in Toronto, Canada is 20°C.", false}, + {"firefunction", ` functools[{"name": "get_current_weather", "arguments": {"format":"fahrenheit","location":"San Francisco, CA"}},{"name": "get_current_weather", "arguments": {"format":"celsius","location":"Toronto, Canada"}}]`, true}, + {"firefunction", " The weather in San Francisco, CA is 70°F and in Toronto, Canada is 20°C.", false}, } var tools []api.Tool @@ -216,17 +223,19 @@ The temperature in San Francisco, CA is 70°F and in Toronto, Canada is 20°C.`} t.Run("parse", func(t *testing.T) { m := &Model{Template: tmpl} actual, ok := m.parseToolCalls(tt.output) - if !ok { - t.Fatal("failed to parse tool calls") + if ok != tt.ok { + t.Fatalf("expected %t, got %t", tt.ok, ok) } - for i := range actual { - // ID is randomly generated so clear it for comparison - actual[i].ID = "" - } + if tt.ok { + for i := range actual { + // ID is randomly generated so clear it for comparison + actual[i].ID = "" + } - if diff := cmp.Diff(actual, calls); diff != "" { - t.Errorf("mismatch (-got +want):\n%s", diff) + if diff := cmp.Diff(actual, calls); diff != "" { + t.Errorf("mismatch (-got +want):\n%s", diff) + } } }) }) From 987dbab0b063653b2be71060449c8add7b76cc6e Mon Sep 17 00:00:00 2001 From: royjhan <65097070+royjhan@users.noreply.github.com> Date: Tue, 16 Jul 2024 13:36:08 -0700 Subject: [PATCH 06/29] OpenAI: /v1/embeddings compatibility (#5285) * OpenAI v1 models * Empty List Testing * Add back envconfig * v1/models docs * Remove Docs * OpenAI batch embed compatibility * merge conflicts * integrate with api/embed * ep * merge conflicts * request tests * rm resp test * merge conflict * merge conflict * test fixes * test fn renaming * input validation for empty string --------- Co-authored-by: jmorganca --- openai/openai.go | 111 ++++++++++++++++++++++++++++++++++++++++++ openai/openai_test.go | 72 +++++++++++++++++++++++++++ server/routes.go | 1 + 3 files changed, 184 insertions(+) diff --git a/openai/openai.go b/openai/openai.go index b289d73e..88bdaec1 100644 --- a/openai/openai.go +++ b/openai/openai.go @@ -61,6 +61,11 @@ type ResponseFormat struct { Type string `json:"type"` } +type EmbedRequest struct { + Input any `json:"input"` + Model string `json:"model"` +} + type ChatCompletionRequest struct { Model string `json:"model"` Messages []Message `json:"messages"` @@ -134,11 +139,23 @@ type Model struct { OwnedBy string `json:"owned_by"` } +type Embedding struct { + Object string `json:"object"` + Embedding []float32 `json:"embedding"` + Index int `json:"index"` +} + type ListCompletion struct { Object string `json:"object"` Data []Model `json:"data"` } +type EmbeddingList struct { + Object string `json:"object"` + Data []Embedding `json:"data"` + Model string `json:"model"` +} + func NewError(code int, message string) ErrorResponse { var etype string switch code { @@ -262,6 +279,27 @@ func toListCompletion(r api.ListResponse) ListCompletion { } } +func toEmbeddingList(model string, r api.EmbedResponse) EmbeddingList { + if r.Embeddings != nil { + var data []Embedding + for i, e := range r.Embeddings { + data = append(data, Embedding{ + Object: "embedding", + Embedding: e, + Index: i, + }) + } + + return EmbeddingList{ + Object: "list", + Data: data, + Model: model, + } + } + + return EmbeddingList{} +} + func toModel(r api.ShowResponse, m string) Model { return Model{ Id: m, @@ -465,6 +503,11 @@ type RetrieveWriter struct { model string } +type EmbedWriter struct { + BaseWriter + model string +} + func (w *BaseWriter) writeError(code int, data []byte) (int, error) { var serr api.StatusError err := json.Unmarshal(data, &serr) @@ -630,6 +673,33 @@ func (w *RetrieveWriter) Write(data []byte) (int, error) { return w.writeResponse(data) } +func (w *EmbedWriter) writeResponse(data []byte) (int, error) { + var embedResponse api.EmbedResponse + err := json.Unmarshal(data, &embedResponse) + + if err != nil { + return 0, err + } + + w.ResponseWriter.Header().Set("Content-Type", "application/json") + err = json.NewEncoder(w.ResponseWriter).Encode(toEmbeddingList(w.model, embedResponse)) + + if err != nil { + return 0, err + } + + return len(data), nil +} + +func (w *EmbedWriter) Write(data []byte) (int, error) { + code := w.ResponseWriter.Status() + if code != http.StatusOK { + return w.writeError(code, data) + } + + return w.writeResponse(data) +} + func ListMiddleware() gin.HandlerFunc { return func(c *gin.Context) { w := &ListWriter{ @@ -693,6 +763,47 @@ func CompletionsMiddleware() gin.HandlerFunc { id: fmt.Sprintf("cmpl-%d", rand.Intn(999)), } + c.Writer = w + c.Next() + } +} + +func EmbeddingsMiddleware() gin.HandlerFunc { + return func(c *gin.Context) { + var req EmbedRequest + err := c.ShouldBindJSON(&req) + if err != nil { + c.AbortWithStatusJSON(http.StatusBadRequest, NewError(http.StatusBadRequest, err.Error())) + return + } + + if req.Input == "" { + req.Input = []string{""} + } + + if req.Input == nil { + c.AbortWithStatusJSON(http.StatusBadRequest, NewError(http.StatusBadRequest, "invalid input")) + return + } + + if v, ok := req.Input.([]any); ok && len(v) == 0 { + c.AbortWithStatusJSON(http.StatusBadRequest, NewError(http.StatusBadRequest, "invalid input")) + return + } + + var b bytes.Buffer + if err := json.NewEncoder(&b).Encode(api.EmbedRequest{Model: req.Model, Input: req.Input}); err != nil { + c.AbortWithStatusJSON(http.StatusInternalServerError, NewError(http.StatusInternalServerError, err.Error())) + return + } + + c.Request.Body = io.NopCloser(&b) + + w := &EmbedWriter{ + BaseWriter: BaseWriter{ResponseWriter: c.Writer}, + model: req.Model, + } + c.Writer = w c.Next() diff --git a/openai/openai_test.go b/openai/openai_test.go index 99f8baaf..5fc22b88 100644 --- a/openai/openai_test.go +++ b/openai/openai_test.go @@ -161,6 +161,78 @@ func TestMiddlewareRequests(t *testing.T) { } }, }, + { + Name: "embed handler single input", + Method: http.MethodPost, + Path: "/api/embed", + Handler: EmbeddingsMiddleware, + Setup: func(t *testing.T, req *http.Request) { + body := EmbedRequest{ + Input: "Hello", + Model: "test-model", + } + + bodyBytes, _ := json.Marshal(body) + + req.Body = io.NopCloser(bytes.NewReader(bodyBytes)) + req.Header.Set("Content-Type", "application/json") + }, + Expected: func(t *testing.T, req *http.Request) { + var embedReq api.EmbedRequest + if err := json.NewDecoder(req.Body).Decode(&embedReq); err != nil { + t.Fatal(err) + } + + if embedReq.Input != "Hello" { + t.Fatalf("expected 'Hello', got %s", embedReq.Input) + } + + if embedReq.Model != "test-model" { + t.Fatalf("expected 'test-model', got %s", embedReq.Model) + } + }, + }, + { + Name: "embed handler batch input", + Method: http.MethodPost, + Path: "/api/embed", + Handler: EmbeddingsMiddleware, + Setup: func(t *testing.T, req *http.Request) { + body := EmbedRequest{ + Input: []string{"Hello", "World"}, + Model: "test-model", + } + + bodyBytes, _ := json.Marshal(body) + + req.Body = io.NopCloser(bytes.NewReader(bodyBytes)) + req.Header.Set("Content-Type", "application/json") + }, + Expected: func(t *testing.T, req *http.Request) { + var embedReq api.EmbedRequest + if err := json.NewDecoder(req.Body).Decode(&embedReq); err != nil { + t.Fatal(err) + } + + input, ok := embedReq.Input.([]any) + + if !ok { + t.Fatalf("expected input to be a list") + } + + if input[0].(string) != "Hello" { + t.Fatalf("expected 'Hello', got %s", input[0]) + } + + if input[1].(string) != "World" { + t.Fatalf("expected 'World', got %s", input[1]) + } + + if embedReq.Model != "test-model" { + t.Fatalf("expected 'test-model', got %s", embedReq.Model) + } + }, + }, } gin.SetMode(gin.TestMode) diff --git a/server/routes.go b/server/routes.go index d0cbe6cc..d22a099a 100644 --- a/server/routes.go +++ b/server/routes.go @@ -1064,6 +1064,7 @@ func (s *Server) GenerateRoutes() http.Handler { // Compatibility endpoints r.POST("/v1/chat/completions", openai.ChatMiddleware(), s.ChatHandler) r.POST("/v1/completions", openai.CompletionsMiddleware(), s.GenerateHandler) + r.POST("/v1/embeddings", openai.EmbeddingsMiddleware(), s.EmbedHandler) r.GET("/v1/models", openai.ListMiddleware(), s.ListModelsHandler) r.GET("/v1/models/:model", openai.RetrieveMiddleware(), s.ShowModelHandler) From 5a83f79afdf14b60008120d6fbd4fe94ba3f5241 Mon Sep 17 00:00:00 2001 From: Michael Yang Date: Tue, 16 Jul 2024 13:48:38 -0700 Subject: [PATCH 07/29] remove unneeded tool calls --- api/types.go | 2 -- server/model.go | 7 +------ server/model_test.go | 7 ------- 3 files changed, 1 insertion(+), 15 deletions(-) diff --git a/api/types.go b/api/types.go index e670d114..b18ee228 100644 --- a/api/types.go +++ b/api/types.go @@ -115,8 +115,6 @@ type Message struct { } type ToolCall struct { - ID string `json:"id"` - Type string `json:"type"` Function struct { Name string `json:"name"` Arguments map[string]any `json:"arguments"` diff --git a/server/model.go b/server/model.go index 9e22d63a..de65d6b6 100644 --- a/server/model.go +++ b/server/model.go @@ -16,7 +16,6 @@ import ( "strings" "text/template/parse" - "github.com/google/uuid" "github.com/ollama/ollama/api" "github.com/ollama/ollama/convert" "github.com/ollama/ollama/llm" @@ -363,11 +362,7 @@ func (m *Model) parseToolCalls(s string) ([]api.ToolCall, bool) { var toolCalls []api.ToolCall for _, kv := range objs { - call := api.ToolCall{ - ID: uuid.New().String(), - Type: "function", - } - + var call api.ToolCall for k, v := range kv { switch k { case name: diff --git a/server/model_test.go b/server/model_test.go index d39f2891..2e9dad3d 100644 --- a/server/model_test.go +++ b/server/model_test.go @@ -181,7 +181,6 @@ The temperature in San Francisco, CA is 70°F and in Toronto, Canada is 20°C.`, calls := []api.ToolCall{ { - Type: "function", Function: function{ Name: "get_current_weather", Arguments: map[string]any{ @@ -191,7 +190,6 @@ The temperature in San Francisco, CA is 70°F and in Toronto, Canada is 20°C.`, }, }, { - Type: "function", Function: function{ Name: "get_current_weather", Arguments: map[string]any{ @@ -228,11 +226,6 @@ The temperature in San Francisco, CA is 70°F and in Toronto, Canada is 20°C.`, } if tt.ok { - for i := range actual { - // ID is randomly generated so clear it for comparison - actual[i].ID = "" - } - if diff := cmp.Diff(actual, calls); diff != "" { t.Errorf("mismatch (-got +want):\n%s", diff) } From 97c20ede33ca67f8bf21e55ffe318831d83290d0 Mon Sep 17 00:00:00 2001 From: Thorsten Sommer Date: Tue, 16 Jul 2024 23:24:27 +0200 Subject: [PATCH 08/29] README: Added AI Studio to the list of UIs (#5721) * Added AI Studio to the list of UIs --- README.md | 1 + 1 file changed, 1 insertion(+) diff --git a/README.md b/README.md index eb5e8532..40f4aede 100644 --- a/README.md +++ b/README.md @@ -294,6 +294,7 @@ See the [API documentation](./docs/api.md) for all endpoints. - [LLocal.in](https://github.com/kartikm7/llocal) (Easy to use Electron Desktop Client for Ollama) - [Ollama with Google Mesop](https://github.com/rapidarchitect/ollama_mesop/) (Mesop Chat Client implementation with Ollama) - [Kerlig AI](https://www.kerlig.com/) (AI writing assistant for macOS) +- [AI Studio](https://github.com/MindWorkAI/AI-Studio) ### Terminal From d290e87513664be8ca3120348614d124991ccb86 Mon Sep 17 00:00:00 2001 From: Michael Yang Date: Thu, 20 Jun 2024 19:13:36 -0700 Subject: [PATCH 09/29] add suffix support to generate endpoint this change is triggered by the presence of "suffix", particularly useful for code completion tasks --- api/types.go | 3 ++ server/images.go | 17 ++++++-- server/routes.go | 40 +++++++++++------- server/routes_generate_test.go | 77 ++++++++++++++++++++++++++++++---- template/template.go | 10 ++++- template/template_test.go | 35 ++++++++++++++++ 6 files changed, 155 insertions(+), 27 deletions(-) diff --git a/api/types.go b/api/types.go index e670d114..3029fca8 100644 --- a/api/types.go +++ b/api/types.go @@ -47,6 +47,9 @@ type GenerateRequest struct { // Prompt is the textual prompt to send to the model. Prompt string `json:"prompt"` + // Suffix is the text that comes after the inserted text. + Suffix string `json:"suffix"` + // System overrides the model's default system message/prompt. System string `json:"system"` diff --git a/server/images.go b/server/images.go index 1b87888e..5e4e8858 100644 --- a/server/images.go +++ b/server/images.go @@ -34,13 +34,19 @@ import ( "github.com/ollama/ollama/version" ) -var errCapabilityCompletion = errors.New("completion") +var ( + errCapabilities = errors.New("does not support") + errCapabilityCompletion = errors.New("completion") + errCapabilityTools = errors.New("tools") + errCapabilityInsert = errors.New("insert") +) type Capability string const ( CapabilityCompletion = Capability("completion") CapabilityTools = Capability("tools") + CapabilityInsert = Capability("insert") ) type registryOptions struct { @@ -93,7 +99,12 @@ func (m *Model) CheckCapabilities(caps ...Capability) error { } case CapabilityTools: if !slices.Contains(m.Template.Vars(), "tools") { - errs = append(errs, errors.New("tools")) + errs = append(errs, errCapabilityTools) + } + case CapabilityInsert: + vars := m.Template.Vars() + if !slices.Contains(vars, "suffix") { + errs = append(errs, errCapabilityInsert) } default: slog.Error("unknown capability", "capability", cap) @@ -102,7 +113,7 @@ func (m *Model) CheckCapabilities(caps ...Capability) error { } if err := errors.Join(errs...); err != nil { - return fmt.Errorf("does not support %w", errors.Join(errs...)) + return fmt.Errorf("%w %w", errCapabilities, errors.Join(errs...)) } return nil diff --git a/server/routes.go b/server/routes.go index d22a099a..c7f74fa4 100644 --- a/server/routes.go +++ b/server/routes.go @@ -122,6 +122,10 @@ func (s *Server) GenerateHandler(c *gin.Context) { } caps := []Capability{CapabilityCompletion} + if req.Suffix != "" { + caps = append(caps, CapabilityInsert) + } + r, m, opts, err := s.scheduleRunner(c.Request.Context(), req.Model, caps, req.Options, req.KeepAlive) if errors.Is(err, errCapabilityCompletion) { c.JSON(http.StatusBadRequest, gin.H{"error": fmt.Sprintf("%q does not support generate", req.Model)}) @@ -150,19 +154,6 @@ func (s *Server) GenerateHandler(c *gin.Context) { prompt := req.Prompt if !req.Raw { - var msgs []api.Message - if req.System != "" { - msgs = append(msgs, api.Message{Role: "system", Content: req.System}) - } else if m.System != "" { - msgs = append(msgs, api.Message{Role: "system", Content: m.System}) - } - - for _, i := range images { - msgs = append(msgs, api.Message{Role: "user", Content: fmt.Sprintf("[img-%d]", i.ID)}) - } - - msgs = append(msgs, api.Message{Role: "user", Content: req.Prompt}) - tmpl := m.Template if req.Template != "" { tmpl, err = template.Parse(req.Template) @@ -183,7 +174,26 @@ func (s *Server) GenerateHandler(c *gin.Context) { b.WriteString(s) } - if err := tmpl.Execute(&b, template.Values{Messages: msgs}); err != nil { + var values template.Values + if req.Suffix != "" { + values.Prompt = prompt + values.Suffix = req.Suffix + } else { + var msgs []api.Message + if req.System != "" { + msgs = append(msgs, api.Message{Role: "system", Content: req.System}) + } else if m.System != "" { + msgs = append(msgs, api.Message{Role: "system", Content: m.System}) + } + + for _, i := range images { + msgs = append(msgs, api.Message{Role: "user", Content: fmt.Sprintf("[img-%d]", i.ID)}) + } + + values.Messages = append(msgs, api.Message{Role: "user", Content: req.Prompt}) + } + + if err := tmpl.Execute(&b, values); err != nil { c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()}) return } @@ -1394,7 +1404,7 @@ func (s *Server) ChatHandler(c *gin.Context) { func handleScheduleError(c *gin.Context, name string, err error) { switch { - case errors.Is(err, errRequired): + case errors.Is(err, errCapabilities), errors.Is(err, errRequired): c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()}) case errors.Is(err, context.Canceled): c.JSON(499, gin.H{"error": "request canceled"}) diff --git a/server/routes_generate_test.go b/server/routes_generate_test.go index 9d899328..c914b300 100644 --- a/server/routes_generate_test.go +++ b/server/routes_generate_test.go @@ -73,6 +73,8 @@ func TestGenerateChat(t *testing.T) { getCpuFn: gpu.GetCPUInfo, reschedDelay: 250 * time.Millisecond, loadFn: func(req *LlmRequest, ggml *llm.GGML, gpus gpu.GpuInfoList, numParallel int) { + // add 10ms delay to simulate loading + time.Sleep(10 * time.Millisecond) req.successCh <- &runnerRef{ llama: &mock, } @@ -83,7 +85,7 @@ func TestGenerateChat(t *testing.T) { go s.sched.Run(context.TODO()) w := createRequest(t, s.CreateModelHandler, api.CreateRequest{ - Name: "test", + Model: "test", Modelfile: fmt.Sprintf(`FROM %s TEMPLATE """ {{- if .System }}System: {{ .System }} {{ end }} @@ -141,9 +143,9 @@ func TestGenerateChat(t *testing.T) { } }) - t.Run("missing capabilities", func(t *testing.T) { + t.Run("missing capabilities chat", func(t *testing.T) { w := createRequest(t, s.CreateModelHandler, api.CreateRequest{ - Name: "bert", + Model: "bert", Modelfile: fmt.Sprintf("FROM %s", createBinFile(t, llm.KV{ "general.architecture": "bert", "bert.pooling_type": uint32(0), @@ -243,7 +245,7 @@ func TestGenerateChat(t *testing.T) { } if actual.TotalDuration == 0 { - t.Errorf("expected load duration > 0, got 0") + t.Errorf("expected total duration > 0, got 0") } } @@ -379,7 +381,7 @@ func TestGenerate(t *testing.T) { go s.sched.Run(context.TODO()) w := createRequest(t, s.CreateModelHandler, api.CreateRequest{ - Name: "test", + Model: "test", Modelfile: fmt.Sprintf(`FROM %s TEMPLATE """ {{- if .System }}System: {{ .System }} {{ end }} @@ -437,9 +439,9 @@ func TestGenerate(t *testing.T) { } }) - t.Run("missing capabilities", func(t *testing.T) { + t.Run("missing capabilities generate", func(t *testing.T) { w := createRequest(t, s.CreateModelHandler, api.CreateRequest{ - Name: "bert", + Model: "bert", Modelfile: fmt.Sprintf("FROM %s", createBinFile(t, llm.KV{ "general.architecture": "bert", "bert.pooling_type": uint32(0), @@ -464,6 +466,22 @@ func TestGenerate(t *testing.T) { } }) + t.Run("missing capabilities suffix", func(t *testing.T) { + w := createRequest(t, s.GenerateHandler, api.GenerateRequest{ + Model: "test", + Prompt: "def add(", + Suffix: " return c", + }) + + if w.Code != http.StatusBadRequest { + t.Errorf("expected status 400, got %d", w.Code) + } + + if diff := cmp.Diff(w.Body.String(), `{"error":"test does not support insert"}`); diff != "" { + t.Errorf("mismatch (-got +want):\n%s", diff) + } + }) + t.Run("load model", func(t *testing.T) { w := createRequest(t, s.GenerateHandler, api.GenerateRequest{ Model: "test", @@ -540,7 +558,7 @@ func TestGenerate(t *testing.T) { } if actual.TotalDuration == 0 { - t.Errorf("expected load duration > 0, got 0") + t.Errorf("expected total duration > 0, got 0") } } @@ -632,6 +650,49 @@ func TestGenerate(t *testing.T) { checkGenerateResponse(t, w.Body, "test-system", "Abra kadabra!") }) + w = createRequest(t, s.CreateModelHandler, api.CreateRequest{ + Model: "test-suffix", + Modelfile: `FROM test +TEMPLATE """{{- if .Suffix }}
 {{ .Prompt }} {{ .Suffix }} 
+{{- else }}{{ .Prompt }}
+{{- end }}"""`,
+	})
+
+	if w.Code != http.StatusOK {
+		t.Fatalf("expected status 200, got %d", w.Code)
+	}
+
+	t.Run("prompt with suffix", func(t *testing.T) {
+		w := createRequest(t, s.GenerateHandler, api.GenerateRequest{
+			Model:  "test-suffix",
+			Prompt: "def add(",
+			Suffix: "    return c",
+		})
+
+		if w.Code != http.StatusOK {
+			t.Errorf("expected status 200, got %d", w.Code)
+		}
+
+		if diff := cmp.Diff(mock.CompletionRequest.Prompt, "
 def add(     return c "); diff != "" {
+			t.Errorf("mismatch (-got +want):\n%s", diff)
+		}
+	})
+
+	t.Run("prompt without suffix", func(t *testing.T) {
+		w := createRequest(t, s.GenerateHandler, api.GenerateRequest{
+			Model:  "test-suffix",
+			Prompt: "def add(",
+		})
+
+		if w.Code != http.StatusOK {
+			t.Errorf("expected status 200, got %d", w.Code)
+		}
+
+		if diff := cmp.Diff(mock.CompletionRequest.Prompt, "def add("); diff != "" {
+			t.Errorf("mismatch (-got +want):\n%s", diff)
+		}
+	})
+
 	t.Run("raw", func(t *testing.T) {
 		w := createRequest(t, s.GenerateHandler, api.GenerateRequest{
 			Model:  "test-system",
diff --git a/template/template.go b/template/template.go
index 7cdb30ef..5330c0fa 100644
--- a/template/template.go
+++ b/template/template.go
@@ -151,6 +151,8 @@ func (t *Template) Vars() []string {
 type Values struct {
 	Messages []api.Message
 	Tools    []api.Tool
+	Prompt   string
+	Suffix   string
 
 	// forceLegacy is a flag used to test compatibility with legacy templates
 	forceLegacy bool
@@ -204,7 +206,13 @@ func (t *Template) Subtree(fn func(parse.Node) bool) *template.Template {
 
 func (t *Template) Execute(w io.Writer, v Values) error {
 	system, messages := collate(v.Messages)
-	if !v.forceLegacy && slices.Contains(t.Vars(), "messages") {
+	if v.Prompt != "" && v.Suffix != "" {
+		return t.Template.Execute(w, map[string]any{
+			"Prompt":   v.Prompt,
+			"Suffix":   v.Suffix,
+			"Response": "",
+		})
+	} else if !v.forceLegacy && slices.Contains(t.Vars(), "messages") {
 		return t.Template.Execute(w, map[string]any{
 			"System":   system,
 			"Messages": messages,
diff --git a/template/template_test.go b/template/template_test.go
index c678f1b1..ae0db80b 100644
--- a/template/template_test.go
+++ b/template/template_test.go
@@ -359,3 +359,38 @@ Answer: `,
 		})
 	}
 }
+
+func TestExecuteWithSuffix(t *testing.T) {
+	tmpl, err := Parse(`{{- if .Suffix }}
 {{ .Prompt }} {{ .Suffix }} 
+{{- else }}{{ .Prompt }}
+{{- end }}`)
+	if err != nil {
+		t.Fatal(err)
+	}
+
+	cases := []struct {
+		name   string
+		values Values
+		expect string
+	}{
+		{
+			"message", Values{Messages: []api.Message{{Role: "user", Content: "hello"}}}, "hello",
+		},
+		{
+			"prompt suffix", Values{Prompt: "def add(", Suffix: "return x"}, "
 def add( return x ",
+		},
+	}
+
+	for _, tt := range cases {
+		t.Run(tt.name, func(t *testing.T) {
+			var b bytes.Buffer
+			if err := tmpl.Execute(&b, tt.values); err != nil {
+				t.Fatal(err)
+			}
+
+			if diff := cmp.Diff(b.String(), tt.expect); diff != "" {
+				t.Errorf("mismatch (-got +want):\n%s", diff)
+			}
+		})
+	}
+}

From c279f96371575484d212ab069db96ee38dddceb1 Mon Sep 17 00:00:00 2001
From: Michael Yang 
Date: Tue, 16 Jul 2024 14:51:19 -0700
Subject: [PATCH 10/29] remove ToolCall from GenerateResponse

---
 api/types.go     | 3 ---
 server/routes.go | 5 -----
 2 files changed, 8 deletions(-)

diff --git a/api/types.go b/api/types.go
index c3f0bae0..e687b8a4 100644
--- a/api/types.go
+++ b/api/types.go
@@ -405,9 +405,6 @@ type GenerateResponse struct {
 	// Response is the textual response itself.
 	Response string `json:"response"`
 
-	// ToolCalls is the list of tools the model wants to call
-	ToolCalls []ToolCall `json:"tool_calls,omitempty"`
-
 	// Done specifies if the response is complete.
 	Done bool `json:"done"`
 
diff --git a/server/routes.go b/server/routes.go
index c7f74fa4..b4a8b4ac 100644
--- a/server/routes.go
+++ b/server/routes.go
@@ -275,11 +275,6 @@ func (s *Server) GenerateHandler(c *gin.Context) {
 		}
 
 		r.Response = sb.String()
-		if toolCalls, ok := m.parseToolCalls(sb.String()); ok {
-			r.ToolCalls = toolCalls
-			r.Response = ""
-		}
-
 		c.JSON(http.StatusOK, r)
 		return
 	}

From 0d41623b52b77725624a9f4dbc4c0f9a356a9021 Mon Sep 17 00:00:00 2001
From: royjhan <65097070+royjhan@users.noreply.github.com>
Date: Tue, 16 Jul 2024 20:50:14 -0700
Subject: [PATCH 11/29] OpenAI: Add Suffix to `v1/completions` (#5611)

* add suffix

* remove todo

* remove TODO

* add to test

* rm outdated prompt tokens info md

* fix test

* fix test
---
 docs/openai.md        | 4 ----
 openai/openai.go      | 4 ++--
 openai/openai_test.go | 5 +++++
 3 files changed, 7 insertions(+), 6 deletions(-)

diff --git a/docs/openai.md b/docs/openai.md
index 9dda05c3..248ba74a 100644
--- a/docs/openai.md
+++ b/docs/openai.md
@@ -103,10 +103,6 @@ curl http://localhost:11434/v1/chat/completions \
 - [ ] `user`
 - [ ] `n`
 
-#### Notes
-
-- `usage.prompt_tokens` will be 0 for completions where prompt evaluation is cached
-
 ## Models
 
 Before using a model, pull it locally `ollama pull`:
diff --git a/openai/openai.go b/openai/openai.go
index 88bdaec1..81e4011d 100644
--- a/openai/openai.go
+++ b/openai/openai.go
@@ -111,6 +111,7 @@ type CompletionRequest struct {
 	Stream           bool     `json:"stream"`
 	Temperature      *float32 `json:"temperature"`
 	TopP             float32  `json:"top_p"`
+	Suffix           string   `json:"suffix"`
 }
 
 type Completion struct {
@@ -188,7 +189,6 @@ func toChatCompletion(id string, r api.ChatResponse) ChatCompletion {
 			}(r.DoneReason),
 		}},
 		Usage: Usage{
-			// TODO: ollama returns 0 for prompt eval if the prompt was cached, but openai returns the actual count
 			PromptTokens:     r.PromptEvalCount,
 			CompletionTokens: r.EvalCount,
 			TotalTokens:      r.PromptEvalCount + r.EvalCount,
@@ -234,7 +234,6 @@ func toCompletion(id string, r api.GenerateResponse) Completion {
 			}(r.DoneReason),
 		}},
 		Usage: Usage{
-			// TODO: ollama returns 0 for prompt eval if the prompt was cached, but openai returns the actual count
 			PromptTokens:     r.PromptEvalCount,
 			CompletionTokens: r.EvalCount,
 			TotalTokens:      r.PromptEvalCount + r.EvalCount,
@@ -475,6 +474,7 @@ func fromCompleteRequest(r CompletionRequest) (api.GenerateRequest, error) {
 		Prompt:  r.Prompt,
 		Options: options,
 		Stream:  &r.Stream,
+		Suffix:  r.Suffix,
 	}, nil
 }
 
diff --git a/openai/openai_test.go b/openai/openai_test.go
index 5fc22b88..046ee69c 100644
--- a/openai/openai_test.go
+++ b/openai/openai_test.go
@@ -85,6 +85,7 @@ func TestMiddlewareRequests(t *testing.T) {
 					Prompt:      "Hello",
 					Temperature: &temp,
 					Stop:        []string{"\n", "stop"},
+					Suffix:      "suffix",
 				}
 
 				bodyBytes, _ := json.Marshal(body)
@@ -115,6 +116,10 @@ func TestMiddlewareRequests(t *testing.T) {
 				if stopTokens[0] != "\n" || stopTokens[1] != "stop" {
 					t.Fatalf("expected ['\\n', 'stop'], got %v", stopTokens)
 				}
+
+				if genReq.Suffix != "suffix" {
+					t.Fatalf("expected 'suffix', got %s", genReq.Suffix)
+				}
 			},
 		},
 		{

From 154f6f45d4acd4ea1f2e35cac3b90eb6faeea6bd Mon Sep 17 00:00:00 2001
From: royjhan <65097070+royjhan@users.noreply.github.com>
Date: Tue, 16 Jul 2024 20:52:59 -0700
Subject: [PATCH 12/29] OpenAI: Support Tools (#5614)

* reopen pr

* tools

* remove tc from stream for now

* ID and Function

* openai expects arguments to be a string (#5739)

* mutually exclusive content and tool calls

* clean up

---------

Co-authored-by: Jeffrey Morgan 
---
 openai/openai.go | 57 ++++++++++++++++++++++++++++++++++++++++++++----
 1 file changed, 53 insertions(+), 4 deletions(-)

diff --git a/openai/openai.go b/openai/openai.go
index 81e4011d..01864e48 100644
--- a/openai/openai.go
+++ b/openai/openai.go
@@ -7,6 +7,7 @@ import (
 	"encoding/json"
 	"fmt"
 	"io"
+	"log/slog"
 	"math/rand"
 	"net/http"
 	"strings"
@@ -29,8 +30,9 @@ type ErrorResponse struct {
 }
 
 type Message struct {
-	Role    string `json:"role"`
-	Content any    `json:"content"`
+	Role      string     `json:"role"`
+	Content   any        `json:"content"`
+	ToolCalls []ToolCall `json:"tool_calls,omitempty"`
 }
 
 type Choice struct {
@@ -78,6 +80,7 @@ type ChatCompletionRequest struct {
 	PresencePenalty  *float64        `json:"presence_penalty_penalty"`
 	TopP             *float64        `json:"top_p"`
 	ResponseFormat   *ResponseFormat `json:"response_format"`
+	Tools            []api.Tool      `json:"tools"`
 }
 
 type ChatCompletion struct {
@@ -133,6 +136,15 @@ type CompletionChunk struct {
 	SystemFingerprint string                `json:"system_fingerprint"`
 }
 
+type ToolCall struct {
+	ID       string `json:"id"`
+	Type     string `json:"type"`
+	Function struct {
+		Name      string `json:"name"`
+		Arguments string `json:"arguments"`
+	} `json:"function"`
+}
+
 type Model struct {
 	Id      string `json:"id"`
 	Object  string `json:"object"`
@@ -171,7 +183,31 @@ func NewError(code int, message string) ErrorResponse {
 	return ErrorResponse{Error{Type: etype, Message: message}}
 }
 
+func toolCallId() string {
+	const letterBytes = "abcdefghijklmnopqrstuvwxyz0123456789"
+	b := make([]byte, 8)
+	for i := range b {
+		b[i] = letterBytes[rand.Intn(len(letterBytes))]
+	}
+	return "call_" + strings.ToLower(string(b))
+}
+
 func toChatCompletion(id string, r api.ChatResponse) ChatCompletion {
+	toolCalls := make([]ToolCall, len(r.Message.ToolCalls))
+	for i, tc := range r.Message.ToolCalls {
+		toolCalls[i].ID = toolCallId()
+		toolCalls[i].Type = "function"
+		toolCalls[i].Function.Name = tc.Function.Name
+
+		args, err := json.Marshal(tc.Function.Arguments)
+		if err != nil {
+			slog.Error("could not marshall function arguments to json", "error", err)
+			continue
+		}
+
+		toolCalls[i].Function.Arguments = string(args)
+	}
+
 	return ChatCompletion{
 		Id:                id,
 		Object:            "chat.completion",
@@ -180,7 +216,7 @@ func toChatCompletion(id string, r api.ChatResponse) ChatCompletion {
 		SystemFingerprint: "fp_ollama",
 		Choices: []Choice{{
 			Index:   0,
-			Message: Message{Role: r.Message.Role, Content: r.Message.Content},
+			Message: Message{Role: r.Message.Role, Content: r.Message.Content, ToolCalls: toolCalls},
 			FinishReason: func(reason string) *string {
 				if len(reason) > 0 {
 					return &reason
@@ -366,7 +402,19 @@ func fromChatRequest(r ChatCompletionRequest) (*api.ChatRequest, error) {
 			}
 			messages = append(messages, message)
 		default:
-			return nil, fmt.Errorf("invalid message content type: %T", content)
+			if msg.ToolCalls == nil {
+				return nil, fmt.Errorf("invalid message content type: %T", content)
+			}
+
+			toolCalls := make([]api.ToolCall, len(msg.ToolCalls))
+			for i, tc := range msg.ToolCalls {
+				toolCalls[i].Function.Name = tc.Function.Name
+				err := json.Unmarshal([]byte(tc.Function.Arguments), &toolCalls[i].Function.Arguments)
+				if err != nil {
+					return nil, fmt.Errorf("invalid tool call arguments")
+				}
+			}
+			messages = append(messages, api.Message{Role: msg.Role, ToolCalls: toolCalls})
 		}
 	}
 
@@ -424,6 +472,7 @@ func fromChatRequest(r ChatCompletionRequest) (*api.ChatRequest, error) {
 		Format:   format,
 		Options:  options,
 		Stream:   &r.Stream,
+		Tools:    r.Tools,
 	}, nil
 }
 

From d281a6e6036cf122fc9828ca93c58d7d3d57502f Mon Sep 17 00:00:00 2001
From: =?UTF-8?q?P=C3=A1kozdi=20Gy=C3=B6rgy?= 
Date: Wed, 17 Jul 2024 19:24:44 +0200
Subject: [PATCH 13/29] add sidellama link (#5702)

---
 README.md | 1 +
 1 file changed, 1 insertion(+)

diff --git a/README.md b/README.md
index 40f4aede..b96f4c16 100644
--- a/README.md
+++ b/README.md
@@ -295,6 +295,7 @@ See the [API documentation](./docs/api.md) for all endpoints.
 - [Ollama with Google Mesop](https://github.com/rapidarchitect/ollama_mesop/) (Mesop Chat Client implementation with Ollama)
 - [Kerlig AI](https://www.kerlig.com/) (AI writing assistant for macOS)
 - [AI Studio](https://github.com/MindWorkAI/AI-Studio)
+- [Sidellama](https://github.com/gyopak/sidellama) (browser-based LLM client)
 
 ### Terminal
 

From 5b82960df840a8bd545b9a60a1b69c089e0e24f1 Mon Sep 17 00:00:00 2001
From: Michael Yang 
Date: Wed, 17 Jul 2024 10:39:22 -0700
Subject: [PATCH 14/29] stub response (#5750)

---
 template/template.go | 6 ++++--
 1 file changed, 4 insertions(+), 2 deletions(-)

diff --git a/template/template.go b/template/template.go
index 5330c0fa..85b4d21a 100644
--- a/template/template.go
+++ b/template/template.go
@@ -217,6 +217,7 @@ func (t *Template) Execute(w io.Writer, v Values) error {
 			"System":   system,
 			"Messages": messages,
 			"Tools":    v.Tools,
+			"Response": "",
 		})
 	}
 
@@ -270,8 +271,9 @@ func (t *Template) Execute(w io.Writer, v Values) error {
 
 	tree := parse.Tree{Root: nodes.(*parse.ListNode)}
 	if err := template.Must(template.New("").AddParseTree("", &tree)).Execute(&b, map[string]any{
-		"System": system,
-		"Prompt": prompt,
+		"System":   system,
+		"Prompt":   prompt,
+		"Response": "",
 	}); err != nil {
 		return err
 	}

From 5fd698812655fb87e31262072f8a3e6a34b438a9 Mon Sep 17 00:00:00 2001
From: Michael Yang 
Date: Wed, 17 Jul 2024 11:02:36 -0700
Subject: [PATCH 15/29] parse tool call as individual objects

---
 server/model.go                               |  8 ++--
 server/model_test.go                          |  4 ++
 .../tools/llama3-groq-tool-use.gotmpl         | 43 +++++++++++++++++++
 .../testdata/tools/llama3-groq-tool-use.out   | 24 +++++++++++
 4 files changed, 76 insertions(+), 3 deletions(-)
 create mode 100644 server/testdata/tools/llama3-groq-tool-use.gotmpl
 create mode 100644 server/testdata/tools/llama3-groq-tool-use.out

diff --git a/server/model.go b/server/model.go
index de65d6b6..e5d6179b 100644
--- a/server/model.go
+++ b/server/model.go
@@ -344,7 +344,9 @@ func (m *Model) parseToolCalls(s string) ([]api.ToolCall, bool) {
 
 	var objs []map[string]any
 	for offset := 0; offset < len(s); {
-		if err := json.NewDecoder(strings.NewReader(s[offset:])).Decode(&objs); errors.Is(err, io.EOF) {
+		var obj map[string]any
+		decoder := json.NewDecoder(strings.NewReader(s[offset:]))
+		if err := decoder.Decode(&obj); errors.Is(err, io.EOF) {
 			break
 		} else if syntax := &(json.SyntaxError{}); errors.As(err, &syntax) {
 			// skip over any syntax errors
@@ -355,8 +357,8 @@ func (m *Model) parseToolCalls(s string) ([]api.ToolCall, bool) {
 		} else if err != nil {
 			return nil, false
 		} else {
-			// break when an object is decoded
-			break
+			offset += int(decoder.InputOffset())
+			objs = append(objs, obj)
 		}
 	}
 
diff --git a/server/model_test.go b/server/model_test.go
index 2e9dad3d..f0382843 100644
--- a/server/model_test.go
+++ b/server/model_test.go
@@ -167,6 +167,10 @@ The temperature in San Francisco, CA is 70°F and in Toronto, Canada is 20°C.`,
 		{"command-r-plus", " The weather in San Francisco, CA is 70°F and in Toronto, Canada is 20°C.", false},
 		{"firefunction", ` functools[{"name": "get_current_weather", "arguments": {"format":"fahrenheit","location":"San Francisco, CA"}},{"name": "get_current_weather", "arguments": {"format":"celsius","location":"Toronto, Canada"}}]`, true},
 		{"firefunction", " The weather in San Francisco, CA is 70°F and in Toronto, Canada is 20°C.", false},
+		{"llama3-groq-tool-use", `
+{"name": "get_current_weather", "arguments": {"format":"fahrenheit","location":"San Francisco, CA"}}
+{"name": "get_current_weather", "arguments": {"format":"celsius","location":"Toronto, Canada"}}
+`, true},
 	}
 
 	var tools []api.Tool
diff --git a/server/testdata/tools/llama3-groq-tool-use.gotmpl b/server/testdata/tools/llama3-groq-tool-use.gotmpl
new file mode 100644
index 00000000..e174f8a5
--- /dev/null
+++ b/server/testdata/tools/llama3-groq-tool-use.gotmpl
@@ -0,0 +1,43 @@
+{{- if .Messages }}
+{{- if or .System .Tools }}<|start_header_id|>system<|end_header_id|>
+
+{{ .System }}
+{{- if .Tools }} You are provided with function signatures within  XML tags. You may call one or more functions to assist with the user query. Don't make assumptions about what values to plug into functions. For each function call return a json object with function name and arguments within  XML tags as follows:
+
+{"name": ,"arguments": }
+
+
+Here are the available tools:
+
+{{- range .Tools }} {{ json .Function }}
+{{- end }} 
+{{- end }}
+{{- end }}<|eot_id|>
+{{- range .Messages }}
+{{- if ne .Role "system" }}<|start_header_id|>{{ .Role }}<|end_header_id|>
+
+{{ if eq .Role "user" }}{{ .Content }}
+{{- else if eq .Role "assistant" }}
+{{- if .Content }}{{ .Content }}
+{{- else if .ToolCalls }}
+{{ range .ToolCalls }}{"name": "{{ .Function.Name }}", "arguments": {{ json .Function.Arguments }}}
+{{- end }}
+
+{{- end }}
+{{- else if eq .Role "tool" }}
+{{ .Content }}
+
+{{- end }}<|eot_id|>
+{{- end }}
+{{- end }}<|start_header_id|>assistant<|end_header_id|>
+
+{{ else }}
+{{ if .System }}<|start_header_id|>system<|end_header_id|>
+
+{{ .System }}<|eot_id|>{{ end }}{{ if .Prompt }}<|start_header_id|>user<|end_header_id|>
+
+{{ .Prompt }}<|eot_id|>{{ end }}<|start_header_id|>assistant<|end_header_id|>
+
+{{ end }}{{ .Response }}
+{{- if .Response }}<|eot_id|>
+{{- end }}
\ No newline at end of file
diff --git a/server/testdata/tools/llama3-groq-tool-use.out b/server/testdata/tools/llama3-groq-tool-use.out
new file mode 100644
index 00000000..75a49558
--- /dev/null
+++ b/server/testdata/tools/llama3-groq-tool-use.out
@@ -0,0 +1,24 @@
+<|start_header_id|>system<|end_header_id|>
+
+You are a knowledgable assistant. You can answer questions and perform tasks. You are provided with function signatures within  XML tags. You may call one or more functions to assist with the user query. Don't make assumptions about what values to plug into functions. For each function call return a json object with function name and arguments within  XML tags as follows:
+
+{"name": ,"arguments": }
+
+
+Here are the available tools:
+ {"name":"get_current_weather","description":"Get the current weather","parameters":{"type":"object","required":["location","format"],"properties":{"format":{"type":"string","description":"The temperature unit to use. Infer this from the users location.","enum":["celsius","fahrenheit"]},"location":{"type":"string","description":"The city and state, e.g. San Francisco, CA"}}}} <|eot_id|><|start_header_id|>user<|end_header_id|>
+
+What's the weather like today in Paris?<|eot_id|><|start_header_id|>assistant<|end_header_id|>
+
+
+{"name": "get_current_weather", "arguments": {"format":"celsius","location":"Paris, France"}}
+<|eot_id|><|start_header_id|>tool<|end_header_id|>
+
+
+22
+<|eot_id|><|start_header_id|>assistant<|end_header_id|>
+
+The current temperature in Paris, France is 22 degrees Celsius.<|eot_id|><|start_header_id|>user<|end_header_id|>
+
+What's the weather like today in San Francisco and Toronto?<|eot_id|><|start_header_id|>assistant<|end_header_id|>
+

From b2554455572b28c0e18423d6fe6896cf7137dbd6 Mon Sep 17 00:00:00 2001
From: Michael Yang 
Date: Wed, 17 Jul 2024 15:35:11 -0700
Subject: [PATCH 16/29] marshal json automatically for some template values
 (#5758)

---
 api/types.go                                  | 73 ++++++++++++-------
 server/model.go                               | 18 +++--
 server/model_test.go                          | 13 +---
 server/testdata/tools/command-r-plus.gotmpl   |  2 +-
 server/testdata/tools/firefunction.gotmpl     |  4 +-
 .../tools/llama3-groq-tool-use.gotmpl         |  4 +-
 server/testdata/tools/mistral.gotmpl          |  4 +-
 template/template.go                          |  6 +-
 8 files changed, 72 insertions(+), 52 deletions(-)

diff --git a/api/types.go b/api/types.go
index e687b8a4..c7e9dce3 100644
--- a/api/types.go
+++ b/api/types.go
@@ -101,12 +101,19 @@ type ChatRequest struct {
 	KeepAlive *Duration `json:"keep_alive,omitempty"`
 
 	// Tools is an optional list of tools the model has access to.
-	Tools []Tool `json:"tools,omitempty"`
+	Tools `json:"tools,omitempty"`
 
 	// Options lists model-specific options.
 	Options map[string]interface{} `json:"options"`
 }
 
+type Tools []Tool
+
+func (t Tools) String() string {
+	bts, _ := json.Marshal(t)
+	return string(bts)
+}
+
 // Message is a single message in a chat sequence. The message contains the
 // role ("system", "user", or "assistant"), the content and an optional list
 // of images.
@@ -117,30 +124,6 @@ type Message struct {
 	ToolCalls []ToolCall  `json:"tool_calls,omitempty"`
 }
 
-type ToolCall struct {
-	Function struct {
-		Name      string         `json:"name"`
-		Arguments map[string]any `json:"arguments"`
-	} `json:"function"`
-}
-
-type Tool struct {
-	Type     string `json:"type"`
-	Function struct {
-		Name        string `json:"name"`
-		Description string `json:"description"`
-		Parameters  struct {
-			Type       string   `json:"type"`
-			Required   []string `json:"required"`
-			Properties map[string]struct {
-				Type        string   `json:"type"`
-				Description string   `json:"description"`
-				Enum        []string `json:"enum,omitempty"`
-			} `json:"properties"`
-		} `json:"parameters"`
-	} `json:"function"`
-}
-
 func (m *Message) UnmarshalJSON(b []byte) error {
 	type Alias Message
 	var a Alias
@@ -153,6 +136,46 @@ func (m *Message) UnmarshalJSON(b []byte) error {
 	return nil
 }
 
+type ToolCall struct {
+	Function ToolCallFunction `json:"function"`
+}
+
+type ToolCallFunction struct {
+	Name      string                    `json:"name"`
+	Arguments ToolCallFunctionArguments `json:"arguments"`
+}
+
+type ToolCallFunctionArguments map[string]any
+
+func (t *ToolCallFunctionArguments) String() string {
+	bts, _ := json.Marshal(t)
+	return string(bts)
+}
+
+type Tool struct {
+	Type     string       `json:"type"`
+	Function ToolFunction `json:"function"`
+}
+
+type ToolFunction struct {
+	Name        string `json:"name"`
+	Description string `json:"description"`
+	Parameters  struct {
+		Type       string   `json:"type"`
+		Required   []string `json:"required"`
+		Properties map[string]struct {
+			Type        string   `json:"type"`
+			Description string   `json:"description"`
+			Enum        []string `json:"enum,omitempty"`
+		} `json:"properties"`
+	} `json:"parameters"`
+}
+
+func (t *ToolFunction) String() string {
+	bts, _ := json.Marshal(t)
+	return string(bts)
+}
+
 // ChatResponse is the response returned by [Client.Chat]. Its fields are
 // similar to [GenerateResponse].
 type ChatResponse struct {
diff --git a/server/model.go b/server/model.go
index e5d6179b..65231ab1 100644
--- a/server/model.go
+++ b/server/model.go
@@ -311,12 +311,14 @@ func (m *Model) parseToolCalls(s string) ([]api.ToolCall, bool) {
 	}
 
 	var b bytes.Buffer
-	if err := tmpl.Execute(&b, map[string][]map[string]any{
+	if err := tmpl.Execute(&b, map[string][]api.ToolCall{
 		"ToolCalls": {
 			{
-				"Function": map[string]any{
-					"Name":      "@@name@@",
-					"Arguments": "@@arguments@@",
+				Function: api.ToolCallFunction{
+					Name: "@@name@@",
+					Arguments: api.ToolCallFunctionArguments{
+						"@@argument@@": 1,
+					},
 				},
 			},
 		},
@@ -324,7 +326,7 @@ func (m *Model) parseToolCalls(s string) ([]api.ToolCall, bool) {
 		return nil, false
 	}
 
-	var kv map[string]string
+	var kv map[string]any
 	// execute the subtree with placeholders to identify the keys
 	// trim any commands that might exist in the template
 	if err := json.Unmarshal(bytes.TrimSuffix(b.Bytes(), []byte(",")), &kv); err != nil {
@@ -334,10 +336,10 @@ func (m *Model) parseToolCalls(s string) ([]api.ToolCall, bool) {
 	// find the keys that correspond to the name and arguments fields
 	var name, arguments string
 	for k, v := range kv {
-		switch v {
-		case "@@name@@":
+		switch v.(type) {
+		case string:
 			name = k
-		case "@@arguments@@":
+		case map[string]any:
 			arguments = k
 		}
 	}
diff --git a/server/model_test.go b/server/model_test.go
index f0382843..7c826b06 100644
--- a/server/model_test.go
+++ b/server/model_test.go
@@ -115,11 +115,6 @@ func TestExtractFromZipFile(t *testing.T) {
 	}
 }
 
-type function struct {
-	Name      string         `json:"name"`
-	Arguments map[string]any `json:"arguments"`
-}
-
 func readFile(t *testing.T, base, name string) *bytes.Buffer {
 	t.Helper()
 
@@ -185,18 +180,18 @@ The temperature in San Francisco, CA is 70°F and in Toronto, Canada is 20°C.`,
 
 	calls := []api.ToolCall{
 		{
-			Function: function{
+			Function: api.ToolCallFunction{
 				Name: "get_current_weather",
-				Arguments: map[string]any{
+				Arguments: api.ToolCallFunctionArguments{
 					"format":   "fahrenheit",
 					"location": "San Francisco, CA",
 				},
 			},
 		},
 		{
-			Function: function{
+			Function: api.ToolCallFunction{
 				Name: "get_current_weather",
-				Arguments: map[string]any{
+				Arguments: api.ToolCallFunctionArguments{
 					"format":   "celsius",
 					"location": "Toronto, Canada",
 				},
diff --git a/server/testdata/tools/command-r-plus.gotmpl b/server/testdata/tools/command-r-plus.gotmpl
index 088a4f0e..f30124e3 100644
--- a/server/testdata/tools/command-r-plus.gotmpl
+++ b/server/testdata/tools/command-r-plus.gotmpl
@@ -46,7 +46,7 @@ Action: ```json
 {{- range .ToolCalls }}
     {
         "tool_name": "{{ .Function.Name }}",
-        "parameters": {{ json .Function.Arguments }}
+        "parameters": {{ .Function.Arguments }}
     }
 {{- end }}
 ]```
diff --git a/server/testdata/tools/firefunction.gotmpl b/server/testdata/tools/firefunction.gotmpl
index bca88b3b..312be205 100644
--- a/server/testdata/tools/firefunction.gotmpl
+++ b/server/testdata/tools/firefunction.gotmpl
@@ -17,7 +17,7 @@ If you decide to call functions:
 
 Available functions as JSON spec:
 {{- if .Tools }}
-{{ json .Tools }}
+{{ .Tools }}
 {{- end }}<|eot_id|>
 {{- end }}
 {{- range .Messages }}<|start_header_id|>
@@ -25,7 +25,7 @@ Available functions as JSON spec:
 {{- end }}<|end_header_id|>
 {{- if .Content }}{{ .Content }}
 {{- else if .ToolCalls }} functools[
-{{- range .ToolCalls }}{{ "{" }}"name": "{{ .Function.Name }}", "arguments": {{ json .Function.Arguments }}{{ "}" }}
+{{- range .ToolCalls }}{{ "{" }}"name": "{{ .Function.Name }}", "arguments": {{ .Function.Arguments }}{{ "}" }}
 {{- end }}]
 {{- end }}<|eot_id|>
 {{- end }}<|start_header_id|>assistant<|end_header_id|>
\ No newline at end of file
diff --git a/server/testdata/tools/llama3-groq-tool-use.gotmpl b/server/testdata/tools/llama3-groq-tool-use.gotmpl
index e174f8a5..45e9b462 100644
--- a/server/testdata/tools/llama3-groq-tool-use.gotmpl
+++ b/server/testdata/tools/llama3-groq-tool-use.gotmpl
@@ -9,7 +9,7 @@
 
 Here are the available tools:
 
-{{- range .Tools }} {{ json .Function }}
+{{- range .Tools }} {{ .Function }}
 {{- end }} 
 {{- end }}
 {{- end }}<|eot_id|>
@@ -20,7 +20,7 @@ Here are the available tools:
 {{- else if eq .Role "assistant" }}
 {{- if .Content }}{{ .Content }}
 {{- else if .ToolCalls }}
-{{ range .ToolCalls }}{"name": "{{ .Function.Name }}", "arguments": {{ json .Function.Arguments }}}
+{{ range .ToolCalls }}{"name": "{{ .Function.Name }}", "arguments": {{ .Function.Arguments }}}
 {{- end }}
 
 {{- end }}
diff --git a/server/testdata/tools/mistral.gotmpl b/server/testdata/tools/mistral.gotmpl
index a98bc7ad..b08d6c2c 100644
--- a/server/testdata/tools/mistral.gotmpl
+++ b/server/testdata/tools/mistral.gotmpl
@@ -1,13 +1,13 @@
 {{- range $index, $_ := .Messages }}
 {{- if eq .Role "user" }}
-{{- if and (eq (len (slice $.Messages $index)) 1) $.Tools }}[AVAILABLE_TOOLS] {{ json $.Tools }}[/AVAILABLE_TOOLS]
+{{- if and (eq (len (slice $.Messages $index)) 1) $.Tools }}[AVAILABLE_TOOLS] {{ $.Tools }}[/AVAILABLE_TOOLS]
 {{- end }}[INST] {{ if and (eq (len (slice $.Messages $index)) 1) $.System }}{{ $.System }}
 
 {{ end }}{{ .Content }}[/INST]
 {{- else if eq .Role "assistant" }}
 {{- if .Content }} {{ .Content }}
 {{- else if .ToolCalls }}[TOOL_CALLS] [
-{{- range .ToolCalls }}{"name": "{{ .Function.Name }}", "arguments": {{ json .Function.Arguments }}}
+{{- range .ToolCalls }}{"name": "{{ .Function.Name }}", "arguments": {{ .Function.Arguments }}}
 {{- end }}]
 {{- end }}
 {{- else if eq .Role "tool" }}[TOOL_RESULTS] {"content": {{ .Content }}}[/TOOL_RESULTS]
diff --git a/template/template.go b/template/template.go
index 85b4d21a..b5bfb16c 100644
--- a/template/template.go
+++ b/template/template.go
@@ -150,9 +150,9 @@ func (t *Template) Vars() []string {
 
 type Values struct {
 	Messages []api.Message
-	Tools    []api.Tool
-	Prompt   string
-	Suffix   string
+	api.Tools
+	Prompt string
+	Suffix string
 
 	// forceLegacy is a flag used to test compatibility with legacy templates
 	forceLegacy bool

From 319fb1ce033921b70685abce0b2e06f52304fe74 Mon Sep 17 00:00:00 2001
From: Jeffrey Morgan 
Date: Thu, 18 Jul 2024 08:50:23 -0700
Subject: [PATCH 17/29] server: only parse tool calls if tools are provided
 (#5771)

* server: only parse tool calls if tools are provided

* still set `resp.Message.Content`
---
 server/routes.go | 9 ++++++---
 1 file changed, 6 insertions(+), 3 deletions(-)

diff --git a/server/routes.go b/server/routes.go
index b4a8b4ac..347d5221 100644
--- a/server/routes.go
+++ b/server/routes.go
@@ -1385,9 +1385,12 @@ func (s *Server) ChatHandler(c *gin.Context) {
 		}
 
 		resp.Message.Content = sb.String()
-		if toolCalls, ok := m.parseToolCalls(sb.String()); ok {
-			resp.Message.ToolCalls = toolCalls
-			resp.Message.Content = ""
+
+		if len(req.Tools) > 0 {
+			if toolCalls, ok := m.parseToolCalls(sb.String()); ok {
+				resp.Message.ToolCalls = toolCalls
+				resp.Message.Content = ""
+			}
 		}
 
 		c.JSON(http.StatusOK, resp)

From 84e5721f3a2febbd8bc2600bedcaa6e67aaa9b49 Mon Sep 17 00:00:00 2001
From: Jeffrey Morgan 
Date: Thu, 18 Jul 2024 11:28:19 -0700
Subject: [PATCH 18/29] always provide content even if empty (#5778)

---
 api/types.go | 2 +-
 1 file changed, 1 insertion(+), 1 deletion(-)

diff --git a/api/types.go b/api/types.go
index c7e9dce3..65a99c76 100644
--- a/api/types.go
+++ b/api/types.go
@@ -119,7 +119,7 @@ func (t Tools) String() string {
 // of images.
 type Message struct {
 	Role      string      `json:"role"`
-	Content   string      `json:"content,omitempty"`
+	Content   string      `json:"content"`
 	Images    []ImageData `json:"images,omitempty"`
 	ToolCalls []ToolCall  `json:"tool_calls,omitempty"`
 }

From 70b1010fa53095b4f8699b04375fa29f8c3e54b0 Mon Sep 17 00:00:00 2001
From: Jeffrey Morgan 
Date: Thu, 18 Jul 2024 11:44:57 -0700
Subject: [PATCH 19/29] server: check for empty tools array too (#5779)

---
 server/routes.go | 2 +-
 1 file changed, 1 insertion(+), 1 deletion(-)

diff --git a/server/routes.go b/server/routes.go
index 347d5221..c33b7195 100644
--- a/server/routes.go
+++ b/server/routes.go
@@ -1290,7 +1290,7 @@ func (s *Server) ChatHandler(c *gin.Context) {
 	}
 
 	caps := []Capability{CapabilityCompletion}
-	if req.Tools != nil {
+	if len(req.Tools) > 0 {
 		caps = append(caps, CapabilityTools)
 	}
 

From 43606d6d6a0d5e4105692d6d1233cfaa96a562b0 Mon Sep 17 00:00:00 2001
From: Michael Yang 
Date: Thu, 18 Jul 2024 12:07:59 -0700
Subject: [PATCH 20/29] fix parsing tool calls

---
 server/model.go | 3 ++-
 1 file changed, 2 insertions(+), 1 deletion(-)

diff --git a/server/model.go b/server/model.go
index 65231ab1..a084dd8c 100644
--- a/server/model.go
+++ b/server/model.go
@@ -348,7 +348,7 @@ func (m *Model) parseToolCalls(s string) ([]api.ToolCall, bool) {
 	for offset := 0; offset < len(s); {
 		var obj map[string]any
 		decoder := json.NewDecoder(strings.NewReader(s[offset:]))
-		if err := decoder.Decode(&obj); errors.Is(err, io.EOF) {
+		if err := decoder.Decode(&obj); errors.Is(err, io.EOF) || errors.Is(err, io.ErrUnexpectedEOF) {
 			break
 		} else if syntax := &(json.SyntaxError{}); errors.As(err, &syntax) {
 			// skip over any syntax errors
@@ -357,6 +357,7 @@ func (m *Model) parseToolCalls(s string) ([]api.ToolCall, bool) {
 			// skip over any unmarshalable types
 			offset += int(unmarshalType.Offset)
 		} else if err != nil {
+			slog.Error("parseToolCalls", "error", err)
 			return nil, false
 		} else {
 			offset += int(decoder.InputOffset())

From 51b2fd299cd568093ce796aef3e7e37ae656b02a Mon Sep 17 00:00:00 2001
From: royjhan <65097070+royjhan@users.noreply.github.com>
Date: Fri, 19 Jul 2024 11:19:20 -0700
Subject: [PATCH 21/29] adjust openai chat msg processing (#5729)

---
 openai/openai.go      | 7 +++----
 openai/openai_test.go | 8 ++++++--
 2 files changed, 9 insertions(+), 6 deletions(-)

diff --git a/openai/openai.go b/openai/openai.go
index 01864e48..93b63296 100644
--- a/openai/openai.go
+++ b/openai/openai.go
@@ -351,7 +351,6 @@ func fromChatRequest(r ChatCompletionRequest) (*api.ChatRequest, error) {
 		case string:
 			messages = append(messages, api.Message{Role: msg.Role, Content: content})
 		case []any:
-			message := api.Message{Role: msg.Role}
 			for _, c := range content {
 				data, ok := c.(map[string]any)
 				if !ok {
@@ -363,7 +362,7 @@ func fromChatRequest(r ChatCompletionRequest) (*api.ChatRequest, error) {
 					if !ok {
 						return nil, fmt.Errorf("invalid message format")
 					}
-					message.Content = text
+					messages = append(messages, api.Message{Role: msg.Role, Content: text})
 				case "image_url":
 					var url string
 					if urlMap, ok := data["image_url"].(map[string]any); ok {
@@ -395,12 +394,12 @@ func fromChatRequest(r ChatCompletionRequest) (*api.ChatRequest, error) {
 					if err != nil {
 						return nil, fmt.Errorf("invalid message format")
 					}
-					message.Images = append(message.Images, img)
+
+					messages = append(messages, api.Message{Role: msg.Role, Images: []api.ImageData{img}})
 				default:
 					return nil, fmt.Errorf("invalid message format")
 				}
 			}
-			messages = append(messages, message)
 		default:
 			if msg.ToolCalls == nil {
 				return nil, fmt.Errorf("invalid message content type: %T", content)
diff --git a/openai/openai_test.go b/openai/openai_test.go
index 046ee69c..ad056e6d 100644
--- a/openai/openai_test.go
+++ b/openai/openai_test.go
@@ -161,8 +161,12 @@ func TestMiddlewareRequests(t *testing.T) {
 
 				img, _ := base64.StdEncoding.DecodeString(imageURL[len(prefix):])
 
-				if !bytes.Equal(chatReq.Messages[0].Images[0], img) {
-					t.Fatalf("expected image encoding, got %s", chatReq.Messages[0].Images[0])
+				if chatReq.Messages[1].Role != "user" {
+					t.Fatalf("expected 'user', got %s", chatReq.Messages[1].Role)
+				}
+
+				if !bytes.Equal(chatReq.Messages[1].Images[0], img) {
+					t.Fatalf("expected image encoding, got %s", chatReq.Messages[1].Images[0])
 				}
 			},
 		},

From c57317cbf0c865dd1fbe4852e1cce3cf4703b7ee Mon Sep 17 00:00:00 2001
From: royjhan <65097070+royjhan@users.noreply.github.com>
Date: Fri, 19 Jul 2024 11:37:12 -0700
Subject: [PATCH 22/29] OpenAI: Function Based Testing (#5752)

* distinguish error forwarding

* more coverage

* rm comment
---
 openai/openai.go      |   1 +
 openai/openai_test.go | 461 +++++++++++++++++++++++++-----------------
 2 files changed, 279 insertions(+), 183 deletions(-)

diff --git a/openai/openai.go b/openai/openai.go
index 93b63296..de6f4bd5 100644
--- a/openai/openai.go
+++ b/openai/openai.go
@@ -877,6 +877,7 @@ func ChatMiddleware() gin.HandlerFunc {
 		chatReq, err := fromChatRequest(req)
 		if err != nil {
 			c.AbortWithStatusJSON(http.StatusBadRequest, NewError(http.StatusBadRequest, err.Error()))
+			return
 		}
 
 		if err := json.NewEncoder(&b).Encode(chatReq); err != nil {
diff --git a/openai/openai_test.go b/openai/openai_test.go
index ad056e6d..f978d46c 100644
--- a/openai/openai_test.go
+++ b/openai/openai_test.go
@@ -20,113 +20,59 @@ const prefix = `data:image/jpeg;base64,`
 const image = `iVBORw0KGgoAAAANSUhEUgAAAAEAAAABCAQAAAC1HAwCAAAAC0lEQVR42mNk+A8AAQUBAScY42YAAAAASUVORK5CYII=`
 const imageURL = prefix + image
 
-func TestMiddlewareRequests(t *testing.T) {
+func prepareRequest(req *http.Request, body any) {
+	bodyBytes, _ := json.Marshal(body)
+	req.Body = io.NopCloser(bytes.NewReader(bodyBytes))
+	req.Header.Set("Content-Type", "application/json")
+}
+
+func captureRequestMiddleware(capturedRequest any) gin.HandlerFunc {
+	return func(c *gin.Context) {
+		bodyBytes, _ := io.ReadAll(c.Request.Body)
+		c.Request.Body = io.NopCloser(bytes.NewReader(bodyBytes))
+		err := json.Unmarshal(bodyBytes, capturedRequest)
+		if err != nil {
+			c.AbortWithStatusJSON(http.StatusInternalServerError, "failed to unmarshal request")
+		}
+		c.Next()
+	}
+}
+
+func TestChatMiddleware(t *testing.T) {
 	type testCase struct {
 		Name     string
-		Method   string
-		Path     string
-		Handler  func() gin.HandlerFunc
 		Setup    func(t *testing.T, req *http.Request)
-		Expected func(t *testing.T, req *http.Request)
+		Expected func(t *testing.T, req *api.ChatRequest, resp *httptest.ResponseRecorder)
 	}
 
-	var capturedRequest *http.Request
-
-	captureRequestMiddleware := func() gin.HandlerFunc {
-		return func(c *gin.Context) {
-			bodyBytes, _ := io.ReadAll(c.Request.Body)
-			c.Request.Body = io.NopCloser(bytes.NewReader(bodyBytes))
-			capturedRequest = c.Request
-			c.Next()
-		}
-	}
+	var capturedRequest *api.ChatRequest
 
 	testCases := []testCase{
 		{
-			Name:    "chat handler",
-			Method:  http.MethodPost,
-			Path:    "/api/chat",
-			Handler: ChatMiddleware,
+			Name: "chat handler",
 			Setup: func(t *testing.T, req *http.Request) {
 				body := ChatCompletionRequest{
 					Model:    "test-model",
 					Messages: []Message{{Role: "user", Content: "Hello"}},
 				}
-
-				bodyBytes, _ := json.Marshal(body)
-
-				req.Body = io.NopCloser(bytes.NewReader(bodyBytes))
-				req.Header.Set("Content-Type", "application/json")
+				prepareRequest(req, body)
 			},
-			Expected: func(t *testing.T, req *http.Request) {
-				var chatReq api.ChatRequest
-				if err := json.NewDecoder(req.Body).Decode(&chatReq); err != nil {
-					t.Fatal(err)
+			Expected: func(t *testing.T, req *api.ChatRequest, resp *httptest.ResponseRecorder) {
+				if resp.Code != http.StatusOK {
+					t.Fatalf("expected 200, got %d", resp.Code)
 				}
 
-				if chatReq.Messages[0].Role != "user" {
-					t.Fatalf("expected 'user', got %s", chatReq.Messages[0].Role)
+				if req.Messages[0].Role != "user" {
+					t.Fatalf("expected 'user', got %s", req.Messages[0].Role)
 				}
 
-				if chatReq.Messages[0].Content != "Hello" {
-					t.Fatalf("expected 'Hello', got %s", chatReq.Messages[0].Content)
+				if req.Messages[0].Content != "Hello" {
+					t.Fatalf("expected 'Hello', got %s", req.Messages[0].Content)
 				}
 			},
 		},
 		{
-			Name:    "completions handler",
-			Method:  http.MethodPost,
-			Path:    "/api/generate",
-			Handler: CompletionsMiddleware,
-			Setup: func(t *testing.T, req *http.Request) {
-				temp := float32(0.8)
-				body := CompletionRequest{
-					Model:       "test-model",
-					Prompt:      "Hello",
-					Temperature: &temp,
-					Stop:        []string{"\n", "stop"},
-					Suffix:      "suffix",
-				}
-
-				bodyBytes, _ := json.Marshal(body)
-
-				req.Body = io.NopCloser(bytes.NewReader(bodyBytes))
-				req.Header.Set("Content-Type", "application/json")
-			},
-			Expected: func(t *testing.T, req *http.Request) {
-				var genReq api.GenerateRequest
-				if err := json.NewDecoder(req.Body).Decode(&genReq); err != nil {
-					t.Fatal(err)
-				}
-
-				if genReq.Prompt != "Hello" {
-					t.Fatalf("expected 'Hello', got %s", genReq.Prompt)
-				}
-
-				if genReq.Options["temperature"] != 1.6 {
-					t.Fatalf("expected 1.6, got %f", genReq.Options["temperature"])
-				}
-
-				stopTokens, ok := genReq.Options["stop"].([]any)
-
-				if !ok {
-					t.Fatalf("expected stop tokens to be a list")
-				}
-
-				if stopTokens[0] != "\n" || stopTokens[1] != "stop" {
-					t.Fatalf("expected ['\\n', 'stop'], got %v", stopTokens)
-				}
-
-				if genReq.Suffix != "suffix" {
-					t.Fatalf("expected 'suffix', got %s", genReq.Suffix)
-				}
-			},
-		},
-		{
-			Name:    "chat handler with image content",
-			Method:  http.MethodPost,
-			Path:    "/api/chat",
-			Handler: ChatMiddleware,
+			Name: "chat handler with image content",
 			Setup: func(t *testing.T, req *http.Request) {
 				body := ChatCompletionRequest{
 					Model: "test-model",
@@ -139,91 +85,254 @@ func TestMiddlewareRequests(t *testing.T) {
 						},
 					},
 				}
-
-				bodyBytes, _ := json.Marshal(body)
-
-				req.Body = io.NopCloser(bytes.NewReader(bodyBytes))
-				req.Header.Set("Content-Type", "application/json")
+				prepareRequest(req, body)
 			},
-			Expected: func(t *testing.T, req *http.Request) {
-				var chatReq api.ChatRequest
-				if err := json.NewDecoder(req.Body).Decode(&chatReq); err != nil {
-					t.Fatal(err)
+			Expected: func(t *testing.T, req *api.ChatRequest, resp *httptest.ResponseRecorder) {
+				if resp.Code != http.StatusOK {
+					t.Fatalf("expected 200, got %d", resp.Code)
 				}
 
-				if chatReq.Messages[0].Role != "user" {
-					t.Fatalf("expected 'user', got %s", chatReq.Messages[0].Role)
+				if req.Messages[0].Role != "user" {
+					t.Fatalf("expected 'user', got %s", req.Messages[0].Role)
 				}
 
-				if chatReq.Messages[0].Content != "Hello" {
-					t.Fatalf("expected 'Hello', got %s", chatReq.Messages[0].Content)
+				if req.Messages[0].Content != "Hello" {
+					t.Fatalf("expected 'Hello', got %s", req.Messages[0].Content)
 				}
 
 				img, _ := base64.StdEncoding.DecodeString(imageURL[len(prefix):])
 
-				if chatReq.Messages[1].Role != "user" {
-					t.Fatalf("expected 'user', got %s", chatReq.Messages[1].Role)
+				if req.Messages[1].Role != "user" {
+					t.Fatalf("expected 'user', got %s", req.Messages[1].Role)
 				}
 
-				if !bytes.Equal(chatReq.Messages[1].Images[0], img) {
-					t.Fatalf("expected image encoding, got %s", chatReq.Messages[1].Images[0])
+				if !bytes.Equal(req.Messages[1].Images[0], img) {
+					t.Fatalf("expected image encoding, got %s", req.Messages[1].Images[0])
 				}
 			},
 		},
 		{
-			Name:    "embed handler single input",
-			Method:  http.MethodPost,
-			Path:    "/api/embed",
-			Handler: EmbeddingsMiddleware,
+			Name: "chat handler with tools",
+			Setup: func(t *testing.T, req *http.Request) {
+				body := ChatCompletionRequest{
+					Model: "test-model",
+					Messages: []Message{
+						{Role: "user", Content: "What's the weather like in Paris Today?"},
+						{Role: "assistant", ToolCalls: []ToolCall{{
+							ID:   "id",
+							Type: "function",
+							Function: struct {
+								Name      string `json:"name"`
+								Arguments string `json:"arguments"`
+							}{
+								Name:      "get_current_weather",
+								Arguments: "{\"location\": \"Paris, France\", \"format\": \"celsius\"}",
+							},
+						}}},
+					},
+				}
+				prepareRequest(req, body)
+			},
+			Expected: func(t *testing.T, req *api.ChatRequest, resp *httptest.ResponseRecorder) {
+				if resp.Code != 200 {
+					t.Fatalf("expected 200, got %d", resp.Code)
+				}
+
+				if req.Messages[0].Content != "What's the weather like in Paris Today?" {
+					t.Fatalf("expected What's the weather like in Paris Today?, got %s", req.Messages[0].Content)
+				}
+
+				if req.Messages[1].ToolCalls[0].Function.Arguments["location"] != "Paris, France" {
+					t.Fatalf("expected 'Paris, France', got %v", req.Messages[1].ToolCalls[0].Function.Arguments["location"])
+				}
+
+				if req.Messages[1].ToolCalls[0].Function.Arguments["format"] != "celsius" {
+					t.Fatalf("expected celsius, got %v", req.Messages[1].ToolCalls[0].Function.Arguments["format"])
+				}
+			},
+		},
+		{
+			Name: "chat handler error forwarding",
+			Setup: func(t *testing.T, req *http.Request) {
+				body := ChatCompletionRequest{
+					Model:    "test-model",
+					Messages: []Message{{Role: "user", Content: 2}},
+				}
+				prepareRequest(req, body)
+			},
+			Expected: func(t *testing.T, req *api.ChatRequest, resp *httptest.ResponseRecorder) {
+				if resp.Code != http.StatusBadRequest {
+					t.Fatalf("expected 400, got %d", resp.Code)
+				}
+
+				if !strings.Contains(resp.Body.String(), "invalid message content type") {
+					t.Fatalf("error was not forwarded")
+				}
+			},
+		},
+	}
+
+	endpoint := func(c *gin.Context) {
+		c.Status(http.StatusOK)
+	}
+
+	gin.SetMode(gin.TestMode)
+	router := gin.New()
+	router.Use(ChatMiddleware(), captureRequestMiddleware(&capturedRequest))
+	router.Handle(http.MethodPost, "/api/chat", endpoint)
+
+	for _, tc := range testCases {
+		t.Run(tc.Name, func(t *testing.T) {
+			req, _ := http.NewRequest(http.MethodPost, "/api/chat", nil)
+
+			tc.Setup(t, req)
+
+			resp := httptest.NewRecorder()
+			router.ServeHTTP(resp, req)
+
+			tc.Expected(t, capturedRequest, resp)
+
+			capturedRequest = nil
+		})
+	}
+}
+
+func TestCompletionsMiddleware(t *testing.T) {
+	type testCase struct {
+		Name     string
+		Setup    func(t *testing.T, req *http.Request)
+		Expected func(t *testing.T, req *api.GenerateRequest, resp *httptest.ResponseRecorder)
+	}
+
+	var capturedRequest *api.GenerateRequest
+
+	testCases := []testCase{
+		{
+			Name: "completions handler",
+			Setup: func(t *testing.T, req *http.Request) {
+				temp := float32(0.8)
+				body := CompletionRequest{
+					Model:       "test-model",
+					Prompt:      "Hello",
+					Temperature: &temp,
+					Stop:        []string{"\n", "stop"},
+					Suffix:      "suffix",
+				}
+				prepareRequest(req, body)
+			},
+			Expected: func(t *testing.T, req *api.GenerateRequest, resp *httptest.ResponseRecorder) {
+				if req.Prompt != "Hello" {
+					t.Fatalf("expected 'Hello', got %s", req.Prompt)
+				}
+
+				if req.Options["temperature"] != 1.6 {
+					t.Fatalf("expected 1.6, got %f", req.Options["temperature"])
+				}
+
+				stopTokens, ok := req.Options["stop"].([]any)
+
+				if !ok {
+					t.Fatalf("expected stop tokens to be a list")
+				}
+
+				if stopTokens[0] != "\n" || stopTokens[1] != "stop" {
+					t.Fatalf("expected ['\\n', 'stop'], got %v", stopTokens)
+				}
+
+				if req.Suffix != "suffix" {
+					t.Fatalf("expected 'suffix', got %s", req.Suffix)
+				}
+			},
+		},
+		{
+			Name: "completions handler error forwarding",
+			Setup: func(t *testing.T, req *http.Request) {
+				body := CompletionRequest{
+					Model:       "test-model",
+					Prompt:      "Hello",
+					Temperature: nil,
+					Stop:        []int{1, 2},
+					Suffix:      "suffix",
+				}
+				prepareRequest(req, body)
+			},
+			Expected: func(t *testing.T, req *api.GenerateRequest, resp *httptest.ResponseRecorder) {
+				if resp.Code != http.StatusBadRequest {
+					t.Fatalf("expected 400, got %d", resp.Code)
+				}
+
+				if !strings.Contains(resp.Body.String(), "invalid type for 'stop' field") {
+					t.Fatalf("error was not forwarded")
+				}
+			},
+		},
+	}
+
+	endpoint := func(c *gin.Context) {
+		c.Status(http.StatusOK)
+	}
+
+	gin.SetMode(gin.TestMode)
+	router := gin.New()
+	router.Use(CompletionsMiddleware(), captureRequestMiddleware(&capturedRequest))
+	router.Handle(http.MethodPost, "/api/generate", endpoint)
+
+	for _, tc := range testCases {
+		t.Run(tc.Name, func(t *testing.T) {
+			req, _ := http.NewRequest(http.MethodPost, "/api/generate", nil)
+
+			tc.Setup(t, req)
+
+			resp := httptest.NewRecorder()
+			router.ServeHTTP(resp, req)
+
+			tc.Expected(t, capturedRequest, resp)
+
+			capturedRequest = nil
+		})
+	}
+}
+
+func TestEmbeddingsMiddleware(t *testing.T) {
+	type testCase struct {
+		Name     string
+		Setup    func(t *testing.T, req *http.Request)
+		Expected func(t *testing.T, req *api.EmbedRequest, resp *httptest.ResponseRecorder)
+	}
+
+	var capturedRequest *api.EmbedRequest
+
+	testCases := []testCase{
+		{
+			Name: "embed handler single input",
 			Setup: func(t *testing.T, req *http.Request) {
 				body := EmbedRequest{
 					Input: "Hello",
 					Model: "test-model",
 				}
-
-				bodyBytes, _ := json.Marshal(body)
-
-				req.Body = io.NopCloser(bytes.NewReader(bodyBytes))
-				req.Header.Set("Content-Type", "application/json")
+				prepareRequest(req, body)
 			},
-			Expected: func(t *testing.T, req *http.Request) {
-				var embedReq api.EmbedRequest
-				if err := json.NewDecoder(req.Body).Decode(&embedReq); err != nil {
-					t.Fatal(err)
+			Expected: func(t *testing.T, req *api.EmbedRequest, resp *httptest.ResponseRecorder) {
+				if req.Input != "Hello" {
+					t.Fatalf("expected 'Hello', got %s", req.Input)
 				}
 
-				if embedReq.Input != "Hello" {
-					t.Fatalf("expected 'Hello', got %s", embedReq.Input)
-				}
-
-				if embedReq.Model != "test-model" {
-					t.Fatalf("expected 'test-model', got %s", embedReq.Model)
+				if req.Model != "test-model" {
+					t.Fatalf("expected 'test-model', got %s", req.Model)
 				}
 			},
 		},
 		{
-			Name:    "embed handler batch input",
-			Method:  http.MethodPost,
-			Path:    "/api/embed",
-			Handler: EmbeddingsMiddleware,
+			Name: "embed handler batch input",
 			Setup: func(t *testing.T, req *http.Request) {
 				body := EmbedRequest{
 					Input: []string{"Hello", "World"},
 					Model: "test-model",
 				}
-
-				bodyBytes, _ := json.Marshal(body)
-
-				req.Body = io.NopCloser(bytes.NewReader(bodyBytes))
-				req.Header.Set("Content-Type", "application/json")
+				prepareRequest(req, body)
 			},
-			Expected: func(t *testing.T, req *http.Request) {
-				var embedReq api.EmbedRequest
-				if err := json.NewDecoder(req.Body).Decode(&embedReq); err != nil {
-					t.Fatal(err)
-				}
-
-				input, ok := embedReq.Input.([]any)
+			Expected: func(t *testing.T, req *api.EmbedRequest, resp *httptest.ResponseRecorder) {
+				input, ok := req.Input.([]any)
 
 				if !ok {
 					t.Fatalf("expected input to be a list")
@@ -237,36 +346,52 @@ func TestMiddlewareRequests(t *testing.T) {
 					t.Fatalf("expected 'World', got %s", input[1])
 				}
 
-				if embedReq.Model != "test-model" {
-					t.Fatalf("expected 'test-model', got %s", embedReq.Model)
+				if req.Model != "test-model" {
+					t.Fatalf("expected 'test-model', got %s", req.Model)
+				}
+			},
+		},
+		{
+			Name: "embed handler error forwarding",
+			Setup: func(t *testing.T, req *http.Request) {
+				body := EmbedRequest{
+					Model: "test-model",
+				}
+				prepareRequest(req, body)
+			},
+			Expected: func(t *testing.T, req *api.EmbedRequest, resp *httptest.ResponseRecorder) {
+				if resp.Code != http.StatusBadRequest {
+					t.Fatalf("expected 400, got %d", resp.Code)
+				}
+
+				if !strings.Contains(resp.Body.String(), "invalid input") {
+					t.Fatalf("error was not forwarded")
 				}
 			},
 		},
 	}
 
-	gin.SetMode(gin.TestMode)
-	router := gin.New()
-
 	endpoint := func(c *gin.Context) {
 		c.Status(http.StatusOK)
 	}
 
+	gin.SetMode(gin.TestMode)
+	router := gin.New()
+	router.Use(EmbeddingsMiddleware(), captureRequestMiddleware(&capturedRequest))
+	router.Handle(http.MethodPost, "/api/embed", endpoint)
+
 	for _, tc := range testCases {
 		t.Run(tc.Name, func(t *testing.T) {
-			router = gin.New()
-			router.Use(captureRequestMiddleware())
-			router.Use(tc.Handler())
-			router.Handle(tc.Method, tc.Path, endpoint)
-			req, _ := http.NewRequest(tc.Method, tc.Path, nil)
+			req, _ := http.NewRequest(http.MethodPost, "/api/embed", nil)
 
-			if tc.Setup != nil {
-				tc.Setup(t, req)
-			}
+			tc.Setup(t, req)
 
 			resp := httptest.NewRecorder()
 			router.ServeHTTP(resp, req)
 
-			tc.Expected(t, capturedRequest)
+			tc.Expected(t, capturedRequest, resp)
+
+			capturedRequest = nil
 		})
 	}
 }
@@ -284,36 +409,6 @@ func TestMiddlewareResponses(t *testing.T) {
 	}
 
 	testCases := []testCase{
-		{
-			Name:     "completions handler error forwarding",
-			Method:   http.MethodPost,
-			Path:     "/api/generate",
-			TestPath: "/api/generate",
-			Handler:  CompletionsMiddleware,
-			Endpoint: func(c *gin.Context) {
-				c.JSON(http.StatusBadRequest, gin.H{"error": "invalid request"})
-			},
-			Setup: func(t *testing.T, req *http.Request) {
-				body := CompletionRequest{
-					Model:  "test-model",
-					Prompt: "Hello",
-				}
-
-				bodyBytes, _ := json.Marshal(body)
-
-				req.Body = io.NopCloser(bytes.NewReader(bodyBytes))
-				req.Header.Set("Content-Type", "application/json")
-			},
-			Expected: func(t *testing.T, resp *httptest.ResponseRecorder) {
-				if resp.Code != http.StatusBadRequest {
-					t.Fatalf("expected 400, got %d", resp.Code)
-				}
-
-				if !strings.Contains(resp.Body.String(), `"invalid request"`) {
-					t.Fatalf("error was not forwarded")
-				}
-			},
-		},
 		{
 			Name:     "list handler",
 			Method:   http.MethodGet,
@@ -330,8 +425,6 @@ func TestMiddlewareResponses(t *testing.T) {
 				})
 			},
 			Expected: func(t *testing.T, resp *httptest.ResponseRecorder) {
-				assert.Equal(t, http.StatusOK, resp.Code)
-
 				var listResp ListCompletion
 				if err := json.NewDecoder(resp.Body).Decode(&listResp); err != nil {
 					t.Fatal(err)
@@ -395,6 +488,8 @@ func TestMiddlewareResponses(t *testing.T) {
 			resp := httptest.NewRecorder()
 			router.ServeHTTP(resp, req)
 
+			assert.Equal(t, http.StatusOK, resp.Code)
+
 			tc.Expected(t, resp)
 		})
 	}

From e8b954c646544d40d84be50aae9cd909fcbd8f41 Mon Sep 17 00:00:00 2001
From: Josh <76125168+joshyan1@users.noreply.github.com>
Date: Fri, 19 Jul 2024 15:24:29 -0700
Subject: [PATCH 23/29] server: validate template (#5734)

add template validation to modelfile
---
 server/images.go             |  6 ++++++
 server/routes.go             | 14 +++++++++++---
 server/routes_create_test.go | 36 ++++++++++++++++++++++++++++++++++++
 3 files changed, 53 insertions(+), 3 deletions(-)

diff --git a/server/images.go b/server/images.go
index 5e4e8858..574dec19 100644
--- a/server/images.go
+++ b/server/images.go
@@ -492,6 +492,12 @@ func CreateModel(ctx context.Context, name model.Name, modelFileDir, quantizatio
 				layers = append(layers, baseLayer.Layer)
 			}
 		case "license", "template", "system":
+			if c.Name == "template" {
+				if _, err := template.Parse(c.Args); err != nil {
+					return fmt.Errorf("%w: %s", errBadTemplate, err)
+				}
+			}
+
 			if c.Name != "license" {
 				// replace
 				layers = slices.DeleteFunc(layers, func(layer *Layer) bool {
diff --git a/server/routes.go b/server/routes.go
index c33b7195..85db7924 100644
--- a/server/routes.go
+++ b/server/routes.go
@@ -56,6 +56,7 @@ func init() {
 }
 
 var errRequired = errors.New("is required")
+var errBadTemplate = errors.New("template error")
 
 func modelOptions(model *Model, requestOpts map[string]interface{}) (api.Options, error) {
 	opts := api.DefaultOptions()
@@ -609,8 +610,11 @@ func (s *Server) CreateModelHandler(c *gin.Context) {
 
 		quantization := cmp.Or(r.Quantize, r.Quantization)
 		if err := CreateModel(ctx, name, filepath.Dir(r.Path), strings.ToUpper(quantization), f, fn); err != nil {
+			if errors.Is(err, errBadTemplate) {
+			  ch <- gin.H{"error": err.Error(), "status": http.StatusBadRequest}
+			}
 			ch <- gin.H{"error": err.Error()}
-		}
+		  }
 	}()
 
 	if r.Stream != nil && !*r.Stream {
@@ -1196,11 +1200,15 @@ func waitForStream(c *gin.Context, ch chan interface{}) {
 				return
 			}
 		case gin.H:
+			status, ok := r["status"].(int)
+			if !ok {
+				status = http.StatusInternalServerError
+			}
 			if errorMsg, ok := r["error"].(string); ok {
-				c.JSON(http.StatusInternalServerError, gin.H{"error": errorMsg})
+				c.JSON(status, gin.H{"error": errorMsg})
 				return
 			} else {
-				c.JSON(http.StatusInternalServerError, gin.H{"error": "unexpected error format in progress response"})
+				c.JSON(status, gin.H{"error": "unexpected error format in progress response"})
 				return
 			}
 		default:
diff --git a/server/routes_create_test.go b/server/routes_create_test.go
index cb548ebd..3234ea5e 100644
--- a/server/routes_create_test.go
+++ b/server/routes_create_test.go
@@ -491,6 +491,42 @@ func TestCreateTemplateSystem(t *testing.T) {
 	if string(system) != "Say bye!" {
 		t.Errorf("expected \"Say bye!\", actual %s", system)
 	}
+
+	t.Run("incomplete template", func(t *testing.T) {
+		w := createRequest(t, s.CreateModelHandler, api.CreateRequest{
+			Name:      "test",
+			Modelfile: fmt.Sprintf("FROM %s\nTEMPLATE {{ .Prompt", createBinFile(t, nil, nil)),
+			Stream:    &stream,
+		})
+	
+		if w.Code != http.StatusBadRequest {
+			t.Fatalf("expected status code 400, actual %d", w.Code)
+		}
+	})
+
+	t.Run("template with unclosed if", func(t *testing.T) {
+		w := createRequest(t, s.CreateModelHandler, api.CreateRequest{
+			Name:      "test",
+			Modelfile: fmt.Sprintf("FROM %s\nTEMPLATE {{ if .Prompt }}", createBinFile(t, nil, nil)),
+			Stream:    &stream,
+		})
+	
+		if w.Code != http.StatusBadRequest {
+			t.Fatalf("expected status code 400, actual %d", w.Code)
+		}
+	})
+
+	t.Run("template with undefined function", func(t *testing.T) {
+		w := createRequest(t, s.CreateModelHandler, api.CreateRequest{
+			Name:      "test",
+			Modelfile: fmt.Sprintf("FROM %s\nTEMPLATE {{  Prompt }}", createBinFile(t, nil, nil)),
+			Stream:    &stream,
+		})
+	
+		if w.Code != http.StatusBadRequest {
+			t.Fatalf("expected status code 400, actual %d", w.Code)
+		}
+	})
 }
 
 func TestCreateLicenses(t *testing.T) {

From 69a2d4ccffa7532680bed245fad77bb166ec0bb9 Mon Sep 17 00:00:00 2001
From: Jeffrey Morgan 
Date: Fri, 19 Jul 2024 19:11:25 -0700
Subject: [PATCH 24/29] Fix generate test flakyness (#5804)

---
 server/routes_generate_test.go | 6 ++++--
 1 file changed, 4 insertions(+), 2 deletions(-)

diff --git a/server/routes_generate_test.go b/server/routes_generate_test.go
index c914b300..5c0caff1 100644
--- a/server/routes_generate_test.go
+++ b/server/routes_generate_test.go
@@ -73,8 +73,8 @@ func TestGenerateChat(t *testing.T) {
 			getCpuFn:      gpu.GetCPUInfo,
 			reschedDelay:  250 * time.Millisecond,
 			loadFn: func(req *LlmRequest, ggml *llm.GGML, gpus gpu.GpuInfoList, numParallel int) {
-				// add 10ms delay to simulate loading
-				time.Sleep(10 * time.Millisecond)
+				// add small delay to simulate loading
+				time.Sleep(time.Millisecond)
 				req.successCh <- &runnerRef{
 					llama: &mock,
 				}
@@ -371,6 +371,8 @@ func TestGenerate(t *testing.T) {
 			getCpuFn:      gpu.GetCPUInfo,
 			reschedDelay:  250 * time.Millisecond,
 			loadFn: func(req *LlmRequest, ggml *llm.GGML, gpus gpu.GpuInfoList, numParallel int) {
+				// add small delay to simulate loading
+				time.Sleep(time.Millisecond)
 				req.successCh <- &runnerRef{
 					llama: &mock,
 				}

From 20090f3172c4848584060cbd51e7c9b14c3630cb Mon Sep 17 00:00:00 2001
From: Jeffrey Morgan 
Date: Fri, 19 Jul 2024 20:19:26 -0700
Subject: [PATCH 25/29] preserve last assistant message (#5802)

---
 template/template.go      |  3 ++-
 template/template_test.go | 20 ++++++++++++++++++++
 2 files changed, 22 insertions(+), 1 deletion(-)

diff --git a/template/template.go b/template/template.go
index b5bfb16c..f7453791 100644
--- a/template/template.go
+++ b/template/template.go
@@ -264,6 +264,7 @@ func (t *Template) Execute(w io.Writer, v Values) error {
 	nodes := deleteNode(t.Template.Root.Copy(), func(n parse.Node) bool {
 		if field, ok := n.(*parse.FieldNode); ok && slices.Contains(field.Ident, "Response") {
 			cut = true
+			return false
 		}
 
 		return cut
@@ -273,7 +274,7 @@ func (t *Template) Execute(w io.Writer, v Values) error {
 	if err := template.Must(template.New("").AddParseTree("", &tree)).Execute(&b, map[string]any{
 		"System":   system,
 		"Prompt":   prompt,
-		"Response": "",
+		"Response": response,
 	}); err != nil {
 		return err
 	}
diff --git a/template/template_test.go b/template/template_test.go
index ae0db80b..b46e1df5 100644
--- a/template/template_test.go
+++ b/template/template_test.go
@@ -260,6 +260,26 @@ func TestExecuteWithMessages(t *testing.T) {
 
 Hello friend![/INST] Hello human![INST] What is your name?[/INST] `,
 		},
+		{
+			"mistral assistant",
+			[]template{
+				{"no response", `[INST] {{ .Prompt }}[/INST] `},
+				{"response", `[INST] {{ .Prompt }}[/INST] {{ .Response }}`},
+				{"messages", `
+{{- range $i, $m := .Messages }}
+{{- if eq .Role "user" }}[INST] {{ .Content }}[/INST] {{ else if eq .Role "assistant" }}{{ .Content }}{{ end }}
+{{- end }}`},
+			},
+			Values{
+				Messages: []api.Message{
+					{Role: "user", Content: "Hello friend!"},
+					{Role: "assistant", Content: "Hello human!"},
+					{Role: "user", Content: "What is your name?"},
+					{Role: "assistant", Content: "My name is Ollama and I"},
+				},
+			},
+			`[INST] Hello friend![/INST] Hello human![INST] What is your name?[/INST] My name is Ollama and I`,
+		},
 		{
 			"chatml",
 			[]template{

From 1475eab95f5a4ddc5b1bb169df0d89e71732dfa0 Mon Sep 17 00:00:00 2001
From: Jeffrey Morgan 
Date: Sat, 20 Jul 2024 13:41:21 -0400
Subject: [PATCH 26/29] add patch for tekken (#5807)

---
 llm/patches/10-tekken.diff | 43 ++++++++++++++++++++++++++++++++++++++
 1 file changed, 43 insertions(+)
 create mode 100644 llm/patches/10-tekken.diff

diff --git a/llm/patches/10-tekken.diff b/llm/patches/10-tekken.diff
new file mode 100644
index 00000000..56a583e0
--- /dev/null
+++ b/llm/patches/10-tekken.diff
@@ -0,0 +1,43 @@
+diff --git a/include/llama.h b/include/llama.h
+index bb4b05ba..a92174e0 100644
+--- a/include/llama.h
++++ b/include/llama.h
+@@ -92,6 +92,7 @@ extern "C" {
+         LLAMA_VOCAB_PRE_TYPE_CHATGLM4       = 17,
+         LLAMA_VOCAB_PRE_TYPE_VIKING         = 18,
+         LLAMA_VOCAB_PRE_TYPE_JAIS           = 19,
++        LLAMA_VOCAB_PRE_TYPE_TEKKEN         = 20,
+     };
+ 
+     // note: these values should be synchronized with ggml_rope
+diff --git a/src/llama.cpp b/src/llama.cpp
+index 18364976..435b6fe5 100644
+--- a/src/llama.cpp
++++ b/src/llama.cpp
+@@ -5429,6 +5429,12 @@ static void llm_load_vocab(
+             } else if (
+                 tokenizer_pre == "jais") {
+                 vocab.type_pre = LLAMA_VOCAB_PRE_TYPE_JAIS;
++            } else if (
++                tokenizer_pre == "tekken") {
++                vocab.type_pre = LLAMA_VOCAB_PRE_TYPE_TEKKEN;
++                vocab.tokenizer_clean_spaces = false;
++                vocab.tokenizer_ignore_merges = true;
++                vocab.tokenizer_add_bos = true;
+             } else {
+                 LLAMA_LOG_WARN("%s: missing or unrecognized pre-tokenizer type, using: 'default'\n", __func__);
+                 vocab.type_pre = LLAMA_VOCAB_PRE_TYPE_DEFAULT;
+@@ -15448,6 +15454,13 @@ struct llm_tokenizer_bpe {
+                     " ?[^(\\s|.,!?…。,、।۔،)]+",
+                 };
+                 break;
++            case LLAMA_VOCAB_PRE_TYPE_TEKKEN:
++                    // original regex from tokenizer.json
++                    // "[^\\r\\n\\p{L}\\p{N}]?[\\p{Lu}\\p{Lt}\\p{Lm}\\p{Lo}\\p{M}]*[\\p{Ll}\\p{Lm}\\p{Lo}\\p{M}]+|[^\\r\\n\\p{L}\\p{N}]?[\\p{Lu}\\p{Lt}\\p{Lm}\\p{Lo}\\p{M}]+[\\p{Ll}\\p{Lm}\\p{Lo}\\p{M}]*|\\p{N}| ?[^\\s\\p{L}\\p{N}]+[\\r\\n/]*|\\s*[\\r\\n]+|\\s+(?!\\S)|\\s+"
++                regex_exprs = {
++                    "[^\\r\\n\\p{L}\\p{N}]?((?=[\\p{L}])([^a-z]))*((?=[\\p{L}])([^A-Z]))+|[^\\r\\n\\p{L}\\p{N}]?((?=[\\p{L}])([^a-z]))+((?=[\\p{L}])([^A-Z]))*|\\p{N}| ?[^\\s\\p{L}\\p{N}]+[\\r\\n/]*|\\s*[\\r\\n]+|\\s+(?!\\S)|\\s+",
++                };
++                break;
+             default:
+                 // default regex for BPE tokenization pre-processing
+                 regex_exprs = {

From 283948c83b5cbf74f6cf86dce4434238e64d6e1c Mon Sep 17 00:00:00 2001
From: Daniel Hiltgen 
Date: Fri, 19 Jul 2024 15:07:26 -0700
Subject: [PATCH 27/29] Adjust windows ROCm discovery

The v5 hip library returns unsupported GPUs which wont enumerate at
inference time in the runner so this makes sure we align discovery.  The
gfx906 cards are no longer supported so we shouldn't compile with that
GPU type as it wont enumerate at runtime.
---
 docs/gpu.md                  | 15 +++++++++++++--
 gpu/amd_hip_windows.go       |  5 +++--
 gpu/amd_windows.go           |  3 ++-
 llm/generate/gen_windows.ps1 |  2 +-
 llm/server.go                |  2 ++
 5 files changed, 21 insertions(+), 6 deletions(-)

diff --git a/docs/gpu.md b/docs/gpu.md
index 80f276c3..e669ea32 100644
--- a/docs/gpu.md
+++ b/docs/gpu.md
@@ -46,13 +46,24 @@ sudo modprobe nvidia_uvm`
 
 ## AMD Radeon
 Ollama supports the following AMD GPUs:
+
+### Linux Support
 | Family         | Cards and accelerators                                                                                                               |
 | -------------- | ---------------------------------------------------------------------------------------------------------------------------------------------- |
 | AMD Radeon RX  | `7900 XTX` `7900 XT` `7900 GRE` `7800 XT` `7700 XT` `7600 XT` `7600` `6950 XT` `6900 XTX` `6900XT` `6800 XT` `6800` `Vega 64` `Vega 56`    |
 | AMD Radeon PRO | `W7900` `W7800` `W7700` `W7600` `W7500` `W6900X` `W6800X Duo` `W6800X` `W6800` `V620` `V420` `V340` `V320` `Vega II Duo` `Vega II` `VII` `SSG` |
 | AMD Instinct   | `MI300X` `MI300A` `MI300` `MI250X` `MI250` `MI210` `MI200` `MI100` `MI60` `MI50`                                                               |
 
-### Overrides
+### Windows Support
+With ROCm v6.1, the following GPUs are supported on Windows.
+
+| Family         | Cards and accelerators                                                                                                               |
+| -------------- | ---------------------------------------------------------------------------------------------------------------------------------------------- |
+| AMD Radeon RX  | `7900 XTX` `7900 XT` `7900 GRE` `7800 XT` `7700 XT` `7600 XT` `7600` `6950 XT` `6900 XTX` `6900XT` `6800 XT` `6800`    |
+| AMD Radeon PRO | `W7900` `W7800` `W7700` `W7600` `W7500` `W6900X` `W6800X Duo` `W6800X` `W6800` `V620` |
+
+
+### Overrides on Linux
 Ollama leverages the AMD ROCm library, which does not support all AMD GPUs. In
 some cases you can force the system to try to use a similar LLVM target that is
 close.  For example The Radeon RX 5400 is `gfx1034` (also known as 10.3.4)
@@ -63,7 +74,7 @@ would set `HSA_OVERRIDE_GFX_VERSION="10.3.0"` as an environment variable for the
 server.  If you have an unsupported AMD GPU you can experiment using the list of
 supported types below.
 
-At this time, the known supported GPU types are the following LLVM Targets.
+At this time, the known supported GPU types on linux are the following LLVM Targets.
 This table shows some example GPUs that map to these LLVM targets:
 | **LLVM Target** | **An Example GPU** |
 |-----------------|---------------------|
diff --git a/gpu/amd_hip_windows.go b/gpu/amd_hip_windows.go
index 2586278c..98806234 100644
--- a/gpu/amd_hip_windows.go
+++ b/gpu/amd_hip_windows.go
@@ -33,9 +33,10 @@ type HipLib struct {
 }
 
 func NewHipLib() (*HipLib, error) {
-	h, err := windows.LoadLibrary("amdhip64.dll")
+	// At runtime we depend on v6, so discover GPUs with the same library for a consistent set of GPUs
+	h, err := windows.LoadLibrary("amdhip64_6.dll")
 	if err != nil {
-		return nil, fmt.Errorf("unable to load amdhip64.dll: %w", err)
+		return nil, fmt.Errorf("unable to load amdhip64_6.dll, please make sure to upgrade to the latest amd driver: %w", err)
 	}
 	hl := &HipLib{}
 	hl.dll = h
diff --git a/gpu/amd_windows.go b/gpu/amd_windows.go
index 425259d7..20aed447 100644
--- a/gpu/amd_windows.go
+++ b/gpu/amd_windows.go
@@ -92,7 +92,8 @@ func AMDGetGPUInfo() []RocmGPUInfo {
 			continue
 		}
 		if gfxOverride == "" {
-			if !slices.Contains[[]string, string](supported, gfx) {
+			// Strip off Target Features when comparing
+			if !slices.Contains[[]string, string](supported, strings.Split(gfx, ":")[0]) {
 				slog.Warn("amdgpu is not supported", "gpu", i, "gpu_type", gfx, "library", libDir, "supported_types", supported)
 				// TODO - consider discrete markdown just for ROCM troubleshooting?
 				slog.Warn("See https://github.com/ollama/ollama/blob/main/docs/troubleshooting.md for HSA_OVERRIDE_GFX_VERSION usage")
diff --git a/llm/generate/gen_windows.ps1 b/llm/generate/gen_windows.ps1
index beb964f9..d8bce92d 100644
--- a/llm/generate/gen_windows.ps1
+++ b/llm/generate/gen_windows.ps1
@@ -7,8 +7,8 @@ function amdGPUs {
         return $env:AMDGPU_TARGETS
     }
     # Current supported rocblas list from ROCm v6.1.2 on windows
+    # https://rocm.docs.amd.com/projects/install-on-windows/en/latest/reference/system-requirements.html#windows-supported-gpus
     $GPU_LIST = @(
-        "gfx906:xnack-"
         "gfx1030"
         "gfx1100"
         "gfx1101"
diff --git a/llm/server.go b/llm/server.go
index 36c0e0b5..ba7eab03 100644
--- a/llm/server.go
+++ b/llm/server.go
@@ -385,8 +385,10 @@ func NewLlamaServer(gpus gpu.GpuInfoList, model string, ggml *GGML, adapters, pr
 			filteredEnv := []string{}
 			for _, ev := range s.cmd.Env {
 				if strings.HasPrefix(ev, "CUDA_") ||
+					strings.HasPrefix(ev, "ROCR_") ||
 					strings.HasPrefix(ev, "ROCM_") ||
 					strings.HasPrefix(ev, "HIP_") ||
+					strings.HasPrefix(ev, "GPU_") ||
 					strings.HasPrefix(ev, "HSA_") ||
 					strings.HasPrefix(ev, "GGML_") ||
 					strings.HasPrefix(ev, "PATH=") ||

From 5534f2cc6a3f29022998950472741d16e7a66b40 Mon Sep 17 00:00:00 2001
From: Jeffrey Morgan 
Date: Sat, 20 Jul 2024 21:48:12 -0400
Subject: [PATCH 28/29] llm: consider `head_dim` in llama arch (#5817)

---
 llm/patches/11-embd_kv.diff | 19 +++++++++++++++++++
 1 file changed, 19 insertions(+)
 create mode 100644 llm/patches/11-embd_kv.diff

diff --git a/llm/patches/11-embd_kv.diff b/llm/patches/11-embd_kv.diff
new file mode 100644
index 00000000..ad17a700
--- /dev/null
+++ b/llm/patches/11-embd_kv.diff
@@ -0,0 +1,19 @@
+diff --git a/src/llama.cpp b/src/llama.cpp
+index 2b9ace28..e60d3d8d 100644
+--- a/src/llama.cpp
++++ b/src/llama.cpp
+@@ -6052,10 +6052,10 @@ static bool llm_load_tensors(
+ 
+                         layer.attn_norm = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_ATTN_NORM, "weight", i), {n_embd});
+ 
+-                        layer.wq = ml.create_tensor(ctx_split, tn(LLM_TENSOR_ATTN_Q,   "weight", i), {n_embd, n_embd});
+-                        layer.wk = ml.create_tensor(ctx_split, tn(LLM_TENSOR_ATTN_K,   "weight", i), {n_embd, n_embd_gqa});
+-                        layer.wv = ml.create_tensor(ctx_split, tn(LLM_TENSOR_ATTN_V,   "weight", i), {n_embd, n_embd_gqa});
+-                        layer.wo = ml.create_tensor(ctx_split, tn(LLM_TENSOR_ATTN_OUT, "weight", i), {n_embd, n_embd});
++                        layer.wq = ml.create_tensor(ctx_split, tn(LLM_TENSOR_ATTN_Q,   "weight", i), {n_embd,  n_embd_head_k * n_head});
++                        layer.wk = ml.create_tensor(ctx_split, tn(LLM_TENSOR_ATTN_K,   "weight", i), {n_embd, n_embd_k_gqa});
++                        layer.wv = ml.create_tensor(ctx_split, tn(LLM_TENSOR_ATTN_V,   "weight", i), {n_embd, n_embd_v_gqa});
++                        layer.wo = ml.create_tensor(ctx_split, tn(LLM_TENSOR_ATTN_OUT, "weight", i), {n_embd_head_k * n_head, n_embd});
+ 
+                         // optional bias tensors
+                         layer.bq = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_ATTN_Q,   "bias", i), {n_embd},     llama_model_loader::TENSOR_NOT_REQUIRED);

From 80ee9b5e47fc0ea99d1f3f33224923266627c15c Mon Sep 17 00:00:00 2001
From: Jeffrey Morgan 
Date: Sun, 21 Jul 2024 00:22:11 -0400
Subject: [PATCH 29/29] Remove out of space test temporarily (#5825)

---
 server/sched_test.go | 37 -------------------------------------
 1 file changed, 37 deletions(-)

diff --git a/server/sched_test.go b/server/sched_test.go
index 7991e7c5..9ddd1fab 100644
--- a/server/sched_test.go
+++ b/server/sched_test.go
@@ -7,7 +7,6 @@ import (
 	"fmt"
 	"log/slog"
 	"os"
-	"runtime"
 	"testing"
 	"time"
 
@@ -356,42 +355,6 @@ func TestRequestsMultipleLoadedModels(t *testing.T) {
 	s.loadedMu.Unlock()
 }
 
-func TestRequestsModelTooBigForSystem(t *testing.T) {
-	ctx, done := context.WithTimeout(context.Background(), 500*time.Millisecond)
-	defer done()
-	s := InitScheduler(ctx)
-	s.getGpuFn = func() gpu.GpuInfoList {
-		g := gpu.GpuInfo{Library: "metal"}
-		g.TotalMemory = 4 * format.MebiByte
-		g.FreeMemory = 3 * format.MebiByte
-		return []gpu.GpuInfo{g}
-	}
-
-	s.getCpuFn = func() gpu.GpuInfoList {
-		g := gpu.GpuInfo{Library: "cpu"}
-		g.TotalMemory = 4 * format.MebiByte
-		g.FreeMemory = 2 * format.MebiByte
-		return []gpu.GpuInfo{g}
-	}
-	a := newScenarioRequest(t, ctx, "ollama-model-1", 10, &api.Duration{Duration: 5 * time.Millisecond})
-
-	s.newServerFn = a.newServer
-	slog.Info("a")
-	s.pendingReqCh <- a.req
-	require.Len(t, s.pendingReqCh, 1)
-	s.Run(ctx)
-	select {
-	case <-a.req.successCh:
-		if runtime.GOOS == "linux" {
-			t.Fatal("request should have been rejected with out of space")
-		}
-		// else - Darwin and Windows don't reject right now
-	case err := <-a.req.errCh:
-		require.Contains(t, err.Error(), "too large")
-	case <-ctx.Done():
-		t.Fatal("timeout")
-	}
-}
 func TestGetRunner(t *testing.T) {
 	ctx, done := context.WithTimeout(context.Background(), 100*time.Millisecond)
 	defer done()