From 548a7df014df6e50947878bfc2eb6fc3106e3906 Mon Sep 17 00:00:00 2001 From: Michael Yang Date: Wed, 17 Apr 2024 14:54:14 -0700 Subject: [PATCH] update list handler to use model.Name --- server/images.go | 11 ------ server/manifest.go | 79 ++++++++++++++++++++++++++++++++++++++ server/manifests.go | 34 ---------------- server/routes.go | 83 +++++++++++++++++++++------------------- types/model/name.go | 54 ++++++++++++++++++++++++-- types/model/name_test.go | 43 +++++++++++++++++++++ 6 files changed, 216 insertions(+), 88 deletions(-) create mode 100644 server/manifest.go delete mode 100644 server/manifests.go diff --git a/server/images.go b/server/images.go index 2817b1d3..a5a5dacc 100644 --- a/server/images.go +++ b/server/images.go @@ -52,7 +52,6 @@ type Model struct { System string License []string Digest string - Size int64 Options map[string]interface{} Messages []Message } @@ -161,15 +160,6 @@ type RootFS struct { DiffIDs []string `json:"diff_ids"` } -func (m *ManifestV2) GetTotalSize() (total int64) { - for _, layer := range m.Layers { - total += layer.Size - } - - total += m.Config.Size - return total -} - func GetManifest(mp ModelPath) (*ManifestV2, string, error) { fp, err := mp.GetManifestPath() if err != nil { @@ -210,7 +200,6 @@ func GetModel(name string) (*Model, error) { Digest: digest, Template: "{{ .Prompt }}", License: []string{}, - Size: manifest.GetTotalSize(), } filename, err := GetBlobsPath(manifest.Config.Digest) diff --git a/server/manifest.go b/server/manifest.go new file mode 100644 index 00000000..8a17700e --- /dev/null +++ b/server/manifest.go @@ -0,0 +1,79 @@ +package server + +import ( + "bytes" + "crypto/sha256" + "encoding/json" + "fmt" + "io" + "os" + "path/filepath" + + "github.com/ollama/ollama/types/model" +) + +type Manifest struct { + ManifestV2 + Digest string `json:"-"` +} + +func (m *Manifest) Size() (size int64) { + for _, layer := range append(m.Layers, m.Config) { + size += layer.Size + } + + return +} + +func ParseNamedManifest(name model.Name) (*Manifest, error) { + if !name.IsFullyQualified() { + return nil, model.Unqualified(name) + } + + manifests, err := GetManifestPath() + if err != nil { + return nil, err + } + + var manifest ManifestV2 + manifestfile, err := os.Open(filepath.Join(manifests, name.Filepath())) + if err != nil { + return nil, err + } + + sha256sum := sha256.New() + if err := json.NewDecoder(io.TeeReader(manifestfile, sha256sum)).Decode(&manifest); err != nil { + return nil, err + } + + return &Manifest{ + ManifestV2: manifest, + Digest: fmt.Sprintf("%x", sha256sum.Sum(nil)), + }, nil +} + +func WriteManifest(name string, config *Layer, layers []*Layer) error { + manifest := ManifestV2{ + SchemaVersion: 2, + MediaType: "application/vnd.docker.distribution.manifest.v2+json", + Config: config, + Layers: layers, + } + + var b bytes.Buffer + if err := json.NewEncoder(&b).Encode(manifest); err != nil { + return err + } + + modelpath := ParseModelPath(name) + manifestPath, err := modelpath.GetManifestPath() + if err != nil { + return err + } + + if err := os.MkdirAll(filepath.Dir(manifestPath), 0o755); err != nil { + return err + } + + return os.WriteFile(manifestPath, b.Bytes(), 0o644) +} diff --git a/server/manifests.go b/server/manifests.go deleted file mode 100644 index 2b39db65..00000000 --- a/server/manifests.go +++ /dev/null @@ -1,34 +0,0 @@ -package server - -import ( - "bytes" - "encoding/json" - "os" - "path/filepath" -) - -func WriteManifest(name string, config *Layer, layers []*Layer) error { - manifest := ManifestV2{ - SchemaVersion: 2, - MediaType: "application/vnd.docker.distribution.manifest.v2+json", - Config: config, - Layers: layers, - } - - var b bytes.Buffer - if err := json.NewEncoder(&b).Encode(manifest); err != nil { - return err - } - - modelpath := ParseModelPath(name) - manifestPath, err := modelpath.GetManifestPath() - if err != nil { - return err - } - - if err := os.MkdirAll(filepath.Dir(manifestPath), 0o755); err != nil { - return err - } - - return os.WriteFile(manifestPath, b.Bytes(), 0o644) -} diff --git a/server/routes.go b/server/routes.go index 0a11909c..33f00382 100644 --- a/server/routes.go +++ b/server/routes.go @@ -719,62 +719,65 @@ func GetModelInfo(req api.ShowRequest) (*api.ShowResponse, error) { } func (s *Server) ListModelsHandler(c *gin.Context) { - models := make([]api.ModelResponse, 0) - manifestsPath, err := GetManifestPath() + manifests, err := GetManifestPath() if err != nil { c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()}) return } - modelResponse := func(modelName string) (api.ModelResponse, error) { - model, err := GetModel(modelName) - if err != nil { - return api.ModelResponse{}, err - } - - modelDetails := api.ModelDetails{ - Format: model.Config.ModelFormat, - Family: model.Config.ModelFamily, - Families: model.Config.ModelFamilies, - ParameterSize: model.Config.ModelType, - QuantizationLevel: model.Config.FileType, - } - - return api.ModelResponse{ - Model: model.ShortName, - Name: model.ShortName, - Size: model.Size, - Digest: model.Digest, - Details: modelDetails, - }, nil - } - - walkFunc := func(path string, info os.FileInfo, _ error) error { + var models []api.ModelResponse + if err := filepath.Walk(manifests, func(path string, info os.FileInfo, _ error) error { if !info.IsDir() { - path, tag := filepath.Split(path) - model := strings.Trim(strings.TrimPrefix(path, manifestsPath), string(os.PathSeparator)) - modelPath := strings.Join([]string{model, tag}, ":") - canonicalModelPath := strings.ReplaceAll(modelPath, string(os.PathSeparator), "/") - - resp, err := modelResponse(canonicalModelPath) + rel, err := filepath.Rel(manifests, path) if err != nil { - slog.Info(fmt.Sprintf("skipping file: %s", canonicalModelPath)) - // nolint: nilerr - return nil + return err } - resp.ModifiedAt = info.ModTime() - models = append(models, resp) + n := model.ParseNameFromFilepath(rel) + m, err := ParseNamedManifest(n) + if err != nil { + return err + } + + f, err := m.Config.Open() + if err != nil { + return err + } + defer f.Close() + + var c ConfigV2 + if err := json.NewDecoder(f).Decode(&c); err != nil { + return err + } + + // tag should never be masked + models = append(models, api.ModelResponse{ + Model: n.DisplayShortest(), + Name: n.DisplayShortest(), + Size: m.Size(), + Digest: m.Digest, + ModifiedAt: info.ModTime(), + Details: api.ModelDetails{ + Format: c.ModelFormat, + Family: c.ModelFamily, + Families: c.ModelFamilies, + ParameterSize: c.ModelType, + QuantizationLevel: c.FileType, + }, + }) } return nil - } - - if err := filepath.Walk(manifestsPath, walkFunc); err != nil { + }); err != nil { c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()}) return } + slices.SortStableFunc(models, func(i, j api.ModelResponse) int { + // most recently modified first + return cmp.Compare(j.ModifiedAt.Unix(), i.ModifiedAt.Unix()) + }) + c.JSON(http.StatusOK, api.ListResponse{Models: models}) } diff --git a/types/model/name.go b/types/model/name.go index fbb30fd4..6d2a187b 100644 --- a/types/model/name.go +++ b/types/model/name.go @@ -35,6 +35,12 @@ func Unqualified(n Name) error { // spot in logs. const MissingPart = "!MISSING!" +const ( + defaultHost = "registry.ollama.ai" + defaultNamespace = "library" + defaultTag = "latest" +) + // DefaultName returns a name with the default values for the host, namespace, // and tag parts. The model and digest parts are empty. // @@ -43,9 +49,9 @@ const MissingPart = "!MISSING!" // - The default tag is ("latest") func DefaultName() Name { return Name{ - Host: "registry.ollama.ai", - Namespace: "library", - Tag: "latest", + Host: defaultHost, + Namespace: defaultNamespace, + Tag: defaultTag, } } @@ -169,6 +175,27 @@ func ParseNameBare(s string) Name { return n } +// ParseNameFromFilepath parses a 4-part filepath as a Name. The parts are +// expected to be in the form: +// +// { host } "/" { namespace } "/" { model } "/" { tag } +func ParseNameFromFilepath(s string) (n Name) { + parts := strings.Split(s, string(filepath.Separator)) + if len(parts) != 4 { + return Name{} + } + + n.Host = parts[0] + n.Namespace = parts[1] + n.Model = parts[2] + n.Tag = parts[3] + if !n.IsFullyQualified() { + return Name{} + } + + return n +} + // Merge merges the host, namespace, and tag parts of the two names, // preferring the non-empty parts of a. func Merge(a, b Name) Name { @@ -203,6 +230,27 @@ func (n Name) String() string { return b.String() } +// DisplayShort returns a short string version of the name. +func (n Name) DisplayShortest() string { + var sb strings.Builder + + if n.Host != defaultHost { + sb.WriteString(n.Host) + sb.WriteByte('/') + sb.WriteString(n.Namespace) + sb.WriteByte('/') + } else if n.Namespace != defaultNamespace { + sb.WriteString(n.Namespace) + sb.WriteByte('/') + } + + // always include model and tag + sb.WriteString(n.Model) + sb.WriteString(":") + sb.WriteString(n.Tag) + return sb.String() +} + // IsValid reports whether all parts of the name are present and valid. The // digest is a special case, and is checked for validity only if present. func (n Name) IsValid() bool { diff --git a/types/model/name_test.go b/types/model/name_test.go index 47263c20..19bc2e2d 100644 --- a/types/model/name_test.go +++ b/types/model/name_test.go @@ -309,6 +309,49 @@ func TestParseDigest(t *testing.T) { } } +func TestParseNameFromFilepath(t *testing.T) { + cases := map[string]Name{ + filepath.Join("host", "namespace", "model", "tag"): {Host: "host", Namespace: "namespace", Model: "model", Tag: "tag"}, + filepath.Join("host:port", "namespace", "model", "tag"): {Host: "host:port", Namespace: "namespace", Model: "model", Tag: "tag"}, + filepath.Join("namespace", "model", "tag"): {}, + filepath.Join("model", "tag"): {}, + filepath.Join("model"): {}, + filepath.Join("..", "..", "model", "tag"): {}, + filepath.Join("", "namespace", ".", "tag"): {}, + filepath.Join(".", ".", ".", "."): {}, + filepath.Join("/", "path", "to", "random", "file"): {}, + } + + for in, want := range cases { + t.Run(in, func(t *testing.T) { + got := ParseNameFromFilepath(in) + + if !reflect.DeepEqual(got, want) { + t.Errorf("parseNameFromFilepath(%q) = %v; want %v", in, got, want) + } + }) + } +} + +func TestDisplayShortest(t *testing.T) { + cases := map[string]string{ + "registry.ollama.ai/library/model:latest": "model:latest", + "registry.ollama.ai/library/model:tag": "model:tag", + "registry.ollama.ai/namespace/model:tag": "namespace/model:tag", + "host/namespace/model:tag": "host/namespace/model:tag", + "host/library/model:tag": "host/library/model:tag", + } + + for in, want := range cases { + t.Run(in, func(t *testing.T) { + got := ParseNameBare(in).DisplayShortest() + if got != want { + t.Errorf("parseName(%q).DisplayShortest() = %q; want %q", in, got, want) + } + }) + } +} + func FuzzName(f *testing.F) { for s := range testCases { f.Add(s)