From a9f6c56652dabbf77b64e695fb9bff6f0b6de797 Mon Sep 17 00:00:00 2001 From: Jeffrey Morgan Date: Tue, 22 Aug 2023 09:39:42 -0700 Subject: [PATCH] fix `FROM` instruction erroring when referring to a file --- cmd/cmd.go | 8 ++++-- server/images.go | 58 +++++++++++----------------------------- server/modelpath.go | 38 +++++++++++--------------- server/modelpath_test.go | 52 +++++++---------------------------- server/routes.go | 6 +---- 5 files changed, 47 insertions(+), 115 deletions(-) diff --git a/cmd/cmd.go b/cmd/cmd.go index d713a35b..09fb2e92 100644 --- a/cmd/cmd.go +++ b/cmd/cmd.go @@ -102,11 +102,15 @@ func RunHandler(cmd *cobra.Command, args []string) error { return err } - mp, err := server.ParseModelPath(args[0], insecure) + mp := server.ParseModelPath(args[0]) if err != nil { return err } + if mp.ProtocolScheme == "http" && !insecure { + return fmt.Errorf("insecure protocol http") + } + fp, err := mp.GetManifestPath(false) if err != nil { return err @@ -515,7 +519,7 @@ func generateInteractive(cmd *cobra.Command, model string) error { case strings.HasPrefix(line, "/show"): args := strings.Fields(line) if len(args) > 1 { - mp, err := server.ParseModelPath(model, false) + mp := server.ParseModelPath(model) if err != nil { return err } diff --git a/server/images.go b/server/images.go index 0c0d428e..441c54ea 100644 --- a/server/images.go +++ b/server/images.go @@ -153,11 +153,7 @@ func GetManifest(mp ModelPath) (*ManifestV2, error) { } func GetModel(name string) (*Model, error) { - mp, err := ParseModelPath(name, false) - if err != nil { - return nil, err - } - + mp := ParseModelPath(name) manifest, err := GetManifest(mp) if err != nil { return nil, err @@ -276,11 +272,7 @@ func CreateModel(ctx context.Context, name string, path string, fn func(resp api fn(api.ProgressResponse{Status: "looking for model"}) embed.model = c.Args - mp, err := ParseModelPath(c.Args, false) - if err != nil { - return err - } - + mp := ParseModelPath(c.Args) mf, err := GetManifest(mp) if err != nil { modelFile, err := filenameWithPath(path, c.Args) @@ -682,11 +674,7 @@ func SaveLayers(layers []*LayerReader, fn func(resp api.ProgressResponse), force } func CreateManifest(name string, cfg *LayerReader, layers []*Layer) error { - mp, err := ParseModelPath(name, false) - if err != nil { - return err - } - + mp := ParseModelPath(name) manifest := ManifestV2{ SchemaVersion: 2, MediaType: "application/vnd.docker.distribution.manifest.v2+json", @@ -817,21 +805,13 @@ func CreateLayer(f io.ReadSeeker) (*LayerReader, error) { } func CopyModel(src, dest string) error { - srcModelPath, err := ParseModelPath(src, false) - if err != nil { - return err - } - + srcModelPath := ParseModelPath(src) srcPath, err := srcModelPath.GetManifestPath(false) if err != nil { return err } - destModelPath, err := ParseModelPath(dest, false) - if err != nil { - return err - } - + destModelPath := ParseModelPath(dest) destPath, err := destModelPath.GetManifestPath(true) if err != nil { return err @@ -854,11 +834,7 @@ func CopyModel(src, dest string) error { } func DeleteModel(name string) error { - mp, err := ParseModelPath(name, false) - if err != nil { - return err - } - + mp := ParseModelPath(name) manifest, err := GetManifest(mp) if err != nil { return err @@ -884,10 +860,7 @@ func DeleteModel(name string) error { return nil } tag := path[:slashIndex] + ":" + path[slashIndex+1:] - fmp, err := ParseModelPath(tag, false) - if err != nil { - return err - } + fmp := ParseModelPath(tag) // skip the manifest we're trying to delete if mp.GetFullTagname() == fmp.GetFullTagname() { @@ -940,13 +913,13 @@ func DeleteModel(name string) error { } func PushModel(ctx context.Context, name string, regOpts *RegistryOptions, fn func(api.ProgressResponse)) error { - mp, err := ParseModelPath(name, regOpts.Insecure) - if err != nil { - return err - } - + mp := ParseModelPath(name) fn(api.ProgressResponse{Status: "retrieving manifest"}) + if mp.ProtocolScheme == "http" && !regOpts.Insecure { + return fmt.Errorf("insecure protocol http") + } + manifest, err := GetManifest(mp) if err != nil { fn(api.ProgressResponse{Status: "couldn't retrieve manifest"}) @@ -1026,9 +999,10 @@ func PushModel(ctx context.Context, name string, regOpts *RegistryOptions, fn fu } func PullModel(ctx context.Context, name string, regOpts *RegistryOptions, fn func(api.ProgressResponse)) error { - mp, err := ParseModelPath(name, regOpts.Insecure) - if err != nil { - return err + mp := ParseModelPath(name) + + if mp.ProtocolScheme == "http" && !regOpts.Insecure { + return fmt.Errorf("insecure protocol http") } fn(api.ProgressResponse{Status: "pulling manifest"}) diff --git a/server/modelpath.go b/server/modelpath.go index e331f1f6..0fe67211 100644 --- a/server/modelpath.go +++ b/server/modelpath.go @@ -30,7 +30,7 @@ var ( ErrInsecureProtocol = errors.New("insecure protocol http") ) -func ParseModelPath(name string, allowInsecure bool) (ModelPath, error) { +func ParseModelPath(name string) ModelPath { mp := ModelPath{ ProtocolScheme: DefaultProtocolScheme, Registry: DefaultRegistry, @@ -39,39 +39,31 @@ func ParseModelPath(name string, allowInsecure bool) (ModelPath, error) { Tag: DefaultTag, } - protocol, rest, didSplit := strings.Cut(name, "://") - if didSplit { - if protocol == "https" || protocol == "http" && allowInsecure { - mp.ProtocolScheme = protocol - name = rest - } else if protocol == "http" && !allowInsecure { - return ModelPath{}, ErrInsecureProtocol - } else { - return ModelPath{}, ErrInvalidProtocol - } + parts := strings.Split(name, "://") + if len(parts) > 1 { + mp.ProtocolScheme = parts[0] + name = parts[1] } - slashParts := strings.Split(name, "/") - switch len(slashParts) { + parts = strings.Split(name, "/") + switch len(parts) { case 3: - mp.Registry = slashParts[0] - mp.Namespace = slashParts[1] - mp.Repository = slashParts[2] + mp.Registry = parts[0] + mp.Namespace = parts[1] + mp.Repository = parts[2] case 2: - mp.Namespace = slashParts[0] - mp.Repository = slashParts[1] + mp.Namespace = parts[0] + mp.Repository = parts[1] case 1: - mp.Repository = slashParts[0] - default: - return ModelPath{}, ErrInvalidImageFormat + mp.Repository = parts[0] } - if repo, tag, didSplit := strings.Cut(mp.Repository, ":"); didSplit { + if repo, tag, found := strings.Cut(mp.Repository, ":"); found { mp.Repository = repo mp.Tag = tag } - return mp, nil + return mp } func (mp ModelPath) GetNamespaceRepository() string { diff --git a/server/modelpath_test.go b/server/modelpath_test.go index 2641af90..c52c689c 100644 --- a/server/modelpath_test.go +++ b/server/modelpath_test.go @@ -3,20 +3,14 @@ package server import "testing" func TestParseModelPath(t *testing.T) { - type input struct { - name string - allowInsecure bool - } - tests := []struct { name string - args input + arg string want ModelPath - wantErr error }{ { "full path https", - input{"https://example.com/ns/repo:tag", false}, + "https://example.com/ns/repo:tag", ModelPath{ ProtocolScheme: "https", Registry: "example.com", @@ -24,17 +18,10 @@ func TestParseModelPath(t *testing.T) { Repository: "repo", Tag: "tag", }, - nil, }, { - "full path http without insecure", - input{"http://example.com/ns/repo:tag", false}, - ModelPath{}, - ErrInsecureProtocol, - }, - { - "full path http with insecure", - input{"http://example.com/ns/repo:tag", true}, + "full path http", + "http://example.com/ns/repo:tag", ModelPath{ ProtocolScheme: "http", Registry: "example.com", @@ -42,17 +29,10 @@ func TestParseModelPath(t *testing.T) { Repository: "repo", Tag: "tag", }, - nil, - }, - { - "full path invalid protocol", - input{"file://example.com/ns/repo:tag", false}, - ModelPath{}, - ErrInvalidProtocol, }, { "no protocol", - input{"example.com/ns/repo:tag", false}, + "example.com/ns/repo:tag", ModelPath{ ProtocolScheme: "https", Registry: "example.com", @@ -60,11 +40,10 @@ func TestParseModelPath(t *testing.T) { Repository: "repo", Tag: "tag", }, - nil, }, { "no registry", - input{"ns/repo:tag", false}, + "ns/repo:tag", ModelPath{ ProtocolScheme: "https", Registry: DefaultRegistry, @@ -72,11 +51,10 @@ func TestParseModelPath(t *testing.T) { Repository: "repo", Tag: "tag", }, - nil, }, { "no namespace", - input{"repo:tag", false}, + "repo:tag", ModelPath{ ProtocolScheme: "https", Registry: DefaultRegistry, @@ -84,11 +62,10 @@ func TestParseModelPath(t *testing.T) { Repository: "repo", Tag: "tag", }, - nil, }, { "no tag", - input{"repo", false}, + "repo", ModelPath{ ProtocolScheme: "https", Registry: DefaultRegistry, @@ -96,23 +73,12 @@ func TestParseModelPath(t *testing.T) { Repository: "repo", Tag: DefaultTag, }, - nil, - }, - { - "invalid image format", - input{"example.com/a/b/c", false}, - ModelPath{}, - ErrInvalidImageFormat, }, } for _, tc := range tests { t.Run(tc.name, func(t *testing.T) { - got, err := ParseModelPath(tc.args.name, tc.args.allowInsecure) - - if err != tc.wantErr { - t.Errorf("got: %q want: %q", err, tc.wantErr) - } + got := ParseModelPath(tc.arg) if got != tc.want { t.Errorf("got: %q want: %q", got, tc.want) diff --git a/server/routes.go b/server/routes.go index d0dc3d32..880eba8a 100644 --- a/server/routes.go +++ b/server/routes.go @@ -358,11 +358,7 @@ func ListModelsHandler(c *gin.Context) { } tag := path[:slashIndex] + ":" + path[slashIndex+1:] - mp, err := ParseModelPath(tag, false) - if err != nil { - return err - } - + mp := ParseModelPath(tag) manifest, err := GetManifest(mp) if err != nil { log.Printf("skipping file: %s", fp)