From 4c1caa373376db583fe2c6aa88c873b476b5f92b Mon Sep 17 00:00:00 2001 From: Bruce MacDonald Date: Tue, 25 Jul 2023 14:25:13 -0400 Subject: [PATCH] download models when creating from modelfile --- api/client.go | 4 +-- api/types.go | 4 --- cmd/cmd.go | 31 +++++++++++++++++----- server/images.go | 69 +++++++++++++++++++++++++++++++----------------- server/routes.go | 6 ++--- 5 files changed, 73 insertions(+), 41 deletions(-) diff --git a/api/client.go b/api/client.go index 7a6126cd..6786fa48 100644 --- a/api/client.go +++ b/api/client.go @@ -189,11 +189,11 @@ func (c *Client) Push(ctx context.Context, req *PushRequest, fn PushProgressFunc }) } -type CreateProgressFunc func(CreateProgress) error +type CreateProgressFunc func(ProgressResponse) error func (c *Client) Create(ctx context.Context, req *CreateRequest, fn CreateProgressFunc) error { return c.stream(ctx, http.MethodPost, "/api/create", req, func(bts []byte) error { - var resp CreateProgress + var resp ProgressResponse if err := json.Unmarshal(bts, &resp); err != nil { return err } diff --git a/api/types.go b/api/types.go index cabec90a..07ce8122 100644 --- a/api/types.go +++ b/api/types.go @@ -40,10 +40,6 @@ type CreateRequest struct { Path string `json:"path"` } -type CreateProgress struct { - Status string `json:"status"` -} - type DeleteRequest struct { Name string `json:"name"` } diff --git a/cmd/cmd.go b/cmd/cmd.go index 693d0cab..7761b03b 100644 --- a/cmd/cmd.go +++ b/cmd/cmd.go @@ -36,15 +36,32 @@ func CreateHandler(cmd *cobra.Command, args []string) error { var spinner *Spinner + var currentDigest string + var bar *progressbar.ProgressBar + request := api.CreateRequest{Name: args[0], Path: filename} - fn := func(resp api.CreateProgress) error { - if spinner != nil { - spinner.Stop() + fn := func(resp api.ProgressResponse) error { + if resp.Digest != currentDigest && resp.Digest != "" { + if spinner != nil { + spinner.Stop() + } + currentDigest = resp.Digest + bar = progressbar.DefaultBytes( + int64(resp.Total), + fmt.Sprintf("pulling %s...", resp.Digest[7:19]), + ) + + bar.Set(resp.Completed) + } else if resp.Digest == currentDigest && resp.Digest != "" { + bar.Set(resp.Completed) + } else { + currentDigest = "" + if spinner != nil { + spinner.Stop() + } + spinner = NewSpinner(resp.Status) + go spinner.Spin(100 * time.Millisecond) } - - spinner = NewSpinner(resp.Status) - go spinner.Spin(100 * time.Millisecond) - return nil } diff --git a/server/images.go b/server/images.go index 370a5e6d..6e74b055 100644 --- a/server/images.go +++ b/server/images.go @@ -187,15 +187,15 @@ func GetModel(name string) (*Model, error) { return model, nil } -func CreateModel(name string, path string, fn func(status string)) error { +func CreateModel(name string, path string, fn func(resp api.ProgressResponse)) error { mf, err := os.Open(path) if err != nil { - fn(fmt.Sprintf("couldn't open modelfile '%s'", path)) + fn(api.ProgressResponse{Status: fmt.Sprintf("couldn't open modelfile '%s'", path)}) return fmt.Errorf("failed to open file: %w", err) } defer mf.Close() - fn("parsing modelfile") + fn(api.ProgressResponse{Status: "parsing modelfile"}) commands, err := parser.Parse(mf) if err != nil { return err @@ -208,7 +208,7 @@ func CreateModel(name string, path string, fn func(status string)) error { log.Printf("[%s] - %s\n", c.Name, c.Args) switch c.Name { case "model": - fn("looking for model") + fn(api.ProgressResponse{Status: "looking for model"}) mf, err := GetManifest(ParseModelPath(c.Args)) if err != nil { fp := c.Args @@ -229,20 +229,40 @@ func CreateModel(name string, path string, fn func(status string)) error { fp = filepath.Join(filepath.Dir(path), fp) } - fn("creating model layer") - file, err := os.Open(fp) - if err != nil { - return fmt.Errorf("failed to open file: %v", err) - } - defer file.Close() + if _, err := os.Stat(fp); err != nil { + // the model file does not exist, try pulling it + if errors.Is(err, os.ErrNotExist) { + fn(api.ProgressResponse{Status: "pulling model file"}) + if err := PullModel(c.Args, &RegistryOptions{}, fn); err != nil { + return err + } + mf, err = GetManifest(ParseModelPath(c.Args)) + if err != nil { + return fmt.Errorf("failed to open file after pull: %v", err) + } - l, err := CreateLayer(file) - if err != nil { - return fmt.Errorf("failed to create layer: %v", err) + } else { + return err + } + } else { + // create a model from this specified file + fn(api.ProgressResponse{Status: "creating model layer"}) + + file, err := os.Open(fp) + if err != nil { + return fmt.Errorf("failed to open file: %v", err) + } + defer file.Close() + + l, err := CreateLayer(file) + if err != nil { + return fmt.Errorf("failed to create layer: %v", err) + } + l.MediaType = "application/vnd.ollama.image.model" + layers = append(layers, l) } - l.MediaType = "application/vnd.ollama.image.model" - layers = append(layers, l) - } else { + } + if mf != nil { log.Printf("manifest = %#v", mf) for _, l := range mf.Layers { newLayer, err := GetLayerWithBufferFromLayer(l) @@ -253,7 +273,7 @@ func CreateModel(name string, path string, fn func(status string)) error { } } case "license", "template", "system", "prompt": - fn(fmt.Sprintf("creating %s layer", c.Name)) + fn(api.ProgressResponse{Status: fmt.Sprintf("creating model %s layer", c.Name)}) // remove the prompt layer if one exists mediaType := fmt.Sprintf("application/vnd.ollama.image.%s", c.Name) layers = removeLayerFromLayers(layers, mediaType) @@ -272,7 +292,7 @@ func CreateModel(name string, path string, fn func(status string)) error { // Create a single layer for the parameters if len(params) > 0 { - fn("creating parameter layer") + fn(api.ProgressResponse{Status: "creating parameter layer"}) layers = removeLayerFromLayers(layers, "application/vnd.ollama.image.params") paramData, err := paramsToReader(params) if err != nil { @@ -297,7 +317,7 @@ func CreateModel(name string, path string, fn func(status string)) error { } // Create a layer for the config object - fn("creating config layer") + fn(api.ProgressResponse{Status: "creating config layer"}) cfg, err := createConfigLayer(digests) if err != nil { return err @@ -310,13 +330,13 @@ func CreateModel(name string, path string, fn func(status string)) error { } // Create the manifest - fn("writing manifest") + fn(api.ProgressResponse{Status: "writing manifest"}) err = CreateManifest(name, cfg, manifestLayers) if err != nil { return err } - fn("success") + fn(api.ProgressResponse{Status: "success"}) return nil } @@ -331,7 +351,7 @@ func removeLayerFromLayers(layers []*LayerReader, mediaType string) []*LayerRead return layers[:j] } -func SaveLayers(layers []*LayerReader, fn func(status string), force bool) error { +func SaveLayers(layers []*LayerReader, fn func(resp api.ProgressResponse), force bool) error { // Write each of the layers to disk for _, layer := range layers { fp, err := GetBlobsPath(layer.Digest) @@ -341,7 +361,8 @@ func SaveLayers(layers []*LayerReader, fn func(status string), force bool) error _, err = os.Stat(fp) if os.IsNotExist(err) || force { - fn(fmt.Sprintf("writing layer %s", layer.Digest)) + fn(api.ProgressResponse{Status: fmt.Sprintf("writing layer %s", layer.Digest)}) + out, err := os.Create(fp) if err != nil { log.Printf("couldn't create %s", fp) @@ -354,7 +375,7 @@ func SaveLayers(layers []*LayerReader, fn func(status string), force bool) error } } else { - fn(fmt.Sprintf("using already created layer %s", layer.Digest)) + fn(api.ProgressResponse{Status: fmt.Sprintf("using already created layer %s", layer.Digest)}) } } diff --git a/server/routes.go b/server/routes.go index d5b2e127..aabcb718 100644 --- a/server/routes.go +++ b/server/routes.go @@ -147,10 +147,8 @@ func CreateModelHandler(c *gin.Context) { ch := make(chan any) go func() { defer close(ch) - fn := func(status string) { - ch <- api.CreateProgress{ - Status: status, - } + fn := func(resp api.ProgressResponse) { + ch <- resp } if err := CreateModel(req.Name, req.Path, fn); err != nil {