diff --git a/types/model/name.go b/types/model/name.go index c7822f08..906b3152 100644 --- a/types/model/name.go +++ b/types/model/name.go @@ -3,6 +3,7 @@ package model import ( "cmp" "errors" + "fmt" "hash/maphash" "io" "log/slog" @@ -25,11 +26,17 @@ var ( // Defaults const ( - // DefaultMask is the default mask used by [Name.DisplayShortest]. - DefaultMask = "registry.ollama.ai/library/_:latest" + // MaskDefault is the default mask used by [Name.DisplayShortest]. + MaskDefault = "registry.ollama.ai/library/?:latest" + + // MaskNothing is a mask that masks nothing. + MaskNothing = "?/?/?:?" // DefaultFill is the default fill used by [ParseName]. - DefaultFill = "registry.ollama.ai/library/_:latest" + FillDefault = "registry.ollama.ai/library/?:latest+Q4_0" + + // FillNothing is a fill that fills nothing. + FillNothing = "?/?/?:?+?" ) const MaxNamePartLen = 128 @@ -47,11 +54,7 @@ const ( PartBuild PartDigest - // Invalid is a special part that is used to indicate that a part is - // invalid. It is not a valid part of a Name. - // - // It should be kept as the last part in the list. - PartInvalid + PartExtraneous = -1 ) var kindNames = map[PartKind]string{ @@ -61,7 +64,6 @@ var kindNames = map[PartKind]string{ PartTag: "Tag", PartBuild: "Build", PartDigest: "Digest", - PartInvalid: "Invalid", } func (k PartKind) String() string { @@ -96,8 +98,6 @@ func (k PartKind) String() string { // The parts can be obtained in their original form by calling [Name.Parts]. // // To check if a Name has at minimum a valid model part, use [Name.IsValid]. -// -// To make a Name by filling in missing parts from another Name, use [Fill]. type Name struct { _ structs.Incomparable parts [6]string // host, namespace, model, tag, build, digest @@ -109,7 +109,7 @@ type Name struct { // and mean zero allocations for String. } -// ParseNameFill parses s into a Name, and returns the result of filling it with +// ParseName parses s into a Name, and returns the result of filling it with // defaults. The input string must be a valid string // representation of a model name in the form: // @@ -139,19 +139,19 @@ type Name struct { // // It returns the zero value if any part is invalid. // -// As a rule of thumb, an valid name is one that can be round-tripped with -// the [Name.String] method. That means ("x+") is invalid because -// [Name.String] will not print a "+" if the build is empty. +// # Fills // -// For more about filling in missing parts, see [Fill]. -func ParseNameFill(s, defaults string) Name { +// For any valid s, the fill string is used to fill in missing parts of the +// Name. The fill string must be a valid Name with the exception that any part +// may be the string ("?"), which will not be considered for filling. +func ParseName(s, fill string) Name { var r Name parts(s)(func(kind PartKind, part string) bool { - if kind == PartInvalid { + if kind == PartDigest && !ParseDigest(part).IsValid() { r = Name{} return false } - if kind == PartDigest && !ParseDigest(part).IsValid() { + if kind == PartExtraneous || !isValidPart(kind, part) { r = Name{} return false } @@ -159,34 +159,48 @@ func ParseNameFill(s, defaults string) Name { return true }) if r.IsValid() || r.IsResolved() { - if defaults == "" { - return r - } - return Fill(r, ParseNameFill(defaults, "")) + fill = cmp.Or(fill, FillDefault) + return fillName(r, fill) } return Name{} } -// ParseName is equal to ParseNameFill(s, DefaultFill). -func ParseName(s string) Name { - return ParseNameFill(s, DefaultFill) +func parseMask(s string) Name { + var r Name + parts(s)(func(kind PartKind, part string) bool { + if part == "?" { + // mask part; treat as empty but valid + return true + } + if !isValidPart(kind, part) { + panic(fmt.Errorf("invalid mask part %s: %q", kind, part)) + } + r.parts[kind] = part + return true + }) + return r } -func MustParseNameFill(s, defaults string) Name { - r := ParseNameFill(s, "") +func MustParseName(s, defaults string) Name { + r := ParseName(s, "") if !r.IsValid() { - panic("model.MustParseName: invalid name: " + s) + panic("invalid Name: " + s) } return r } -// Fill fills in the missing parts of dst with the parts of src. +// fillName fills in the missing parts of dst with the parts of src. // // The returned Name will only be valid if dst is valid. -func Fill(dst, src Name) Name { - var r Name +// +// It skipps fill parts that are "?". +func fillName(r Name, fill string) Name { + f := parseMask(fill) for i := range r.parts { - r.parts[i] = cmp.Or(dst.parts[i], src.parts[i]) + if f.parts[i] == "?" { + continue + } + r.parts[i] = cmp.Or(r.parts[i], f.parts[i]) } return r } @@ -231,30 +245,58 @@ func (r Name) slice(from, to PartKind) Name { return v } -// DisplayShortest returns the shortest possible display string in form: +// DisplayShortest returns the shortest possible, masked display string in form: // // [host/][/][:] // -// The host is omitted if it is the mask host is the same as r. -// The namespace is omitted if the host and the namespace are the same as r. -// The tag is omitted if it is the mask tag is the same as r. +// # Masks +// +// The mask is a string that specifies which parts of the name to omit based +// on case-insensitive comparison. [Name.DisplayShortest] omits parts of the name +// that are the same as the mask, moving from left to right until the first +// unequal part is found. It then moves right to left until the first unequal +// part is found. The result is the shortest possible display string. +// +// Unlike a [Name] the mask can contain "?" characters which are treated as +// wildcards. A "?" will never match a part of the name, since a valid name +// can never contain a "?" character. +// +// For example: Given a Name ("registry.ollama.ai/library/mistral:latest") masked +// with ("registry.ollama.ai/library/?:latest") will produce the display string +// ("mistral"). +// +// If mask is the empty string, then [MaskDefault] is used. +// +// # Safety +// +// To avoid unsafe behavior, DisplayShortest will panic if r is the zero +// value to prevent the returns of a "" string. Callers should consult +// [Name.IsValid] before calling this method. +// +// # Builds +// +// For now, DisplayShortest does consider the build or return one in the +// result. We can lift this restriction when needed. func (r Name) DisplayShortest(mask string) string { - mask = cmp.Or(mask, DefaultMask) - d := ParseName(mask) - if !d.IsValid() { - panic("mask is an invalid Name") + mask = cmp.Or(mask, MaskDefault) + d := parseMask(mask) + if d.IsZero() { + panic(fmt.Errorf("invalid mask %q", mask)) } - equalSlice := func(form, to PartKind) bool { - return r.slice(form, to).EqualFold(d.slice(form, to)) + if r.IsZero() { + panic("invalid Name") } - if equalSlice(PartHost, PartNamespace) { - r.parts[PartNamespace] = "" + for i := range PartTag { + if !strings.EqualFold(r.parts[i], d.parts[i]) { + break + } + r.parts[i] = "" } - if equalSlice(PartHost, PartHost) { - r.parts[PartHost] = "" - } - if equalSlice(PartTag, PartTag) { - r.parts[PartTag] = "" + for i := PartTag; i >= 0; i-- { + if !strings.EqualFold(r.parts[i], d.parts[i]) { + break + } + r.parts[i] = "" } return r.slice(PartHost, PartTag).String() } @@ -418,27 +460,16 @@ type iter_Seq2[A, B any] func(func(A, B) bool) // No other normalizations are performed. func parts(s string) iter_Seq2[PartKind, string] { return func(yield func(PartKind, string) bool) { - //nolint:gosimple if strings.HasPrefix(s, "http://") { s = s[len("http://"):] - } - //nolint:gosimple - if strings.HasPrefix(s, "https://") { - s = s[len("https://"):] + } else { + s = strings.TrimPrefix(s, "https://") } if len(s) > MaxNamePartLen || len(s) == 0 { return } - yieldValid := func(kind PartKind, part string) bool { - if !isValidPart(kind, part) { - yield(PartInvalid, "") - return false - } - return yield(kind, part) - } - numConsecutiveDots := 0 partLen := 0 state, j := PartDigest, len(s) @@ -448,7 +479,7 @@ func parts(s string) iter_Seq2[PartKind, string] { // we don't keep spinning on it, waiting for // an isInValidPart check which would scan // over it again. - yield(PartInvalid, "") + yield(state, s[i+1:j]) return } @@ -456,7 +487,7 @@ func parts(s string) iter_Seq2[PartKind, string] { case '@': switch state { case PartDigest: - if !yieldValid(PartDigest, s[i+1:j]) { + if !yield(PartDigest, s[i+1:j]) { return } if i == 0 { @@ -468,67 +499,63 @@ func parts(s string) iter_Seq2[PartKind, string] { } state, j, partLen = PartBuild, i, 0 default: - yield(PartInvalid, "") + yield(PartExtraneous, s[i+1:j]) return } case '+': switch state { case PartBuild, PartDigest: - if !yieldValid(PartBuild, s[i+1:j]) { + if !yield(PartBuild, s[i+1:j]) { return } state, j, partLen = PartTag, i, 0 default: - yield(PartInvalid, "") + yield(PartExtraneous, s[i+1:j]) return } case ':': switch state { case PartTag, PartBuild, PartDigest: - if !yieldValid(PartTag, s[i+1:j]) { + if !yield(PartTag, s[i+1:j]) { return } state, j, partLen = PartModel, i, 0 default: - yield(PartInvalid, "") + yield(PartExtraneous, s[i+1:j]) return } case '/': switch state { case PartModel, PartTag, PartBuild, PartDigest: - if !yieldValid(PartModel, s[i+1:j]) { + if !yield(PartModel, s[i+1:j]) { return } state, j = PartNamespace, i case PartNamespace: - if !yieldValid(PartNamespace, s[i+1:j]) { + if !yield(PartNamespace, s[i+1:j]) { return } state, j, partLen = PartHost, i, 0 default: - yield(PartInvalid, "") + yield(PartExtraneous, s[i+1:j]) return } default: if s[i] == '.' { if numConsecutiveDots++; numConsecutiveDots > 1 { - yield(PartInvalid, "") + yield(state, "") return } } else { numConsecutiveDots = 0 } - if !isValidByteFor(state, s[i]) { - yield(PartInvalid, "") - return - } } } if state <= PartNamespace { - yieldValid(state, s[:j]) + yield(state, s[:j]) } else { - yieldValid(PartModel, s[:j]) + yield(PartModel, s[:j]) } } } diff --git a/types/model/name_test.go b/types/model/name_test.go index 78b57fca..14d36b64 100644 --- a/types/model/name_test.go +++ b/types/model/name_test.go @@ -111,11 +111,11 @@ func TestNameConsecutiveDots(t *testing.T) { for i := 1; i < 10; i++ { s := strings.Repeat(".", i) if i > 1 { - if g := ParseNameFill(s, "").String(); g != "" { + if g := ParseName(s, FillNothing).String(); g != "" { t.Errorf("ParseName(%q) = %q; want empty string", s, g) } } else { - if g := ParseNameFill(s, "").String(); g != s { + if g := ParseName(s, FillNothing).String(); g != s { t.Errorf("ParseName(%q) = %q; want %q", s, g, s) } } @@ -148,14 +148,14 @@ func TestParseName(t *testing.T) { s := prefix + baseName t.Run(s, func(t *testing.T) { - name := ParseNameFill(s, "") + name := ParseName(s, FillNothing) got := fieldsFromName(name) if got != want { t.Errorf("ParseName(%q) = %q; want %q", s, got, want) } // test round-trip - if !ParseNameFill(name.String(), "").EqualFold(name) { + if !ParseName(name.String(), FillNothing).EqualFold(name) { t.Errorf("ParseName(%q).String() = %s; want %s", s, name.String(), baseName) } }) @@ -163,6 +163,47 @@ func TestParseName(t *testing.T) { } } +func TestParseNameFill(t *testing.T) { + cases := []struct { + in string + fill string + want string + }{ + {"mistral", "example.com/library/?:latest+Q4_0", "example.com/library/mistral:latest+Q4_0"}, + {"mistral", "example.com/library/?:latest", "example.com/library/mistral:latest"}, + {"llama2:x", "example.com/library/?:latest+Q4_0", "example.com/library/llama2:x+Q4_0"}, + + // Invalid + {"", "example.com/library/?:latest+Q4_0", ""}, + {"llama2:?", "example.com/library/?:latest+Q4_0", ""}, + } + + for _, tt := range cases { + t.Run(tt.in, func(t *testing.T) { + name := ParseName(tt.in, tt.fill) + if g := name.String(); g != tt.want { + t.Errorf("ParseName(%q, %q) = %q; want %q", tt.in, tt.fill, g, tt.want) + } + }) + } +} + +func TestParseNameHTTPDoublePrefixStrip(t *testing.T) { + cases := []string{ + "http://https://valid.com/valid/valid:latest", + "https://http://valid.com/valid/valid:latest", + } + for _, s := range cases { + t.Run(s, func(t *testing.T) { + name := ParseName(s, FillNothing) + if name.IsValid() { + t.Errorf("expected invalid path; got %#v", name) + } + }) + } + +} + func TestCompleteWithAndWithoutBuild(t *testing.T) { cases := []struct { in string @@ -179,7 +220,7 @@ func TestCompleteWithAndWithoutBuild(t *testing.T) { for _, tt := range cases { t.Run(tt.in, func(t *testing.T) { - p := ParseNameFill(tt.in, "") + p := ParseName(tt.in, FillNothing) t.Logf("ParseName(%q) = %#v", tt.in, p) if g := p.IsComplete(); g != tt.complete { t.Errorf("Complete(%q) = %v; want %v", tt.in, g, tt.complete) @@ -194,7 +235,7 @@ func TestCompleteWithAndWithoutBuild(t *testing.T) { // inlined when used in Complete, preventing any allocations or // escaping to the heap. allocs := testing.AllocsPerRun(1000, func() { - keep(ParseNameFill("complete.com/x/mistral:latest+Q4_0", "").IsComplete()) + keep(ParseName("complete.com/x/mistral:latest+Q4_0", FillNothing).IsComplete()) }) if allocs > 0 { t.Errorf("Complete allocs = %v; want 0", allocs) @@ -211,7 +252,7 @@ func TestNameLogValue(t *testing.T) { t.Run(s, func(t *testing.T) { var b bytes.Buffer log := slog.New(slog.NewTextHandler(&b, nil)) - name := ParseNameFill(s, "") + name := ParseName(s, FillNothing) log.Info("", "name", name) want := fmt.Sprintf("name=%s", name.GoString()) got := b.String() @@ -258,7 +299,7 @@ func TestNameGoString(t *testing.T) { for _, tt := range cases { t.Run(tt.name, func(t *testing.T) { - p := ParseNameFill(tt.in, "") + p := ParseName(tt.in, FillNothing) tt.wantGoString = cmp.Or(tt.wantGoString, tt.in) if g := fmt.Sprintf("%#v", p); g != tt.wantGoString { t.Errorf("GoString() = %q; want %q", g, tt.wantGoString) @@ -286,11 +327,14 @@ func TestDisplayShortest(t *testing.T) { {"example.com/library/mistral:Latest+Q4_0", "example.com/library/_:latest", "mistral", false}, {"example.com/library/mistral:Latest+q4_0", "example.com/library/_:latest", "mistral", false}, + // zero value + {"", MaskDefault, "", true}, + // invalid mask {"example.com/library/mistral:latest+Q4_0", "example.com/mistral", "", true}, // DefaultMask - {"registry.ollama.ai/library/mistral:latest+Q4_0", DefaultMask, "mistral", false}, + {"registry.ollama.ai/library/mistral:latest+Q4_0", MaskDefault, "mistral", false}, // Auto-Fill {"x", "example.com/library/_:latest", "x", false}, @@ -309,7 +353,7 @@ func TestDisplayShortest(t *testing.T) { } }() - p := ParseNameFill(tt.in, "") + p := ParseName(tt.in, FillNothing) t.Logf("ParseName(%q) = %#v", tt.in, p) if g := p.DisplayShortest(tt.mask); g != tt.want { t.Errorf("got = %q; want %q", g, tt.want) @@ -320,7 +364,7 @@ func TestDisplayShortest(t *testing.T) { func TestParseNameAllocs(t *testing.T) { allocs := testing.AllocsPerRun(1000, func() { - keep(ParseNameFill("example.com/mistral:7b+Q4_0", "")) + keep(ParseName("example.com/mistral:7b+Q4_0", FillNothing)) }) if allocs > 0 { t.Errorf("ParseName allocs = %v; want 0", allocs) @@ -331,7 +375,7 @@ func BenchmarkParseName(b *testing.B) { b.ReportAllocs() for range b.N { - keep(ParseNameFill("example.com/mistral:7b+Q4_0", "")) + keep(ParseName("example.com/mistral:7b+Q4_0", FillNothing)) } } @@ -346,7 +390,7 @@ func FuzzParseName(f *testing.F) { f.Add(":@!@") f.Add("...") f.Fuzz(func(t *testing.T, s string) { - r0 := ParseNameFill(s, "") + r0 := ParseName(s, FillNothing) if strings.Contains(s, "..") && !r0.IsZero() { t.Fatalf("non-zero value for path with '..': %q", s) @@ -369,36 +413,15 @@ func FuzzParseName(f *testing.F) { t.Errorf("String() did not round-trip with case insensitivity: %q\ngot = %q\nwant = %q", s, r0.String(), s) } - r1 := ParseNameFill(r0.String(), "") + r1 := ParseName(r0.String(), FillNothing) if !r0.EqualFold(r1) { t.Errorf("round-trip mismatch: %+v != %+v", r0, r1) } }) } -func TestFill(t *testing.T) { - cases := []struct { - dst string - src string - want string - }{ - {"mistral", "o.com/library/PLACEHOLDER:latest+Q4_0", "o.com/library/mistral:latest+Q4_0"}, - {"o.com/library/mistral", "PLACEHOLDER:latest+Q4_0", "o.com/library/mistral:latest+Q4_0"}, - {"", "o.com/library/mistral:latest+Q4_0", "o.com/library/mistral:latest+Q4_0"}, - } - - for _, tt := range cases { - t.Run(tt.dst, func(t *testing.T) { - r := Fill(ParseNameFill(tt.dst, ""), ParseNameFill(tt.src, "")) - if r.String() != tt.want { - t.Errorf("Fill(%q, %q) = %q; want %q", tt.dst, tt.src, r, tt.want) - } - }) - } -} - func TestNameStringAllocs(t *testing.T) { - name := ParseNameFill("example.com/ns/mistral:latest+Q4_0", "") + name := ParseName("example.com/ns/mistral:latest+Q4_0", FillNothing) allocs := testing.AllocsPerRun(1000, func() { keep(name.String()) }) @@ -407,25 +430,16 @@ func TestNameStringAllocs(t *testing.T) { } } -func ExampleFill() { - defaults := ParseNameFill("registry.ollama.com/library/PLACEHOLDER:latest+Q4_0", "") - r := Fill(ParseNameFill("mistral", ""), defaults) - fmt.Println(r) - - // Output: - // registry.ollama.com/library/mistral:latest+Q4_0 -} - func ExampleName_MapHash() { m := map[uint64]bool{} // key 1 - m[ParseNameFill("mistral:latest+q4", "").MapHash()] = true - m[ParseNameFill("miSTRal:latest+Q4", "").MapHash()] = true - m[ParseNameFill("mistral:LATest+Q4", "").MapHash()] = true + m[ParseName("mistral:latest+q4", FillNothing).MapHash()] = true + m[ParseName("miSTRal:latest+Q4", FillNothing).MapHash()] = true + m[ParseName("mistral:LATest+Q4", FillNothing).MapHash()] = true // key 2 - m[ParseNameFill("mistral:LATest", "").MapHash()] = true + m[ParseName("mistral:LATest", FillNothing).MapHash()] = true fmt.Println(len(m)) // Output: @@ -434,9 +448,9 @@ func ExampleName_MapHash() { func ExampleName_CompareFold_sort() { names := []Name{ - ParseNameFill("mistral:latest", ""), - ParseNameFill("mistRal:7b+q4", ""), - ParseNameFill("MIstral:7b", ""), + ParseName("mistral:latest", FillNothing), + ParseName("mistRal:7b+q4", FillNothing), + ParseName("MIstral:7b", FillNothing), } slices.SortFunc(names, Name.CompareFold) @@ -457,7 +471,7 @@ func ExampleName_completeAndResolved() { "x/y/z:latest+q4_0", "@sha123-1", } { - name := ParseNameFill(s, "") + name := ParseName(s, FillNothing) fmt.Printf("complete:%v resolved:%v digest:%s\n", name.IsComplete(), name.IsResolved(), name.Digest()) } @@ -468,7 +482,7 @@ func ExampleName_completeAndResolved() { } func ExampleName_DisplayShortest() { - name := ParseNameFill("example.com/jmorganca/mistral:latest+Q4_0", "") + name := ParseName("example.com/jmorganca/mistral:latest+Q4_0", FillNothing) fmt.Println(name.DisplayShortest("example.com/jmorganca/_:latest")) fmt.Println(name.DisplayShortest("example.com/_/_:latest")) @@ -476,7 +490,7 @@ func ExampleName_DisplayShortest() { fmt.Println(name.DisplayShortest("_/_/_:_")) // Default - name = ParseNameFill("registry.ollama.ai/library/mistral:latest+Q4_0", "") + name = ParseName("registry.ollama.ai/library/mistral:latest+Q4_0", FillNothing) fmt.Println(name.DisplayShortest("")) // Output: