Merge pull request #4330 from ollama/mxyng/cache-intermediate-layers
cache and reuse intermediate blobs
This commit is contained in:
commit
b4dce13309
4 changed files with 53 additions and 18 deletions
|
@ -340,7 +340,24 @@ func CreateModel(ctx context.Context, name, modelFileDir, quantization string, m
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
} else if strings.HasPrefix(c.Args, "@") {
|
} else if strings.HasPrefix(c.Args, "@") {
|
||||||
blobpath, err := GetBlobsPath(strings.TrimPrefix(c.Args, "@"))
|
digest := strings.TrimPrefix(c.Args, "@")
|
||||||
|
if ib, ok := intermediateBlobs.Load(digest); ok {
|
||||||
|
p, err := GetBlobsPath(ib.(string))
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
if _, err := os.Stat(p); errors.Is(err, os.ErrNotExist) {
|
||||||
|
// pass
|
||||||
|
} else if err != nil {
|
||||||
|
return err
|
||||||
|
} else {
|
||||||
|
fn(api.ProgressResponse{Status: fmt.Sprintf("using cached layer %s", ib.(string))})
|
||||||
|
digest = ib.(string)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
blobpath, err := GetBlobsPath(digest)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
@ -351,14 +368,14 @@ func CreateModel(ctx context.Context, name, modelFileDir, quantization string, m
|
||||||
}
|
}
|
||||||
defer blob.Close()
|
defer blob.Close()
|
||||||
|
|
||||||
baseLayers, err = parseFromFile(ctx, blob, fn)
|
baseLayers, err = parseFromFile(ctx, blob, digest, fn)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
} else if file, err := os.Open(realpath(modelFileDir, c.Args)); err == nil {
|
} else if file, err := os.Open(realpath(modelFileDir, c.Args)); err == nil {
|
||||||
defer file.Close()
|
defer file.Close()
|
||||||
|
|
||||||
baseLayers, err = parseFromFile(ctx, file, fn)
|
baseLayers, err = parseFromFile(ctx, file, "", fn)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
@ -398,10 +415,14 @@ func CreateModel(ctx context.Context, name, modelFileDir, quantization string, m
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
|
f16digest := baseLayer.Layer.Digest
|
||||||
|
|
||||||
baseLayer.Layer, err = NewLayer(temp, baseLayer.Layer.MediaType)
|
baseLayer.Layer, err = NewLayer(temp, baseLayer.Layer.MediaType)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
|
intermediateBlobs.Store(f16digest, baseLayer.Layer.Digest)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -80,7 +80,7 @@ func NewLayerFromLayer(digest, mediatype, from string) (*Layer, error) {
|
||||||
}, nil
|
}, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (l *Layer) Open() (io.ReadCloser, error) {
|
func (l *Layer) Open() (io.ReadSeekCloser, error) {
|
||||||
blob, err := GetBlobsPath(l.Digest)
|
blob, err := GetBlobsPath(l.Digest)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
|
|
|
@ -10,6 +10,7 @@ import (
|
||||||
"net/http"
|
"net/http"
|
||||||
"os"
|
"os"
|
||||||
"path/filepath"
|
"path/filepath"
|
||||||
|
"sync"
|
||||||
|
|
||||||
"github.com/ollama/ollama/api"
|
"github.com/ollama/ollama/api"
|
||||||
"github.com/ollama/ollama/convert"
|
"github.com/ollama/ollama/convert"
|
||||||
|
@ -17,6 +18,8 @@ import (
|
||||||
"github.com/ollama/ollama/types/model"
|
"github.com/ollama/ollama/types/model"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
var intermediateBlobs sync.Map
|
||||||
|
|
||||||
type layerWithGGML struct {
|
type layerWithGGML struct {
|
||||||
*Layer
|
*Layer
|
||||||
*llm.GGML
|
*llm.GGML
|
||||||
|
@ -76,7 +79,7 @@ func parseFromModel(ctx context.Context, name model.Name, fn func(api.ProgressRe
|
||||||
return layers, nil
|
return layers, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func parseFromZipFile(_ context.Context, file *os.File, fn func(api.ProgressResponse)) (layers []*layerWithGGML, err error) {
|
func parseFromZipFile(_ context.Context, file *os.File, digest string, fn func(api.ProgressResponse)) (layers []*layerWithGGML, err error) {
|
||||||
stat, err := file.Stat()
|
stat, err := file.Stat()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
|
@ -169,12 +172,7 @@ func parseFromZipFile(_ context.Context, file *os.File, fn func(api.ProgressResp
|
||||||
return nil, fmt.Errorf("aaa: %w", err)
|
return nil, fmt.Errorf("aaa: %w", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
blobpath, err := GetBlobsPath(layer.Digest)
|
bin, err := layer.Open()
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
|
|
||||||
bin, err := os.Open(blobpath)
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
@ -185,16 +183,13 @@ func parseFromZipFile(_ context.Context, file *os.File, fn func(api.ProgressResp
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
layer, err = NewLayerFromLayer(layer.Digest, layer.MediaType, "")
|
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
|
|
||||||
layers = append(layers, &layerWithGGML{layer, ggml})
|
layers = append(layers, &layerWithGGML{layer, ggml})
|
||||||
|
|
||||||
|
intermediateBlobs.Store(digest, layer.Digest)
|
||||||
return layers, nil
|
return layers, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func parseFromFile(ctx context.Context, file *os.File, fn func(api.ProgressResponse)) (layers []*layerWithGGML, err error) {
|
func parseFromFile(ctx context.Context, file *os.File, digest string, fn func(api.ProgressResponse)) (layers []*layerWithGGML, err error) {
|
||||||
sr := io.NewSectionReader(file, 0, 512)
|
sr := io.NewSectionReader(file, 0, 512)
|
||||||
contentType, err := detectContentType(sr)
|
contentType, err := detectContentType(sr)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
|
@ -205,7 +200,7 @@ func parseFromFile(ctx context.Context, file *os.File, fn func(api.ProgressRespo
|
||||||
case "gguf", "ggla":
|
case "gguf", "ggla":
|
||||||
// noop
|
// noop
|
||||||
case "application/zip":
|
case "application/zip":
|
||||||
return parseFromZipFile(ctx, file, fn)
|
return parseFromZipFile(ctx, file, digest, fn)
|
||||||
default:
|
default:
|
||||||
return nil, fmt.Errorf("unsupported content type: %s", contentType)
|
return nil, fmt.Errorf("unsupported content type: %s", contentType)
|
||||||
}
|
}
|
||||||
|
|
|
@ -841,6 +841,25 @@ func (s *Server) HeadBlobHandler(c *gin.Context) {
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *Server) CreateBlobHandler(c *gin.Context) {
|
func (s *Server) CreateBlobHandler(c *gin.Context) {
|
||||||
|
ib, ok := intermediateBlobs.Load(c.Param("digest"))
|
||||||
|
if ok {
|
||||||
|
p, err := GetBlobsPath(ib.(string))
|
||||||
|
if err != nil {
|
||||||
|
c.AbortWithStatusJSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
if _, err := os.Stat(p); errors.Is(err, os.ErrNotExist) {
|
||||||
|
intermediateBlobs.Delete(c.Param("digest"))
|
||||||
|
} else if err != nil {
|
||||||
|
c.AbortWithStatusJSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
|
||||||
|
return
|
||||||
|
} else {
|
||||||
|
c.Status(http.StatusOK)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
path, err := GetBlobsPath(c.Param("digest"))
|
path, err := GetBlobsPath(c.Param("digest"))
|
||||||
if err != nil {
|
if err != nil {
|
||||||
c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": err.Error()})
|
c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": err.Error()})
|
||||||
|
|
Loading…
Reference in a new issue