types/model: restrict digest hash part to a minimum of 2 characters (#3858)

This allows users of a valid Digest to know it has a minimum of 2
characters in the hash part for use when sharding.

This is a reasonable restriction as the hash part is a SHA256 hash which
is 64 characters long, which is the common hash used. There is no
anticipation of using a hash with less than 2 characters.

Also, add MustParseDigest.

Also, replace Digest.Type with Digest.Split for getting both the type
and hash parts together, which is most the common case when asking for
either.
This commit is contained in:
Blake Mizerany 2024-04-23 18:24:17 -07:00 committed by GitHub
parent 16b52331a4
commit 4dc4f1be34
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
3 changed files with 34 additions and 19 deletions

View file

@ -15,14 +15,10 @@ type Digest struct {
s string s string
} }
// Type returns the digest type of the digest. // Split returns the digest type and the digest value.
// func (d Digest) Split() (typ, digest string) {
// Example: typ, digest, _ = strings.Cut(d.s, "-")
// return
// ParseDigest("sha256-1234").Type() // returns "sha256"
func (d Digest) Type() string {
typ, _, _ := strings.Cut(d.s, "-")
return typ
} }
// String returns the digest in the form of "<digest-type>-<digest>", or the // String returns the digest in the form of "<digest-type>-<digest>", or the
@ -51,12 +47,20 @@ func ParseDigest(s string) Digest {
if !ok { if !ok {
typ, digest, ok = strings.Cut(s, ":") typ, digest, ok = strings.Cut(s, ":")
} }
if ok && isValidDigestType(typ) && isValidHex(digest) { if ok && isValidDigestType(typ) && isValidHex(digest) && len(digest) >= 2 {
return Digest{s: fmt.Sprintf("%s-%s", typ, digest)} return Digest{s: fmt.Sprintf("%s-%s", typ, digest)}
} }
return Digest{} return Digest{}
} }
func MustParseDigest(s string) Digest {
d := ParseDigest(s)
if !d.IsValid() {
panic(fmt.Sprintf("invalid digest: %q", s))
}
return d
}
func isValidDigestType(s string) bool { func isValidDigestType(s string) bool {
if len(s) == 0 { if len(s) == 0 {
return false return false

View file

@ -7,6 +7,7 @@ import (
"hash/maphash" "hash/maphash"
"io" "io"
"log/slog" "log/slog"
"path"
"path/filepath" "path/filepath"
"slices" "slices"
"strings" "strings"
@ -589,10 +590,20 @@ func ParseNameFromURLPath(s, fill string) Name {
// Example: // Example:
// //
// ParseName("example.com/namespace/model:tag+build").URLPath() // returns "/example.com/namespace/model:tag" // ParseName("example.com/namespace/model:tag+build").URLPath() // returns "/example.com/namespace/model:tag"
func (r Name) URLPath() string { func (r Name) DisplayURLPath() string {
return r.DisplayShortest(MaskNothing) return r.DisplayShortest(MaskNothing)
} }
// URLPath returns a complete, canonicalized, relative URL path using the parts of a
// complete Name in the form:
//
// <host>/<namespace>/<model>/<tag>
//
// The parts are downcased.
func (r Name) URLPath() string {
return strings.ToLower(path.Join(r.parts[:PartBuild]...))
}
// ParseNameFromFilepath parses a file path into a Name. The input string must be a // ParseNameFromFilepath parses a file path into a Name. The input string must be a
// valid file path representation of a model name in the form: // valid file path representation of a model name in the form:
// //

View file

@ -50,10 +50,10 @@ var testNames = map[string]fields{
"mistral:latest@": {}, "mistral:latest@": {},
// resolved // resolved
"x@sha123-1": {model: "x", digest: "sha123-1"}, "x@sha123-12": {model: "x", digest: "sha123-12"},
"@sha456-2": {digest: "sha456-2"}, "@sha456-22": {digest: "sha456-22"},
"@sha456-1": {},
"@@sha123-1": {}, "@@sha123-22": {},
// preserves case for build // preserves case for build
"x+b": {model: "x", build: "b"}, "x+b": {model: "x", build: "b"},
@ -485,7 +485,7 @@ func TestNamePath(t *testing.T) {
t.Run(tt.in, func(t *testing.T) { t.Run(tt.in, func(t *testing.T) {
p := ParseName(tt.in, FillNothing) p := ParseName(tt.in, FillNothing)
t.Logf("ParseName(%q) = %#v", tt.in, p) t.Logf("ParseName(%q) = %#v", tt.in, p)
if g := p.URLPath(); g != tt.want { if g := p.DisplayURLPath(); g != tt.want {
t.Errorf("got = %q; want %q", g, tt.want) t.Errorf("got = %q; want %q", g, tt.want)
} }
}) })
@ -678,18 +678,18 @@ func ExampleName_CompareFold_sort() {
func ExampleName_completeAndResolved() { func ExampleName_completeAndResolved() {
for _, s := range []string{ for _, s := range []string{
"x/y/z:latest+q4_0@sha123-1", "x/y/z:latest+q4_0@sha123-abc",
"x/y/z:latest+q4_0", "x/y/z:latest+q4_0",
"@sha123-1", "@sha123-abc",
} { } {
name := ParseName(s, FillNothing) name := ParseName(s, FillNothing)
fmt.Printf("complete:%v resolved:%v digest:%s\n", name.IsComplete(), name.IsResolved(), name.Digest()) fmt.Printf("complete:%v resolved:%v digest:%s\n", name.IsComplete(), name.IsResolved(), name.Digest())
} }
// Output: // Output:
// complete:true resolved:true digest:sha123-1 // complete:true resolved:true digest:sha123-abc
// complete:true resolved:false digest: // complete:true resolved:false digest:
// complete:false resolved:true digest:sha123-1 // complete:false resolved:true digest:sha123-abc
} }
func ExampleName_DisplayShortest() { func ExampleName_DisplayShortest() {