Strip protocol from model path (#377)
This commit is contained in:
parent
e3054fc74e
commit
0a892419ad
5 changed files with 231 additions and 43 deletions
26
cmd/cmd.go
26
cmd/cmd.go
|
@ -97,7 +97,16 @@ func CreateHandler(cmd *cobra.Command, args []string) error {
|
|||
}
|
||||
|
||||
func RunHandler(cmd *cobra.Command, args []string) error {
|
||||
mp := server.ParseModelPath(args[0])
|
||||
insecure, err := cmd.Flags().GetBool("insecure")
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
mp, err := server.ParseModelPath(args[0], insecure)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
fp, err := mp.GetManifestPath(false)
|
||||
if err != nil {
|
||||
return err
|
||||
|
@ -106,7 +115,7 @@ func RunHandler(cmd *cobra.Command, args []string) error {
|
|||
_, err = os.Stat(fp)
|
||||
switch {
|
||||
case errors.Is(err, os.ErrNotExist):
|
||||
if err := pull(args[0], false); err != nil {
|
||||
if err := pull(args[0], insecure); err != nil {
|
||||
var apiStatusError api.StatusError
|
||||
if !errors.As(err, &apiStatusError) {
|
||||
return err
|
||||
|
@ -506,7 +515,11 @@ func generateInteractive(cmd *cobra.Command, model string) error {
|
|||
case strings.HasPrefix(line, "/show"):
|
||||
args := strings.Fields(line)
|
||||
if len(args) > 1 {
|
||||
mp := server.ParseModelPath(model)
|
||||
mp, err := server.ParseModelPath(model, false)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
manifest, err := server.GetManifest(mp)
|
||||
if err != nil {
|
||||
fmt.Println("error: couldn't get a manifest for this model")
|
||||
|
@ -569,7 +582,7 @@ func generateBatch(cmd *cobra.Command, model string) error {
|
|||
}
|
||||
|
||||
func RunServer(cmd *cobra.Command, _ []string) error {
|
||||
var host, port = "127.0.0.1", "11434"
|
||||
host, port := "127.0.0.1", "11434"
|
||||
|
||||
parts := strings.Split(os.Getenv("OLLAMA_HOST"), ":")
|
||||
if ip := net.ParseIP(parts[0]); ip != nil {
|
||||
|
@ -630,7 +643,7 @@ func initializeKeypair() error {
|
|||
return fmt.Errorf("could not create directory %w", err)
|
||||
}
|
||||
|
||||
err = os.WriteFile(privKeyPath, pem.EncodeToMemory(privKeyBytes), 0600)
|
||||
err = os.WriteFile(privKeyPath, pem.EncodeToMemory(privKeyBytes), 0o600)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
@ -642,7 +655,7 @@ func initializeKeypair() error {
|
|||
|
||||
pubKeyData := ssh.MarshalAuthorizedKey(sshPrivateKey.PublicKey())
|
||||
|
||||
err = os.WriteFile(pubKeyPath, pubKeyData, 0644)
|
||||
err = os.WriteFile(pubKeyPath, pubKeyData, 0o644)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
@ -737,6 +750,7 @@ func NewCLI() *cobra.Command {
|
|||
}
|
||||
|
||||
runCmd.Flags().Bool("verbose", false, "Show timings for response")
|
||||
runCmd.Flags().Bool("insecure", false, "Use an insecure registry")
|
||||
|
||||
serveCmd := &cobra.Command{
|
||||
Use: "serve",
|
||||
|
|
|
@ -153,7 +153,10 @@ func GetManifest(mp ModelPath) (*ManifestV2, error) {
|
|||
}
|
||||
|
||||
func GetModel(name string) (*Model, error) {
|
||||
mp := ParseModelPath(name)
|
||||
mp, err := ParseModelPath(name, false)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
manifest, err := GetManifest(mp)
|
||||
if err != nil {
|
||||
|
@ -272,7 +275,12 @@ func CreateModel(ctx context.Context, name string, path string, fn func(resp api
|
|||
case "model":
|
||||
fn(api.ProgressResponse{Status: "looking for model"})
|
||||
embed.model = c.Args
|
||||
mp := ParseModelPath(c.Args)
|
||||
|
||||
mp, err := ParseModelPath(c.Args, false)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
mf, err := GetManifest(mp)
|
||||
if err != nil {
|
||||
modelFile, err := filenameWithPath(path, c.Args)
|
||||
|
@ -286,7 +294,7 @@ func CreateModel(ctx context.Context, name string, path string, fn func(resp api
|
|||
if err := PullModel(ctx, c.Args, &RegistryOptions{}, fn); err != nil {
|
||||
return err
|
||||
}
|
||||
mf, err = GetManifest(ParseModelPath(c.Args))
|
||||
mf, err = GetManifest(mp)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to open file after pull: %v", err)
|
||||
}
|
||||
|
@ -674,7 +682,10 @@ func SaveLayers(layers []*LayerReader, fn func(resp api.ProgressResponse), force
|
|||
}
|
||||
|
||||
func CreateManifest(name string, cfg *LayerReader, layers []*Layer) error {
|
||||
mp := ParseModelPath(name)
|
||||
mp, err := ParseModelPath(name, false)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
manifest := ManifestV2{
|
||||
SchemaVersion: 2,
|
||||
|
@ -806,11 +817,22 @@ func CreateLayer(f io.ReadSeeker) (*LayerReader, error) {
|
|||
}
|
||||
|
||||
func CopyModel(src, dest string) error {
|
||||
srcPath, err := ParseModelPath(src).GetManifestPath(false)
|
||||
srcModelPath, err := ParseModelPath(src, false)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
destPath, err := ParseModelPath(dest).GetManifestPath(true)
|
||||
|
||||
srcPath, err := srcModelPath.GetManifestPath(false)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
destModelPath, err := ParseModelPath(dest, false)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
destPath, err := destModelPath.GetManifestPath(true)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
@ -832,7 +854,10 @@ func CopyModel(src, dest string) error {
|
|||
}
|
||||
|
||||
func DeleteModel(name string) error {
|
||||
mp := ParseModelPath(name)
|
||||
mp, err := ParseModelPath(name, false)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
manifest, err := GetManifest(mp)
|
||||
if err != nil {
|
||||
|
@ -859,7 +884,10 @@ func DeleteModel(name string) error {
|
|||
return nil
|
||||
}
|
||||
tag := path[:slashIndex] + ":" + path[slashIndex+1:]
|
||||
fmp := ParseModelPath(tag)
|
||||
fmp, err := ParseModelPath(tag, false)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// skip the manifest we're trying to delete
|
||||
if mp.GetFullTagname() == fmp.GetFullTagname() {
|
||||
|
@ -912,7 +940,10 @@ func DeleteModel(name string) error {
|
|||
}
|
||||
|
||||
func PushModel(ctx context.Context, name string, regOpts *RegistryOptions, fn func(api.ProgressResponse)) error {
|
||||
mp := ParseModelPath(name)
|
||||
mp, err := ParseModelPath(name, regOpts.Insecure)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
fn(api.ProgressResponse{Status: "retrieving manifest"})
|
||||
|
||||
|
@ -995,7 +1026,10 @@ func PushModel(ctx context.Context, name string, regOpts *RegistryOptions, fn fu
|
|||
}
|
||||
|
||||
func PullModel(ctx context.Context, name string, regOpts *RegistryOptions, fn func(api.ProgressResponse)) error {
|
||||
mp := ParseModelPath(name)
|
||||
mp, err := ParseModelPath(name, regOpts.Insecure)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
fn(api.ProgressResponse{Status: "pulling manifest"})
|
||||
|
||||
|
|
|
@ -1,6 +1,7 @@
|
|||
package server
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"fmt"
|
||||
"os"
|
||||
"path/filepath"
|
||||
|
@ -23,42 +24,54 @@ const (
|
|||
DefaultProtocolScheme = "https"
|
||||
)
|
||||
|
||||
func ParseModelPath(name string) ModelPath {
|
||||
slashParts := strings.Split(name, "/")
|
||||
var registry, namespace, repository, tag string
|
||||
var (
|
||||
ErrInvalidImageFormat = errors.New("invalid image format")
|
||||
ErrInvalidProtocol = errors.New("invalid protocol scheme")
|
||||
ErrInsecureProtocol = errors.New("insecure protocol http")
|
||||
)
|
||||
|
||||
func ParseModelPath(name string, allowInsecure bool) (ModelPath, error) {
|
||||
mp := ModelPath{
|
||||
ProtocolScheme: DefaultProtocolScheme,
|
||||
Registry: DefaultRegistry,
|
||||
Namespace: DefaultNamespace,
|
||||
Repository: "",
|
||||
Tag: DefaultTag,
|
||||
}
|
||||
|
||||
protocol, rest, didSplit := strings.Cut(name, "://")
|
||||
if didSplit {
|
||||
if protocol == "https" || protocol == "http" && allowInsecure {
|
||||
mp.ProtocolScheme = protocol
|
||||
name = rest
|
||||
} else if protocol == "http" && !allowInsecure {
|
||||
return ModelPath{}, ErrInsecureProtocol
|
||||
} else {
|
||||
return ModelPath{}, ErrInvalidProtocol
|
||||
}
|
||||
}
|
||||
|
||||
slashParts := strings.Split(name, "/")
|
||||
switch len(slashParts) {
|
||||
case 3:
|
||||
registry = slashParts[0]
|
||||
namespace = slashParts[1]
|
||||
repository = strings.Split(slashParts[2], ":")[0]
|
||||
mp.Registry = slashParts[0]
|
||||
mp.Namespace = slashParts[1]
|
||||
mp.Repository = slashParts[2]
|
||||
case 2:
|
||||
registry = DefaultRegistry
|
||||
namespace = slashParts[0]
|
||||
repository = strings.Split(slashParts[1], ":")[0]
|
||||
mp.Namespace = slashParts[0]
|
||||
mp.Repository = slashParts[1]
|
||||
case 1:
|
||||
registry = DefaultRegistry
|
||||
namespace = DefaultNamespace
|
||||
repository = strings.Split(slashParts[0], ":")[0]
|
||||
mp.Repository = slashParts[0]
|
||||
default:
|
||||
fmt.Println("Invalid image format.")
|
||||
return ModelPath{}
|
||||
return ModelPath{}, ErrInvalidImageFormat
|
||||
}
|
||||
|
||||
colonParts := strings.Split(slashParts[len(slashParts)-1], ":")
|
||||
if len(colonParts) == 2 {
|
||||
tag = colonParts[1]
|
||||
} else {
|
||||
tag = DefaultTag
|
||||
if repo, tag, didSplit := strings.Cut(mp.Repository, ":"); didSplit {
|
||||
mp.Repository = repo
|
||||
mp.Tag = tag
|
||||
}
|
||||
|
||||
return ModelPath{
|
||||
ProtocolScheme: DefaultProtocolScheme,
|
||||
Registry: registry,
|
||||
Namespace: namespace,
|
||||
Repository: repository,
|
||||
Tag: tag,
|
||||
}
|
||||
return mp, nil
|
||||
}
|
||||
|
||||
func (mp ModelPath) GetNamespaceRepository() string {
|
||||
|
|
122
server/modelpath_test.go
Normal file
122
server/modelpath_test.go
Normal file
|
@ -0,0 +1,122 @@
|
|||
package server
|
||||
|
||||
import "testing"
|
||||
|
||||
func TestParseModelPath(t *testing.T) {
|
||||
type input struct {
|
||||
name string
|
||||
allowInsecure bool
|
||||
}
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
args input
|
||||
want ModelPath
|
||||
wantErr error
|
||||
}{
|
||||
{
|
||||
"full path https",
|
||||
input{"https://example.com/ns/repo:tag", false},
|
||||
ModelPath{
|
||||
ProtocolScheme: "https",
|
||||
Registry: "example.com",
|
||||
Namespace: "ns",
|
||||
Repository: "repo",
|
||||
Tag: "tag",
|
||||
},
|
||||
nil,
|
||||
},
|
||||
{
|
||||
"full path http without insecure",
|
||||
input{"http://example.com/ns/repo:tag", false},
|
||||
ModelPath{},
|
||||
ErrInsecureProtocol,
|
||||
},
|
||||
{
|
||||
"full path http with insecure",
|
||||
input{"http://example.com/ns/repo:tag", true},
|
||||
ModelPath{
|
||||
ProtocolScheme: "http",
|
||||
Registry: "example.com",
|
||||
Namespace: "ns",
|
||||
Repository: "repo",
|
||||
Tag: "tag",
|
||||
},
|
||||
nil,
|
||||
},
|
||||
{
|
||||
"full path invalid protocol",
|
||||
input{"file://example.com/ns/repo:tag", false},
|
||||
ModelPath{},
|
||||
ErrInvalidProtocol,
|
||||
},
|
||||
{
|
||||
"no protocol",
|
||||
input{"example.com/ns/repo:tag", false},
|
||||
ModelPath{
|
||||
ProtocolScheme: "https",
|
||||
Registry: "example.com",
|
||||
Namespace: "ns",
|
||||
Repository: "repo",
|
||||
Tag: "tag",
|
||||
},
|
||||
nil,
|
||||
},
|
||||
{
|
||||
"no registry",
|
||||
input{"ns/repo:tag", false},
|
||||
ModelPath{
|
||||
ProtocolScheme: "https",
|
||||
Registry: DefaultRegistry,
|
||||
Namespace: "ns",
|
||||
Repository: "repo",
|
||||
Tag: "tag",
|
||||
},
|
||||
nil,
|
||||
},
|
||||
{
|
||||
"no namespace",
|
||||
input{"repo:tag", false},
|
||||
ModelPath{
|
||||
ProtocolScheme: "https",
|
||||
Registry: DefaultRegistry,
|
||||
Namespace: DefaultNamespace,
|
||||
Repository: "repo",
|
||||
Tag: "tag",
|
||||
},
|
||||
nil,
|
||||
},
|
||||
{
|
||||
"no tag",
|
||||
input{"repo", false},
|
||||
ModelPath{
|
||||
ProtocolScheme: "https",
|
||||
Registry: DefaultRegistry,
|
||||
Namespace: DefaultNamespace,
|
||||
Repository: "repo",
|
||||
Tag: DefaultTag,
|
||||
},
|
||||
nil,
|
||||
},
|
||||
{
|
||||
"invalid image format",
|
||||
input{"example.com/a/b/c", false},
|
||||
ModelPath{},
|
||||
ErrInvalidImageFormat,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tc := range tests {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
got, err := ParseModelPath(tc.args.name, tc.args.allowInsecure)
|
||||
|
||||
if err != tc.wantErr {
|
||||
t.Errorf("got: %q want: %q", err, tc.wantErr)
|
||||
}
|
||||
|
||||
if got != tc.want {
|
||||
t.Errorf("got: %q want: %q", got, tc.want)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
|
@ -357,7 +357,12 @@ func ListModelsHandler(c *gin.Context) {
|
|||
return nil
|
||||
}
|
||||
tag := path[:slashIndex] + ":" + path[slashIndex+1:]
|
||||
mp := ParseModelPath(tag)
|
||||
|
||||
mp, err := ParseModelPath(tag, false)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
manifest, err := GetManifest(mp)
|
||||
if err != nil {
|
||||
log.Printf("skipping file: %s", fp)
|
||||
|
|
Loading…
Reference in a new issue