From b0d14ed51c459d3c63649ebcaa8c581431a946ae Mon Sep 17 00:00:00 2001 From: Michael Yang Date: Tue, 14 Nov 2023 12:30:34 -0800 Subject: [PATCH 01/11] refactor create model --- server/images.go | 330 +++++++++++++++++++++-------------------------- server/routes.go | 16 ++- 2 files changed, 162 insertions(+), 184 deletions(-) diff --git a/server/images.go b/server/images.go index 8d784fef..42def270 100644 --- a/server/images.go +++ b/server/images.go @@ -248,200 +248,172 @@ func filenameWithPath(path, f string) (string, error) { return f, nil } -func CreateModel(ctx context.Context, name string, path string, fn func(resp api.ProgressResponse)) error { - mp := ParseModelPath(name) - - var manifest *ManifestV2 - var err error - var noprune string - - // build deleteMap to prune unused layers - deleteMap := make(map[string]bool) - - if noprune = os.Getenv("OLLAMA_NOPRUNE"); noprune == "" { - manifest, _, err = GetManifest(mp) - if err != nil && !errors.Is(err, os.ErrNotExist) { - return err - } - - if manifest != nil { - for _, l := range manifest.Layers { - deleteMap[l.Digest] = true - } - deleteMap[manifest.Config.Digest] = true - } - } - - mf, err := os.Open(path) +func realpath(p string) string { + abspath, err := filepath.Abs(p) if err != nil { - fn(api.ProgressResponse{Status: fmt.Sprintf("couldn't open modelfile '%s'", path)}) - return fmt.Errorf("failed to open file: %w", err) + return p } - defer mf.Close() - fn(api.ProgressResponse{Status: "parsing modelfile"}) - commands, err := parser.Parse(mf) + home, err := os.UserHomeDir() if err != nil { - return err + return abspath } + if p == "~" { + return home + } else if strings.HasPrefix(p, "~/") { + return filepath.Join(home, p[2:]) + } + + return abspath +} + +func CreateModel(ctx context.Context, name string, commands []parser.Command, fn func(resp api.ProgressResponse)) error { config := ConfigV2{ - Architecture: "amd64", OS: "linux", + Architecture: "amd64", } + deleteMap := make(map[string]struct{}) + var layers []*LayerReader + params := make(map[string][]string) - var sourceParams map[string]any + fromParams := make(map[string]any) + for _, c := range commands { - log.Printf("[%s] - %s\n", c.Name, c.Args) + log.Printf("[%s] - %s", c.Name, c.Args) + mediatype := fmt.Sprintf("application/vnd.ollama.image.%s", c.Name) + switch c.Name { case "model": - fn(api.ProgressResponse{Status: "looking for model"}) - - mp := ParseModelPath(c.Args) - mf, _, err := GetManifest(mp) + bin, err := os.Open(realpath(c.Args)) if err != nil { - modelFile, err := filenameWithPath(path, c.Args) - if err != nil { + // not a file on disk so must be a model reference + modelpath := ParseModelPath(c.Args) + manifest, _, err := GetManifest(modelpath) + switch { + case errors.Is(err, os.ErrNotExist): + fn(api.ProgressResponse{Status: "pulling model"}) + if err := PullModel(ctx, c.Args, &RegistryOptions{}, fn); err != nil { + return err + } + + manifest, _, err = GetManifest(modelpath) + if err != nil { + return err + } + case err != nil: return err } - if _, err := os.Stat(modelFile); err != nil { - // the model file does not exist, try pulling it - if errors.Is(err, os.ErrNotExist) { - fn(api.ProgressResponse{Status: "pulling model file"}) - if err := PullModel(ctx, c.Args, &RegistryOptions{}, fn); err != nil { - return err - } - mf, _, err = GetManifest(mp) - if err != nil { - return fmt.Errorf("failed to open file after pull: %v", err) - } - } else { - return err - } - } else { - // create a model from this specified file - fn(api.ProgressResponse{Status: "creating model layer"}) - file, err := os.Open(modelFile) - if err != nil { - return fmt.Errorf("failed to open file: %v", err) - } - defer file.Close() - ggml, err := llm.DecodeGGML(file) - if err != nil { - return err - } - - config.ModelFormat = ggml.Name() - config.ModelFamily = ggml.ModelFamily() - config.ModelType = ggml.ModelType() - config.FileType = ggml.FileType() - - // reset the file - file.Seek(0, io.SeekStart) - - l, err := CreateLayer(file) - if err != nil { - return fmt.Errorf("failed to create layer: %v", err) - } - l.MediaType = "application/vnd.ollama.image.model" - layers = append(layers, l) - } - } - - if mf != nil { fn(api.ProgressResponse{Status: "reading model metadata"}) - sourceBlobPath, err := GetBlobsPath(mf.Config.Digest) + fromConfigPath, err := GetBlobsPath(manifest.Config.Digest) if err != nil { return err } - sourceBlob, err := os.Open(sourceBlobPath) + fromConfigFile, err := os.Open(fromConfigPath) if err != nil { return err } - defer sourceBlob.Close() + defer fromConfigFile.Close() - var source ConfigV2 - if err := json.NewDecoder(sourceBlob).Decode(&source); err != nil { + var fromConfig ConfigV2 + if err := json.NewDecoder(fromConfigFile).Decode(&fromConfig); err != nil { return err } - // copy the model metadata - config.ModelFamily = source.ModelFamily - config.ModelType = source.ModelType - config.ModelFormat = source.ModelFormat - config.FileType = source.FileType + config.ModelFormat = fromConfig.ModelFormat + config.ModelFamily = fromConfig.ModelFamily + config.ModelType = fromConfig.ModelType + config.FileType = fromConfig.FileType - for _, l := range mf.Layers { - if l.MediaType == "application/vnd.ollama.image.params" { - sourceParamsBlobPath, err := GetBlobsPath(l.Digest) + for _, layer := range manifest.Layers { + deleteMap[layer.Digest] = struct{}{} + if layer.MediaType == "application/vnd.ollama.image.params" { + fromParamsPath, err := GetBlobsPath(layer.Digest) if err != nil { return err } - sourceParamsBlob, err := os.Open(sourceParamsBlobPath) + fromParamsFile, err := os.Open(fromParamsPath) if err != nil { return err } - defer sourceParamsBlob.Close() + defer fromParamsFile.Close() - if err := json.NewDecoder(sourceParamsBlob).Decode(&sourceParams); err != nil { + if err := json.NewDecoder(fromParamsFile).Decode(&fromParams); err != nil { return err } } - newLayer, err := GetLayerWithBufferFromLayer(l) + layer, err := GetLayerWithBufferFromLayer(layer) if err != nil { return err } - newLayer.From = mp.GetShortTagname() - layers = append(layers, newLayer) - } - } - case "adapter": - fn(api.ProgressResponse{Status: fmt.Sprintf("creating model %s layer", c.Name)}) - fp, err := filenameWithPath(path, c.Args) + layer.From = modelpath.GetShortTagname() + layers = append(layers, layer) + } + + deleteMap[manifest.Config.Digest] = struct{}{} + continue + } + defer bin.Close() + + fn(api.ProgressResponse{Status: "creating model layer"}) + ggml, err := llm.DecodeGGML(bin) if err != nil { return err } - // create a model from this specified file - fn(api.ProgressResponse{Status: "creating model layer"}) + config.ModelFormat = ggml.Name() + config.ModelFamily = ggml.ModelFamily() + config.ModelType = ggml.ModelType() + config.FileType = ggml.FileType() - file, err := os.Open(fp) + bin.Seek(0, io.SeekStart) + layer, err := CreateLayer(bin) if err != nil { - return fmt.Errorf("failed to open file: %v", err) + return err } - defer file.Close() - l, err := CreateLayer(file) + layer.MediaType = mediatype + layers = append(layers, layer) + case "adapter": + fn(api.ProgressResponse{Status: "creating adapter layer"}) + bin, err := os.Open(realpath(c.Args)) if err != nil { - return fmt.Errorf("failed to create layer: %v", err) + return err } - l.MediaType = "application/vnd.ollama.image.adapter" - layers = append(layers, l) - case "license": - fn(api.ProgressResponse{Status: fmt.Sprintf("creating model %s layer", c.Name)}) - mediaType := fmt.Sprintf("application/vnd.ollama.image.%s", c.Name) + defer bin.Close() - layer, err := CreateLayer(strings.NewReader(c.Args)) + layer, err := CreateLayer(bin) if err != nil { return err } if layer.Size > 0 { - layer.MediaType = mediaType + layer.MediaType = mediatype layers = append(layers, layer) } - case "template", "system", "prompt": - fn(api.ProgressResponse{Status: fmt.Sprintf("creating model %s layer", c.Name)}) - // remove the layer if one exists - mediaType := fmt.Sprintf("application/vnd.ollama.image.%s", c.Name) - layers = removeLayerFromLayers(layers, mediaType) + case "license": + fn(api.ProgressResponse{Status: "creating license layer"}) + layer, err := CreateLayer(strings.NewReader(c.Args)) + if err != nil { + return err + } + + if layer.Size > 0 { + layer.MediaType = mediatype + layers = append(layers, layer) + } + case "template", "system": + fn(api.ProgressResponse{Status: fmt.Sprintf("creating %s layer", c.Name)}) + + // remove duplicates layers + layers = removeLayerFromLayers(layers, mediatype) layer, err := CreateLayer(strings.NewReader(c.Args)) if err != nil { @@ -449,48 +421,47 @@ func CreateModel(ctx context.Context, name string, path string, fn func(resp api } if layer.Size > 0 { - layer.MediaType = mediaType + layer.MediaType = mediatype layers = append(layers, layer) } default: - // runtime parameters, build a list of args for each parameter to allow multiple values to be specified (ex: multiple stop sequences) params[c.Name] = append(params[c.Name], c.Args) } } - // Create a single layer for the parameters if len(params) > 0 { - fn(api.ProgressResponse{Status: "creating parameter layer"}) + fn(api.ProgressResponse{Status: "creating parameters layer"}) - layers = removeLayerFromLayers(layers, "application/vnd.ollama.image.params") formattedParams, err := formatParams(params) if err != nil { - return fmt.Errorf("couldn't create params json: %v", err) + return err } - for k, v := range sourceParams { + for k, v := range fromParams { if _, ok := formattedParams[k]; !ok { formattedParams[k] = v } } if config.ModelType == "65B" { - if numGQA, ok := formattedParams["num_gqa"].(int); ok && numGQA == 8 { + if gqa, ok := formattedParams["gqa"].(int); ok && gqa == 8 { config.ModelType = "70B" } } - bts, err := json.Marshal(formattedParams) + var b bytes.Buffer + if err := json.NewEncoder(&b).Encode(formattedParams); err != nil { + return err + } + + fn(api.ProgressResponse{Status: "creating config layer"}) + layer, err := CreateLayer(bytes.NewReader(b.Bytes())) if err != nil { return err } - l, err := CreateLayer(bytes.NewReader(bts)) - if err != nil { - return fmt.Errorf("failed to create layer: %v", err) - } - l.MediaType = "application/vnd.ollama.image.params" - layers = append(layers, l) + layer.MediaType = "application/vnd.ollama.image.params" + layers = append(layers, layer) } digests, err := getLayerDigests(layers) @@ -498,36 +469,31 @@ func CreateModel(ctx context.Context, name string, path string, fn func(resp api return err } - var manifestLayers []*Layer - for _, l := range layers { - manifestLayers = append(manifestLayers, &l.Layer) - delete(deleteMap, l.Layer.Digest) - } - - // Create a layer for the config object - fn(api.ProgressResponse{Status: "creating config layer"}) - cfg, err := createConfigLayer(config, digests) + configLayer, err := createConfigLayer(config, digests) if err != nil { return err } - layers = append(layers, cfg) - delete(deleteMap, cfg.Layer.Digest) + + layers = append(layers, configLayer) + delete(deleteMap, configLayer.Digest) if err := SaveLayers(layers, fn, false); err != nil { return err } - // Create the manifest + var contentLayers []*Layer + for _, layer := range layers { + contentLayers = append(contentLayers, &layer.Layer) + delete(deleteMap, layer.Digest) + } + fn(api.ProgressResponse{Status: "writing manifest"}) - err = CreateManifest(name, cfg, manifestLayers) - if err != nil { + if err := CreateManifest(name, configLayer, contentLayers); err != nil { return err } - if noprune == "" { - fn(api.ProgressResponse{Status: "removing any unused layers"}) - err = deleteUnusedLayers(nil, deleteMap, false) - if err != nil { + if noprune := os.Getenv("OLLAMA_NOPRUNE"); noprune == "" { + if err := deleteUnusedLayers(nil, deleteMap, false); err != nil { return err } } @@ -739,7 +705,7 @@ func CopyModel(src, dest string) error { return nil } -func deleteUnusedLayers(skipModelPath *ModelPath, deleteMap map[string]bool, dryRun bool) error { +func deleteUnusedLayers(skipModelPath *ModelPath, deleteMap map[string]struct{}, dryRun bool) error { fp, err := GetManifestPath() if err != nil { return err @@ -779,21 +745,19 @@ func deleteUnusedLayers(skipModelPath *ModelPath, deleteMap map[string]bool, dry } // only delete the files which are still in the deleteMap - for k, v := range deleteMap { - if v { - fp, err := GetBlobsPath(k) - if err != nil { - log.Printf("couldn't get file path for '%s': %v", k, err) + for k := range deleteMap { + fp, err := GetBlobsPath(k) + if err != nil { + log.Printf("couldn't get file path for '%s': %v", k, err) + continue + } + if !dryRun { + if err := os.Remove(fp); err != nil { + log.Printf("couldn't remove file '%s': %v", fp, err) continue } - if !dryRun { - if err := os.Remove(fp); err != nil { - log.Printf("couldn't remove file '%s': %v", fp, err) - continue - } - } else { - log.Printf("wanted to remove: %s", fp) - } + } else { + log.Printf("wanted to remove: %s", fp) } } @@ -801,7 +765,7 @@ func deleteUnusedLayers(skipModelPath *ModelPath, deleteMap map[string]bool, dry } func PruneLayers() error { - deleteMap := make(map[string]bool) + deleteMap := make(map[string]struct{}) p, err := GetBlobsPath("") if err != nil { return err @@ -818,7 +782,7 @@ func PruneLayers() error { if runtime.GOOS == "windows" { name = strings.ReplaceAll(name, "-", ":") } - deleteMap[name] = true + deleteMap[name] = struct{}{} } log.Printf("total blobs: %d", len(deleteMap)) @@ -873,11 +837,11 @@ func DeleteModel(name string) error { return err } - deleteMap := make(map[string]bool) + deleteMap := make(map[string]struct{}) for _, layer := range manifest.Layers { - deleteMap[layer.Digest] = true + deleteMap[layer.Digest] = struct{}{} } - deleteMap[manifest.Config.Digest] = true + deleteMap[manifest.Config.Digest] = struct{}{} err = deleteUnusedLayers(&mp, deleteMap, false) if err != nil { @@ -1013,7 +977,7 @@ func PullModel(ctx context.Context, name string, regOpts *RegistryOptions, fn fu var noprune string // build deleteMap to prune unused layers - deleteMap := make(map[string]bool) + deleteMap := make(map[string]struct{}) if noprune = os.Getenv("OLLAMA_NOPRUNE"); noprune == "" { manifest, _, err = GetManifest(mp) @@ -1023,9 +987,9 @@ func PullModel(ctx context.Context, name string, regOpts *RegistryOptions, fn fu if manifest != nil { for _, l := range manifest.Layers { - deleteMap[l.Digest] = true + deleteMap[l.Digest] = struct{}{} } - deleteMap[manifest.Config.Digest] = true + deleteMap[manifest.Config.Digest] = struct{}{} } } diff --git a/server/routes.go b/server/routes.go index a543b10e..e53bad43 100644 --- a/server/routes.go +++ b/server/routes.go @@ -26,6 +26,7 @@ import ( "github.com/jmorganca/ollama/api" "github.com/jmorganca/ollama/llm" + "github.com/jmorganca/ollama/parser" "github.com/jmorganca/ollama/version" ) @@ -414,6 +415,19 @@ func CreateModelHandler(c *gin.Context) { return } + modelfile, err := os.Open(req.Path) + if err != nil { + c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": err.Error()}) + return + } + defer modelfile.Close() + + commands, err := parser.Parse(modelfile) + if err != nil { + c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": err.Error()}) + return + } + ch := make(chan any) go func() { defer close(ch) @@ -424,7 +438,7 @@ func CreateModelHandler(c *gin.Context) { ctx, cancel := context.WithCancel(c.Request.Context()) defer cancel() - if err := CreateModel(ctx, req.Name, req.Path, fn); err != nil { + if err := CreateModel(ctx, req.Name, commands, fn); err != nil { ch <- gin.H{"error": err.Error()} } }() From 3ca56b5adafb6a46465024df1d9a4d52a6ae4f2f Mon Sep 17 00:00:00 2001 From: Michael Yang Date: Tue, 14 Nov 2023 13:45:07 -0800 Subject: [PATCH 02/11] add create modelfile field --- api/types.go | 7 ++++--- server/routes.go | 22 ++++++++++++++++------ 2 files changed, 20 insertions(+), 9 deletions(-) diff --git a/api/types.go b/api/types.go index ffa5b7ca..2a36a1f6 100644 --- a/api/types.go +++ b/api/types.go @@ -99,9 +99,10 @@ type EmbeddingResponse struct { } type CreateRequest struct { - Name string `json:"name"` - Path string `json:"path"` - Stream *bool `json:"stream,omitempty"` + Name string `json:"name"` + Path string `json:"path"` + Modelfile string `json:"modelfile"` + Stream *bool `json:"stream,omitempty"` } type DeleteRequest struct { diff --git a/server/routes.go b/server/routes.go index e53bad43..65a96911 100644 --- a/server/routes.go +++ b/server/routes.go @@ -410,17 +410,27 @@ func CreateModelHandler(c *gin.Context) { return } - if req.Name == "" || req.Path == "" { - c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": "name and path are required"}) + if req.Name == "" { + c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": "name is required"}) return } - modelfile, err := os.Open(req.Path) - if err != nil { - c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": err.Error()}) + if req.Path == "" && req.Modelfile == "" { + c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": "path or modelfile are required"}) return } - defer modelfile.Close() + + var modelfile io.Reader = strings.NewReader(req.Modelfile) + if req.Path != "" && req.Modelfile == "" { + bin, err := os.Open(req.Path) + if err != nil { + c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": fmt.Sprintf("error reading modelfile: %s", err)}) + return + } + defer bin.Close() + + modelfile = bin + } commands, err := parser.Parse(modelfile) if err != nil { From 1552cee59f6080fc8b74e81317e94381d2e1844a Mon Sep 17 00:00:00 2001 From: Michael Yang Date: Tue, 14 Nov 2023 14:07:40 -0800 Subject: [PATCH 03/11] client create modelfile --- api/client.go | 27 +++++++++++++++++- api/types.go | 4 +++ cmd/cmd.go | 73 +++++++++++++++++++++++++++++++++++++++--------- server/routes.go | 57 +++++++++++++++++++++++++++++++++++++ 4 files changed, 147 insertions(+), 14 deletions(-) diff --git a/api/client.go b/api/client.go index 974c08eb..262918b3 100644 --- a/api/client.go +++ b/api/client.go @@ -5,6 +5,7 @@ import ( "bytes" "context" "encoding/json" + "errors" "fmt" "io" "net" @@ -95,11 +96,19 @@ func (c *Client) do(ctx context.Context, method, path string, reqData, respData var reqBody io.Reader var data []byte var err error - if reqData != nil { + + switch reqData := reqData.(type) { + case io.Reader: + // reqData is already an io.Reader + reqBody = reqData + case nil: + // noop + default: data, err = json.Marshal(reqData) if err != nil { return err } + reqBody = bytes.NewReader(data) } @@ -287,3 +296,19 @@ func (c *Client) Heartbeat(ctx context.Context) error { } return nil } + +func (c *Client) CreateBlob(ctx context.Context, digest string, r io.Reader) (string, error) { + var response CreateBlobResponse + if err := c.do(ctx, http.MethodGet, fmt.Sprintf("/api/blobs/%s/path", digest), nil, &response); err != nil { + var statusError StatusError + if !errors.As(err, &statusError) || statusError.StatusCode != http.StatusNotFound { + return "", err + } + + if err := c.do(ctx, http.MethodPost, fmt.Sprintf("/api/blobs/%s", digest), r, &response); err != nil { + return "", err + } + } + + return response.Path, nil +} diff --git a/api/types.go b/api/types.go index 2a36a1f6..347c4f84 100644 --- a/api/types.go +++ b/api/types.go @@ -105,6 +105,10 @@ type CreateRequest struct { Stream *bool `json:"stream,omitempty"` } +type CreateBlobResponse struct { + Path string `json:"path"` +} + type DeleteRequest struct { Name string `json:"name"` } diff --git a/cmd/cmd.go b/cmd/cmd.go index 8fc6e4c4..30c6bcf6 100644 --- a/cmd/cmd.go +++ b/cmd/cmd.go @@ -1,9 +1,11 @@ package cmd import ( + "bytes" "context" "crypto/ed25519" "crypto/rand" + "crypto/sha256" "encoding/pem" "errors" "fmt" @@ -27,6 +29,7 @@ import ( "github.com/jmorganca/ollama/api" "github.com/jmorganca/ollama/format" + "github.com/jmorganca/ollama/parser" "github.com/jmorganca/ollama/progressbar" "github.com/jmorganca/ollama/readline" "github.com/jmorganca/ollama/server" @@ -45,17 +48,65 @@ func CreateHandler(cmd *cobra.Command, args []string) error { return err } - var spinner *Spinner + modelfile, err := os.ReadFile(filename) + if err != nil { + return err + } + + spinner := NewSpinner("transferring context") + go spinner.Spin(100 * time.Millisecond) + + commands, err := parser.Parse(bytes.NewReader(modelfile)) + if err != nil { + return err + } + + home, err := os.UserHomeDir() + if err != nil { + return err + } + + for _, c := range commands { + switch c.Name { + case "model", "adapter": + path := c.Args + if path == "~" { + path = home + } else if strings.HasPrefix(path, "~/") { + path = filepath.Join(home, path[2:]) + } + + bin, err := os.Open(path) + if errors.Is(err, os.ErrNotExist) && c.Name == "model" { + // value might be a model reference and not a real file + } else if err != nil { + return err + } + defer bin.Close() + + hash := sha256.New() + if _, err := io.Copy(hash, bin); err != nil { + return err + } + bin.Seek(0, io.SeekStart) + + digest := fmt.Sprintf("sha256:%x", hash.Sum(nil)) + path, err = client.CreateBlob(cmd.Context(), digest, bin) + if err != nil { + return err + } + + modelfile = bytes.ReplaceAll(modelfile, []byte(c.Args), []byte(path)) + } + } var currentDigest string var bar *progressbar.ProgressBar - request := api.CreateRequest{Name: args[0], Path: filename} + request := api.CreateRequest{Name: args[0], Path: filename, Modelfile: string(modelfile)} fn := func(resp api.ProgressResponse) error { if resp.Digest != currentDigest && resp.Digest != "" { - if spinner != nil { - spinner.Stop() - } + spinner.Stop() currentDigest = resp.Digest // pulling bar = progressbar.DefaultBytes( @@ -67,9 +118,7 @@ func CreateHandler(cmd *cobra.Command, args []string) error { bar.Set64(resp.Completed) } else { currentDigest = "" - if spinner != nil { - spinner.Stop() - } + spinner.Stop() spinner = NewSpinner(resp.Status) go spinner.Spin(100 * time.Millisecond) } @@ -81,11 +130,9 @@ func CreateHandler(cmd *cobra.Command, args []string) error { return err } - if spinner != nil { - spinner.Stop() - if spinner.description != "success" { - return errors.New("unexpected end to create model") - } + spinner.Stop() + if spinner.description != "success" { + return errors.New("unexpected end to create model") } return nil diff --git a/server/routes.go b/server/routes.go index 65a96911..c12a7cda 100644 --- a/server/routes.go +++ b/server/routes.go @@ -2,6 +2,7 @@ package server import ( "context" + "crypto/sha256" "encoding/json" "errors" "fmt" @@ -649,6 +650,60 @@ func CopyModelHandler(c *gin.Context) { } } +func GetBlobHandler(c *gin.Context) { + path, err := GetBlobsPath(c.Param("digest")) + if err != nil { + c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": err.Error()}) + return + } + + if _, err := os.Stat(path); err != nil { + c.AbortWithStatusJSON(http.StatusNotFound, gin.H{"error": fmt.Sprintf("blob %q not found", c.Param("digest"))}) + return + } + + c.JSON(http.StatusOK, api.CreateBlobResponse{Path: path}) +} + +func CreateBlobHandler(c *gin.Context) { + hash := sha256.New() + temp, err := os.CreateTemp("", c.Param("digest")) + if err != nil { + c.AbortWithStatusJSON(http.StatusInternalServerError, gin.H{"error": err.Error()}) + return + } + defer temp.Close() + defer os.Remove(temp.Name()) + + if _, err := io.Copy(temp, io.TeeReader(c.Request.Body, hash)); err != nil { + c.AbortWithStatusJSON(http.StatusInternalServerError, gin.H{"error": err.Error()}) + return + } + + if fmt.Sprintf("sha256:%x", hash.Sum(nil)) != c.Param("digest") { + c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": "digest does not match body"}) + return + } + + if err := temp.Close(); err != nil { + c.AbortWithStatusJSON(http.StatusInternalServerError, gin.H{"error": err.Error()}) + return + } + + targetPath, err := GetBlobsPath(c.Param("digest")) + if err != nil { + c.AbortWithStatusJSON(http.StatusInternalServerError, gin.H{"error": err.Error()}) + return + } + + if err := os.Rename(temp.Name(), targetPath); err != nil { + c.AbortWithStatusJSON(http.StatusInternalServerError, gin.H{"error": err.Error()}) + return + } + + c.JSON(http.StatusOK, api.CreateBlobResponse{Path: targetPath}) +} + var defaultAllowOrigins = []string{ "localhost", "127.0.0.1", @@ -708,6 +763,7 @@ func Serve(ln net.Listener, allowOrigins []string) error { r.POST("/api/copy", CopyModelHandler) r.DELETE("/api/delete", DeleteModelHandler) r.POST("/api/show", ShowModelHandler) + r.POST("/api/blobs/:digest", CreateBlobHandler) for _, method := range []string{http.MethodGet, http.MethodHead} { r.Handle(method, "/", func(c *gin.Context) { @@ -715,6 +771,7 @@ func Serve(ln net.Listener, allowOrigins []string) error { }) r.Handle(method, "/api/tags", ListModelsHandler) + r.Handle(method, "/api/blobs/:digest/path", GetBlobHandler) } log.Printf("Listening on %s (version %s)", ln.Addr(), version.Version) From a07c935d345bef1bce4f3411da0e4b69fa9bf266 Mon Sep 17 00:00:00 2001 From: Michael Yang Date: Tue, 14 Nov 2023 14:27:51 -0800 Subject: [PATCH 04/11] ignore non blobs --- server/images.go | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/server/images.go b/server/images.go index 42def270..893fc4f2 100644 --- a/server/images.go +++ b/server/images.go @@ -782,7 +782,9 @@ func PruneLayers() error { if runtime.GOOS == "windows" { name = strings.ReplaceAll(name, "-", ":") } - deleteMap[name] = struct{}{} + if strings.HasPrefix(name, "sha256:") { + deleteMap[name] = struct{}{} + } } log.Printf("total blobs: %d", len(deleteMap)) From cac11c9137294961adbb08a6c279ab652af3bcdc Mon Sep 17 00:00:00 2001 From: Michael Yang Date: Tue, 14 Nov 2023 14:44:10 -0800 Subject: [PATCH 05/11] update api docs --- docs/api.md | 33 +++++++++++++++++++++++++++++++-- 1 file changed, 31 insertions(+), 2 deletions(-) diff --git a/docs/api.md b/docs/api.md index 08402266..aa3ef034 100644 --- a/docs/api.md +++ b/docs/api.md @@ -4,6 +4,7 @@ - [Generate a completion](#generate-a-completion) - [Create a Model](#create-a-model) +- [Create a Blob](#create-a-blob) - [List Local Models](#list-local-models) - [Show Model Information](#show-model-information) - [Copy a Model](#copy-a-model) @@ -292,12 +293,13 @@ curl -X POST http://localhost:11434/api/generate -d '{ POST /api/create ``` -Create a model from a [`Modelfile`](./modelfile.md) +Create a model from a [`Modelfile`](./modelfile.md). It is recommended to set `modelfile` to the content of the Modelfile rather than just set `path`. This is a requirement for remote create. Remote model creation should also create any file blobs, fields such as `FROM` and `ADAPTER`, explicitly with the server using [Create a Blob](#create-a-blob) and the value to the path indicated in the response. ### Parameters - `name`: name of the model to create - `path`: path to the Modelfile +- `modelfile`: contents of the Modelfile - `stream`: (optional) if `false` the response will be returned as a single response object, rather than a stream of objects ### Examples @@ -307,7 +309,8 @@ Create a model from a [`Modelfile`](./modelfile.md) ```shell curl -X POST http://localhost:11434/api/create -d '{ "name": "mario", - "path": "~/Modelfile" + "path": "~/Modelfile", + "modelfile": "FROM llama2" }' ``` @@ -321,6 +324,32 @@ A stream of JSON objects. When finished, `status` is `success`. } ``` +## Create a Blob + +```shell +POST /api/blobs/:digest +``` + +Create a blob from a file. Returns the server file path. + +### Query Parameters + +- `digest`: the expected SHA256 digest of the file + +### Examples + +```shell +curl -X POST http://localhost:11434/api/blobs/sha256:29fdb92e57cf0827ded04ae6461b5931d01fa595843f55d36f5b275a52087dd2 -d @llama2-13b-q4_0.gguf +``` + +### Response + +```json +{ + "path": "/home/user/.ollama/models/blobs/sha256:29fdb92e57cf0827ded04ae6461b5931d01fa595843f55d36f5b275a52087dd2" +} +``` + ## List Local Models ```shell From d660eebf22c11c5b13bc990aaa4f9bf538bc5480 Mon Sep 17 00:00:00 2001 From: Michael Yang Date: Wed, 15 Nov 2023 10:57:09 -0800 Subject: [PATCH 06/11] fix create from model tag --- cmd/cmd.go | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/cmd/cmd.go b/cmd/cmd.go index 30c6bcf6..d68c5868 100644 --- a/cmd/cmd.go +++ b/cmd/cmd.go @@ -78,7 +78,7 @@ func CreateHandler(cmd *cobra.Command, args []string) error { bin, err := os.Open(path) if errors.Is(err, os.ErrNotExist) && c.Name == "model" { - // value might be a model reference and not a real file + continue } else if err != nil { return err } From 1901044b075517fc48a298d044227db9666e3904 Mon Sep 17 00:00:00 2001 From: Michael Yang Date: Wed, 15 Nov 2023 10:59:38 -0800 Subject: [PATCH 07/11] use checksum reference --- api/client.go | 13 ++++++------- cmd/cmd.go | 5 ++--- server/images.go | 9 +++++++++ server/routes.go | 5 +++-- 4 files changed, 20 insertions(+), 12 deletions(-) diff --git a/api/client.go b/api/client.go index 262918b3..44af222c 100644 --- a/api/client.go +++ b/api/client.go @@ -297,18 +297,17 @@ func (c *Client) Heartbeat(ctx context.Context) error { return nil } -func (c *Client) CreateBlob(ctx context.Context, digest string, r io.Reader) (string, error) { - var response CreateBlobResponse - if err := c.do(ctx, http.MethodGet, fmt.Sprintf("/api/blobs/%s/path", digest), nil, &response); err != nil { +func (c *Client) CreateBlob(ctx context.Context, digest string, r io.Reader) error { + if err := c.do(ctx, http.MethodHead, fmt.Sprintf("/api/blobs/%s", digest), nil, nil); err != nil { var statusError StatusError if !errors.As(err, &statusError) || statusError.StatusCode != http.StatusNotFound { - return "", err + return err } - if err := c.do(ctx, http.MethodPost, fmt.Sprintf("/api/blobs/%s", digest), r, &response); err != nil { - return "", err + if err := c.do(ctx, http.MethodPost, fmt.Sprintf("/api/blobs/%s", digest), r, nil); err != nil { + return err } } - return response.Path, nil + return nil } diff --git a/cmd/cmd.go b/cmd/cmd.go index d68c5868..008c6b38 100644 --- a/cmd/cmd.go +++ b/cmd/cmd.go @@ -91,12 +91,11 @@ func CreateHandler(cmd *cobra.Command, args []string) error { bin.Seek(0, io.SeekStart) digest := fmt.Sprintf("sha256:%x", hash.Sum(nil)) - path, err = client.CreateBlob(cmd.Context(), digest, bin) - if err != nil { + if err = client.CreateBlob(cmd.Context(), digest, bin); err != nil { return err } - modelfile = bytes.ReplaceAll(modelfile, []byte(c.Args), []byte(path)) + modelfile = bytes.ReplaceAll(modelfile, []byte(c.Args), []byte("@"+digest)) } } diff --git a/server/images.go b/server/images.go index 893fc4f2..47ded90c 100644 --- a/server/images.go +++ b/server/images.go @@ -287,6 +287,15 @@ func CreateModel(ctx context.Context, name string, commands []parser.Command, fn switch c.Name { case "model": + if strings.HasPrefix(c.Args, "@") { + blobPath, err := GetBlobsPath(strings.TrimPrefix(c.Args, "@")) + if err != nil { + return err + } + + c.Args = blobPath + } + bin, err := os.Open(realpath(c.Args)) if err != nil { // not a file on disk so must be a model reference diff --git a/server/routes.go b/server/routes.go index c12a7cda..b3998e5c 100644 --- a/server/routes.go +++ b/server/routes.go @@ -650,7 +650,7 @@ func CopyModelHandler(c *gin.Context) { } } -func GetBlobHandler(c *gin.Context) { +func HeadBlobHandler(c *gin.Context) { path, err := GetBlobsPath(c.Param("digest")) if err != nil { c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": err.Error()}) @@ -771,9 +771,10 @@ func Serve(ln net.Listener, allowOrigins []string) error { }) r.Handle(method, "/api/tags", ListModelsHandler) - r.Handle(method, "/api/blobs/:digest/path", GetBlobHandler) } + r.HEAD("/api/blobs/:digest", HeadBlobHandler) + log.Printf("Listening on %s (version %s)", ln.Addr(), version.Version) s := &http.Server{ Handler: r, From 71d71d09889e1c2cea16ecf6c20d9c30f35a12e0 Mon Sep 17 00:00:00 2001 From: Michael Yang Date: Wed, 15 Nov 2023 11:01:32 -0800 Subject: [PATCH 08/11] update docs --- docs/api.md | 11 +++++------ 1 file changed, 5 insertions(+), 6 deletions(-) diff --git a/docs/api.md b/docs/api.md index aa3ef034..9975c4e3 100644 --- a/docs/api.md +++ b/docs/api.md @@ -4,7 +4,6 @@ - [Generate a completion](#generate-a-completion) - [Create a Model](#create-a-model) -- [Create a Blob](#create-a-blob) - [List Local Models](#list-local-models) - [Show Model Information](#show-model-information) - [Copy a Model](#copy-a-model) @@ -298,7 +297,7 @@ Create a model from a [`Modelfile`](./modelfile.md). It is recommended to set `m ### Parameters - `name`: name of the model to create -- `path`: path to the Modelfile +- `path`: path to the Modelfile (deprecated: please use modelfile instead) - `modelfile`: contents of the Modelfile - `stream`: (optional) if `false` the response will be returned as a single response object, rather than a stream of objects @@ -324,7 +323,7 @@ A stream of JSON objects. When finished, `status` is `success`. } ``` -## Create a Blob +### Create a Blob ```shell POST /api/blobs/:digest @@ -332,17 +331,17 @@ POST /api/blobs/:digest Create a blob from a file. Returns the server file path. -### Query Parameters +#### Query Parameters - `digest`: the expected SHA256 digest of the file -### Examples +#### Examples ```shell curl -X POST http://localhost:11434/api/blobs/sha256:29fdb92e57cf0827ded04ae6461b5931d01fa595843f55d36f5b275a52087dd2 -d @llama2-13b-q4_0.gguf ``` -### Response +#### Response ```json { From bc22d5a38b22d894468362307304e931be91c2b0 Mon Sep 17 00:00:00 2001 From: Michael Yang Date: Wed, 15 Nov 2023 13:55:37 -0800 Subject: [PATCH 09/11] no blob response --- api/types.go | 4 ---- server/routes.go | 4 ++-- 2 files changed, 2 insertions(+), 6 deletions(-) diff --git a/api/types.go b/api/types.go index 347c4f84..2a36a1f6 100644 --- a/api/types.go +++ b/api/types.go @@ -105,10 +105,6 @@ type CreateRequest struct { Stream *bool `json:"stream,omitempty"` } -type CreateBlobResponse struct { - Path string `json:"path"` -} - type DeleteRequest struct { Name string `json:"name"` } diff --git a/server/routes.go b/server/routes.go index b3998e5c..9ebe273c 100644 --- a/server/routes.go +++ b/server/routes.go @@ -662,7 +662,7 @@ func HeadBlobHandler(c *gin.Context) { return } - c.JSON(http.StatusOK, api.CreateBlobResponse{Path: path}) + c.Status(http.StatusOK) } func CreateBlobHandler(c *gin.Context) { @@ -701,7 +701,7 @@ func CreateBlobHandler(c *gin.Context) { return } - c.JSON(http.StatusOK, api.CreateBlobResponse{Path: targetPath}) + c.Status(http.StatusCreated) } var defaultAllowOrigins = []string{ From 652d90e1c76e48a56c96c77528cffebdcc632c60 Mon Sep 17 00:00:00 2001 From: Michael Yang Date: Wed, 15 Nov 2023 15:15:36 -0800 Subject: [PATCH 10/11] Update server/images.go Co-authored-by: Bruce MacDonald --- server/images.go | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/server/images.go b/server/images.go index 47ded90c..d8ff0fd8 100644 --- a/server/images.go +++ b/server/images.go @@ -421,7 +421,7 @@ func CreateModel(ctx context.Context, name string, commands []parser.Command, fn case "template", "system": fn(api.ProgressResponse{Status: fmt.Sprintf("creating %s layer", c.Name)}) - // remove duplicates layers + // remove duplicate layers layers = removeLayerFromLayers(layers, mediatype) layer, err := CreateLayer(strings.NewReader(c.Args)) From 54f92f01cbeb16ca2baff5720f24b08d1c543e7f Mon Sep 17 00:00:00 2001 From: Michael Yang Date: Wed, 15 Nov 2023 15:22:12 -0800 Subject: [PATCH 11/11] update docs --- docs/api.md | 36 +++++++++++++++++++++++++++++------- server/routes.go | 3 +-- 2 files changed, 30 insertions(+), 9 deletions(-) diff --git a/docs/api.md b/docs/api.md index 9975c4e3..9bb4d378 100644 --- a/docs/api.md +++ b/docs/api.md @@ -323,6 +323,30 @@ A stream of JSON objects. When finished, `status` is `success`. } ``` +### Check if a Blob Exists + +```shell +HEAD /api/blobs/:digest +``` + +Check if a blob is known to the server. + +#### Query Parameters + +- `digest`: the SHA256 digest of the blob + +#### Examples + +##### Request + +```shell +curl -I http://localhost:11434/api/blobs/sha256:29fdb92e57cf0827ded04ae6461b5931d01fa595843f55d36f5b275a52087dd2 +``` + +##### Response + +Return 200 OK if the blob exists, 404 Not Found if it does not. + ### Create a Blob ```shell @@ -337,17 +361,15 @@ Create a blob from a file. Returns the server file path. #### Examples +##### Request + ```shell -curl -X POST http://localhost:11434/api/blobs/sha256:29fdb92e57cf0827ded04ae6461b5931d01fa595843f55d36f5b275a52087dd2 -d @llama2-13b-q4_0.gguf +curl -T model.bin -X POST http://localhost:11434/api/blobs/sha256:29fdb92e57cf0827ded04ae6461b5931d01fa595843f55d36f5b275a52087dd2 ``` -#### Response +##### Response -```json -{ - "path": "/home/user/.ollama/models/blobs/sha256:29fdb92e57cf0827ded04ae6461b5931d01fa595843f55d36f5b275a52087dd2" -} -``` +Return 201 Created if the blob was successfully created. ## List Local Models diff --git a/server/routes.go b/server/routes.go index 9ebe273c..58145576 100644 --- a/server/routes.go +++ b/server/routes.go @@ -764,6 +764,7 @@ func Serve(ln net.Listener, allowOrigins []string) error { r.DELETE("/api/delete", DeleteModelHandler) r.POST("/api/show", ShowModelHandler) r.POST("/api/blobs/:digest", CreateBlobHandler) + r.HEAD("/api/blobs/:digest", HeadBlobHandler) for _, method := range []string{http.MethodGet, http.MethodHead} { r.Handle(method, "/", func(c *gin.Context) { @@ -773,8 +774,6 @@ func Serve(ln net.Listener, allowOrigins []string) error { r.Handle(method, "/api/tags", ListModelsHandler) } - r.HEAD("/api/blobs/:digest", HeadBlobHandler) - log.Printf("Listening on %s (version %s)", ln.Addr(), version.Version) s := &http.Server{ Handler: r,