diff --git a/server/layer.go b/server/layer.go index dcca3854..d3d3c120 100644 --- a/server/layer.go +++ b/server/layer.go @@ -88,3 +88,26 @@ func (l *Layer) Open() (io.ReadCloser, error) { return os.Open(blob) } + +func (l *Layer) Remove() error { + ms, err := Manifests() + if err != nil { + return err + } + + for _, m := range ms { + for _, layer := range append(m.Layers, m.Config) { + if layer.Digest == l.Digest { + // something is using this layer + return nil + } + } + } + + blob, err := GetBlobsPath(l.Digest) + if err != nil { + return err + } + + return os.Remove(blob) +} diff --git a/server/manifest.go b/server/manifest.go index 8a17700e..36ed5b4c 100644 --- a/server/manifest.go +++ b/server/manifest.go @@ -14,7 +14,9 @@ import ( type Manifest struct { ManifestV2 - Digest string `json:"-"` + + filepath string + digest string } func (m *Manifest) Size() (size int64) { @@ -25,9 +27,28 @@ func (m *Manifest) Size() (size int64) { return } -func ParseNamedManifest(name model.Name) (*Manifest, error) { - if !name.IsFullyQualified() { - return nil, model.Unqualified(name) +func (m *Manifest) Remove() error { + if err := os.Remove(m.filepath); err != nil { + return err + } + + for _, layer := range append(m.Layers, m.Config) { + if err := layer.Remove(); err != nil { + return err + } + } + + manifests, err := GetManifestPath() + if err != nil { + return err + } + + return PruneDirectory(manifests) +} + +func ParseNamedManifest(n model.Name) (*Manifest, error) { + if !n.IsFullyQualified() { + return nil, model.Unqualified(n) } manifests, err := GetManifestPath() @@ -35,20 +56,24 @@ func ParseNamedManifest(name model.Name) (*Manifest, error) { return nil, err } - var manifest ManifestV2 - manifestfile, err := os.Open(filepath.Join(manifests, name.Filepath())) + p := filepath.Join(manifests, n.Filepath()) + + var m ManifestV2 + f, err := os.Open(p) if err != nil { return nil, err } + defer f.Close() sha256sum := sha256.New() - if err := json.NewDecoder(io.TeeReader(manifestfile, sha256sum)).Decode(&manifest); err != nil { + if err := json.NewDecoder(io.TeeReader(f, sha256sum)).Decode(&m); err != nil { return nil, err } return &Manifest{ - ManifestV2: manifest, - Digest: fmt.Sprintf("%x", sha256sum.Sum(nil)), + ManifestV2: m, + filepath: p, + digest: fmt.Sprintf("%x", sha256sum.Sum(nil)), }, nil } @@ -77,3 +102,36 @@ func WriteManifest(name string, config *Layer, layers []*Layer) error { return os.WriteFile(manifestPath, b.Bytes(), 0o644) } + +func Manifests() (map[model.Name]*Manifest, error) { + manifests, err := GetManifestPath() + if err != nil { + return nil, err + } + + // TODO(mxyng): use something less brittle + matches, err := filepath.Glob(fmt.Sprintf("%s/*/*/*/*", manifests)) + if err != nil { + return nil, err + } + + ms := make(map[model.Name]*Manifest) + for _, match := range matches { + rel, err := filepath.Rel(manifests, match) + if err != nil { + return nil, err + } + + n := model.ParseNameFromFilepath(rel) + if n.IsValid() { + m, err := ParseNamedManifest(n) + if err != nil { + return nil, err + } + + ms[n] = m + } + } + + return ms, nil +} diff --git a/server/routes.go b/server/routes.go index 123ef9a3..ff888e3c 100644 --- a/server/routes.go +++ b/server/routes.go @@ -574,48 +574,31 @@ func (s *Server) CreateModelHandler(c *gin.Context) { } func (s *Server) DeleteModelHandler(c *gin.Context) { - var req api.DeleteRequest - err := c.ShouldBindJSON(&req) - switch { - case errors.Is(err, io.EOF): + var r api.DeleteRequest + if err := c.ShouldBindJSON(&r); errors.Is(err, io.EOF) { c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": "missing request body"}) return - case err != nil: + } else if err != nil { c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": err.Error()}) return } - var model string - if req.Model != "" { - model = req.Model - } else if req.Name != "" { - model = req.Name - } else { - c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": "model is required"}) + n := model.ParseName(cmp.Or(r.Model, r.Name)) + if !n.IsValid() { + c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": fmt.Sprintf("name %q is invalid", cmp.Or(r.Model, r.Name))}) return } - if err := DeleteModel(model); err != nil { - if os.IsNotExist(err) { - c.JSON(http.StatusNotFound, gin.H{"error": fmt.Sprintf("model '%s' not found", model)}) - } else { - c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()}) - } - return - } - - manifestsPath, err := GetManifestPath() + m, err := ParseNamedManifest(n) if err != nil { c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()}) return } - if err := PruneDirectory(manifestsPath); err != nil { + if err := m.Remove(); err != nil { c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()}) return } - - c.JSON(http.StatusOK, nil) } func (s *Server) ShowModelHandler(c *gin.Context) { @@ -769,7 +752,7 @@ func (s *Server) ListModelsHandler(c *gin.Context) { Model: n.DisplayShortest(), Name: n.DisplayShortest(), Size: m.Size(), - Digest: m.Digest, + Digest: m.digest, ModifiedAt: info.ModTime(), Details: api.ModelDetails{ Format: c.ModelFormat,