Merge pull request #6363 from ollama/mxyng/fix-noprune

fix: noprune on pull
This commit is contained in:
Michael Yang 2024-08-15 12:20:38 -07:00 committed by GitHub
commit e3d7f32af7
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
2 changed files with 31 additions and 68 deletions

View file

@ -215,25 +215,20 @@ func GetManifest(mp ModelPath) (*Manifest, string, error) {
return nil, "", err
}
if _, err = os.Stat(fp); err != nil {
return nil, "", err
}
var manifest *Manifest
bts, err := os.ReadFile(fp)
f, err := os.Open(fp)
if err != nil {
return nil, "", fmt.Errorf("couldn't open file '%s'", fp)
return nil, "", err
}
defer f.Close()
shaSum := sha256.Sum256(bts)
shaStr := hex.EncodeToString(shaSum[:])
sha256sum := sha256.New()
if err := json.Unmarshal(bts, &manifest); err != nil {
var manifest Manifest
if err := json.NewDecoder(io.TeeReader(f, sha256sum)).Decode(&manifest); err != nil {
return nil, "", err
}
return manifest, shaStr, nil
return &manifest, hex.EncodeToString(sha256sum.Sum(nil)), nil
}
func GetModel(name string) (*Model, error) {
@ -692,43 +687,18 @@ func CopyModel(src, dst model.Name) error {
return err
}
func deleteUnusedLayers(skipModelPath *ModelPath, deleteMap map[string]struct{}) error {
fp, err := GetManifestPath()
if err != nil {
return err
}
walkFunc := func(path string, info os.FileInfo, _ error) error {
if info.IsDir() {
return nil
}
dir, file := filepath.Split(path)
dir = strings.Trim(strings.TrimPrefix(dir, fp), string(os.PathSeparator))
tag := strings.Join([]string{dir, file}, ":")
fmp := ParseModelPath(tag)
// skip the manifest we're trying to delete
if skipModelPath != nil && skipModelPath.GetFullTagname() == fmp.GetFullTagname() {
return nil
}
// save (i.e. delete from the deleteMap) any files used in other manifests
manifest, _, err := GetManifest(fmp)
func deleteUnusedLayers(deleteMap map[string]struct{}) error {
manifests, err := Manifests()
if err != nil {
return err
}
for _, manifest := range manifests {
for _, layer := range manifest.Layers {
delete(deleteMap, layer.Digest)
}
delete(deleteMap, manifest.Config.Digest)
return nil
}
if err := filepath.Walk(fp, walkFunc); err != nil {
return err
}
// only delete the files which are still in the deleteMap
@ -781,8 +751,7 @@ func PruneLayers() error {
slog.Info(fmt.Sprintf("total blobs: %d", len(deleteMap)))
err = deleteUnusedLayers(nil, deleteMap)
if err != nil {
if err := deleteUnusedLayers(deleteMap); err != nil {
slog.Error(fmt.Sprintf("couldn't remove unused layers: %v", err))
return nil
}
@ -877,20 +846,14 @@ func PushModel(ctx context.Context, name string, regOpts *registryOptions, fn fu
func PullModel(ctx context.Context, name string, regOpts *registryOptions, fn func(api.ProgressResponse)) error {
mp := ParseModelPath(name)
var manifest *Manifest
var err error
var noprune string
// build deleteMap to prune unused layers
deleteMap := make(map[string]struct{})
if !envconfig.NoPrune() {
manifest, _, err = GetManifest(mp)
if err != nil && !errors.Is(err, os.ErrNotExist) {
manifest, _, err := GetManifest(mp)
if errors.Is(err, os.ErrNotExist) {
// noop
} else if err != nil && !errors.Is(err, os.ErrNotExist) {
return err
}
if manifest != nil {
} else {
for _, l := range manifest.Layers {
deleteMap[l.Digest] = struct{}{}
}
@ -898,7 +861,6 @@ func PullModel(ctx context.Context, name string, regOpts *registryOptions, fn fu
deleteMap[manifest.Config.Digest] = struct{}{}
}
}
}
if mp.ProtocolScheme == "http" && !regOpts.Insecure {
return errors.New("insecure protocol http")
@ -975,11 +937,9 @@ func PullModel(ctx context.Context, name string, regOpts *registryOptions, fn fu
return err
}
if noprune == "" {
fn(api.ProgressResponse{Status: "removing any unused layers"})
err = deleteUnusedLayers(nil, deleteMap)
if err != nil {
slog.Error(fmt.Sprintf("couldn't remove unused layers: %v", err))
if !envconfig.NoPrune() && len(deleteMap) > 0 {
fn(api.ProgressResponse{Status: "removing unused layers"})
if err := deleteUnusedLayers(deleteMap); err != nil {
fn(api.ProgressResponse{Status: fmt.Sprintf("couldn't remove unused layers: %v", err)})
}
}
@ -1000,12 +960,12 @@ func pullModelManifest(ctx context.Context, mp ModelPath, regOpts *registryOptio
}
defer resp.Body.Close()
var m *Manifest
var m Manifest
if err := json.NewDecoder(resp.Body).Decode(&m); err != nil {
return nil, err
}
return m, err
return &m, err
}
// GetSHA256Digest returns the SHA256 hash of a given buffer and returns it, and the size of buffer

View file

@ -5,6 +5,7 @@ import (
"encoding/hex"
"encoding/json"
"errors"
"fmt"
"io"
"log/slog"
"os"
@ -150,14 +151,16 @@ func Manifests() (map[model.Name]*Manifest, error) {
n := model.ParseNameFromFilepath(rel)
if !n.IsValid() {
slog.Warn("bad manifest name", "path", rel, "error", err)
slog.Warn("bad manifest name", "path", rel)
continue
}
m, err := ParseNamedManifest(n)
if err != nil {
if syntax := &(json.SyntaxError{}); errors.As(err, &syntax) {
slog.Warn("bad manifest", "name", n, "error", err)
continue
} else if err != nil {
return nil, fmt.Errorf("%s: %w", n, err)
}
ms[n] = m