From 3520c0e4d5c1cc845d178ec080b0967d18cf1796 Mon Sep 17 00:00:00 2001 From: Michael Yang Date: Fri, 10 May 2024 15:48:41 -0700 Subject: [PATCH] cache and reuse intermediate blobs particularly useful for zipfiles and f16s --- server/images.go | 27 ++++++++++++++++++++++++--- server/layer.go | 2 +- server/model.go | 23 +++++++++-------------- server/routes.go | 19 +++++++++++++++++++ 4 files changed, 53 insertions(+), 18 deletions(-) diff --git a/server/images.go b/server/images.go index 0ccc90b9..8e8fd921 100644 --- a/server/images.go +++ b/server/images.go @@ -340,7 +340,24 @@ func CreateModel(ctx context.Context, name, modelFileDir, quantization string, m return err } } else if strings.HasPrefix(c.Args, "@") { - blobpath, err := GetBlobsPath(strings.TrimPrefix(c.Args, "@")) + digest := strings.TrimPrefix(c.Args, "@") + if ib, ok := intermediateBlobs.Load(digest); ok { + p, err := GetBlobsPath(ib.(string)) + if err != nil { + return err + } + + if _, err := os.Stat(p); errors.Is(err, os.ErrNotExist) { + // pass + } else if err != nil { + return err + } else { + fn(api.ProgressResponse{Status: fmt.Sprintf("using cached layer %s", ib.(string))}) + digest = ib.(string) + } + } + + blobpath, err := GetBlobsPath(digest) if err != nil { return err } @@ -351,14 +368,14 @@ func CreateModel(ctx context.Context, name, modelFileDir, quantization string, m } defer blob.Close() - baseLayers, err = parseFromFile(ctx, blob, fn) + baseLayers, err = parseFromFile(ctx, blob, digest, fn) if err != nil { return err } } else if file, err := os.Open(realpath(modelFileDir, c.Args)); err == nil { defer file.Close() - baseLayers, err = parseFromFile(ctx, file, fn) + baseLayers, err = parseFromFile(ctx, file, "", fn) if err != nil { return err } @@ -398,10 +415,14 @@ func CreateModel(ctx context.Context, name, modelFileDir, quantization string, m return err } + f16digest := baseLayer.Layer.Digest + baseLayer.Layer, err = NewLayer(temp, baseLayer.Layer.MediaType) if err != nil { return err } + + intermediateBlobs.Store(f16digest, baseLayer.Layer.Digest) } } diff --git a/server/layer.go b/server/layer.go index dcca3854..9ca43046 100644 --- a/server/layer.go +++ b/server/layer.go @@ -80,7 +80,7 @@ func NewLayerFromLayer(digest, mediatype, from string) (*Layer, error) { }, nil } -func (l *Layer) Open() (io.ReadCloser, error) { +func (l *Layer) Open() (io.ReadSeekCloser, error) { blob, err := GetBlobsPath(l.Digest) if err != nil { return nil, err diff --git a/server/model.go b/server/model.go index eea5d13a..eabb8f3d 100644 --- a/server/model.go +++ b/server/model.go @@ -10,6 +10,7 @@ import ( "net/http" "os" "path/filepath" + "sync" "github.com/ollama/ollama/api" "github.com/ollama/ollama/convert" @@ -17,6 +18,8 @@ import ( "github.com/ollama/ollama/types/model" ) +var intermediateBlobs sync.Map + type layerWithGGML struct { *Layer *llm.GGML @@ -76,7 +79,7 @@ func parseFromModel(ctx context.Context, name model.Name, fn func(api.ProgressRe return layers, nil } -func parseFromZipFile(_ context.Context, file *os.File, fn func(api.ProgressResponse)) (layers []*layerWithGGML, err error) { +func parseFromZipFile(_ context.Context, file *os.File, digest string, fn func(api.ProgressResponse)) (layers []*layerWithGGML, err error) { stat, err := file.Stat() if err != nil { return nil, err @@ -169,12 +172,7 @@ func parseFromZipFile(_ context.Context, file *os.File, fn func(api.ProgressResp return nil, fmt.Errorf("aaa: %w", err) } - blobpath, err := GetBlobsPath(layer.Digest) - if err != nil { - return nil, err - } - - bin, err := os.Open(blobpath) + bin, err := layer.Open() if err != nil { return nil, err } @@ -185,16 +183,13 @@ func parseFromZipFile(_ context.Context, file *os.File, fn func(api.ProgressResp return nil, err } - layer, err = NewLayerFromLayer(layer.Digest, layer.MediaType, "") - if err != nil { - return nil, err - } - layers = append(layers, &layerWithGGML{layer, ggml}) + + intermediateBlobs.Store(digest, layer.Digest) return layers, nil } -func parseFromFile(ctx context.Context, file *os.File, fn func(api.ProgressResponse)) (layers []*layerWithGGML, err error) { +func parseFromFile(ctx context.Context, file *os.File, digest string, fn func(api.ProgressResponse)) (layers []*layerWithGGML, err error) { sr := io.NewSectionReader(file, 0, 512) contentType, err := detectContentType(sr) if err != nil { @@ -205,7 +200,7 @@ func parseFromFile(ctx context.Context, file *os.File, fn func(api.ProgressRespo case "gguf", "ggla": // noop case "application/zip": - return parseFromZipFile(ctx, file, fn) + return parseFromZipFile(ctx, file, digest, fn) default: return nil, fmt.Errorf("unsupported content type: %s", contentType) } diff --git a/server/routes.go b/server/routes.go index fff228f3..12b11b5c 100644 --- a/server/routes.go +++ b/server/routes.go @@ -841,6 +841,25 @@ func (s *Server) HeadBlobHandler(c *gin.Context) { } func (s *Server) CreateBlobHandler(c *gin.Context) { + ib, ok := intermediateBlobs.Load(c.Param("digest")) + if ok { + p, err := GetBlobsPath(ib.(string)) + if err != nil { + c.AbortWithStatusJSON(http.StatusInternalServerError, gin.H{"error": err.Error()}) + return + } + + if _, err := os.Stat(p); errors.Is(err, os.ErrNotExist) { + intermediateBlobs.Delete(c.Param("digest")) + } else if err != nil { + c.AbortWithStatusJSON(http.StatusInternalServerError, gin.H{"error": err.Error()}) + return + } else { + c.Status(http.StatusOK) + return + } + } + path, err := GetBlobsPath(c.Param("digest")) if err != nil { c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": err.Error()})