Strip protocol from model path (#377)

This commit is contained in:
Ryan Baker 2023-08-21 21:56:56 -07:00 committed by GitHub
parent e3054fc74e
commit 0a892419ad
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
5 changed files with 231 additions and 43 deletions

View file

@ -97,7 +97,16 @@ func CreateHandler(cmd *cobra.Command, args []string) error {
} }
func RunHandler(cmd *cobra.Command, args []string) error { func RunHandler(cmd *cobra.Command, args []string) error {
mp := server.ParseModelPath(args[0]) insecure, err := cmd.Flags().GetBool("insecure")
if err != nil {
return err
}
mp, err := server.ParseModelPath(args[0], insecure)
if err != nil {
return err
}
fp, err := mp.GetManifestPath(false) fp, err := mp.GetManifestPath(false)
if err != nil { if err != nil {
return err return err
@ -106,7 +115,7 @@ func RunHandler(cmd *cobra.Command, args []string) error {
_, err = os.Stat(fp) _, err = os.Stat(fp)
switch { switch {
case errors.Is(err, os.ErrNotExist): case errors.Is(err, os.ErrNotExist):
if err := pull(args[0], false); err != nil { if err := pull(args[0], insecure); err != nil {
var apiStatusError api.StatusError var apiStatusError api.StatusError
if !errors.As(err, &apiStatusError) { if !errors.As(err, &apiStatusError) {
return err return err
@ -506,7 +515,11 @@ func generateInteractive(cmd *cobra.Command, model string) error {
case strings.HasPrefix(line, "/show"): case strings.HasPrefix(line, "/show"):
args := strings.Fields(line) args := strings.Fields(line)
if len(args) > 1 { if len(args) > 1 {
mp := server.ParseModelPath(model) mp, err := server.ParseModelPath(model, false)
if err != nil {
return err
}
manifest, err := server.GetManifest(mp) manifest, err := server.GetManifest(mp)
if err != nil { if err != nil {
fmt.Println("error: couldn't get a manifest for this model") fmt.Println("error: couldn't get a manifest for this model")
@ -569,7 +582,7 @@ func generateBatch(cmd *cobra.Command, model string) error {
} }
func RunServer(cmd *cobra.Command, _ []string) error { func RunServer(cmd *cobra.Command, _ []string) error {
var host, port = "127.0.0.1", "11434" host, port := "127.0.0.1", "11434"
parts := strings.Split(os.Getenv("OLLAMA_HOST"), ":") parts := strings.Split(os.Getenv("OLLAMA_HOST"), ":")
if ip := net.ParseIP(parts[0]); ip != nil { if ip := net.ParseIP(parts[0]); ip != nil {
@ -630,7 +643,7 @@ func initializeKeypair() error {
return fmt.Errorf("could not create directory %w", err) return fmt.Errorf("could not create directory %w", err)
} }
err = os.WriteFile(privKeyPath, pem.EncodeToMemory(privKeyBytes), 0600) err = os.WriteFile(privKeyPath, pem.EncodeToMemory(privKeyBytes), 0o600)
if err != nil { if err != nil {
return err return err
} }
@ -642,7 +655,7 @@ func initializeKeypair() error {
pubKeyData := ssh.MarshalAuthorizedKey(sshPrivateKey.PublicKey()) pubKeyData := ssh.MarshalAuthorizedKey(sshPrivateKey.PublicKey())
err = os.WriteFile(pubKeyPath, pubKeyData, 0644) err = os.WriteFile(pubKeyPath, pubKeyData, 0o644)
if err != nil { if err != nil {
return err return err
} }
@ -737,6 +750,7 @@ func NewCLI() *cobra.Command {
} }
runCmd.Flags().Bool("verbose", false, "Show timings for response") runCmd.Flags().Bool("verbose", false, "Show timings for response")
runCmd.Flags().Bool("insecure", false, "Use an insecure registry")
serveCmd := &cobra.Command{ serveCmd := &cobra.Command{
Use: "serve", Use: "serve",

View file

@ -153,7 +153,10 @@ func GetManifest(mp ModelPath) (*ManifestV2, error) {
} }
func GetModel(name string) (*Model, error) { func GetModel(name string) (*Model, error) {
mp := ParseModelPath(name) mp, err := ParseModelPath(name, false)
if err != nil {
return nil, err
}
manifest, err := GetManifest(mp) manifest, err := GetManifest(mp)
if err != nil { if err != nil {
@ -272,7 +275,12 @@ func CreateModel(ctx context.Context, name string, path string, fn func(resp api
case "model": case "model":
fn(api.ProgressResponse{Status: "looking for model"}) fn(api.ProgressResponse{Status: "looking for model"})
embed.model = c.Args embed.model = c.Args
mp := ParseModelPath(c.Args)
mp, err := ParseModelPath(c.Args, false)
if err != nil {
return err
}
mf, err := GetManifest(mp) mf, err := GetManifest(mp)
if err != nil { if err != nil {
modelFile, err := filenameWithPath(path, c.Args) modelFile, err := filenameWithPath(path, c.Args)
@ -286,7 +294,7 @@ func CreateModel(ctx context.Context, name string, path string, fn func(resp api
if err := PullModel(ctx, c.Args, &RegistryOptions{}, fn); err != nil { if err := PullModel(ctx, c.Args, &RegistryOptions{}, fn); err != nil {
return err return err
} }
mf, err = GetManifest(ParseModelPath(c.Args)) mf, err = GetManifest(mp)
if err != nil { if err != nil {
return fmt.Errorf("failed to open file after pull: %v", err) return fmt.Errorf("failed to open file after pull: %v", err)
} }
@ -674,7 +682,10 @@ func SaveLayers(layers []*LayerReader, fn func(resp api.ProgressResponse), force
} }
func CreateManifest(name string, cfg *LayerReader, layers []*Layer) error { func CreateManifest(name string, cfg *LayerReader, layers []*Layer) error {
mp := ParseModelPath(name) mp, err := ParseModelPath(name, false)
if err != nil {
return err
}
manifest := ManifestV2{ manifest := ManifestV2{
SchemaVersion: 2, SchemaVersion: 2,
@ -806,11 +817,22 @@ func CreateLayer(f io.ReadSeeker) (*LayerReader, error) {
} }
func CopyModel(src, dest string) error { func CopyModel(src, dest string) error {
srcPath, err := ParseModelPath(src).GetManifestPath(false) srcModelPath, err := ParseModelPath(src, false)
if err != nil { if err != nil {
return err return err
} }
destPath, err := ParseModelPath(dest).GetManifestPath(true)
srcPath, err := srcModelPath.GetManifestPath(false)
if err != nil {
return err
}
destModelPath, err := ParseModelPath(dest, false)
if err != nil {
return err
}
destPath, err := destModelPath.GetManifestPath(true)
if err != nil { if err != nil {
return err return err
} }
@ -832,7 +854,10 @@ func CopyModel(src, dest string) error {
} }
func DeleteModel(name string) error { func DeleteModel(name string) error {
mp := ParseModelPath(name) mp, err := ParseModelPath(name, false)
if err != nil {
return err
}
manifest, err := GetManifest(mp) manifest, err := GetManifest(mp)
if err != nil { if err != nil {
@ -859,7 +884,10 @@ func DeleteModel(name string) error {
return nil return nil
} }
tag := path[:slashIndex] + ":" + path[slashIndex+1:] tag := path[:slashIndex] + ":" + path[slashIndex+1:]
fmp := ParseModelPath(tag) fmp, err := ParseModelPath(tag, false)
if err != nil {
return err
}
// skip the manifest we're trying to delete // skip the manifest we're trying to delete
if mp.GetFullTagname() == fmp.GetFullTagname() { if mp.GetFullTagname() == fmp.GetFullTagname() {
@ -912,7 +940,10 @@ func DeleteModel(name string) error {
} }
func PushModel(ctx context.Context, name string, regOpts *RegistryOptions, fn func(api.ProgressResponse)) error { func PushModel(ctx context.Context, name string, regOpts *RegistryOptions, fn func(api.ProgressResponse)) error {
mp := ParseModelPath(name) mp, err := ParseModelPath(name, regOpts.Insecure)
if err != nil {
return err
}
fn(api.ProgressResponse{Status: "retrieving manifest"}) fn(api.ProgressResponse{Status: "retrieving manifest"})
@ -995,7 +1026,10 @@ func PushModel(ctx context.Context, name string, regOpts *RegistryOptions, fn fu
} }
func PullModel(ctx context.Context, name string, regOpts *RegistryOptions, fn func(api.ProgressResponse)) error { func PullModel(ctx context.Context, name string, regOpts *RegistryOptions, fn func(api.ProgressResponse)) error {
mp := ParseModelPath(name) mp, err := ParseModelPath(name, regOpts.Insecure)
if err != nil {
return err
}
fn(api.ProgressResponse{Status: "pulling manifest"}) fn(api.ProgressResponse{Status: "pulling manifest"})

View file

@ -1,6 +1,7 @@
package server package server
import ( import (
"errors"
"fmt" "fmt"
"os" "os"
"path/filepath" "path/filepath"
@ -23,42 +24,54 @@ const (
DefaultProtocolScheme = "https" DefaultProtocolScheme = "https"
) )
func ParseModelPath(name string) ModelPath { var (
slashParts := strings.Split(name, "/") ErrInvalidImageFormat = errors.New("invalid image format")
var registry, namespace, repository, tag string ErrInvalidProtocol = errors.New("invalid protocol scheme")
ErrInsecureProtocol = errors.New("insecure protocol http")
)
func ParseModelPath(name string, allowInsecure bool) (ModelPath, error) {
mp := ModelPath{
ProtocolScheme: DefaultProtocolScheme,
Registry: DefaultRegistry,
Namespace: DefaultNamespace,
Repository: "",
Tag: DefaultTag,
}
protocol, rest, didSplit := strings.Cut(name, "://")
if didSplit {
if protocol == "https" || protocol == "http" && allowInsecure {
mp.ProtocolScheme = protocol
name = rest
} else if protocol == "http" && !allowInsecure {
return ModelPath{}, ErrInsecureProtocol
} else {
return ModelPath{}, ErrInvalidProtocol
}
}
slashParts := strings.Split(name, "/")
switch len(slashParts) { switch len(slashParts) {
case 3: case 3:
registry = slashParts[0] mp.Registry = slashParts[0]
namespace = slashParts[1] mp.Namespace = slashParts[1]
repository = strings.Split(slashParts[2], ":")[0] mp.Repository = slashParts[2]
case 2: case 2:
registry = DefaultRegistry mp.Namespace = slashParts[0]
namespace = slashParts[0] mp.Repository = slashParts[1]
repository = strings.Split(slashParts[1], ":")[0]
case 1: case 1:
registry = DefaultRegistry mp.Repository = slashParts[0]
namespace = DefaultNamespace
repository = strings.Split(slashParts[0], ":")[0]
default: default:
fmt.Println("Invalid image format.") return ModelPath{}, ErrInvalidImageFormat
return ModelPath{}
} }
colonParts := strings.Split(slashParts[len(slashParts)-1], ":") if repo, tag, didSplit := strings.Cut(mp.Repository, ":"); didSplit {
if len(colonParts) == 2 { mp.Repository = repo
tag = colonParts[1] mp.Tag = tag
} else {
tag = DefaultTag
} }
return ModelPath{ return mp, nil
ProtocolScheme: DefaultProtocolScheme,
Registry: registry,
Namespace: namespace,
Repository: repository,
Tag: tag,
}
} }
func (mp ModelPath) GetNamespaceRepository() string { func (mp ModelPath) GetNamespaceRepository() string {

122
server/modelpath_test.go Normal file
View file

@ -0,0 +1,122 @@
package server
import "testing"
func TestParseModelPath(t *testing.T) {
type input struct {
name string
allowInsecure bool
}
tests := []struct {
name string
args input
want ModelPath
wantErr error
}{
{
"full path https",
input{"https://example.com/ns/repo:tag", false},
ModelPath{
ProtocolScheme: "https",
Registry: "example.com",
Namespace: "ns",
Repository: "repo",
Tag: "tag",
},
nil,
},
{
"full path http without insecure",
input{"http://example.com/ns/repo:tag", false},
ModelPath{},
ErrInsecureProtocol,
},
{
"full path http with insecure",
input{"http://example.com/ns/repo:tag", true},
ModelPath{
ProtocolScheme: "http",
Registry: "example.com",
Namespace: "ns",
Repository: "repo",
Tag: "tag",
},
nil,
},
{
"full path invalid protocol",
input{"file://example.com/ns/repo:tag", false},
ModelPath{},
ErrInvalidProtocol,
},
{
"no protocol",
input{"example.com/ns/repo:tag", false},
ModelPath{
ProtocolScheme: "https",
Registry: "example.com",
Namespace: "ns",
Repository: "repo",
Tag: "tag",
},
nil,
},
{
"no registry",
input{"ns/repo:tag", false},
ModelPath{
ProtocolScheme: "https",
Registry: DefaultRegistry,
Namespace: "ns",
Repository: "repo",
Tag: "tag",
},
nil,
},
{
"no namespace",
input{"repo:tag", false},
ModelPath{
ProtocolScheme: "https",
Registry: DefaultRegistry,
Namespace: DefaultNamespace,
Repository: "repo",
Tag: "tag",
},
nil,
},
{
"no tag",
input{"repo", false},
ModelPath{
ProtocolScheme: "https",
Registry: DefaultRegistry,
Namespace: DefaultNamespace,
Repository: "repo",
Tag: DefaultTag,
},
nil,
},
{
"invalid image format",
input{"example.com/a/b/c", false},
ModelPath{},
ErrInvalidImageFormat,
},
}
for _, tc := range tests {
t.Run(tc.name, func(t *testing.T) {
got, err := ParseModelPath(tc.args.name, tc.args.allowInsecure)
if err != tc.wantErr {
t.Errorf("got: %q want: %q", err, tc.wantErr)
}
if got != tc.want {
t.Errorf("got: %q want: %q", got, tc.want)
}
})
}
}

View file

@ -357,7 +357,12 @@ func ListModelsHandler(c *gin.Context) {
return nil return nil
} }
tag := path[:slashIndex] + ":" + path[slashIndex+1:] tag := path[:slashIndex] + ":" + path[slashIndex+1:]
mp := ParseModelPath(tag)
mp, err := ParseModelPath(tag, false)
if err != nil {
return err
}
manifest, err := GetManifest(mp) manifest, err := GetManifest(mp)
if err != nil { if err != nil {
log.Printf("skipping file: %s", fp) log.Printf("skipping file: %s", fp)