Merge pull request #6539 from ollama/mxyng/validate-modelpath

fix: validate modelpath
This commit is contained in:
Michael Yang 2024-08-28 14:38:27 -07:00 committed by GitHub
commit 5eb77bf976
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
4 changed files with 26 additions and 20 deletions

View file

@ -32,6 +32,10 @@ linters:
linters-settings: linters-settings:
gci: gci:
sections: [standard, default, localmodule] sections: [standard, default, localmodule]
staticcheck:
checks:
- all
- -SA1019 # omit Deprecated check
severity: severity:
default-severity: error default-severity: error
rules: rules:

View file

@ -296,15 +296,17 @@ type EmbeddingResponse struct {
// CreateRequest is the request passed to [Client.Create]. // CreateRequest is the request passed to [Client.Create].
type CreateRequest struct { type CreateRequest struct {
Model string `json:"model"` Model string `json:"model"`
Path string `json:"path"`
Modelfile string `json:"modelfile"` Modelfile string `json:"modelfile"`
Stream *bool `json:"stream,omitempty"` Stream *bool `json:"stream,omitempty"`
Quantize string `json:"quantize,omitempty"` Quantize string `json:"quantize,omitempty"`
// Name is deprecated, see Model // Deprecated: set the model name with Model instead
Name string `json:"name"` Name string `json:"name"`
// Quantization is deprecated, see Quantize // Deprecated: set the file content with Modelfile instead
Path string `json:"path"`
// Deprecated: use Quantize instead
Quantization string `json:"quantization,omitempty"` Quantization string `json:"quantization,omitempty"`
} }
@ -312,7 +314,7 @@ type CreateRequest struct {
type DeleteRequest struct { type DeleteRequest struct {
Model string `json:"model"` Model string `json:"model"`
// Name is deprecated, see Model // Deprecated: set the model name with Model instead
Name string `json:"name"` Name string `json:"name"`
} }
@ -327,7 +329,7 @@ type ShowRequest struct {
Options map[string]interface{} `json:"options"` Options map[string]interface{} `json:"options"`
// Name is deprecated, see Model // Deprecated: set the model name with Model instead
Name string `json:"name"` Name string `json:"name"`
} }
@ -359,7 +361,7 @@ type PullRequest struct {
Password string `json:"password"` Password string `json:"password"`
Stream *bool `json:"stream,omitempty"` Stream *bool `json:"stream,omitempty"`
// Name is deprecated, see Model // Deprecated: set the model name with Model instead
Name string `json:"name"` Name string `json:"name"`
} }
@ -380,7 +382,7 @@ type PushRequest struct {
Password string `json:"password"` Password string `json:"password"`
Stream *bool `json:"stream,omitempty"` Stream *bool `json:"stream,omitempty"`
// Name is deprecated, see Model // Deprecated: set the model name with Model instead
Name string `json:"name"` Name string `json:"name"`
} }

View file

@ -73,18 +73,6 @@ func ParseModelPath(name string) ModelPath {
var errModelPathInvalid = errors.New("invalid model path") var errModelPathInvalid = errors.New("invalid model path")
func (mp ModelPath) Validate() error {
if mp.Repository == "" {
return fmt.Errorf("%w: model repository name is required", errModelPathInvalid)
}
if strings.Contains(mp.Tag, ":") {
return fmt.Errorf("%w: ':' (colon) is not allowed in tag names", errModelPathInvalid)
}
return nil
}
func (mp ModelPath) GetNamespaceRepository() string { func (mp ModelPath) GetNamespaceRepository() string {
return fmt.Sprintf("%s/%s", mp.Namespace, mp.Repository) return fmt.Sprintf("%s/%s", mp.Namespace, mp.Repository)
} }
@ -105,7 +93,11 @@ func (mp ModelPath) GetShortTagname() string {
// GetManifestPath returns the path to the manifest file for the given model path, it is up to the caller to create the directory if it does not exist. // GetManifestPath returns the path to the manifest file for the given model path, it is up to the caller to create the directory if it does not exist.
func (mp ModelPath) GetManifestPath() (string, error) { func (mp ModelPath) GetManifestPath() (string, error) {
return filepath.Join(envconfig.Models(), "manifests", mp.Registry, mp.Namespace, mp.Repository, mp.Tag), nil if p := filepath.Join(mp.Registry, mp.Namespace, mp.Repository, mp.Tag); filepath.IsLocal(p) {
return filepath.Join(envconfig.Models(), "manifests", p), nil
}
return "", errModelPathInvalid
} }
func (mp ModelPath) BaseURL() *url.URL { func (mp ModelPath) BaseURL() *url.URL {

View file

@ -1,6 +1,7 @@
package server package server
import ( import (
"errors"
"os" "os"
"path/filepath" "path/filepath"
"testing" "testing"
@ -154,3 +155,10 @@ func TestParseModelPath(t *testing.T) {
}) })
} }
} }
func TestInsecureModelpath(t *testing.T) {
mp := ParseModelPath("../../..:something")
if _, err := mp.GetManifestPath(); !errors.Is(err, errModelPathInvalid) {
t.Errorf("expected error: %v", err)
}
}