update copy to use model.Name

This commit is contained in:
Michael Yang 2024-04-16 16:22:38 -07:00
parent 2010cbc5fa
commit 592dae31c8
2 changed files with 32 additions and 34 deletions

View file

@ -29,6 +29,7 @@ import (
"github.com/ollama/ollama/format" "github.com/ollama/ollama/format"
"github.com/ollama/ollama/llm" "github.com/ollama/ollama/llm"
"github.com/ollama/ollama/parser" "github.com/ollama/ollama/parser"
"github.com/ollama/ollama/types/model"
"github.com/ollama/ollama/version" "github.com/ollama/ollama/version"
) )
@ -701,36 +702,32 @@ func convertModel(name, path string, fn func(resp api.ProgressResponse)) (string
return path, nil return path, nil
} }
func CopyModel(src, dest string) error { func CopyModel(src, dst model.Name) error {
srcModelPath := ParseModelPath(src) manifests, err := GetManifestPath()
srcPath, err := srcModelPath.GetManifestPath()
if err != nil { if err != nil {
return err return err
} }
destModelPath := ParseModelPath(dest) dstpath := filepath.Join(manifests, dst.FilepathNoBuild())
destPath, err := destModelPath.GetManifestPath() if err := os.MkdirAll(filepath.Dir(dstpath), 0o755); err != nil {
if err != nil {
return err
}
if err := os.MkdirAll(filepath.Dir(destPath), 0o755); err != nil {
return err return err
} }
// copy the file srcpath := filepath.Join(manifests, src.FilepathNoBuild())
input, err := os.ReadFile(srcPath) srcfile, err := os.Open(srcpath)
if err != nil { if err != nil {
fmt.Println("Error reading file:", err)
return err return err
} }
defer srcfile.Close()
err = os.WriteFile(destPath, input, 0o644) dstfile, err := os.Create(dstpath)
if err != nil { if err != nil {
fmt.Println("Error reading file:", err)
return err return err
} }
defer dstfile.Close()
return nil _, err = io.Copy(dstfile, srcfile)
return err
} }
func deleteUnusedLayers(skipModelPath *ModelPath, deleteMap map[string]struct{}, dryRun bool) error { func deleteUnusedLayers(skipModelPath *ModelPath, deleteMap map[string]struct{}, dryRun bool) error {

View file

@ -29,6 +29,7 @@ import (
"github.com/ollama/ollama/llm" "github.com/ollama/ollama/llm"
"github.com/ollama/ollama/openai" "github.com/ollama/ollama/openai"
"github.com/ollama/ollama/parser" "github.com/ollama/ollama/parser"
"github.com/ollama/ollama/types/model"
"github.com/ollama/ollama/version" "github.com/ollama/ollama/version"
) )
@ -788,34 +789,34 @@ func (s *Server) ListModelsHandler(c *gin.Context) {
} }
func (s *Server) CopyModelHandler(c *gin.Context) { func (s *Server) CopyModelHandler(c *gin.Context) {
var req api.CopyRequest var r api.CopyRequest
err := c.ShouldBindJSON(&req) if err := c.ShouldBindJSON(&r); errors.Is(err, io.EOF) {
switch {
case errors.Is(err, io.EOF):
c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": "missing request body"}) c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": "missing request body"})
return return
case err != nil: } else if err != nil {
c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": err.Error()}) c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": err.Error()})
return return
} }
if req.Source == "" || req.Destination == "" { src := model.ParseName(r.Source)
c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": "source add destination are required"}) if !src.IsValid() {
_ = c.Error(fmt.Errorf("source %q is invalid", r.Source))
}
dst := model.ParseName(r.Destination)
if !dst.IsValid() {
_ = c.Error(fmt.Errorf("destination %q is invalid", r.Destination))
}
if len(c.Errors) > 0 {
c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": c.Errors.Errors()})
return return
} }
if err := ParseModelPath(req.Destination).Validate(); err != nil { if err := CopyModel(src, dst); errors.Is(err, os.ErrNotExist) {
c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": err.Error()}) c.JSON(http.StatusNotFound, gin.H{"error": fmt.Sprintf("model %q not found", r.Source)})
return } else if err != nil {
} c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
if err := CopyModel(req.Source, req.Destination); err != nil {
if os.IsNotExist(err) {
c.JSON(http.StatusNotFound, gin.H{"error": fmt.Sprintf("model '%s' not found", req.Source)})
} else {
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
}
return
} }
} }