From a2fc933fed2e05266aff324deb2d35933563a575 Mon Sep 17 00:00:00 2001 From: Michael Yang Date: Wed, 17 Apr 2024 17:23:19 -0700 Subject: [PATCH 1/6] update delete handler to use model.Name --- server/layer.go | 23 ++++++++++++++ server/manifest.go | 76 ++++++++++++++++++++++++++++++++++++++++------ server/routes.go | 35 ++++++--------------- 3 files changed, 99 insertions(+), 35 deletions(-) diff --git a/server/layer.go b/server/layer.go index dcca3854..d3d3c120 100644 --- a/server/layer.go +++ b/server/layer.go @@ -88,3 +88,26 @@ func (l *Layer) Open() (io.ReadCloser, error) { return os.Open(blob) } + +func (l *Layer) Remove() error { + ms, err := Manifests() + if err != nil { + return err + } + + for _, m := range ms { + for _, layer := range append(m.Layers, m.Config) { + if layer.Digest == l.Digest { + // something is using this layer + return nil + } + } + } + + blob, err := GetBlobsPath(l.Digest) + if err != nil { + return err + } + + return os.Remove(blob) +} diff --git a/server/manifest.go b/server/manifest.go index 8a17700e..36ed5b4c 100644 --- a/server/manifest.go +++ b/server/manifest.go @@ -14,7 +14,9 @@ import ( type Manifest struct { ManifestV2 - Digest string `json:"-"` + + filepath string + digest string } func (m *Manifest) Size() (size int64) { @@ -25,9 +27,28 @@ func (m *Manifest) Size() (size int64) { return } -func ParseNamedManifest(name model.Name) (*Manifest, error) { - if !name.IsFullyQualified() { - return nil, model.Unqualified(name) +func (m *Manifest) Remove() error { + if err := os.Remove(m.filepath); err != nil { + return err + } + + for _, layer := range append(m.Layers, m.Config) { + if err := layer.Remove(); err != nil { + return err + } + } + + manifests, err := GetManifestPath() + if err != nil { + return err + } + + return PruneDirectory(manifests) +} + +func ParseNamedManifest(n model.Name) (*Manifest, error) { + if !n.IsFullyQualified() { + return nil, model.Unqualified(n) } manifests, err := GetManifestPath() @@ -35,20 +56,24 @@ func ParseNamedManifest(name model.Name) (*Manifest, error) { return nil, err } - var manifest ManifestV2 - manifestfile, err := os.Open(filepath.Join(manifests, name.Filepath())) + p := filepath.Join(manifests, n.Filepath()) + + var m ManifestV2 + f, err := os.Open(p) if err != nil { return nil, err } + defer f.Close() sha256sum := sha256.New() - if err := json.NewDecoder(io.TeeReader(manifestfile, sha256sum)).Decode(&manifest); err != nil { + if err := json.NewDecoder(io.TeeReader(f, sha256sum)).Decode(&m); err != nil { return nil, err } return &Manifest{ - ManifestV2: manifest, - Digest: fmt.Sprintf("%x", sha256sum.Sum(nil)), + ManifestV2: m, + filepath: p, + digest: fmt.Sprintf("%x", sha256sum.Sum(nil)), }, nil } @@ -77,3 +102,36 @@ func WriteManifest(name string, config *Layer, layers []*Layer) error { return os.WriteFile(manifestPath, b.Bytes(), 0o644) } + +func Manifests() (map[model.Name]*Manifest, error) { + manifests, err := GetManifestPath() + if err != nil { + return nil, err + } + + // TODO(mxyng): use something less brittle + matches, err := filepath.Glob(fmt.Sprintf("%s/*/*/*/*", manifests)) + if err != nil { + return nil, err + } + + ms := make(map[model.Name]*Manifest) + for _, match := range matches { + rel, err := filepath.Rel(manifests, match) + if err != nil { + return nil, err + } + + n := model.ParseNameFromFilepath(rel) + if n.IsValid() { + m, err := ParseNamedManifest(n) + if err != nil { + return nil, err + } + + ms[n] = m + } + } + + return ms, nil +} diff --git a/server/routes.go b/server/routes.go index 123ef9a3..ff888e3c 100644 --- a/server/routes.go +++ b/server/routes.go @@ -574,48 +574,31 @@ func (s *Server) CreateModelHandler(c *gin.Context) { } func (s *Server) DeleteModelHandler(c *gin.Context) { - var req api.DeleteRequest - err := c.ShouldBindJSON(&req) - switch { - case errors.Is(err, io.EOF): + var r api.DeleteRequest + if err := c.ShouldBindJSON(&r); errors.Is(err, io.EOF) { c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": "missing request body"}) return - case err != nil: + } else if err != nil { c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": err.Error()}) return } - var model string - if req.Model != "" { - model = req.Model - } else if req.Name != "" { - model = req.Name - } else { - c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": "model is required"}) + n := model.ParseName(cmp.Or(r.Model, r.Name)) + if !n.IsValid() { + c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": fmt.Sprintf("name %q is invalid", cmp.Or(r.Model, r.Name))}) return } - if err := DeleteModel(model); err != nil { - if os.IsNotExist(err) { - c.JSON(http.StatusNotFound, gin.H{"error": fmt.Sprintf("model '%s' not found", model)}) - } else { - c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()}) - } - return - } - - manifestsPath, err := GetManifestPath() + m, err := ParseNamedManifest(n) if err != nil { c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()}) return } - if err := PruneDirectory(manifestsPath); err != nil { + if err := m.Remove(); err != nil { c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()}) return } - - c.JSON(http.StatusOK, nil) } func (s *Server) ShowModelHandler(c *gin.Context) { @@ -769,7 +752,7 @@ func (s *Server) ListModelsHandler(c *gin.Context) { Model: n.DisplayShortest(), Name: n.DisplayShortest(), Size: m.Size(), - Digest: m.Digest, + Digest: m.digest, ModifiedAt: info.ModTime(), Details: api.ModelDetails{ Format: c.ModelFormat, From c2714fcbfd600c2a13efbc42bab95b49b0b4fa33 Mon Sep 17 00:00:00 2001 From: Michael Yang Date: Mon, 6 May 2024 16:34:13 -0700 Subject: [PATCH 2/6] routes: use Manifests for ListHandler --- server/manifest.go | 11 ++++- server/manifest_test.go | 90 +++++++++++++++++++++++++++++++++++++++++ server/routes.go | 84 +++++++++++++------------------------- 3 files changed, 127 insertions(+), 58 deletions(-) create mode 100644 server/manifest_test.go diff --git a/server/manifest.go b/server/manifest.go index 36ed5b4c..131d4918 100644 --- a/server/manifest.go +++ b/server/manifest.go @@ -6,6 +6,7 @@ import ( "encoding/json" "fmt" "io" + "log/slog" "os" "path/filepath" @@ -16,6 +17,7 @@ type Manifest struct { ManifestV2 filepath string + fi os.FileInfo digest string } @@ -65,6 +67,11 @@ func ParseNamedManifest(n model.Name) (*Manifest, error) { } defer f.Close() + fi, err := f.Stat() + if err != nil { + return nil, err + } + sha256sum := sha256.New() if err := json.NewDecoder(io.TeeReader(f, sha256sum)).Decode(&m); err != nil { return nil, err @@ -73,6 +80,7 @@ func ParseNamedManifest(n model.Name) (*Manifest, error) { return &Manifest{ ManifestV2: m, filepath: p, + fi: fi, digest: fmt.Sprintf("%x", sha256sum.Sum(nil)), }, nil } @@ -126,7 +134,8 @@ func Manifests() (map[model.Name]*Manifest, error) { if n.IsValid() { m, err := ParseNamedManifest(n) if err != nil { - return nil, err + slog.Warn("bad manifest", "name", n, "error", err) + continue } ms[n] = m diff --git a/server/manifest_test.go b/server/manifest_test.go new file mode 100644 index 00000000..35c6bc8d --- /dev/null +++ b/server/manifest_test.go @@ -0,0 +1,90 @@ +package server + +import ( + "encoding/json" + "os" + "path/filepath" + "slices" + "testing" + + "github.com/ollama/ollama/types/model" +) + +func createManifest(t *testing.T, path, name string) { + t.Helper() + + p := filepath.Join(path, "manifests", name) + if err := os.MkdirAll(filepath.Dir(p), 0755); err != nil { + t.Fatal(err) + } + + f, err := os.Create(p) + if err != nil { + t.Fatal(err) + } + defer f.Close() + + if err := json.NewEncoder(f).Encode(ManifestV2{}); err != nil { + t.Fatal(err) + } +} + +func TestManifests(t *testing.T) { + cases := map[string][]string{ + "empty": {}, + "single": { + filepath.Join("host", "namespace", "model", "tag"), + }, + "multiple": { + filepath.Join("registry.ollama.ai", "library", "llama3", "latest"), + filepath.Join("registry.ollama.ai", "library", "llama3", "q4_0"), + filepath.Join("registry.ollama.ai", "library", "llama3", "q4_1"), + filepath.Join("registry.ollama.ai", "library", "llama3", "q8_0"), + filepath.Join("registry.ollama.ai", "library", "llama3", "q5_0"), + filepath.Join("registry.ollama.ai", "library", "llama3", "q5_1"), + filepath.Join("registry.ollama.ai", "library", "llama3", "q2_K"), + filepath.Join("registry.ollama.ai", "library", "llama3", "q3_K_S"), + filepath.Join("registry.ollama.ai", "library", "llama3", "q3_K_M"), + filepath.Join("registry.ollama.ai", "library", "llama3", "q3_K_L"), + filepath.Join("registry.ollama.ai", "library", "llama3", "q4_K_S"), + filepath.Join("registry.ollama.ai", "library", "llama3", "q4_K_M"), + filepath.Join("registry.ollama.ai", "library", "llama3", "q5_K_S"), + filepath.Join("registry.ollama.ai", "library", "llama3", "q5_K_M"), + filepath.Join("registry.ollama.ai", "library", "llama3", "q6_K"), + }, + "hidden": { + filepath.Join("host", "namespace", "model", "tag"), + filepath.Join("host", "namespace", "model", ".hidden"), + }, + } + + for n, wants := range cases { + t.Run(n, func(t *testing.T) { + d := t.TempDir() + t.Setenv("OLLAMA_MODELS", d) + + for _, want := range wants { + createManifest(t, d, want) + } + + ms, err := Manifests() + if err != nil { + t.Fatal(err) + } + + var ns []model.Name + for k := range ms { + ns = append(ns, k) + } + + for _, want := range wants { + n := model.ParseNameFromFilepath(want) + if !n.IsValid() && slices.Contains(ns, n) { + t.Errorf("unexpected invalid name: %s", want) + } else if n.IsValid() && !slices.Contains(ns, n) { + t.Errorf("missing valid name: %s", want) + } + } + }) + } +} diff --git a/server/routes.go b/server/routes.go index ff888e3c..14853feb 100644 --- a/server/routes.go +++ b/server/routes.go @@ -702,72 +702,42 @@ func GetModelInfo(req api.ShowRequest) (*api.ShowResponse, error) { } func (s *Server) ListModelsHandler(c *gin.Context) { - manifests, err := GetManifestPath() + ms, err := Manifests() if err != nil { c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()}) return } models := []api.ModelResponse{} - if err := filepath.Walk(manifests, func(path string, info os.FileInfo, _ error) error { - if !info.IsDir() { - rel, err := filepath.Rel(manifests, path) - if err != nil { - return err - } + for n, m := range ms { + f, err := m.Config.Open() + if err != nil { + slog.Warn("bad manifest filepath", "name", n, "error", err) + continue + } + defer f.Close() - if hidden, err := filepath.Match(".*", filepath.Base(rel)); err != nil { - return err - } else if hidden { - return nil - } - - n := model.ParseNameFromFilepath(rel) - if !n.IsValid() { - slog.Warn("bad manifest filepath", "path", rel) - return nil - } - - m, err := ParseNamedManifest(n) - if err != nil { - slog.Warn("bad manifest", "name", n, "error", err) - return nil - } - - f, err := m.Config.Open() - if err != nil { - slog.Warn("bad manifest config filepath", "name", n, "error", err) - return nil - } - defer f.Close() - - var c ConfigV2 - if err := json.NewDecoder(f).Decode(&c); err != nil { - slog.Warn("bad manifest config", "name", n, "error", err) - return nil - } - - // tag should never be masked - models = append(models, api.ModelResponse{ - Model: n.DisplayShortest(), - Name: n.DisplayShortest(), - Size: m.Size(), - Digest: m.digest, - ModifiedAt: info.ModTime(), - Details: api.ModelDetails{ - Format: c.ModelFormat, - Family: c.ModelFamily, - Families: c.ModelFamilies, - ParameterSize: c.ModelType, - QuantizationLevel: c.FileType, - }, - }) + var cf ConfigV2 + if err := json.NewDecoder(f).Decode(&cf); err != nil { + slog.Warn("bad manifest config", "name", n, "error", err) + continue } - return nil - }); err != nil { - c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()}) - return + // tag should never be masked + models = append(models, api.ModelResponse{ + Model: n.DisplayShortest(), + Name: n.DisplayShortest(), + Size: m.Size(), + Digest: m.digest, + ModifiedAt: m.fi.ModTime(), + Details: api.ModelDetails{ + Format: cf.ModelFormat, + Family: cf.ModelFamily, + Families: cf.ModelFamilies, + ParameterSize: cf.ModelType, + QuantizationLevel: cf.FileType, + }, + }) } slices.SortStableFunc(models, func(i, j api.ModelResponse) int { From b8772a353f83839fd299184766827db86454701e Mon Sep 17 00:00:00 2001 From: Michael Yang Date: Wed, 8 May 2024 14:54:52 -0700 Subject: [PATCH 3/6] remove DeleteModel --- server/images.go | 31 ------------------------------- 1 file changed, 31 deletions(-) diff --git a/server/images.go b/server/images.go index 3f415b6d..94057a49 100644 --- a/server/images.go +++ b/server/images.go @@ -746,37 +746,6 @@ func PruneDirectory(path string) error { return nil } -func DeleteModel(name string) error { - mp := ParseModelPath(name) - manifest, _, err := GetManifest(mp) - if err != nil { - return err - } - - deleteMap := make(map[string]struct{}) - for _, layer := range manifest.Layers { - deleteMap[layer.Digest] = struct{}{} - } - deleteMap[manifest.Config.Digest] = struct{}{} - - err = deleteUnusedLayers(&mp, deleteMap) - if err != nil { - return err - } - - fp, err := mp.GetManifestPath() - if err != nil { - return err - } - err = os.Remove(fp) - if err != nil { - slog.Info(fmt.Sprintf("couldn't remove manifest file '%s': %v", fp, err)) - return err - } - - return nil -} - func PushModel(ctx context.Context, name string, regOpts *registryOptions, fn func(api.ProgressResponse)) error { mp := ParseModelPath(name) fn(api.ProgressResponse{Status: "retrieving manifest"}) From a385382ff5e1dc94ed17e8cd0b29f031e91c33ed Mon Sep 17 00:00:00 2001 From: Michael Yang Date: Wed, 8 May 2024 15:56:40 -0700 Subject: [PATCH 4/6] filepath.Join --- server/manifest.go | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/server/manifest.go b/server/manifest.go index 131d4918..41a96c55 100644 --- a/server/manifest.go +++ b/server/manifest.go @@ -118,7 +118,7 @@ func Manifests() (map[model.Name]*Manifest, error) { } // TODO(mxyng): use something less brittle - matches, err := filepath.Glob(fmt.Sprintf("%s/*/*/*/*", manifests)) + matches, err := filepath.Glob(filepath.Join(manifests, "*", "*", "*", "*")) if err != nil { return nil, err } From 81fb06f5307349244263a199d66eb30926a71d28 Mon Sep 17 00:00:00 2001 From: Michael Yang Date: Thu, 9 May 2024 10:00:18 -0700 Subject: [PATCH 5/6] more resilient Manifests --- server/manifest.go | 17 ++++++++++++++--- server/manifest_test.go | 4 ++++ 2 files changed, 18 insertions(+), 3 deletions(-) diff --git a/server/manifest.go b/server/manifest.go index 41a96c55..a5251298 100644 --- a/server/manifest.go +++ b/server/manifest.go @@ -125,13 +125,24 @@ func Manifests() (map[model.Name]*Manifest, error) { ms := make(map[model.Name]*Manifest) for _, match := range matches { - rel, err := filepath.Rel(manifests, match) + fi, err := os.Stat(match) if err != nil { return nil, err } - n := model.ParseNameFromFilepath(rel) - if n.IsValid() { + if !fi.IsDir() { + rel, err := filepath.Rel(manifests, match) + if err != nil { + slog.Warn("bad filepath", "path", match, "error", err) + continue + } + + n := model.ParseNameFromFilepath(rel) + if !n.IsValid() { + slog.Warn("bad manifest name", "path", rel, "error", err) + continue + } + m, err := ParseNamedManifest(n) if err != nil { slog.Warn("bad manifest", "name", n, "error", err) diff --git a/server/manifest_test.go b/server/manifest_test.go index 35c6bc8d..4da86745 100644 --- a/server/manifest_test.go +++ b/server/manifest_test.go @@ -56,6 +56,10 @@ func TestManifests(t *testing.T) { filepath.Join("host", "namespace", "model", "tag"), filepath.Join("host", "namespace", "model", ".hidden"), }, + "subdir": { + filepath.Join("host", "namespace", "model", "tag", "one"), + filepath.Join("host", "namespace", "model", "tag", "another", "one"), + }, } for n, wants := range cases { From c5e892cb3ef21b4ba315389210205b65e46b62aa Mon Sep 17 00:00:00 2001 From: Michael Yang Date: Mon, 13 May 2024 14:41:37 -0700 Subject: [PATCH 6/6] update tests --- server/manifest_test.go | 114 ++++++++++++++++++------- server/routes_create_test.go | 160 +++++++++++++++++++++++++++++++++++ server/routes_delete_test.go | 71 ++++++++++++++++ server/routes_list_test.go | 61 +++++++++++++ 4 files changed, 377 insertions(+), 29 deletions(-) create mode 100644 server/routes_create_test.go create mode 100644 server/routes_delete_test.go create mode 100644 server/routes_list_test.go diff --git a/server/manifest_test.go b/server/manifest_test.go index 4da86745..b85976fd 100644 --- a/server/manifest_test.go +++ b/server/manifest_test.go @@ -30,35 +30,76 @@ func createManifest(t *testing.T, path, name string) { } func TestManifests(t *testing.T) { - cases := map[string][]string{ + cases := map[string]struct { + ps []string + wantValidCount int + wantInvalidCount int + }{ "empty": {}, "single": { - filepath.Join("host", "namespace", "model", "tag"), + ps: []string{ + filepath.Join("host", "namespace", "model", "tag"), + }, + wantValidCount: 1, }, "multiple": { - filepath.Join("registry.ollama.ai", "library", "llama3", "latest"), - filepath.Join("registry.ollama.ai", "library", "llama3", "q4_0"), - filepath.Join("registry.ollama.ai", "library", "llama3", "q4_1"), - filepath.Join("registry.ollama.ai", "library", "llama3", "q8_0"), - filepath.Join("registry.ollama.ai", "library", "llama3", "q5_0"), - filepath.Join("registry.ollama.ai", "library", "llama3", "q5_1"), - filepath.Join("registry.ollama.ai", "library", "llama3", "q2_K"), - filepath.Join("registry.ollama.ai", "library", "llama3", "q3_K_S"), - filepath.Join("registry.ollama.ai", "library", "llama3", "q3_K_M"), - filepath.Join("registry.ollama.ai", "library", "llama3", "q3_K_L"), - filepath.Join("registry.ollama.ai", "library", "llama3", "q4_K_S"), - filepath.Join("registry.ollama.ai", "library", "llama3", "q4_K_M"), - filepath.Join("registry.ollama.ai", "library", "llama3", "q5_K_S"), - filepath.Join("registry.ollama.ai", "library", "llama3", "q5_K_M"), - filepath.Join("registry.ollama.ai", "library", "llama3", "q6_K"), + ps: []string{ + filepath.Join("registry.ollama.ai", "library", "llama3", "latest"), + filepath.Join("registry.ollama.ai", "library", "llama3", "q4_0"), + filepath.Join("registry.ollama.ai", "library", "llama3", "q4_1"), + filepath.Join("registry.ollama.ai", "library", "llama3", "q8_0"), + filepath.Join("registry.ollama.ai", "library", "llama3", "q5_0"), + filepath.Join("registry.ollama.ai", "library", "llama3", "q5_1"), + filepath.Join("registry.ollama.ai", "library", "llama3", "q2_K"), + filepath.Join("registry.ollama.ai", "library", "llama3", "q3_K_S"), + filepath.Join("registry.ollama.ai", "library", "llama3", "q3_K_M"), + filepath.Join("registry.ollama.ai", "library", "llama3", "q3_K_L"), + filepath.Join("registry.ollama.ai", "library", "llama3", "q4_K_S"), + filepath.Join("registry.ollama.ai", "library", "llama3", "q4_K_M"), + filepath.Join("registry.ollama.ai", "library", "llama3", "q5_K_S"), + filepath.Join("registry.ollama.ai", "library", "llama3", "q5_K_M"), + filepath.Join("registry.ollama.ai", "library", "llama3", "q6_K"), + }, + wantValidCount: 15, }, "hidden": { - filepath.Join("host", "namespace", "model", "tag"), - filepath.Join("host", "namespace", "model", ".hidden"), + ps: []string{ + filepath.Join("host", "namespace", "model", "tag"), + filepath.Join("host", "namespace", "model", ".hidden"), + }, + wantValidCount: 1, + wantInvalidCount: 1, }, "subdir": { - filepath.Join("host", "namespace", "model", "tag", "one"), - filepath.Join("host", "namespace", "model", "tag", "another", "one"), + ps: []string{ + filepath.Join("host", "namespace", "model", "tag", "one"), + filepath.Join("host", "namespace", "model", "tag", "another", "one"), + }, + wantInvalidCount: 2, + }, + "upper tag": { + ps: []string{ + filepath.Join("host", "namespace", "model", "TAG"), + }, + wantValidCount: 1, + }, + "upper model": { + ps: []string{ + filepath.Join("host", "namespace", "MODEL", "tag"), + }, + wantValidCount: 1, + }, + "upper namespace": { + ps: []string{ + filepath.Join("host", "NAMESPACE", "model", "tag"), + }, + wantValidCount: 1, + }, + "upper host": { + ps: []string{ + filepath.Join("HOST", "namespace", "model", "tag"), + }, + wantValidCount: 1, }, } @@ -67,8 +108,8 @@ func TestManifests(t *testing.T) { d := t.TempDir() t.Setenv("OLLAMA_MODELS", d) - for _, want := range wants { - createManifest(t, d, want) + for _, p := range wants.ps { + createManifest(t, d, p) } ms, err := Manifests() @@ -81,13 +122,28 @@ func TestManifests(t *testing.T) { ns = append(ns, k) } - for _, want := range wants { - n := model.ParseNameFromFilepath(want) - if !n.IsValid() && slices.Contains(ns, n) { - t.Errorf("unexpected invalid name: %s", want) - } else if n.IsValid() && !slices.Contains(ns, n) { - t.Errorf("missing valid name: %s", want) + var gotValidCount, gotInvalidCount int + for _, p := range wants.ps { + n := model.ParseNameFromFilepath(p) + if n.IsValid() { + gotValidCount++ + } else { + gotInvalidCount++ } + + if !n.IsValid() && slices.Contains(ns, n) { + t.Errorf("unexpected invalid name: %s", p) + } else if n.IsValid() && !slices.Contains(ns, n) { + t.Errorf("missing valid name: %s", p) + } + } + + if gotValidCount != wants.wantValidCount { + t.Errorf("got valid count %d, want %d", gotValidCount, wants.wantValidCount) + } + + if gotInvalidCount != wants.wantInvalidCount { + t.Errorf("got invalid count %d, want %d", gotInvalidCount, wants.wantInvalidCount) } }) } diff --git a/server/routes_create_test.go b/server/routes_create_test.go new file mode 100644 index 00000000..e5af1ded --- /dev/null +++ b/server/routes_create_test.go @@ -0,0 +1,160 @@ +package server + +import ( + "bytes" + "encoding/binary" + "encoding/json" + "fmt" + "io" + "net/http" + "net/http/httptest" + "os" + "path/filepath" + "slices" + "testing" + + "github.com/gin-gonic/gin" + "github.com/ollama/ollama/api" +) + +var stream bool = false + +func createBinFile(t *testing.T) string { + t.Helper() + + f, err := os.CreateTemp(t.TempDir(), "") + if err != nil { + t.Fatal(err) + } + defer f.Close() + + if err := binary.Write(f, binary.LittleEndian, []byte("GGUF")); err != nil { + t.Fatal(err) + } + + if err := binary.Write(f, binary.LittleEndian, uint32(3)); err != nil { + t.Fatal(err) + } + + if err := binary.Write(f, binary.LittleEndian, uint64(0)); err != nil { + t.Fatal(err) + } + + if err := binary.Write(f, binary.LittleEndian, uint64(0)); err != nil { + t.Fatal(err) + } + + return f.Name() +} + +type responseRecorder struct { + *httptest.ResponseRecorder + http.CloseNotifier +} + +func NewRecorder() *responseRecorder { + return &responseRecorder{ + ResponseRecorder: httptest.NewRecorder(), + } +} + +func (t *responseRecorder) CloseNotify() <-chan bool { + return make(chan bool) +} + +func createRequest(t *testing.T, fn func(*gin.Context), body any) *httptest.ResponseRecorder { + t.Helper() + + w := NewRecorder() + c, _ := gin.CreateTestContext(w) + + var b bytes.Buffer + if err := json.NewEncoder(&b).Encode(body); err != nil { + t.Fatal(err) + } + + c.Request = &http.Request{ + Body: io.NopCloser(&b), + } + + fn(c) + return w.ResponseRecorder +} + +func checkFileExists(t *testing.T, p string, expect []string) { + t.Helper() + + actual, err := filepath.Glob(p) + if err != nil { + t.Fatal(err) + } + + if !slices.Equal(actual, expect) { + t.Fatalf("expected slices to be equal %v", actual) + } +} + +func TestCreateFromBin(t *testing.T) { + p := t.TempDir() + t.Setenv("OLLAMA_MODELS", p) + + var s Server + w := createRequest(t, s.CreateModelHandler, api.CreateRequest{ + Name: "test", + Modelfile: fmt.Sprintf("FROM %s", createBinFile(t)), + Stream: &stream, + }) + + if w.Code != http.StatusOK { + t.Fatalf("expected status code 200, actual %d", w.Code) + } + + checkFileExists(t, filepath.Join(p, "manifests", "*", "*", "*", "*"), []string{ + filepath.Join(p, "manifests", "registry.ollama.ai", "library", "test", "latest"), + }) + + checkFileExists(t, filepath.Join(p, "blobs", "*"), []string{ + filepath.Join(p, "blobs", "sha256-a4e5e156ddec27e286f75328784d7106b60a4eb1d246e950a001a3f944fbda99"), + filepath.Join(p, "blobs", "sha256-ca239d7bd8ea90e4a5d2e6bf88f8d74a47b14336e73eb4e18bed4dd325018116"), + }) +} + +func TestCreateFromModel(t *testing.T) { + p := t.TempDir() + t.Setenv("OLLAMA_MODELS", p) + var s Server + + w := createRequest(t, s.CreateModelHandler, api.CreateRequest{ + Name: "test", + Modelfile: fmt.Sprintf("FROM %s", createBinFile(t)), + Stream: &stream, + }) + + if w.Code != http.StatusOK { + t.Fatalf("expected status code 200, actual %d", w.Code) + } + + checkFileExists(t, filepath.Join(p, "manifests", "*", "*", "*", "*"), []string{ + filepath.Join(p, "manifests", "registry.ollama.ai", "library", "test", "latest"), + }) + + w = createRequest(t, s.CreateModelHandler, api.CreateRequest{ + Name: "test2", + Modelfile: "FROM test", + Stream: &stream, + }) + + if w.Code != http.StatusOK { + t.Fatalf("expected status code 200, actual %d", w.Code) + } + + checkFileExists(t, filepath.Join(p, "manifests", "*", "*", "*", "*"), []string{ + filepath.Join(p, "manifests", "registry.ollama.ai", "library", "test", "latest"), + filepath.Join(p, "manifests", "registry.ollama.ai", "library", "test2", "latest"), + }) + + checkFileExists(t, filepath.Join(p, "blobs", "*"), []string{ + filepath.Join(p, "blobs", "sha256-a4e5e156ddec27e286f75328784d7106b60a4eb1d246e950a001a3f944fbda99"), + filepath.Join(p, "blobs", "sha256-ca239d7bd8ea90e4a5d2e6bf88f8d74a47b14336e73eb4e18bed4dd325018116"), + }) +} diff --git a/server/routes_delete_test.go b/server/routes_delete_test.go new file mode 100644 index 00000000..ea098d05 --- /dev/null +++ b/server/routes_delete_test.go @@ -0,0 +1,71 @@ +package server + +import ( + "fmt" + "net/http" + "path/filepath" + "testing" + + "github.com/ollama/ollama/api" +) + +func TestDelete(t *testing.T) { + p := t.TempDir() + t.Setenv("OLLAMA_MODELS", p) + var s Server + + w := createRequest(t, s.CreateModelHandler, api.CreateRequest{ + Name: "test", + Modelfile: fmt.Sprintf("FROM %s", createBinFile(t)), + }) + + if w.Code != http.StatusOK { + t.Fatalf("expected status code 200, actual %d", w.Code) + } + + w = createRequest(t, s.CreateModelHandler, api.CreateRequest{ + Name: "test2", + Modelfile: fmt.Sprintf("FROM %s\nTEMPLATE {{ .System }} {{ .Prompt }}", createBinFile(t)), + }) + + if w.Code != http.StatusOK { + t.Fatalf("expected status code 200, actual %d", w.Code) + } + + checkFileExists(t, filepath.Join(p, "manifests", "*", "*", "*", "*"), []string{ + filepath.Join(p, "manifests", "registry.ollama.ai", "library", "test", "latest"), + filepath.Join(p, "manifests", "registry.ollama.ai", "library", "test2", "latest"), + }) + + checkFileExists(t, filepath.Join(p, "blobs", "*"), []string{ + filepath.Join(p, "blobs", "sha256-8f2c2167d789c6b2302dff965160fa5029f6a24096d262c1cbb469f21a045382"), + filepath.Join(p, "blobs", "sha256-a4e5e156ddec27e286f75328784d7106b60a4eb1d246e950a001a3f944fbda99"), + filepath.Join(p, "blobs", "sha256-ca239d7bd8ea90e4a5d2e6bf88f8d74a47b14336e73eb4e18bed4dd325018116"), + filepath.Join(p, "blobs", "sha256-fe7ac77b725cda2ccad03f88a880ecdfd7a33192d6cae08fce2c0ee1455991ed"), + }) + + w = createRequest(t, s.DeleteModelHandler, api.DeleteRequest{Name: "test"}) + + if w.Code != http.StatusOK { + t.Fatalf("expected status code 200, actual %d", w.Code) + } + + checkFileExists(t, filepath.Join(p, "manifests", "*", "*", "*", "*"), []string{ + filepath.Join(p, "manifests", "registry.ollama.ai", "library", "test2", "latest"), + }) + + checkFileExists(t, filepath.Join(p, "blobs", "*"), []string{ + filepath.Join(p, "blobs", "sha256-8f2c2167d789c6b2302dff965160fa5029f6a24096d262c1cbb469f21a045382"), + filepath.Join(p, "blobs", "sha256-a4e5e156ddec27e286f75328784d7106b60a4eb1d246e950a001a3f944fbda99"), + filepath.Join(p, "blobs", "sha256-fe7ac77b725cda2ccad03f88a880ecdfd7a33192d6cae08fce2c0ee1455991ed"), + }) + + w = createRequest(t, s.DeleteModelHandler, api.DeleteRequest{Name: "test2"}) + + if w.Code != http.StatusOK { + t.Fatalf("expected status code 200, actual %d", w.Code) + } + + checkFileExists(t, filepath.Join(p, "manifests", "*", "*", "*", "*"), []string{}) + checkFileExists(t, filepath.Join(p, "blobs", "*"), []string{}) +} diff --git a/server/routes_list_test.go b/server/routes_list_test.go new file mode 100644 index 00000000..e92b4eab --- /dev/null +++ b/server/routes_list_test.go @@ -0,0 +1,61 @@ +package server + +import ( + "encoding/json" + "fmt" + "net/http" + "slices" + "testing" + + "github.com/ollama/ollama/api" +) + +func TestList(t *testing.T) { + t.Setenv("OLLAMA_MODELS", t.TempDir()) + + expectNames := []string{ + "mistral:7b-instruct-q4_0", + "zephyr:7b-beta-q5_K_M", + "apple/OpenELM:latest", + "boreas:2b-code-v1.5-q6_K", + "notus:7b-v1-IQ2_S", + // TODO: host:port currently fails on windows (#4107) + // "localhost:5000/library/eurus:700b-v0.5-iq3_XXS", + "mynamespace/apeliotes:latest", + "myhost/mynamespace/lips:code", + } + + var s Server + for _, n := range expectNames { + createRequest(t, s.CreateModelHandler, api.CreateRequest{ + Name: n, + Modelfile: fmt.Sprintf("FROM %s", createBinFile(t)), + }) + } + + w := createRequest(t, s.ListModelsHandler, nil) + if w.Code != http.StatusOK { + t.Fatalf("expected status code 200, actual %d", w.Code) + } + + var resp api.ListResponse + if err := json.NewDecoder(w.Body).Decode(&resp); err != nil { + t.Fatal(err) + } + + if len(resp.Models) != len(expectNames) { + t.Fatalf("expected %d models, actual %d", len(expectNames), len(resp.Models)) + } + + actualNames := make([]string, len(resp.Models)) + for i, m := range resp.Models { + actualNames[i] = m.Name + } + + slices.Sort(actualNames) + slices.Sort(expectNames) + + if !slices.Equal(actualNames, expectNames) { + t.Fatalf("expected slices to be equal %v", actualNames) + } +}