routes: use Manifests for ListHandler
This commit is contained in:
parent
a2fc933fed
commit
c2714fcbfd
3 changed files with 127 additions and 58 deletions
|
@ -6,6 +6,7 @@ import (
|
|||
"encoding/json"
|
||||
"fmt"
|
||||
"io"
|
||||
"log/slog"
|
||||
"os"
|
||||
"path/filepath"
|
||||
|
||||
|
@ -16,6 +17,7 @@ type Manifest struct {
|
|||
ManifestV2
|
||||
|
||||
filepath string
|
||||
fi os.FileInfo
|
||||
digest string
|
||||
}
|
||||
|
||||
|
@ -65,6 +67,11 @@ func ParseNamedManifest(n model.Name) (*Manifest, error) {
|
|||
}
|
||||
defer f.Close()
|
||||
|
||||
fi, err := f.Stat()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
sha256sum := sha256.New()
|
||||
if err := json.NewDecoder(io.TeeReader(f, sha256sum)).Decode(&m); err != nil {
|
||||
return nil, err
|
||||
|
@ -73,6 +80,7 @@ func ParseNamedManifest(n model.Name) (*Manifest, error) {
|
|||
return &Manifest{
|
||||
ManifestV2: m,
|
||||
filepath: p,
|
||||
fi: fi,
|
||||
digest: fmt.Sprintf("%x", sha256sum.Sum(nil)),
|
||||
}, nil
|
||||
}
|
||||
|
@ -126,7 +134,8 @@ func Manifests() (map[model.Name]*Manifest, error) {
|
|||
if n.IsValid() {
|
||||
m, err := ParseNamedManifest(n)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
slog.Warn("bad manifest", "name", n, "error", err)
|
||||
continue
|
||||
}
|
||||
|
||||
ms[n] = m
|
||||
|
|
90
server/manifest_test.go
Normal file
90
server/manifest_test.go
Normal file
|
@ -0,0 +1,90 @@
|
|||
package server
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"slices"
|
||||
"testing"
|
||||
|
||||
"github.com/ollama/ollama/types/model"
|
||||
)
|
||||
|
||||
func createManifest(t *testing.T, path, name string) {
|
||||
t.Helper()
|
||||
|
||||
p := filepath.Join(path, "manifests", name)
|
||||
if err := os.MkdirAll(filepath.Dir(p), 0755); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
f, err := os.Create(p)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
defer f.Close()
|
||||
|
||||
if err := json.NewEncoder(f).Encode(ManifestV2{}); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestManifests(t *testing.T) {
|
||||
cases := map[string][]string{
|
||||
"empty": {},
|
||||
"single": {
|
||||
filepath.Join("host", "namespace", "model", "tag"),
|
||||
},
|
||||
"multiple": {
|
||||
filepath.Join("registry.ollama.ai", "library", "llama3", "latest"),
|
||||
filepath.Join("registry.ollama.ai", "library", "llama3", "q4_0"),
|
||||
filepath.Join("registry.ollama.ai", "library", "llama3", "q4_1"),
|
||||
filepath.Join("registry.ollama.ai", "library", "llama3", "q8_0"),
|
||||
filepath.Join("registry.ollama.ai", "library", "llama3", "q5_0"),
|
||||
filepath.Join("registry.ollama.ai", "library", "llama3", "q5_1"),
|
||||
filepath.Join("registry.ollama.ai", "library", "llama3", "q2_K"),
|
||||
filepath.Join("registry.ollama.ai", "library", "llama3", "q3_K_S"),
|
||||
filepath.Join("registry.ollama.ai", "library", "llama3", "q3_K_M"),
|
||||
filepath.Join("registry.ollama.ai", "library", "llama3", "q3_K_L"),
|
||||
filepath.Join("registry.ollama.ai", "library", "llama3", "q4_K_S"),
|
||||
filepath.Join("registry.ollama.ai", "library", "llama3", "q4_K_M"),
|
||||
filepath.Join("registry.ollama.ai", "library", "llama3", "q5_K_S"),
|
||||
filepath.Join("registry.ollama.ai", "library", "llama3", "q5_K_M"),
|
||||
filepath.Join("registry.ollama.ai", "library", "llama3", "q6_K"),
|
||||
},
|
||||
"hidden": {
|
||||
filepath.Join("host", "namespace", "model", "tag"),
|
||||
filepath.Join("host", "namespace", "model", ".hidden"),
|
||||
},
|
||||
}
|
||||
|
||||
for n, wants := range cases {
|
||||
t.Run(n, func(t *testing.T) {
|
||||
d := t.TempDir()
|
||||
t.Setenv("OLLAMA_MODELS", d)
|
||||
|
||||
for _, want := range wants {
|
||||
createManifest(t, d, want)
|
||||
}
|
||||
|
||||
ms, err := Manifests()
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
var ns []model.Name
|
||||
for k := range ms {
|
||||
ns = append(ns, k)
|
||||
}
|
||||
|
||||
for _, want := range wants {
|
||||
n := model.ParseNameFromFilepath(want)
|
||||
if !n.IsValid() && slices.Contains(ns, n) {
|
||||
t.Errorf("unexpected invalid name: %s", want)
|
||||
} else if n.IsValid() && !slices.Contains(ns, n) {
|
||||
t.Errorf("missing valid name: %s", want)
|
||||
}
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
|
@ -702,49 +702,25 @@ func GetModelInfo(req api.ShowRequest) (*api.ShowResponse, error) {
|
|||
}
|
||||
|
||||
func (s *Server) ListModelsHandler(c *gin.Context) {
|
||||
manifests, err := GetManifestPath()
|
||||
ms, err := Manifests()
|
||||
if err != nil {
|
||||
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
|
||||
return
|
||||
}
|
||||
|
||||
models := []api.ModelResponse{}
|
||||
if err := filepath.Walk(manifests, func(path string, info os.FileInfo, _ error) error {
|
||||
if !info.IsDir() {
|
||||
rel, err := filepath.Rel(manifests, path)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if hidden, err := filepath.Match(".*", filepath.Base(rel)); err != nil {
|
||||
return err
|
||||
} else if hidden {
|
||||
return nil
|
||||
}
|
||||
|
||||
n := model.ParseNameFromFilepath(rel)
|
||||
if !n.IsValid() {
|
||||
slog.Warn("bad manifest filepath", "path", rel)
|
||||
return nil
|
||||
}
|
||||
|
||||
m, err := ParseNamedManifest(n)
|
||||
if err != nil {
|
||||
slog.Warn("bad manifest", "name", n, "error", err)
|
||||
return nil
|
||||
}
|
||||
|
||||
for n, m := range ms {
|
||||
f, err := m.Config.Open()
|
||||
if err != nil {
|
||||
slog.Warn("bad manifest config filepath", "name", n, "error", err)
|
||||
return nil
|
||||
slog.Warn("bad manifest filepath", "name", n, "error", err)
|
||||
continue
|
||||
}
|
||||
defer f.Close()
|
||||
|
||||
var c ConfigV2
|
||||
if err := json.NewDecoder(f).Decode(&c); err != nil {
|
||||
var cf ConfigV2
|
||||
if err := json.NewDecoder(f).Decode(&cf); err != nil {
|
||||
slog.Warn("bad manifest config", "name", n, "error", err)
|
||||
return nil
|
||||
continue
|
||||
}
|
||||
|
||||
// tag should never be masked
|
||||
|
@ -753,23 +729,17 @@ func (s *Server) ListModelsHandler(c *gin.Context) {
|
|||
Name: n.DisplayShortest(),
|
||||
Size: m.Size(),
|
||||
Digest: m.digest,
|
||||
ModifiedAt: info.ModTime(),
|
||||
ModifiedAt: m.fi.ModTime(),
|
||||
Details: api.ModelDetails{
|
||||
Format: c.ModelFormat,
|
||||
Family: c.ModelFamily,
|
||||
Families: c.ModelFamilies,
|
||||
ParameterSize: c.ModelType,
|
||||
QuantizationLevel: c.FileType,
|
||||
Format: cf.ModelFormat,
|
||||
Family: cf.ModelFamily,
|
||||
Families: cf.ModelFamilies,
|
||||
ParameterSize: cf.ModelType,
|
||||
QuantizationLevel: cf.FileType,
|
||||
},
|
||||
})
|
||||
}
|
||||
|
||||
return nil
|
||||
}); err != nil {
|
||||
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
|
||||
return
|
||||
}
|
||||
|
||||
slices.SortStableFunc(models, func(i, j api.ModelResponse) int {
|
||||
// most recently modified first
|
||||
return cmp.Compare(j.ModifiedAt.Unix(), i.ModifiedAt.Unix())
|
||||
|
|
Loading…
Reference in a new issue