types/model: make ParseName variants less confusing (#3617)

Also, fix http stripping bug.

Also, improve upon docs about fills and masks.
This commit is contained in:
Blake Mizerany 2024-04-12 13:57:57 -07:00 committed by GitHub
parent 2b341069a7
commit 08655170aa
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
2 changed files with 176 additions and 135 deletions

View file

@ -3,6 +3,7 @@ package model
import (
"cmp"
"errors"
"fmt"
"hash/maphash"
"io"
"log/slog"
@ -25,11 +26,17 @@ var (
// Defaults
const (
// DefaultMask is the default mask used by [Name.DisplayShortest].
DefaultMask = "registry.ollama.ai/library/_:latest"
// MaskDefault is the default mask used by [Name.DisplayShortest].
MaskDefault = "registry.ollama.ai/library/?:latest"
// MaskNothing is a mask that masks nothing.
MaskNothing = "?/?/?:?"
// DefaultFill is the default fill used by [ParseName].
DefaultFill = "registry.ollama.ai/library/_:latest"
FillDefault = "registry.ollama.ai/library/?:latest+Q4_0"
// FillNothing is a fill that fills nothing.
FillNothing = "?/?/?:?+?"
)
const MaxNamePartLen = 128
@ -47,11 +54,7 @@ const (
PartBuild
PartDigest
// Invalid is a special part that is used to indicate that a part is
// invalid. It is not a valid part of a Name.
//
// It should be kept as the last part in the list.
PartInvalid
PartExtraneous = -1
)
var kindNames = map[PartKind]string{
@ -61,7 +64,6 @@ var kindNames = map[PartKind]string{
PartTag: "Tag",
PartBuild: "Build",
PartDigest: "Digest",
PartInvalid: "Invalid",
}
func (k PartKind) String() string {
@ -96,8 +98,6 @@ func (k PartKind) String() string {
// The parts can be obtained in their original form by calling [Name.Parts].
//
// To check if a Name has at minimum a valid model part, use [Name.IsValid].
//
// To make a Name by filling in missing parts from another Name, use [Fill].
type Name struct {
_ structs.Incomparable
parts [6]string // host, namespace, model, tag, build, digest
@ -109,7 +109,7 @@ type Name struct {
// and mean zero allocations for String.
}
// ParseNameFill parses s into a Name, and returns the result of filling it with
// ParseName parses s into a Name, and returns the result of filling it with
// defaults. The input string must be a valid string
// representation of a model name in the form:
//
@ -139,19 +139,19 @@ type Name struct {
//
// It returns the zero value if any part is invalid.
//
// As a rule of thumb, an valid name is one that can be round-tripped with
// the [Name.String] method. That means ("x+") is invalid because
// [Name.String] will not print a "+" if the build is empty.
// # Fills
//
// For more about filling in missing parts, see [Fill].
func ParseNameFill(s, defaults string) Name {
// For any valid s, the fill string is used to fill in missing parts of the
// Name. The fill string must be a valid Name with the exception that any part
// may be the string ("?"), which will not be considered for filling.
func ParseName(s, fill string) Name {
var r Name
parts(s)(func(kind PartKind, part string) bool {
if kind == PartInvalid {
if kind == PartDigest && !ParseDigest(part).IsValid() {
r = Name{}
return false
}
if kind == PartDigest && !ParseDigest(part).IsValid() {
if kind == PartExtraneous || !isValidPart(kind, part) {
r = Name{}
return false
}
@ -159,34 +159,48 @@ func ParseNameFill(s, defaults string) Name {
return true
})
if r.IsValid() || r.IsResolved() {
if defaults == "" {
return r
}
return Fill(r, ParseNameFill(defaults, ""))
fill = cmp.Or(fill, FillDefault)
return fillName(r, fill)
}
return Name{}
}
// ParseName is equal to ParseNameFill(s, DefaultFill).
func ParseName(s string) Name {
return ParseNameFill(s, DefaultFill)
func parseMask(s string) Name {
var r Name
parts(s)(func(kind PartKind, part string) bool {
if part == "?" {
// mask part; treat as empty but valid
return true
}
if !isValidPart(kind, part) {
panic(fmt.Errorf("invalid mask part %s: %q", kind, part))
}
r.parts[kind] = part
return true
})
return r
}
func MustParseNameFill(s, defaults string) Name {
r := ParseNameFill(s, "")
func MustParseName(s, defaults string) Name {
r := ParseName(s, "")
if !r.IsValid() {
panic("model.MustParseName: invalid name: " + s)
panic("invalid Name: " + s)
}
return r
}
// Fill fills in the missing parts of dst with the parts of src.
// fillName fills in the missing parts of dst with the parts of src.
//
// The returned Name will only be valid if dst is valid.
func Fill(dst, src Name) Name {
var r Name
//
// It skipps fill parts that are "?".
func fillName(r Name, fill string) Name {
f := parseMask(fill)
for i := range r.parts {
r.parts[i] = cmp.Or(dst.parts[i], src.parts[i])
if f.parts[i] == "?" {
continue
}
r.parts[i] = cmp.Or(r.parts[i], f.parts[i])
}
return r
}
@ -231,30 +245,58 @@ func (r Name) slice(from, to PartKind) Name {
return v
}
// DisplayShortest returns the shortest possible display string in form:
// DisplayShortest returns the shortest possible, masked display string in form:
//
// [host/][<namespace>/]<model>[:<tag>]
//
// The host is omitted if it is the mask host is the same as r.
// The namespace is omitted if the host and the namespace are the same as r.
// The tag is omitted if it is the mask tag is the same as r.
// # Masks
//
// The mask is a string that specifies which parts of the name to omit based
// on case-insensitive comparison. [Name.DisplayShortest] omits parts of the name
// that are the same as the mask, moving from left to right until the first
// unequal part is found. It then moves right to left until the first unequal
// part is found. The result is the shortest possible display string.
//
// Unlike a [Name] the mask can contain "?" characters which are treated as
// wildcards. A "?" will never match a part of the name, since a valid name
// can never contain a "?" character.
//
// For example: Given a Name ("registry.ollama.ai/library/mistral:latest") masked
// with ("registry.ollama.ai/library/?:latest") will produce the display string
// ("mistral").
//
// If mask is the empty string, then [MaskDefault] is used.
//
// # Safety
//
// To avoid unsafe behavior, DisplayShortest will panic if r is the zero
// value to prevent the returns of a "" string. Callers should consult
// [Name.IsValid] before calling this method.
//
// # Builds
//
// For now, DisplayShortest does consider the build or return one in the
// result. We can lift this restriction when needed.
func (r Name) DisplayShortest(mask string) string {
mask = cmp.Or(mask, DefaultMask)
d := ParseName(mask)
if !d.IsValid() {
panic("mask is an invalid Name")
mask = cmp.Or(mask, MaskDefault)
d := parseMask(mask)
if d.IsZero() {
panic(fmt.Errorf("invalid mask %q", mask))
}
equalSlice := func(form, to PartKind) bool {
return r.slice(form, to).EqualFold(d.slice(form, to))
if r.IsZero() {
panic("invalid Name")
}
if equalSlice(PartHost, PartNamespace) {
r.parts[PartNamespace] = ""
for i := range PartTag {
if !strings.EqualFold(r.parts[i], d.parts[i]) {
break
}
r.parts[i] = ""
}
if equalSlice(PartHost, PartHost) {
r.parts[PartHost] = ""
}
if equalSlice(PartTag, PartTag) {
r.parts[PartTag] = ""
for i := PartTag; i >= 0; i-- {
if !strings.EqualFold(r.parts[i], d.parts[i]) {
break
}
r.parts[i] = ""
}
return r.slice(PartHost, PartTag).String()
}
@ -418,27 +460,16 @@ type iter_Seq2[A, B any] func(func(A, B) bool)
// No other normalizations are performed.
func parts(s string) iter_Seq2[PartKind, string] {
return func(yield func(PartKind, string) bool) {
//nolint:gosimple
if strings.HasPrefix(s, "http://") {
s = s[len("http://"):]
}
//nolint:gosimple
if strings.HasPrefix(s, "https://") {
s = s[len("https://"):]
} else {
s = strings.TrimPrefix(s, "https://")
}
if len(s) > MaxNamePartLen || len(s) == 0 {
return
}
yieldValid := func(kind PartKind, part string) bool {
if !isValidPart(kind, part) {
yield(PartInvalid, "")
return false
}
return yield(kind, part)
}
numConsecutiveDots := 0
partLen := 0
state, j := PartDigest, len(s)
@ -448,7 +479,7 @@ func parts(s string) iter_Seq2[PartKind, string] {
// we don't keep spinning on it, waiting for
// an isInValidPart check which would scan
// over it again.
yield(PartInvalid, "")
yield(state, s[i+1:j])
return
}
@ -456,7 +487,7 @@ func parts(s string) iter_Seq2[PartKind, string] {
case '@':
switch state {
case PartDigest:
if !yieldValid(PartDigest, s[i+1:j]) {
if !yield(PartDigest, s[i+1:j]) {
return
}
if i == 0 {
@ -468,67 +499,63 @@ func parts(s string) iter_Seq2[PartKind, string] {
}
state, j, partLen = PartBuild, i, 0
default:
yield(PartInvalid, "")
yield(PartExtraneous, s[i+1:j])
return
}
case '+':
switch state {
case PartBuild, PartDigest:
if !yieldValid(PartBuild, s[i+1:j]) {
if !yield(PartBuild, s[i+1:j]) {
return
}
state, j, partLen = PartTag, i, 0
default:
yield(PartInvalid, "")
yield(PartExtraneous, s[i+1:j])
return
}
case ':':
switch state {
case PartTag, PartBuild, PartDigest:
if !yieldValid(PartTag, s[i+1:j]) {
if !yield(PartTag, s[i+1:j]) {
return
}
state, j, partLen = PartModel, i, 0
default:
yield(PartInvalid, "")
yield(PartExtraneous, s[i+1:j])
return
}
case '/':
switch state {
case PartModel, PartTag, PartBuild, PartDigest:
if !yieldValid(PartModel, s[i+1:j]) {
if !yield(PartModel, s[i+1:j]) {
return
}
state, j = PartNamespace, i
case PartNamespace:
if !yieldValid(PartNamespace, s[i+1:j]) {
if !yield(PartNamespace, s[i+1:j]) {
return
}
state, j, partLen = PartHost, i, 0
default:
yield(PartInvalid, "")
yield(PartExtraneous, s[i+1:j])
return
}
default:
if s[i] == '.' {
if numConsecutiveDots++; numConsecutiveDots > 1 {
yield(PartInvalid, "")
yield(state, "")
return
}
} else {
numConsecutiveDots = 0
}
if !isValidByteFor(state, s[i]) {
yield(PartInvalid, "")
return
}
}
}
if state <= PartNamespace {
yieldValid(state, s[:j])
yield(state, s[:j])
} else {
yieldValid(PartModel, s[:j])
yield(PartModel, s[:j])
}
}
}

View file

@ -111,11 +111,11 @@ func TestNameConsecutiveDots(t *testing.T) {
for i := 1; i < 10; i++ {
s := strings.Repeat(".", i)
if i > 1 {
if g := ParseNameFill(s, "").String(); g != "" {
if g := ParseName(s, FillNothing).String(); g != "" {
t.Errorf("ParseName(%q) = %q; want empty string", s, g)
}
} else {
if g := ParseNameFill(s, "").String(); g != s {
if g := ParseName(s, FillNothing).String(); g != s {
t.Errorf("ParseName(%q) = %q; want %q", s, g, s)
}
}
@ -148,14 +148,14 @@ func TestParseName(t *testing.T) {
s := prefix + baseName
t.Run(s, func(t *testing.T) {
name := ParseNameFill(s, "")
name := ParseName(s, FillNothing)
got := fieldsFromName(name)
if got != want {
t.Errorf("ParseName(%q) = %q; want %q", s, got, want)
}
// test round-trip
if !ParseNameFill(name.String(), "").EqualFold(name) {
if !ParseName(name.String(), FillNothing).EqualFold(name) {
t.Errorf("ParseName(%q).String() = %s; want %s", s, name.String(), baseName)
}
})
@ -163,6 +163,47 @@ func TestParseName(t *testing.T) {
}
}
func TestParseNameFill(t *testing.T) {
cases := []struct {
in string
fill string
want string
}{
{"mistral", "example.com/library/?:latest+Q4_0", "example.com/library/mistral:latest+Q4_0"},
{"mistral", "example.com/library/?:latest", "example.com/library/mistral:latest"},
{"llama2:x", "example.com/library/?:latest+Q4_0", "example.com/library/llama2:x+Q4_0"},
// Invalid
{"", "example.com/library/?:latest+Q4_0", ""},
{"llama2:?", "example.com/library/?:latest+Q4_0", ""},
}
for _, tt := range cases {
t.Run(tt.in, func(t *testing.T) {
name := ParseName(tt.in, tt.fill)
if g := name.String(); g != tt.want {
t.Errorf("ParseName(%q, %q) = %q; want %q", tt.in, tt.fill, g, tt.want)
}
})
}
}
func TestParseNameHTTPDoublePrefixStrip(t *testing.T) {
cases := []string{
"http://https://valid.com/valid/valid:latest",
"https://http://valid.com/valid/valid:latest",
}
for _, s := range cases {
t.Run(s, func(t *testing.T) {
name := ParseName(s, FillNothing)
if name.IsValid() {
t.Errorf("expected invalid path; got %#v", name)
}
})
}
}
func TestCompleteWithAndWithoutBuild(t *testing.T) {
cases := []struct {
in string
@ -179,7 +220,7 @@ func TestCompleteWithAndWithoutBuild(t *testing.T) {
for _, tt := range cases {
t.Run(tt.in, func(t *testing.T) {
p := ParseNameFill(tt.in, "")
p := ParseName(tt.in, FillNothing)
t.Logf("ParseName(%q) = %#v", tt.in, p)
if g := p.IsComplete(); g != tt.complete {
t.Errorf("Complete(%q) = %v; want %v", tt.in, g, tt.complete)
@ -194,7 +235,7 @@ func TestCompleteWithAndWithoutBuild(t *testing.T) {
// inlined when used in Complete, preventing any allocations or
// escaping to the heap.
allocs := testing.AllocsPerRun(1000, func() {
keep(ParseNameFill("complete.com/x/mistral:latest+Q4_0", "").IsComplete())
keep(ParseName("complete.com/x/mistral:latest+Q4_0", FillNothing).IsComplete())
})
if allocs > 0 {
t.Errorf("Complete allocs = %v; want 0", allocs)
@ -211,7 +252,7 @@ func TestNameLogValue(t *testing.T) {
t.Run(s, func(t *testing.T) {
var b bytes.Buffer
log := slog.New(slog.NewTextHandler(&b, nil))
name := ParseNameFill(s, "")
name := ParseName(s, FillNothing)
log.Info("", "name", name)
want := fmt.Sprintf("name=%s", name.GoString())
got := b.String()
@ -258,7 +299,7 @@ func TestNameGoString(t *testing.T) {
for _, tt := range cases {
t.Run(tt.name, func(t *testing.T) {
p := ParseNameFill(tt.in, "")
p := ParseName(tt.in, FillNothing)
tt.wantGoString = cmp.Or(tt.wantGoString, tt.in)
if g := fmt.Sprintf("%#v", p); g != tt.wantGoString {
t.Errorf("GoString() = %q; want %q", g, tt.wantGoString)
@ -286,11 +327,14 @@ func TestDisplayShortest(t *testing.T) {
{"example.com/library/mistral:Latest+Q4_0", "example.com/library/_:latest", "mistral", false},
{"example.com/library/mistral:Latest+q4_0", "example.com/library/_:latest", "mistral", false},
// zero value
{"", MaskDefault, "", true},
// invalid mask
{"example.com/library/mistral:latest+Q4_0", "example.com/mistral", "", true},
// DefaultMask
{"registry.ollama.ai/library/mistral:latest+Q4_0", DefaultMask, "mistral", false},
{"registry.ollama.ai/library/mistral:latest+Q4_0", MaskDefault, "mistral", false},
// Auto-Fill
{"x", "example.com/library/_:latest", "x", false},
@ -309,7 +353,7 @@ func TestDisplayShortest(t *testing.T) {
}
}()
p := ParseNameFill(tt.in, "")
p := ParseName(tt.in, FillNothing)
t.Logf("ParseName(%q) = %#v", tt.in, p)
if g := p.DisplayShortest(tt.mask); g != tt.want {
t.Errorf("got = %q; want %q", g, tt.want)
@ -320,7 +364,7 @@ func TestDisplayShortest(t *testing.T) {
func TestParseNameAllocs(t *testing.T) {
allocs := testing.AllocsPerRun(1000, func() {
keep(ParseNameFill("example.com/mistral:7b+Q4_0", ""))
keep(ParseName("example.com/mistral:7b+Q4_0", FillNothing))
})
if allocs > 0 {
t.Errorf("ParseName allocs = %v; want 0", allocs)
@ -331,7 +375,7 @@ func BenchmarkParseName(b *testing.B) {
b.ReportAllocs()
for range b.N {
keep(ParseNameFill("example.com/mistral:7b+Q4_0", ""))
keep(ParseName("example.com/mistral:7b+Q4_0", FillNothing))
}
}
@ -346,7 +390,7 @@ func FuzzParseName(f *testing.F) {
f.Add(":@!@")
f.Add("...")
f.Fuzz(func(t *testing.T, s string) {
r0 := ParseNameFill(s, "")
r0 := ParseName(s, FillNothing)
if strings.Contains(s, "..") && !r0.IsZero() {
t.Fatalf("non-zero value for path with '..': %q", s)
@ -369,36 +413,15 @@ func FuzzParseName(f *testing.F) {
t.Errorf("String() did not round-trip with case insensitivity: %q\ngot = %q\nwant = %q", s, r0.String(), s)
}
r1 := ParseNameFill(r0.String(), "")
r1 := ParseName(r0.String(), FillNothing)
if !r0.EqualFold(r1) {
t.Errorf("round-trip mismatch: %+v != %+v", r0, r1)
}
})
}
func TestFill(t *testing.T) {
cases := []struct {
dst string
src string
want string
}{
{"mistral", "o.com/library/PLACEHOLDER:latest+Q4_0", "o.com/library/mistral:latest+Q4_0"},
{"o.com/library/mistral", "PLACEHOLDER:latest+Q4_0", "o.com/library/mistral:latest+Q4_0"},
{"", "o.com/library/mistral:latest+Q4_0", "o.com/library/mistral:latest+Q4_0"},
}
for _, tt := range cases {
t.Run(tt.dst, func(t *testing.T) {
r := Fill(ParseNameFill(tt.dst, ""), ParseNameFill(tt.src, ""))
if r.String() != tt.want {
t.Errorf("Fill(%q, %q) = %q; want %q", tt.dst, tt.src, r, tt.want)
}
})
}
}
func TestNameStringAllocs(t *testing.T) {
name := ParseNameFill("example.com/ns/mistral:latest+Q4_0", "")
name := ParseName("example.com/ns/mistral:latest+Q4_0", FillNothing)
allocs := testing.AllocsPerRun(1000, func() {
keep(name.String())
})
@ -407,25 +430,16 @@ func TestNameStringAllocs(t *testing.T) {
}
}
func ExampleFill() {
defaults := ParseNameFill("registry.ollama.com/library/PLACEHOLDER:latest+Q4_0", "")
r := Fill(ParseNameFill("mistral", ""), defaults)
fmt.Println(r)
// Output:
// registry.ollama.com/library/mistral:latest+Q4_0
}
func ExampleName_MapHash() {
m := map[uint64]bool{}
// key 1
m[ParseNameFill("mistral:latest+q4", "").MapHash()] = true
m[ParseNameFill("miSTRal:latest+Q4", "").MapHash()] = true
m[ParseNameFill("mistral:LATest+Q4", "").MapHash()] = true
m[ParseName("mistral:latest+q4", FillNothing).MapHash()] = true
m[ParseName("miSTRal:latest+Q4", FillNothing).MapHash()] = true
m[ParseName("mistral:LATest+Q4", FillNothing).MapHash()] = true
// key 2
m[ParseNameFill("mistral:LATest", "").MapHash()] = true
m[ParseName("mistral:LATest", FillNothing).MapHash()] = true
fmt.Println(len(m))
// Output:
@ -434,9 +448,9 @@ func ExampleName_MapHash() {
func ExampleName_CompareFold_sort() {
names := []Name{
ParseNameFill("mistral:latest", ""),
ParseNameFill("mistRal:7b+q4", ""),
ParseNameFill("MIstral:7b", ""),
ParseName("mistral:latest", FillNothing),
ParseName("mistRal:7b+q4", FillNothing),
ParseName("MIstral:7b", FillNothing),
}
slices.SortFunc(names, Name.CompareFold)
@ -457,7 +471,7 @@ func ExampleName_completeAndResolved() {
"x/y/z:latest+q4_0",
"@sha123-1",
} {
name := ParseNameFill(s, "")
name := ParseName(s, FillNothing)
fmt.Printf("complete:%v resolved:%v digest:%s\n", name.IsComplete(), name.IsResolved(), name.Digest())
}
@ -468,7 +482,7 @@ func ExampleName_completeAndResolved() {
}
func ExampleName_DisplayShortest() {
name := ParseNameFill("example.com/jmorganca/mistral:latest+Q4_0", "")
name := ParseName("example.com/jmorganca/mistral:latest+Q4_0", FillNothing)
fmt.Println(name.DisplayShortest("example.com/jmorganca/_:latest"))
fmt.Println(name.DisplayShortest("example.com/_/_:latest"))
@ -476,7 +490,7 @@ func ExampleName_DisplayShortest() {
fmt.Println(name.DisplayShortest("_/_/_:_"))
// Default
name = ParseNameFill("registry.ollama.ai/library/mistral:latest+Q4_0", "")
name = ParseName("registry.ollama.ai/library/mistral:latest+Q4_0", FillNothing)
fmt.Println(name.DisplayShortest(""))
// Output: