From 3948c6ea06a3cdcc331e0a90bc0c754f62c8be55 Mon Sep 17 00:00:00 2001 From: Patrick Devine Date: Mon, 18 Dec 2023 10:41:02 -0800 Subject: [PATCH] add magic header for unit tests (#1558) --- server/routes_test.go | 32 +++++++++++++++++++++----------- 1 file changed, 21 insertions(+), 11 deletions(-) diff --git a/server/routes_test.go b/server/routes_test.go index a8ccb8db..7f3e4f05 100644 --- a/server/routes_test.go +++ b/server/routes_test.go @@ -6,7 +6,6 @@ import ( "encoding/json" "fmt" "io" - "net/http" "net/http/httptest" "os" @@ -33,17 +32,29 @@ func Test_Routes(t *testing.T) { Setup func(t *testing.T, req *http.Request) Expected func(t *testing.T, resp *http.Response) } - var tempModelFile string + + createTestFile := func(t *testing.T, name string) string { + f, err := os.CreateTemp(t.TempDir(), name) + assert.Nil(t, err) + defer f.Close() + + _, err = f.Write([]byte("GGUF")) + assert.Nil(t, err) + _, err = f.Write([]byte{0x2, 0}) + assert.Nil(t, err) + + return f.Name() + } createTestModel := func(t *testing.T, name string) { - f, err := os.CreateTemp("", "ollama-model") - assert.Nil(t, err) - defer os.RemoveAll(f.Name()) + fname := createTestFile(t, "ollama-model") - modelfile := strings.NewReader(fmt.Sprintf("FROM %s", f.Name())) + modelfile := strings.NewReader(fmt.Sprintf("FROM %s", fname)) commands, err := parser.Parse(modelfile) assert.Nil(t, err) - fn := func(resp api.ProgressResponse) {} + fn := func(resp api.ProgressResponse) { + t.Logf("Status: %s", resp.Status) + } err = CreateModel(context.TODO(), name, "", commands, fn) assert.Nil(t, err) } @@ -107,9 +118,9 @@ func Test_Routes(t *testing.T) { Method: http.MethodPost, Path: "/api/create", Setup: func(t *testing.T, req *http.Request) { - f, err := os.CreateTemp("", "ollama-model") + f, err := os.CreateTemp(t.TempDir(), "ollama-model") assert.Nil(t, err) - tempModelFile = f.Name() + defer f.Close() stream := false createReq := api.CreateRequest{ @@ -123,8 +134,6 @@ func Test_Routes(t *testing.T) { req.Body = io.NopCloser(bytes.NewReader(jsonData)) }, Expected: func(t *testing.T, resp *http.Response) { - os.RemoveAll(tempModelFile) - contentType := resp.Header.Get("Content-Type") assert.Equal(t, "application/json", contentType) _, err := io.ReadAll(resp.Body) @@ -173,6 +182,7 @@ func Test_Routes(t *testing.T) { os.Setenv("OLLAMA_MODELS", workDir) for _, tc := range testCases { + t.Logf("Running Test: [%s]", tc.Name) u := httpSrv.URL + tc.Path req, err := http.NewRequestWithContext(context.TODO(), tc.Method, u, nil) assert.Nil(t, err)