From b0d14ed51c459d3c63649ebcaa8c581431a946ae Mon Sep 17 00:00:00 2001 From: Michael Yang Date: Tue, 14 Nov 2023 12:30:34 -0800 Subject: [PATCH] refactor create model --- server/images.go | 330 +++++++++++++++++++++-------------------------- server/routes.go | 16 ++- 2 files changed, 162 insertions(+), 184 deletions(-) diff --git a/server/images.go b/server/images.go index 8d784fef..42def270 100644 --- a/server/images.go +++ b/server/images.go @@ -248,200 +248,172 @@ func filenameWithPath(path, f string) (string, error) { return f, nil } -func CreateModel(ctx context.Context, name string, path string, fn func(resp api.ProgressResponse)) error { - mp := ParseModelPath(name) - - var manifest *ManifestV2 - var err error - var noprune string - - // build deleteMap to prune unused layers - deleteMap := make(map[string]bool) - - if noprune = os.Getenv("OLLAMA_NOPRUNE"); noprune == "" { - manifest, _, err = GetManifest(mp) - if err != nil && !errors.Is(err, os.ErrNotExist) { - return err - } - - if manifest != nil { - for _, l := range manifest.Layers { - deleteMap[l.Digest] = true - } - deleteMap[manifest.Config.Digest] = true - } - } - - mf, err := os.Open(path) +func realpath(p string) string { + abspath, err := filepath.Abs(p) if err != nil { - fn(api.ProgressResponse{Status: fmt.Sprintf("couldn't open modelfile '%s'", path)}) - return fmt.Errorf("failed to open file: %w", err) + return p } - defer mf.Close() - fn(api.ProgressResponse{Status: "parsing modelfile"}) - commands, err := parser.Parse(mf) + home, err := os.UserHomeDir() if err != nil { - return err + return abspath } + if p == "~" { + return home + } else if strings.HasPrefix(p, "~/") { + return filepath.Join(home, p[2:]) + } + + return abspath +} + +func CreateModel(ctx context.Context, name string, commands []parser.Command, fn func(resp api.ProgressResponse)) error { config := ConfigV2{ - Architecture: "amd64", OS: "linux", + Architecture: "amd64", } + deleteMap := make(map[string]struct{}) + var layers []*LayerReader + params := make(map[string][]string) - var sourceParams map[string]any + fromParams := make(map[string]any) + for _, c := range commands { - log.Printf("[%s] - %s\n", c.Name, c.Args) + log.Printf("[%s] - %s", c.Name, c.Args) + mediatype := fmt.Sprintf("application/vnd.ollama.image.%s", c.Name) + switch c.Name { case "model": - fn(api.ProgressResponse{Status: "looking for model"}) - - mp := ParseModelPath(c.Args) - mf, _, err := GetManifest(mp) + bin, err := os.Open(realpath(c.Args)) if err != nil { - modelFile, err := filenameWithPath(path, c.Args) - if err != nil { + // not a file on disk so must be a model reference + modelpath := ParseModelPath(c.Args) + manifest, _, err := GetManifest(modelpath) + switch { + case errors.Is(err, os.ErrNotExist): + fn(api.ProgressResponse{Status: "pulling model"}) + if err := PullModel(ctx, c.Args, &RegistryOptions{}, fn); err != nil { + return err + } + + manifest, _, err = GetManifest(modelpath) + if err != nil { + return err + } + case err != nil: return err } - if _, err := os.Stat(modelFile); err != nil { - // the model file does not exist, try pulling it - if errors.Is(err, os.ErrNotExist) { - fn(api.ProgressResponse{Status: "pulling model file"}) - if err := PullModel(ctx, c.Args, &RegistryOptions{}, fn); err != nil { - return err - } - mf, _, err = GetManifest(mp) - if err != nil { - return fmt.Errorf("failed to open file after pull: %v", err) - } - } else { - return err - } - } else { - // create a model from this specified file - fn(api.ProgressResponse{Status: "creating model layer"}) - file, err := os.Open(modelFile) - if err != nil { - return fmt.Errorf("failed to open file: %v", err) - } - defer file.Close() - ggml, err := llm.DecodeGGML(file) - if err != nil { - return err - } - - config.ModelFormat = ggml.Name() - config.ModelFamily = ggml.ModelFamily() - config.ModelType = ggml.ModelType() - config.FileType = ggml.FileType() - - // reset the file - file.Seek(0, io.SeekStart) - - l, err := CreateLayer(file) - if err != nil { - return fmt.Errorf("failed to create layer: %v", err) - } - l.MediaType = "application/vnd.ollama.image.model" - layers = append(layers, l) - } - } - - if mf != nil { fn(api.ProgressResponse{Status: "reading model metadata"}) - sourceBlobPath, err := GetBlobsPath(mf.Config.Digest) + fromConfigPath, err := GetBlobsPath(manifest.Config.Digest) if err != nil { return err } - sourceBlob, err := os.Open(sourceBlobPath) + fromConfigFile, err := os.Open(fromConfigPath) if err != nil { return err } - defer sourceBlob.Close() + defer fromConfigFile.Close() - var source ConfigV2 - if err := json.NewDecoder(sourceBlob).Decode(&source); err != nil { + var fromConfig ConfigV2 + if err := json.NewDecoder(fromConfigFile).Decode(&fromConfig); err != nil { return err } - // copy the model metadata - config.ModelFamily = source.ModelFamily - config.ModelType = source.ModelType - config.ModelFormat = source.ModelFormat - config.FileType = source.FileType + config.ModelFormat = fromConfig.ModelFormat + config.ModelFamily = fromConfig.ModelFamily + config.ModelType = fromConfig.ModelType + config.FileType = fromConfig.FileType - for _, l := range mf.Layers { - if l.MediaType == "application/vnd.ollama.image.params" { - sourceParamsBlobPath, err := GetBlobsPath(l.Digest) + for _, layer := range manifest.Layers { + deleteMap[layer.Digest] = struct{}{} + if layer.MediaType == "application/vnd.ollama.image.params" { + fromParamsPath, err := GetBlobsPath(layer.Digest) if err != nil { return err } - sourceParamsBlob, err := os.Open(sourceParamsBlobPath) + fromParamsFile, err := os.Open(fromParamsPath) if err != nil { return err } - defer sourceParamsBlob.Close() + defer fromParamsFile.Close() - if err := json.NewDecoder(sourceParamsBlob).Decode(&sourceParams); err != nil { + if err := json.NewDecoder(fromParamsFile).Decode(&fromParams); err != nil { return err } } - newLayer, err := GetLayerWithBufferFromLayer(l) + layer, err := GetLayerWithBufferFromLayer(layer) if err != nil { return err } - newLayer.From = mp.GetShortTagname() - layers = append(layers, newLayer) - } - } - case "adapter": - fn(api.ProgressResponse{Status: fmt.Sprintf("creating model %s layer", c.Name)}) - fp, err := filenameWithPath(path, c.Args) + layer.From = modelpath.GetShortTagname() + layers = append(layers, layer) + } + + deleteMap[manifest.Config.Digest] = struct{}{} + continue + } + defer bin.Close() + + fn(api.ProgressResponse{Status: "creating model layer"}) + ggml, err := llm.DecodeGGML(bin) if err != nil { return err } - // create a model from this specified file - fn(api.ProgressResponse{Status: "creating model layer"}) + config.ModelFormat = ggml.Name() + config.ModelFamily = ggml.ModelFamily() + config.ModelType = ggml.ModelType() + config.FileType = ggml.FileType() - file, err := os.Open(fp) + bin.Seek(0, io.SeekStart) + layer, err := CreateLayer(bin) if err != nil { - return fmt.Errorf("failed to open file: %v", err) + return err } - defer file.Close() - l, err := CreateLayer(file) + layer.MediaType = mediatype + layers = append(layers, layer) + case "adapter": + fn(api.ProgressResponse{Status: "creating adapter layer"}) + bin, err := os.Open(realpath(c.Args)) if err != nil { - return fmt.Errorf("failed to create layer: %v", err) + return err } - l.MediaType = "application/vnd.ollama.image.adapter" - layers = append(layers, l) - case "license": - fn(api.ProgressResponse{Status: fmt.Sprintf("creating model %s layer", c.Name)}) - mediaType := fmt.Sprintf("application/vnd.ollama.image.%s", c.Name) + defer bin.Close() - layer, err := CreateLayer(strings.NewReader(c.Args)) + layer, err := CreateLayer(bin) if err != nil { return err } if layer.Size > 0 { - layer.MediaType = mediaType + layer.MediaType = mediatype layers = append(layers, layer) } - case "template", "system", "prompt": - fn(api.ProgressResponse{Status: fmt.Sprintf("creating model %s layer", c.Name)}) - // remove the layer if one exists - mediaType := fmt.Sprintf("application/vnd.ollama.image.%s", c.Name) - layers = removeLayerFromLayers(layers, mediaType) + case "license": + fn(api.ProgressResponse{Status: "creating license layer"}) + layer, err := CreateLayer(strings.NewReader(c.Args)) + if err != nil { + return err + } + + if layer.Size > 0 { + layer.MediaType = mediatype + layers = append(layers, layer) + } + case "template", "system": + fn(api.ProgressResponse{Status: fmt.Sprintf("creating %s layer", c.Name)}) + + // remove duplicates layers + layers = removeLayerFromLayers(layers, mediatype) layer, err := CreateLayer(strings.NewReader(c.Args)) if err != nil { @@ -449,48 +421,47 @@ func CreateModel(ctx context.Context, name string, path string, fn func(resp api } if layer.Size > 0 { - layer.MediaType = mediaType + layer.MediaType = mediatype layers = append(layers, layer) } default: - // runtime parameters, build a list of args for each parameter to allow multiple values to be specified (ex: multiple stop sequences) params[c.Name] = append(params[c.Name], c.Args) } } - // Create a single layer for the parameters if len(params) > 0 { - fn(api.ProgressResponse{Status: "creating parameter layer"}) + fn(api.ProgressResponse{Status: "creating parameters layer"}) - layers = removeLayerFromLayers(layers, "application/vnd.ollama.image.params") formattedParams, err := formatParams(params) if err != nil { - return fmt.Errorf("couldn't create params json: %v", err) + return err } - for k, v := range sourceParams { + for k, v := range fromParams { if _, ok := formattedParams[k]; !ok { formattedParams[k] = v } } if config.ModelType == "65B" { - if numGQA, ok := formattedParams["num_gqa"].(int); ok && numGQA == 8 { + if gqa, ok := formattedParams["gqa"].(int); ok && gqa == 8 { config.ModelType = "70B" } } - bts, err := json.Marshal(formattedParams) + var b bytes.Buffer + if err := json.NewEncoder(&b).Encode(formattedParams); err != nil { + return err + } + + fn(api.ProgressResponse{Status: "creating config layer"}) + layer, err := CreateLayer(bytes.NewReader(b.Bytes())) if err != nil { return err } - l, err := CreateLayer(bytes.NewReader(bts)) - if err != nil { - return fmt.Errorf("failed to create layer: %v", err) - } - l.MediaType = "application/vnd.ollama.image.params" - layers = append(layers, l) + layer.MediaType = "application/vnd.ollama.image.params" + layers = append(layers, layer) } digests, err := getLayerDigests(layers) @@ -498,36 +469,31 @@ func CreateModel(ctx context.Context, name string, path string, fn func(resp api return err } - var manifestLayers []*Layer - for _, l := range layers { - manifestLayers = append(manifestLayers, &l.Layer) - delete(deleteMap, l.Layer.Digest) - } - - // Create a layer for the config object - fn(api.ProgressResponse{Status: "creating config layer"}) - cfg, err := createConfigLayer(config, digests) + configLayer, err := createConfigLayer(config, digests) if err != nil { return err } - layers = append(layers, cfg) - delete(deleteMap, cfg.Layer.Digest) + + layers = append(layers, configLayer) + delete(deleteMap, configLayer.Digest) if err := SaveLayers(layers, fn, false); err != nil { return err } - // Create the manifest + var contentLayers []*Layer + for _, layer := range layers { + contentLayers = append(contentLayers, &layer.Layer) + delete(deleteMap, layer.Digest) + } + fn(api.ProgressResponse{Status: "writing manifest"}) - err = CreateManifest(name, cfg, manifestLayers) - if err != nil { + if err := CreateManifest(name, configLayer, contentLayers); err != nil { return err } - if noprune == "" { - fn(api.ProgressResponse{Status: "removing any unused layers"}) - err = deleteUnusedLayers(nil, deleteMap, false) - if err != nil { + if noprune := os.Getenv("OLLAMA_NOPRUNE"); noprune == "" { + if err := deleteUnusedLayers(nil, deleteMap, false); err != nil { return err } } @@ -739,7 +705,7 @@ func CopyModel(src, dest string) error { return nil } -func deleteUnusedLayers(skipModelPath *ModelPath, deleteMap map[string]bool, dryRun bool) error { +func deleteUnusedLayers(skipModelPath *ModelPath, deleteMap map[string]struct{}, dryRun bool) error { fp, err := GetManifestPath() if err != nil { return err @@ -779,21 +745,19 @@ func deleteUnusedLayers(skipModelPath *ModelPath, deleteMap map[string]bool, dry } // only delete the files which are still in the deleteMap - for k, v := range deleteMap { - if v { - fp, err := GetBlobsPath(k) - if err != nil { - log.Printf("couldn't get file path for '%s': %v", k, err) + for k := range deleteMap { + fp, err := GetBlobsPath(k) + if err != nil { + log.Printf("couldn't get file path for '%s': %v", k, err) + continue + } + if !dryRun { + if err := os.Remove(fp); err != nil { + log.Printf("couldn't remove file '%s': %v", fp, err) continue } - if !dryRun { - if err := os.Remove(fp); err != nil { - log.Printf("couldn't remove file '%s': %v", fp, err) - continue - } - } else { - log.Printf("wanted to remove: %s", fp) - } + } else { + log.Printf("wanted to remove: %s", fp) } } @@ -801,7 +765,7 @@ func deleteUnusedLayers(skipModelPath *ModelPath, deleteMap map[string]bool, dry } func PruneLayers() error { - deleteMap := make(map[string]bool) + deleteMap := make(map[string]struct{}) p, err := GetBlobsPath("") if err != nil { return err @@ -818,7 +782,7 @@ func PruneLayers() error { if runtime.GOOS == "windows" { name = strings.ReplaceAll(name, "-", ":") } - deleteMap[name] = true + deleteMap[name] = struct{}{} } log.Printf("total blobs: %d", len(deleteMap)) @@ -873,11 +837,11 @@ func DeleteModel(name string) error { return err } - deleteMap := make(map[string]bool) + deleteMap := make(map[string]struct{}) for _, layer := range manifest.Layers { - deleteMap[layer.Digest] = true + deleteMap[layer.Digest] = struct{}{} } - deleteMap[manifest.Config.Digest] = true + deleteMap[manifest.Config.Digest] = struct{}{} err = deleteUnusedLayers(&mp, deleteMap, false) if err != nil { @@ -1013,7 +977,7 @@ func PullModel(ctx context.Context, name string, regOpts *RegistryOptions, fn fu var noprune string // build deleteMap to prune unused layers - deleteMap := make(map[string]bool) + deleteMap := make(map[string]struct{}) if noprune = os.Getenv("OLLAMA_NOPRUNE"); noprune == "" { manifest, _, err = GetManifest(mp) @@ -1023,9 +987,9 @@ func PullModel(ctx context.Context, name string, regOpts *RegistryOptions, fn fu if manifest != nil { for _, l := range manifest.Layers { - deleteMap[l.Digest] = true + deleteMap[l.Digest] = struct{}{} } - deleteMap[manifest.Config.Digest] = true + deleteMap[manifest.Config.Digest] = struct{}{} } } diff --git a/server/routes.go b/server/routes.go index a543b10e..e53bad43 100644 --- a/server/routes.go +++ b/server/routes.go @@ -26,6 +26,7 @@ import ( "github.com/jmorganca/ollama/api" "github.com/jmorganca/ollama/llm" + "github.com/jmorganca/ollama/parser" "github.com/jmorganca/ollama/version" ) @@ -414,6 +415,19 @@ func CreateModelHandler(c *gin.Context) { return } + modelfile, err := os.Open(req.Path) + if err != nil { + c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": err.Error()}) + return + } + defer modelfile.Close() + + commands, err := parser.Parse(modelfile) + if err != nil { + c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": err.Error()}) + return + } + ch := make(chan any) go func() { defer close(ch) @@ -424,7 +438,7 @@ func CreateModelHandler(c *gin.Context) { ctx, cancel := context.WithCancel(c.Request.Context()) defer cancel() - if err := CreateModel(ctx, req.Name, req.Path, fn); err != nil { + if err := CreateModel(ctx, req.Name, commands, fn); err != nil { ch <- gin.H{"error": err.Error()} } }()