diff --git a/types/model/name.go b/types/model/name.go index ec129be5..a0a44703 100644 --- a/types/model/name.go +++ b/types/model/name.go @@ -4,6 +4,7 @@ package model import ( "cmp" + "encoding/hex" "errors" "fmt" "log/slog" @@ -308,3 +309,57 @@ func cutPromised(s, sep string) (before, after string, ok bool) { } return cmp.Or(before, MissingPart), cmp.Or(after, MissingPart), true } + +type DigestType byte + +const ( + DigestTypeInvalid DigestType = iota + DigestTypeSHA256 +) + +func (t DigestType) String() string { + switch t { + case DigestTypeSHA256: + return "sha256" + default: + return "invalid" + } +} + +type Digest struct { + Type DigestType + Sum [32]byte +} + +func ParseDigest(s string) (Digest, error) { + i := strings.IndexAny(s, "-:") + if i < 0 { + return Digest{}, fmt.Errorf("invalid digest %q", s) + } + typ, encSum := s[:i], s[i+1:] + if typ != "sha256" { + return Digest{}, fmt.Errorf("unsupported digest type %q", typ) + } + d := Digest{ + Type: DigestTypeSHA256, + } + n, err := hex.Decode(d.Sum[:], []byte(encSum)) + if err != nil { + return Digest{}, err + } + if n != 32 { + return Digest{}, fmt.Errorf("digest %q decoded to %d bytes; want 32", encSum, n) + } + return d, nil +} + +func (d Digest) String() string { + if d.Type == DigestTypeInvalid { + return "" + } + return fmt.Sprintf("sha256-%x", d.Sum) +} + +func (d Digest) IsValid() bool { + return d.Type != DigestTypeInvalid +} diff --git a/types/model/name_test.go b/types/model/name_test.go index f1416fd2..75659f0d 100644 --- a/types/model/name_test.go +++ b/types/model/name_test.go @@ -232,6 +232,40 @@ func TestFilepathAllocs(t *testing.T) { } } +const ( + validSha256 = "sha256-1000000000000000000000000000000000000000000000000000000000000000" + validSha256Old = "sha256:1000000000000000000000000000000000000000000000000000000000000000" +) + +func TestParseDigest(t *testing.T) { + cases := []struct { + in string + want string + }{ + {"", ""}, // empty + {"sha123-12", ""}, // invalid type + {"sha256-", ""}, // invalid sum + {"sha256-123", ""}, // invalid odd length sum + + {validSha256, validSha256}, + {validSha256Old, validSha256}, + } + for _, tt := range cases { + t.Run(tt.in, func(t *testing.T) { + got, err := ParseDigest(tt.in) + if err != nil { + if tt.want != "" { + t.Errorf("parseDigest(%q) = %v; want %v", tt.in, err, tt.want) + } + return + } + if got.String() != tt.want { + t.Errorf("parseDigest(%q).String() = %q; want %q", tt.in, got, tt.want) + } + }) + } +} + func FuzzName(f *testing.F) { for s := range testCases { f.Add(s)