validate model path
This commit is contained in:
parent
6c1c1ad6a9
commit
d9d50c43cc
2 changed files with 13 additions and 13 deletions
|
@ -73,18 +73,6 @@ func ParseModelPath(name string) ModelPath {
|
||||||
|
|
||||||
var errModelPathInvalid = errors.New("invalid model path")
|
var errModelPathInvalid = errors.New("invalid model path")
|
||||||
|
|
||||||
func (mp ModelPath) Validate() error {
|
|
||||||
if mp.Repository == "" {
|
|
||||||
return fmt.Errorf("%w: model repository name is required", errModelPathInvalid)
|
|
||||||
}
|
|
||||||
|
|
||||||
if strings.Contains(mp.Tag, ":") {
|
|
||||||
return fmt.Errorf("%w: ':' (colon) is not allowed in tag names", errModelPathInvalid)
|
|
||||||
}
|
|
||||||
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func (mp ModelPath) GetNamespaceRepository() string {
|
func (mp ModelPath) GetNamespaceRepository() string {
|
||||||
return fmt.Sprintf("%s/%s", mp.Namespace, mp.Repository)
|
return fmt.Sprintf("%s/%s", mp.Namespace, mp.Repository)
|
||||||
}
|
}
|
||||||
|
@ -105,7 +93,11 @@ func (mp ModelPath) GetShortTagname() string {
|
||||||
|
|
||||||
// GetManifestPath returns the path to the manifest file for the given model path, it is up to the caller to create the directory if it does not exist.
|
// GetManifestPath returns the path to the manifest file for the given model path, it is up to the caller to create the directory if it does not exist.
|
||||||
func (mp ModelPath) GetManifestPath() (string, error) {
|
func (mp ModelPath) GetManifestPath() (string, error) {
|
||||||
return filepath.Join(envconfig.Models(), "manifests", mp.Registry, mp.Namespace, mp.Repository, mp.Tag), nil
|
if p := filepath.Join(mp.Registry, mp.Namespace, mp.Repository, mp.Tag); filepath.IsLocal(p) {
|
||||||
|
return filepath.Join(envconfig.Models(), "manifests", p), nil
|
||||||
|
}
|
||||||
|
|
||||||
|
return "", errModelPathInvalid
|
||||||
}
|
}
|
||||||
|
|
||||||
func (mp ModelPath) BaseURL() *url.URL {
|
func (mp ModelPath) BaseURL() *url.URL {
|
||||||
|
|
|
@ -1,6 +1,7 @@
|
||||||
package server
|
package server
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"errors"
|
||||||
"os"
|
"os"
|
||||||
"path/filepath"
|
"path/filepath"
|
||||||
"testing"
|
"testing"
|
||||||
|
@ -154,3 +155,10 @@ func TestParseModelPath(t *testing.T) {
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestInsecureModelpath(t *testing.T) {
|
||||||
|
mp := ParseModelPath("../../..:something")
|
||||||
|
if _, err := mp.GetManifestPath(); !errors.Is(err, errModelPathInvalid) {
|
||||||
|
t.Errorf("expected error: %v", err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
Loading…
Reference in a new issue