routes: use Manifests for ListHandler
This commit is contained in:
parent
a2fc933fed
commit
c2714fcbfd
3 changed files with 127 additions and 58 deletions
|
@ -6,6 +6,7 @@ import (
|
||||||
"encoding/json"
|
"encoding/json"
|
||||||
"fmt"
|
"fmt"
|
||||||
"io"
|
"io"
|
||||||
|
"log/slog"
|
||||||
"os"
|
"os"
|
||||||
"path/filepath"
|
"path/filepath"
|
||||||
|
|
||||||
|
@ -16,6 +17,7 @@ type Manifest struct {
|
||||||
ManifestV2
|
ManifestV2
|
||||||
|
|
||||||
filepath string
|
filepath string
|
||||||
|
fi os.FileInfo
|
||||||
digest string
|
digest string
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -65,6 +67,11 @@ func ParseNamedManifest(n model.Name) (*Manifest, error) {
|
||||||
}
|
}
|
||||||
defer f.Close()
|
defer f.Close()
|
||||||
|
|
||||||
|
fi, err := f.Stat()
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
sha256sum := sha256.New()
|
sha256sum := sha256.New()
|
||||||
if err := json.NewDecoder(io.TeeReader(f, sha256sum)).Decode(&m); err != nil {
|
if err := json.NewDecoder(io.TeeReader(f, sha256sum)).Decode(&m); err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
|
@ -73,6 +80,7 @@ func ParseNamedManifest(n model.Name) (*Manifest, error) {
|
||||||
return &Manifest{
|
return &Manifest{
|
||||||
ManifestV2: m,
|
ManifestV2: m,
|
||||||
filepath: p,
|
filepath: p,
|
||||||
|
fi: fi,
|
||||||
digest: fmt.Sprintf("%x", sha256sum.Sum(nil)),
|
digest: fmt.Sprintf("%x", sha256sum.Sum(nil)),
|
||||||
}, nil
|
}, nil
|
||||||
}
|
}
|
||||||
|
@ -126,7 +134,8 @@ func Manifests() (map[model.Name]*Manifest, error) {
|
||||||
if n.IsValid() {
|
if n.IsValid() {
|
||||||
m, err := ParseNamedManifest(n)
|
m, err := ParseNamedManifest(n)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
slog.Warn("bad manifest", "name", n, "error", err)
|
||||||
|
continue
|
||||||
}
|
}
|
||||||
|
|
||||||
ms[n] = m
|
ms[n] = m
|
||||||
|
|
90
server/manifest_test.go
Normal file
90
server/manifest_test.go
Normal file
|
@ -0,0 +1,90 @@
|
||||||
|
package server
|
||||||
|
|
||||||
|
import (
|
||||||
|
"encoding/json"
|
||||||
|
"os"
|
||||||
|
"path/filepath"
|
||||||
|
"slices"
|
||||||
|
"testing"
|
||||||
|
|
||||||
|
"github.com/ollama/ollama/types/model"
|
||||||
|
)
|
||||||
|
|
||||||
|
func createManifest(t *testing.T, path, name string) {
|
||||||
|
t.Helper()
|
||||||
|
|
||||||
|
p := filepath.Join(path, "manifests", name)
|
||||||
|
if err := os.MkdirAll(filepath.Dir(p), 0755); err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
|
||||||
|
f, err := os.Create(p)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
defer f.Close()
|
||||||
|
|
||||||
|
if err := json.NewEncoder(f).Encode(ManifestV2{}); err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestManifests(t *testing.T) {
|
||||||
|
cases := map[string][]string{
|
||||||
|
"empty": {},
|
||||||
|
"single": {
|
||||||
|
filepath.Join("host", "namespace", "model", "tag"),
|
||||||
|
},
|
||||||
|
"multiple": {
|
||||||
|
filepath.Join("registry.ollama.ai", "library", "llama3", "latest"),
|
||||||
|
filepath.Join("registry.ollama.ai", "library", "llama3", "q4_0"),
|
||||||
|
filepath.Join("registry.ollama.ai", "library", "llama3", "q4_1"),
|
||||||
|
filepath.Join("registry.ollama.ai", "library", "llama3", "q8_0"),
|
||||||
|
filepath.Join("registry.ollama.ai", "library", "llama3", "q5_0"),
|
||||||
|
filepath.Join("registry.ollama.ai", "library", "llama3", "q5_1"),
|
||||||
|
filepath.Join("registry.ollama.ai", "library", "llama3", "q2_K"),
|
||||||
|
filepath.Join("registry.ollama.ai", "library", "llama3", "q3_K_S"),
|
||||||
|
filepath.Join("registry.ollama.ai", "library", "llama3", "q3_K_M"),
|
||||||
|
filepath.Join("registry.ollama.ai", "library", "llama3", "q3_K_L"),
|
||||||
|
filepath.Join("registry.ollama.ai", "library", "llama3", "q4_K_S"),
|
||||||
|
filepath.Join("registry.ollama.ai", "library", "llama3", "q4_K_M"),
|
||||||
|
filepath.Join("registry.ollama.ai", "library", "llama3", "q5_K_S"),
|
||||||
|
filepath.Join("registry.ollama.ai", "library", "llama3", "q5_K_M"),
|
||||||
|
filepath.Join("registry.ollama.ai", "library", "llama3", "q6_K"),
|
||||||
|
},
|
||||||
|
"hidden": {
|
||||||
|
filepath.Join("host", "namespace", "model", "tag"),
|
||||||
|
filepath.Join("host", "namespace", "model", ".hidden"),
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
for n, wants := range cases {
|
||||||
|
t.Run(n, func(t *testing.T) {
|
||||||
|
d := t.TempDir()
|
||||||
|
t.Setenv("OLLAMA_MODELS", d)
|
||||||
|
|
||||||
|
for _, want := range wants {
|
||||||
|
createManifest(t, d, want)
|
||||||
|
}
|
||||||
|
|
||||||
|
ms, err := Manifests()
|
||||||
|
if err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
|
||||||
|
var ns []model.Name
|
||||||
|
for k := range ms {
|
||||||
|
ns = append(ns, k)
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, want := range wants {
|
||||||
|
n := model.ParseNameFromFilepath(want)
|
||||||
|
if !n.IsValid() && slices.Contains(ns, n) {
|
||||||
|
t.Errorf("unexpected invalid name: %s", want)
|
||||||
|
} else if n.IsValid() && !slices.Contains(ns, n) {
|
||||||
|
t.Errorf("missing valid name: %s", want)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
|
@ -702,49 +702,25 @@ func GetModelInfo(req api.ShowRequest) (*api.ShowResponse, error) {
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *Server) ListModelsHandler(c *gin.Context) {
|
func (s *Server) ListModelsHandler(c *gin.Context) {
|
||||||
manifests, err := GetManifestPath()
|
ms, err := Manifests()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
|
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
models := []api.ModelResponse{}
|
models := []api.ModelResponse{}
|
||||||
if err := filepath.Walk(manifests, func(path string, info os.FileInfo, _ error) error {
|
for n, m := range ms {
|
||||||
if !info.IsDir() {
|
|
||||||
rel, err := filepath.Rel(manifests, path)
|
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
|
|
||||||
if hidden, err := filepath.Match(".*", filepath.Base(rel)); err != nil {
|
|
||||||
return err
|
|
||||||
} else if hidden {
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
n := model.ParseNameFromFilepath(rel)
|
|
||||||
if !n.IsValid() {
|
|
||||||
slog.Warn("bad manifest filepath", "path", rel)
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
m, err := ParseNamedManifest(n)
|
|
||||||
if err != nil {
|
|
||||||
slog.Warn("bad manifest", "name", n, "error", err)
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
f, err := m.Config.Open()
|
f, err := m.Config.Open()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
slog.Warn("bad manifest config filepath", "name", n, "error", err)
|
slog.Warn("bad manifest filepath", "name", n, "error", err)
|
||||||
return nil
|
continue
|
||||||
}
|
}
|
||||||
defer f.Close()
|
defer f.Close()
|
||||||
|
|
||||||
var c ConfigV2
|
var cf ConfigV2
|
||||||
if err := json.NewDecoder(f).Decode(&c); err != nil {
|
if err := json.NewDecoder(f).Decode(&cf); err != nil {
|
||||||
slog.Warn("bad manifest config", "name", n, "error", err)
|
slog.Warn("bad manifest config", "name", n, "error", err)
|
||||||
return nil
|
continue
|
||||||
}
|
}
|
||||||
|
|
||||||
// tag should never be masked
|
// tag should never be masked
|
||||||
|
@ -753,23 +729,17 @@ func (s *Server) ListModelsHandler(c *gin.Context) {
|
||||||
Name: n.DisplayShortest(),
|
Name: n.DisplayShortest(),
|
||||||
Size: m.Size(),
|
Size: m.Size(),
|
||||||
Digest: m.digest,
|
Digest: m.digest,
|
||||||
ModifiedAt: info.ModTime(),
|
ModifiedAt: m.fi.ModTime(),
|
||||||
Details: api.ModelDetails{
|
Details: api.ModelDetails{
|
||||||
Format: c.ModelFormat,
|
Format: cf.ModelFormat,
|
||||||
Family: c.ModelFamily,
|
Family: cf.ModelFamily,
|
||||||
Families: c.ModelFamilies,
|
Families: cf.ModelFamilies,
|
||||||
ParameterSize: c.ModelType,
|
ParameterSize: cf.ModelType,
|
||||||
QuantizationLevel: c.FileType,
|
QuantizationLevel: cf.FileType,
|
||||||
},
|
},
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
return nil
|
|
||||||
}); err != nil {
|
|
||||||
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
slices.SortStableFunc(models, func(i, j api.ModelResponse) int {
|
slices.SortStableFunc(models, func(i, j api.ModelResponse) int {
|
||||||
// most recently modified first
|
// most recently modified first
|
||||||
return cmp.Compare(j.ModifiedAt.Unix(), i.ModifiedAt.Unix())
|
return cmp.Compare(j.ModifiedAt.Unix(), i.ModifiedAt.Unix())
|
||||||
|
|
Loading…
Reference in a new issue