fix FROM instruction erroring when referring to a file

This commit is contained in:
Jeffrey Morgan 2023-08-22 09:39:42 -07:00
parent 0a892419ad
commit a9f6c56652
5 changed files with 47 additions and 115 deletions

View file

@ -102,11 +102,15 @@ func RunHandler(cmd *cobra.Command, args []string) error {
return err return err
} }
mp, err := server.ParseModelPath(args[0], insecure) mp := server.ParseModelPath(args[0])
if err != nil { if err != nil {
return err return err
} }
if mp.ProtocolScheme == "http" && !insecure {
return fmt.Errorf("insecure protocol http")
}
fp, err := mp.GetManifestPath(false) fp, err := mp.GetManifestPath(false)
if err != nil { if err != nil {
return err return err
@ -515,7 +519,7 @@ func generateInteractive(cmd *cobra.Command, model string) error {
case strings.HasPrefix(line, "/show"): case strings.HasPrefix(line, "/show"):
args := strings.Fields(line) args := strings.Fields(line)
if len(args) > 1 { if len(args) > 1 {
mp, err := server.ParseModelPath(model, false) mp := server.ParseModelPath(model)
if err != nil { if err != nil {
return err return err
} }

View file

@ -153,11 +153,7 @@ func GetManifest(mp ModelPath) (*ManifestV2, error) {
} }
func GetModel(name string) (*Model, error) { func GetModel(name string) (*Model, error) {
mp, err := ParseModelPath(name, false) mp := ParseModelPath(name)
if err != nil {
return nil, err
}
manifest, err := GetManifest(mp) manifest, err := GetManifest(mp)
if err != nil { if err != nil {
return nil, err return nil, err
@ -276,11 +272,7 @@ func CreateModel(ctx context.Context, name string, path string, fn func(resp api
fn(api.ProgressResponse{Status: "looking for model"}) fn(api.ProgressResponse{Status: "looking for model"})
embed.model = c.Args embed.model = c.Args
mp, err := ParseModelPath(c.Args, false) mp := ParseModelPath(c.Args)
if err != nil {
return err
}
mf, err := GetManifest(mp) mf, err := GetManifest(mp)
if err != nil { if err != nil {
modelFile, err := filenameWithPath(path, c.Args) modelFile, err := filenameWithPath(path, c.Args)
@ -682,11 +674,7 @@ func SaveLayers(layers []*LayerReader, fn func(resp api.ProgressResponse), force
} }
func CreateManifest(name string, cfg *LayerReader, layers []*Layer) error { func CreateManifest(name string, cfg *LayerReader, layers []*Layer) error {
mp, err := ParseModelPath(name, false) mp := ParseModelPath(name)
if err != nil {
return err
}
manifest := ManifestV2{ manifest := ManifestV2{
SchemaVersion: 2, SchemaVersion: 2,
MediaType: "application/vnd.docker.distribution.manifest.v2+json", MediaType: "application/vnd.docker.distribution.manifest.v2+json",
@ -817,21 +805,13 @@ func CreateLayer(f io.ReadSeeker) (*LayerReader, error) {
} }
func CopyModel(src, dest string) error { func CopyModel(src, dest string) error {
srcModelPath, err := ParseModelPath(src, false) srcModelPath := ParseModelPath(src)
if err != nil {
return err
}
srcPath, err := srcModelPath.GetManifestPath(false) srcPath, err := srcModelPath.GetManifestPath(false)
if err != nil { if err != nil {
return err return err
} }
destModelPath, err := ParseModelPath(dest, false) destModelPath := ParseModelPath(dest)
if err != nil {
return err
}
destPath, err := destModelPath.GetManifestPath(true) destPath, err := destModelPath.GetManifestPath(true)
if err != nil { if err != nil {
return err return err
@ -854,11 +834,7 @@ func CopyModel(src, dest string) error {
} }
func DeleteModel(name string) error { func DeleteModel(name string) error {
mp, err := ParseModelPath(name, false) mp := ParseModelPath(name)
if err != nil {
return err
}
manifest, err := GetManifest(mp) manifest, err := GetManifest(mp)
if err != nil { if err != nil {
return err return err
@ -884,10 +860,7 @@ func DeleteModel(name string) error {
return nil return nil
} }
tag := path[:slashIndex] + ":" + path[slashIndex+1:] tag := path[:slashIndex] + ":" + path[slashIndex+1:]
fmp, err := ParseModelPath(tag, false) fmp := ParseModelPath(tag)
if err != nil {
return err
}
// skip the manifest we're trying to delete // skip the manifest we're trying to delete
if mp.GetFullTagname() == fmp.GetFullTagname() { if mp.GetFullTagname() == fmp.GetFullTagname() {
@ -940,13 +913,13 @@ func DeleteModel(name string) error {
} }
func PushModel(ctx context.Context, name string, regOpts *RegistryOptions, fn func(api.ProgressResponse)) error { func PushModel(ctx context.Context, name string, regOpts *RegistryOptions, fn func(api.ProgressResponse)) error {
mp, err := ParseModelPath(name, regOpts.Insecure) mp := ParseModelPath(name)
if err != nil {
return err
}
fn(api.ProgressResponse{Status: "retrieving manifest"}) fn(api.ProgressResponse{Status: "retrieving manifest"})
if mp.ProtocolScheme == "http" && !regOpts.Insecure {
return fmt.Errorf("insecure protocol http")
}
manifest, err := GetManifest(mp) manifest, err := GetManifest(mp)
if err != nil { if err != nil {
fn(api.ProgressResponse{Status: "couldn't retrieve manifest"}) fn(api.ProgressResponse{Status: "couldn't retrieve manifest"})
@ -1026,9 +999,10 @@ func PushModel(ctx context.Context, name string, regOpts *RegistryOptions, fn fu
} }
func PullModel(ctx context.Context, name string, regOpts *RegistryOptions, fn func(api.ProgressResponse)) error { func PullModel(ctx context.Context, name string, regOpts *RegistryOptions, fn func(api.ProgressResponse)) error {
mp, err := ParseModelPath(name, regOpts.Insecure) mp := ParseModelPath(name)
if err != nil {
return err if mp.ProtocolScheme == "http" && !regOpts.Insecure {
return fmt.Errorf("insecure protocol http")
} }
fn(api.ProgressResponse{Status: "pulling manifest"}) fn(api.ProgressResponse{Status: "pulling manifest"})

View file

@ -30,7 +30,7 @@ var (
ErrInsecureProtocol = errors.New("insecure protocol http") ErrInsecureProtocol = errors.New("insecure protocol http")
) )
func ParseModelPath(name string, allowInsecure bool) (ModelPath, error) { func ParseModelPath(name string) ModelPath {
mp := ModelPath{ mp := ModelPath{
ProtocolScheme: DefaultProtocolScheme, ProtocolScheme: DefaultProtocolScheme,
Registry: DefaultRegistry, Registry: DefaultRegistry,
@ -39,39 +39,31 @@ func ParseModelPath(name string, allowInsecure bool) (ModelPath, error) {
Tag: DefaultTag, Tag: DefaultTag,
} }
protocol, rest, didSplit := strings.Cut(name, "://") parts := strings.Split(name, "://")
if didSplit { if len(parts) > 1 {
if protocol == "https" || protocol == "http" && allowInsecure { mp.ProtocolScheme = parts[0]
mp.ProtocolScheme = protocol name = parts[1]
name = rest
} else if protocol == "http" && !allowInsecure {
return ModelPath{}, ErrInsecureProtocol
} else {
return ModelPath{}, ErrInvalidProtocol
}
} }
slashParts := strings.Split(name, "/") parts = strings.Split(name, "/")
switch len(slashParts) { switch len(parts) {
case 3: case 3:
mp.Registry = slashParts[0] mp.Registry = parts[0]
mp.Namespace = slashParts[1] mp.Namespace = parts[1]
mp.Repository = slashParts[2] mp.Repository = parts[2]
case 2: case 2:
mp.Namespace = slashParts[0] mp.Namespace = parts[0]
mp.Repository = slashParts[1] mp.Repository = parts[1]
case 1: case 1:
mp.Repository = slashParts[0] mp.Repository = parts[0]
default:
return ModelPath{}, ErrInvalidImageFormat
} }
if repo, tag, didSplit := strings.Cut(mp.Repository, ":"); didSplit { if repo, tag, found := strings.Cut(mp.Repository, ":"); found {
mp.Repository = repo mp.Repository = repo
mp.Tag = tag mp.Tag = tag
} }
return mp, nil return mp
} }
func (mp ModelPath) GetNamespaceRepository() string { func (mp ModelPath) GetNamespaceRepository() string {

View file

@ -3,20 +3,14 @@ package server
import "testing" import "testing"
func TestParseModelPath(t *testing.T) { func TestParseModelPath(t *testing.T) {
type input struct {
name string
allowInsecure bool
}
tests := []struct { tests := []struct {
name string name string
args input arg string
want ModelPath want ModelPath
wantErr error
}{ }{
{ {
"full path https", "full path https",
input{"https://example.com/ns/repo:tag", false}, "https://example.com/ns/repo:tag",
ModelPath{ ModelPath{
ProtocolScheme: "https", ProtocolScheme: "https",
Registry: "example.com", Registry: "example.com",
@ -24,17 +18,10 @@ func TestParseModelPath(t *testing.T) {
Repository: "repo", Repository: "repo",
Tag: "tag", Tag: "tag",
}, },
nil,
}, },
{ {
"full path http without insecure", "full path http",
input{"http://example.com/ns/repo:tag", false}, "http://example.com/ns/repo:tag",
ModelPath{},
ErrInsecureProtocol,
},
{
"full path http with insecure",
input{"http://example.com/ns/repo:tag", true},
ModelPath{ ModelPath{
ProtocolScheme: "http", ProtocolScheme: "http",
Registry: "example.com", Registry: "example.com",
@ -42,17 +29,10 @@ func TestParseModelPath(t *testing.T) {
Repository: "repo", Repository: "repo",
Tag: "tag", Tag: "tag",
}, },
nil,
},
{
"full path invalid protocol",
input{"file://example.com/ns/repo:tag", false},
ModelPath{},
ErrInvalidProtocol,
}, },
{ {
"no protocol", "no protocol",
input{"example.com/ns/repo:tag", false}, "example.com/ns/repo:tag",
ModelPath{ ModelPath{
ProtocolScheme: "https", ProtocolScheme: "https",
Registry: "example.com", Registry: "example.com",
@ -60,11 +40,10 @@ func TestParseModelPath(t *testing.T) {
Repository: "repo", Repository: "repo",
Tag: "tag", Tag: "tag",
}, },
nil,
}, },
{ {
"no registry", "no registry",
input{"ns/repo:tag", false}, "ns/repo:tag",
ModelPath{ ModelPath{
ProtocolScheme: "https", ProtocolScheme: "https",
Registry: DefaultRegistry, Registry: DefaultRegistry,
@ -72,11 +51,10 @@ func TestParseModelPath(t *testing.T) {
Repository: "repo", Repository: "repo",
Tag: "tag", Tag: "tag",
}, },
nil,
}, },
{ {
"no namespace", "no namespace",
input{"repo:tag", false}, "repo:tag",
ModelPath{ ModelPath{
ProtocolScheme: "https", ProtocolScheme: "https",
Registry: DefaultRegistry, Registry: DefaultRegistry,
@ -84,11 +62,10 @@ func TestParseModelPath(t *testing.T) {
Repository: "repo", Repository: "repo",
Tag: "tag", Tag: "tag",
}, },
nil,
}, },
{ {
"no tag", "no tag",
input{"repo", false}, "repo",
ModelPath{ ModelPath{
ProtocolScheme: "https", ProtocolScheme: "https",
Registry: DefaultRegistry, Registry: DefaultRegistry,
@ -96,23 +73,12 @@ func TestParseModelPath(t *testing.T) {
Repository: "repo", Repository: "repo",
Tag: DefaultTag, Tag: DefaultTag,
}, },
nil,
},
{
"invalid image format",
input{"example.com/a/b/c", false},
ModelPath{},
ErrInvalidImageFormat,
}, },
} }
for _, tc := range tests { for _, tc := range tests {
t.Run(tc.name, func(t *testing.T) { t.Run(tc.name, func(t *testing.T) {
got, err := ParseModelPath(tc.args.name, tc.args.allowInsecure) got := ParseModelPath(tc.arg)
if err != tc.wantErr {
t.Errorf("got: %q want: %q", err, tc.wantErr)
}
if got != tc.want { if got != tc.want {
t.Errorf("got: %q want: %q", got, tc.want) t.Errorf("got: %q want: %q", got, tc.want)

View file

@ -358,11 +358,7 @@ func ListModelsHandler(c *gin.Context) {
} }
tag := path[:slashIndex] + ":" + path[slashIndex+1:] tag := path[:slashIndex] + ":" + path[slashIndex+1:]
mp, err := ParseModelPath(tag, false) mp := ParseModelPath(tag)
if err != nil {
return err
}
manifest, err := GetManifest(mp) manifest, err := GetManifest(mp)
if err != nil { if err != nil {
log.Printf("skipping file: %s", fp) log.Printf("skipping file: %s", fp)