diff --git a/server/routes_test.go b/server/routes_test.go index 7f0984ea..a8ccb8db 100644 --- a/server/routes_test.go +++ b/server/routes_test.go @@ -1,10 +1,12 @@ package server import ( + "bytes" "context" "encoding/json" "fmt" "io" + "net/http" "net/http/httptest" "os" @@ -31,6 +33,20 @@ func Test_Routes(t *testing.T) { Setup func(t *testing.T, req *http.Request) Expected func(t *testing.T, resp *http.Response) } + var tempModelFile string + + createTestModel := func(t *testing.T, name string) { + f, err := os.CreateTemp("", "ollama-model") + assert.Nil(t, err) + defer os.RemoveAll(f.Name()) + + modelfile := strings.NewReader(fmt.Sprintf("FROM %s", f.Name())) + commands, err := parser.Parse(modelfile) + assert.Nil(t, err) + fn := func(resp api.ProgressResponse) {} + err = CreateModel(context.TODO(), name, "", commands, fn) + assert.Nil(t, err) + } testCases := []testCase{ { @@ -70,16 +86,7 @@ func Test_Routes(t *testing.T) { Method: http.MethodGet, Path: "/api/tags", Setup: func(t *testing.T, req *http.Request) { - f, err := os.CreateTemp("", "ollama-modelfile") - assert.Nil(t, err) - defer os.RemoveAll(f.Name()) - - modelfile := strings.NewReader(fmt.Sprintf("FROM %s", f.Name())) - commands, err := parser.Parse(modelfile) - assert.Nil(t, err) - fn := func(resp api.ProgressResponse) {} - err = CreateModel(context.TODO(), "test", "", commands, fn) - assert.Nil(t, err) + createTestModel(t, "test-model") }, Expected: func(t *testing.T, resp *http.Response) { contentType := resp.Header.Get("Content-Type") @@ -88,12 +95,66 @@ func Test_Routes(t *testing.T) { assert.Nil(t, err) var modelList api.ListResponse - err = json.Unmarshal(body, &modelList) assert.Nil(t, err) assert.Equal(t, 1, len(modelList.Models)) - assert.Equal(t, modelList.Models[0].Name, "test:latest") + assert.Equal(t, modelList.Models[0].Name, "test-model:latest") + }, + }, + { + Name: "Create Model Handler", + Method: http.MethodPost, + Path: "/api/create", + Setup: func(t *testing.T, req *http.Request) { + f, err := os.CreateTemp("", "ollama-model") + assert.Nil(t, err) + tempModelFile = f.Name() + + stream := false + createReq := api.CreateRequest{ + Name: "t-bone", + Modelfile: fmt.Sprintf("FROM %s", f.Name()), + Stream: &stream, + } + jsonData, err := json.Marshal(createReq) + assert.Nil(t, err) + + req.Body = io.NopCloser(bytes.NewReader(jsonData)) + }, + Expected: func(t *testing.T, resp *http.Response) { + os.RemoveAll(tempModelFile) + + contentType := resp.Header.Get("Content-Type") + assert.Equal(t, "application/json", contentType) + _, err := io.ReadAll(resp.Body) + assert.Nil(t, err) + assert.Equal(t, resp.StatusCode, 200) + + model, err := GetModel("t-bone") + assert.Nil(t, err) + assert.Equal(t, "t-bone:latest", model.ShortName) + }, + }, + { + Name: "Copy Model Handler", + Method: http.MethodPost, + Path: "/api/copy", + Setup: func(t *testing.T, req *http.Request) { + createTestModel(t, "hamshank") + copyReq := api.CopyRequest{ + Source: "hamshank", + Destination: "beefsteak", + } + jsonData, err := json.Marshal(copyReq) + assert.Nil(t, err) + + req.Body = io.NopCloser(bytes.NewReader(jsonData)) + }, + Expected: func(t *testing.T, resp *http.Response) { + model, err := GetModel("beefsteak") + assert.Nil(t, err) + assert.Equal(t, "beefsteak:latest", model.ShortName) }, }, } @@ -121,11 +182,13 @@ func Test_Routes(t *testing.T) { } resp, err := httpSrv.Client().Do(req) + defer resp.Body.Close() assert.Nil(t, err) if tc.Expected != nil { tc.Expected(t, resp) } + } }