routes: use Manifests for ListHandler

This commit is contained in:
Michael Yang 2024-05-06 16:34:13 -07:00
parent a2fc933fed
commit c2714fcbfd
3 changed files with 127 additions and 58 deletions

View file

@ -6,6 +6,7 @@ import (
"encoding/json" "encoding/json"
"fmt" "fmt"
"io" "io"
"log/slog"
"os" "os"
"path/filepath" "path/filepath"
@ -16,6 +17,7 @@ type Manifest struct {
ManifestV2 ManifestV2
filepath string filepath string
fi os.FileInfo
digest string digest string
} }
@ -65,6 +67,11 @@ func ParseNamedManifest(n model.Name) (*Manifest, error) {
} }
defer f.Close() defer f.Close()
fi, err := f.Stat()
if err != nil {
return nil, err
}
sha256sum := sha256.New() sha256sum := sha256.New()
if err := json.NewDecoder(io.TeeReader(f, sha256sum)).Decode(&m); err != nil { if err := json.NewDecoder(io.TeeReader(f, sha256sum)).Decode(&m); err != nil {
return nil, err return nil, err
@ -73,6 +80,7 @@ func ParseNamedManifest(n model.Name) (*Manifest, error) {
return &Manifest{ return &Manifest{
ManifestV2: m, ManifestV2: m,
filepath: p, filepath: p,
fi: fi,
digest: fmt.Sprintf("%x", sha256sum.Sum(nil)), digest: fmt.Sprintf("%x", sha256sum.Sum(nil)),
}, nil }, nil
} }
@ -126,7 +134,8 @@ func Manifests() (map[model.Name]*Manifest, error) {
if n.IsValid() { if n.IsValid() {
m, err := ParseNamedManifest(n) m, err := ParseNamedManifest(n)
if err != nil { if err != nil {
return nil, err slog.Warn("bad manifest", "name", n, "error", err)
continue
} }
ms[n] = m ms[n] = m

90
server/manifest_test.go Normal file
View file

@ -0,0 +1,90 @@
package server
import (
"encoding/json"
"os"
"path/filepath"
"slices"
"testing"
"github.com/ollama/ollama/types/model"
)
func createManifest(t *testing.T, path, name string) {
t.Helper()
p := filepath.Join(path, "manifests", name)
if err := os.MkdirAll(filepath.Dir(p), 0755); err != nil {
t.Fatal(err)
}
f, err := os.Create(p)
if err != nil {
t.Fatal(err)
}
defer f.Close()
if err := json.NewEncoder(f).Encode(ManifestV2{}); err != nil {
t.Fatal(err)
}
}
func TestManifests(t *testing.T) {
cases := map[string][]string{
"empty": {},
"single": {
filepath.Join("host", "namespace", "model", "tag"),
},
"multiple": {
filepath.Join("registry.ollama.ai", "library", "llama3", "latest"),
filepath.Join("registry.ollama.ai", "library", "llama3", "q4_0"),
filepath.Join("registry.ollama.ai", "library", "llama3", "q4_1"),
filepath.Join("registry.ollama.ai", "library", "llama3", "q8_0"),
filepath.Join("registry.ollama.ai", "library", "llama3", "q5_0"),
filepath.Join("registry.ollama.ai", "library", "llama3", "q5_1"),
filepath.Join("registry.ollama.ai", "library", "llama3", "q2_K"),
filepath.Join("registry.ollama.ai", "library", "llama3", "q3_K_S"),
filepath.Join("registry.ollama.ai", "library", "llama3", "q3_K_M"),
filepath.Join("registry.ollama.ai", "library", "llama3", "q3_K_L"),
filepath.Join("registry.ollama.ai", "library", "llama3", "q4_K_S"),
filepath.Join("registry.ollama.ai", "library", "llama3", "q4_K_M"),
filepath.Join("registry.ollama.ai", "library", "llama3", "q5_K_S"),
filepath.Join("registry.ollama.ai", "library", "llama3", "q5_K_M"),
filepath.Join("registry.ollama.ai", "library", "llama3", "q6_K"),
},
"hidden": {
filepath.Join("host", "namespace", "model", "tag"),
filepath.Join("host", "namespace", "model", ".hidden"),
},
}
for n, wants := range cases {
t.Run(n, func(t *testing.T) {
d := t.TempDir()
t.Setenv("OLLAMA_MODELS", d)
for _, want := range wants {
createManifest(t, d, want)
}
ms, err := Manifests()
if err != nil {
t.Fatal(err)
}
var ns []model.Name
for k := range ms {
ns = append(ns, k)
}
for _, want := range wants {
n := model.ParseNameFromFilepath(want)
if !n.IsValid() && slices.Contains(ns, n) {
t.Errorf("unexpected invalid name: %s", want)
} else if n.IsValid() && !slices.Contains(ns, n) {
t.Errorf("missing valid name: %s", want)
}
}
})
}
}

View file

@ -702,49 +702,25 @@ func GetModelInfo(req api.ShowRequest) (*api.ShowResponse, error) {
} }
func (s *Server) ListModelsHandler(c *gin.Context) { func (s *Server) ListModelsHandler(c *gin.Context) {
manifests, err := GetManifestPath() ms, err := Manifests()
if err != nil { if err != nil {
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()}) c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
return return
} }
models := []api.ModelResponse{} models := []api.ModelResponse{}
if err := filepath.Walk(manifests, func(path string, info os.FileInfo, _ error) error { for n, m := range ms {
if !info.IsDir() {
rel, err := filepath.Rel(manifests, path)
if err != nil {
return err
}
if hidden, err := filepath.Match(".*", filepath.Base(rel)); err != nil {
return err
} else if hidden {
return nil
}
n := model.ParseNameFromFilepath(rel)
if !n.IsValid() {
slog.Warn("bad manifest filepath", "path", rel)
return nil
}
m, err := ParseNamedManifest(n)
if err != nil {
slog.Warn("bad manifest", "name", n, "error", err)
return nil
}
f, err := m.Config.Open() f, err := m.Config.Open()
if err != nil { if err != nil {
slog.Warn("bad manifest config filepath", "name", n, "error", err) slog.Warn("bad manifest filepath", "name", n, "error", err)
return nil continue
} }
defer f.Close() defer f.Close()
var c ConfigV2 var cf ConfigV2
if err := json.NewDecoder(f).Decode(&c); err != nil { if err := json.NewDecoder(f).Decode(&cf); err != nil {
slog.Warn("bad manifest config", "name", n, "error", err) slog.Warn("bad manifest config", "name", n, "error", err)
return nil continue
} }
// tag should never be masked // tag should never be masked
@ -753,23 +729,17 @@ func (s *Server) ListModelsHandler(c *gin.Context) {
Name: n.DisplayShortest(), Name: n.DisplayShortest(),
Size: m.Size(), Size: m.Size(),
Digest: m.digest, Digest: m.digest,
ModifiedAt: info.ModTime(), ModifiedAt: m.fi.ModTime(),
Details: api.ModelDetails{ Details: api.ModelDetails{
Format: c.ModelFormat, Format: cf.ModelFormat,
Family: c.ModelFamily, Family: cf.ModelFamily,
Families: c.ModelFamilies, Families: cf.ModelFamilies,
ParameterSize: c.ModelType, ParameterSize: cf.ModelType,
QuantizationLevel: c.FileType, QuantizationLevel: cf.FileType,
}, },
}) })
} }
return nil
}); err != nil {
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
return
}
slices.SortStableFunc(models, func(i, j api.ModelResponse) int { slices.SortStableFunc(models, func(i, j api.ModelResponse) int {
// most recently modified first // most recently modified first
return cmp.Compare(j.ModifiedAt.Unix(), i.ModifiedAt.Unix()) return cmp.Compare(j.ModifiedAt.Unix(), i.ModifiedAt.Unix())