validate model path

This commit is contained in:
Michael Yang 2024-08-27 17:56:04 -07:00
parent 6c1c1ad6a9
commit d9d50c43cc
2 changed files with 13 additions and 13 deletions

View file

@ -73,18 +73,6 @@ func ParseModelPath(name string) ModelPath {
var errModelPathInvalid = errors.New("invalid model path")
func (mp ModelPath) Validate() error {
if mp.Repository == "" {
return fmt.Errorf("%w: model repository name is required", errModelPathInvalid)
}
if strings.Contains(mp.Tag, ":") {
return fmt.Errorf("%w: ':' (colon) is not allowed in tag names", errModelPathInvalid)
}
return nil
}
func (mp ModelPath) GetNamespaceRepository() string {
return fmt.Sprintf("%s/%s", mp.Namespace, mp.Repository)
}
@ -105,7 +93,11 @@ func (mp ModelPath) GetShortTagname() string {
// GetManifestPath returns the path to the manifest file for the given model path, it is up to the caller to create the directory if it does not exist.
func (mp ModelPath) GetManifestPath() (string, error) {
return filepath.Join(envconfig.Models(), "manifests", mp.Registry, mp.Namespace, mp.Repository, mp.Tag), nil
if p := filepath.Join(mp.Registry, mp.Namespace, mp.Repository, mp.Tag); filepath.IsLocal(p) {
return filepath.Join(envconfig.Models(), "manifests", p), nil
}
return "", errModelPathInvalid
}
func (mp ModelPath) BaseURL() *url.URL {

View file

@ -1,6 +1,7 @@
package server
import (
"errors"
"os"
"path/filepath"
"testing"
@ -154,3 +155,10 @@ func TestParseModelPath(t *testing.T) {
})
}
}
func TestInsecureModelpath(t *testing.T) {
mp := ParseModelPath("../../..:something")
if _, err := mp.GetManifestPath(); !errors.Is(err, errModelPathInvalid) {
t.Errorf("expected error: %v", err)
}
}