validate the format of the digest when getting the model path (#4175)

This commit is contained in:
Patrick Devine 2024-05-05 11:46:12 -07:00 committed by GitHub
parent 026869915f
commit 2a21363bb7
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
2 changed files with 84 additions and 4 deletions

View file

@ -6,6 +6,7 @@ import (
"net/url" "net/url"
"os" "os"
"path/filepath" "path/filepath"
"regexp"
"strings" "strings"
) )
@ -25,9 +26,10 @@ const (
) )
var ( var (
ErrInvalidImageFormat = errors.New("invalid image format") ErrInvalidImageFormat = errors.New("invalid image format")
ErrInvalidProtocol = errors.New("invalid protocol scheme") ErrInvalidProtocol = errors.New("invalid protocol scheme")
ErrInsecureProtocol = errors.New("insecure protocol http") ErrInsecureProtocol = errors.New("insecure protocol http")
ErrInvalidDigestFormat = errors.New("invalid digest format")
) )
func ParseModelPath(name string) ModelPath { func ParseModelPath(name string) ModelPath {
@ -149,6 +151,17 @@ func GetBlobsPath(digest string) (string, error) {
return "", err return "", err
} }
// only accept actual sha256 digests
pattern := "^sha256[:-][0-9a-fA-F]{64}$"
re := regexp.MustCompile(pattern)
if err != nil {
return "", err
}
if digest != "" && !re.MatchString(digest) {
return "", ErrInvalidDigestFormat
}
digest = strings.ReplaceAll(digest, ":", "-") digest = strings.ReplaceAll(digest, ":", "-")
path := filepath.Join(dir, "blobs", digest) path := filepath.Join(dir, "blobs", digest)
dirPath := filepath.Dir(path) dirPath := filepath.Dir(path)

View file

@ -1,6 +1,73 @@
package server package server
import "testing" import (
"os"
"path/filepath"
"testing"
"github.com/stretchr/testify/assert"
)
func TestGetBlobsPath(t *testing.T) {
// GetBlobsPath expects an actual directory to exist
dir, err := os.MkdirTemp("", "ollama-test")
assert.Nil(t, err)
defer os.RemoveAll(dir)
tests := []struct {
name string
digest string
expected string
err error
}{
{
"empty digest",
"",
filepath.Join(dir, "blobs"),
nil,
},
{
"valid with colon",
"sha256:456402914e838a953e0cf80caa6adbe75383d9e63584a964f504a7bbb8f7aad9",
filepath.Join(dir, "blobs", "sha256-456402914e838a953e0cf80caa6adbe75383d9e63584a964f504a7bbb8f7aad9"),
nil,
},
{
"valid with dash",
"sha256-456402914e838a953e0cf80caa6adbe75383d9e63584a964f504a7bbb8f7aad9",
filepath.Join(dir, "blobs", "sha256-456402914e838a953e0cf80caa6adbe75383d9e63584a964f504a7bbb8f7aad9"),
nil,
},
{
"digest too short",
"sha256-45640291",
"",
ErrInvalidDigestFormat,
},
{
"digest too long",
"sha256-456402914e838a953e0cf80caa6adbe75383d9e63584a964f504a7bbb8f7aad9aaaaaaaaaa",
"",
ErrInvalidDigestFormat,
},
{
"digest invalid chars",
"../sha256-456402914e838a953e0cf80caa6adbe75383d9e63584a964f504a7bbb8f7a",
"",
ErrInvalidDigestFormat,
},
}
for _, tc := range tests {
t.Run(tc.name, func(t *testing.T) {
t.Setenv("OLLAMA_MODELS", dir)
got, err := GetBlobsPath(tc.digest)
assert.ErrorIs(t, tc.err, err, tc.name)
assert.Equal(t, tc.expected, got, tc.name)
})
}
}
func TestParseModelPath(t *testing.T) { func TestParseModelPath(t *testing.T) {
tests := []struct { tests := []struct {