Merge pull request #3718 from ollama/mxyng/modelname-3
update delete handler to use model.Name
This commit is contained in:
commit
bca7b12284
8 changed files with 587 additions and 122 deletions
|
@ -771,37 +771,6 @@ func PruneDirectory(path string) error {
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func DeleteModel(name string) error {
|
|
||||||
mp := ParseModelPath(name)
|
|
||||||
manifest, _, err := GetManifest(mp)
|
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
|
|
||||||
deleteMap := make(map[string]struct{})
|
|
||||||
for _, layer := range manifest.Layers {
|
|
||||||
deleteMap[layer.Digest] = struct{}{}
|
|
||||||
}
|
|
||||||
deleteMap[manifest.Config.Digest] = struct{}{}
|
|
||||||
|
|
||||||
err = deleteUnusedLayers(&mp, deleteMap)
|
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
|
|
||||||
fp, err := mp.GetManifestPath()
|
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
err = os.Remove(fp)
|
|
||||||
if err != nil {
|
|
||||||
slog.Info(fmt.Sprintf("couldn't remove manifest file '%s': %v", fp, err))
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func PushModel(ctx context.Context, name string, regOpts *registryOptions, fn func(api.ProgressResponse)) error {
|
func PushModel(ctx context.Context, name string, regOpts *registryOptions, fn func(api.ProgressResponse)) error {
|
||||||
mp := ParseModelPath(name)
|
mp := ParseModelPath(name)
|
||||||
fn(api.ProgressResponse{Status: "retrieving manifest"})
|
fn(api.ProgressResponse{Status: "retrieving manifest"})
|
||||||
|
|
|
@ -88,3 +88,26 @@ func (l *Layer) Open() (io.ReadSeekCloser, error) {
|
||||||
|
|
||||||
return os.Open(blob)
|
return os.Open(blob)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (l *Layer) Remove() error {
|
||||||
|
ms, err := Manifests()
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, m := range ms {
|
||||||
|
for _, layer := range append(m.Layers, m.Config) {
|
||||||
|
if layer.Digest == l.Digest {
|
||||||
|
// something is using this layer
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
blob, err := GetBlobsPath(l.Digest)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
return os.Remove(blob)
|
||||||
|
}
|
||||||
|
|
|
@ -6,6 +6,7 @@ import (
|
||||||
"encoding/json"
|
"encoding/json"
|
||||||
"fmt"
|
"fmt"
|
||||||
"io"
|
"io"
|
||||||
|
"log/slog"
|
||||||
"os"
|
"os"
|
||||||
"path/filepath"
|
"path/filepath"
|
||||||
|
|
||||||
|
@ -14,7 +15,10 @@ import (
|
||||||
|
|
||||||
type Manifest struct {
|
type Manifest struct {
|
||||||
ManifestV2
|
ManifestV2
|
||||||
Digest string `json:"-"`
|
|
||||||
|
filepath string
|
||||||
|
fi os.FileInfo
|
||||||
|
digest string
|
||||||
}
|
}
|
||||||
|
|
||||||
func (m *Manifest) Size() (size int64) {
|
func (m *Manifest) Size() (size int64) {
|
||||||
|
@ -25,9 +29,28 @@ func (m *Manifest) Size() (size int64) {
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
func ParseNamedManifest(name model.Name) (*Manifest, error) {
|
func (m *Manifest) Remove() error {
|
||||||
if !name.IsFullyQualified() {
|
if err := os.Remove(m.filepath); err != nil {
|
||||||
return nil, model.Unqualified(name)
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, layer := range append(m.Layers, m.Config) {
|
||||||
|
if err := layer.Remove(); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
manifests, err := GetManifestPath()
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
return PruneDirectory(manifests)
|
||||||
|
}
|
||||||
|
|
||||||
|
func ParseNamedManifest(n model.Name) (*Manifest, error) {
|
||||||
|
if !n.IsFullyQualified() {
|
||||||
|
return nil, model.Unqualified(n)
|
||||||
}
|
}
|
||||||
|
|
||||||
manifests, err := GetManifestPath()
|
manifests, err := GetManifestPath()
|
||||||
|
@ -35,20 +58,30 @@ func ParseNamedManifest(name model.Name) (*Manifest, error) {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
var manifest ManifestV2
|
p := filepath.Join(manifests, n.Filepath())
|
||||||
manifestfile, err := os.Open(filepath.Join(manifests, name.Filepath()))
|
|
||||||
|
var m ManifestV2
|
||||||
|
f, err := os.Open(p)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
defer f.Close()
|
||||||
|
|
||||||
|
fi, err := f.Stat()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
sha256sum := sha256.New()
|
sha256sum := sha256.New()
|
||||||
if err := json.NewDecoder(io.TeeReader(manifestfile, sha256sum)).Decode(&manifest); err != nil {
|
if err := json.NewDecoder(io.TeeReader(f, sha256sum)).Decode(&m); err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
return &Manifest{
|
return &Manifest{
|
||||||
ManifestV2: manifest,
|
ManifestV2: m,
|
||||||
Digest: fmt.Sprintf("%x", sha256sum.Sum(nil)),
|
filepath: p,
|
||||||
|
fi: fi,
|
||||||
|
digest: fmt.Sprintf("%x", sha256sum.Sum(nil)),
|
||||||
}, nil
|
}, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -77,3 +110,48 @@ func WriteManifest(name string, config *Layer, layers []*Layer) error {
|
||||||
|
|
||||||
return os.WriteFile(manifestPath, b.Bytes(), 0o644)
|
return os.WriteFile(manifestPath, b.Bytes(), 0o644)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func Manifests() (map[model.Name]*Manifest, error) {
|
||||||
|
manifests, err := GetManifestPath()
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
// TODO(mxyng): use something less brittle
|
||||||
|
matches, err := filepath.Glob(filepath.Join(manifests, "*", "*", "*", "*"))
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
ms := make(map[model.Name]*Manifest)
|
||||||
|
for _, match := range matches {
|
||||||
|
fi, err := os.Stat(match)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
if !fi.IsDir() {
|
||||||
|
rel, err := filepath.Rel(manifests, match)
|
||||||
|
if err != nil {
|
||||||
|
slog.Warn("bad filepath", "path", match, "error", err)
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
n := model.ParseNameFromFilepath(rel)
|
||||||
|
if !n.IsValid() {
|
||||||
|
slog.Warn("bad manifest name", "path", rel, "error", err)
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
m, err := ParseNamedManifest(n)
|
||||||
|
if err != nil {
|
||||||
|
slog.Warn("bad manifest", "name", n, "error", err)
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
ms[n] = m
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return ms, nil
|
||||||
|
}
|
||||||
|
|
150
server/manifest_test.go
Normal file
150
server/manifest_test.go
Normal file
|
@ -0,0 +1,150 @@
|
||||||
|
package server
|
||||||
|
|
||||||
|
import (
|
||||||
|
"encoding/json"
|
||||||
|
"os"
|
||||||
|
"path/filepath"
|
||||||
|
"slices"
|
||||||
|
"testing"
|
||||||
|
|
||||||
|
"github.com/ollama/ollama/types/model"
|
||||||
|
)
|
||||||
|
|
||||||
|
func createManifest(t *testing.T, path, name string) {
|
||||||
|
t.Helper()
|
||||||
|
|
||||||
|
p := filepath.Join(path, "manifests", name)
|
||||||
|
if err := os.MkdirAll(filepath.Dir(p), 0755); err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
|
||||||
|
f, err := os.Create(p)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
defer f.Close()
|
||||||
|
|
||||||
|
if err := json.NewEncoder(f).Encode(ManifestV2{}); err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestManifests(t *testing.T) {
|
||||||
|
cases := map[string]struct {
|
||||||
|
ps []string
|
||||||
|
wantValidCount int
|
||||||
|
wantInvalidCount int
|
||||||
|
}{
|
||||||
|
"empty": {},
|
||||||
|
"single": {
|
||||||
|
ps: []string{
|
||||||
|
filepath.Join("host", "namespace", "model", "tag"),
|
||||||
|
},
|
||||||
|
wantValidCount: 1,
|
||||||
|
},
|
||||||
|
"multiple": {
|
||||||
|
ps: []string{
|
||||||
|
filepath.Join("registry.ollama.ai", "library", "llama3", "latest"),
|
||||||
|
filepath.Join("registry.ollama.ai", "library", "llama3", "q4_0"),
|
||||||
|
filepath.Join("registry.ollama.ai", "library", "llama3", "q4_1"),
|
||||||
|
filepath.Join("registry.ollama.ai", "library", "llama3", "q8_0"),
|
||||||
|
filepath.Join("registry.ollama.ai", "library", "llama3", "q5_0"),
|
||||||
|
filepath.Join("registry.ollama.ai", "library", "llama3", "q5_1"),
|
||||||
|
filepath.Join("registry.ollama.ai", "library", "llama3", "q2_K"),
|
||||||
|
filepath.Join("registry.ollama.ai", "library", "llama3", "q3_K_S"),
|
||||||
|
filepath.Join("registry.ollama.ai", "library", "llama3", "q3_K_M"),
|
||||||
|
filepath.Join("registry.ollama.ai", "library", "llama3", "q3_K_L"),
|
||||||
|
filepath.Join("registry.ollama.ai", "library", "llama3", "q4_K_S"),
|
||||||
|
filepath.Join("registry.ollama.ai", "library", "llama3", "q4_K_M"),
|
||||||
|
filepath.Join("registry.ollama.ai", "library", "llama3", "q5_K_S"),
|
||||||
|
filepath.Join("registry.ollama.ai", "library", "llama3", "q5_K_M"),
|
||||||
|
filepath.Join("registry.ollama.ai", "library", "llama3", "q6_K"),
|
||||||
|
},
|
||||||
|
wantValidCount: 15,
|
||||||
|
},
|
||||||
|
"hidden": {
|
||||||
|
ps: []string{
|
||||||
|
filepath.Join("host", "namespace", "model", "tag"),
|
||||||
|
filepath.Join("host", "namespace", "model", ".hidden"),
|
||||||
|
},
|
||||||
|
wantValidCount: 1,
|
||||||
|
wantInvalidCount: 1,
|
||||||
|
},
|
||||||
|
"subdir": {
|
||||||
|
ps: []string{
|
||||||
|
filepath.Join("host", "namespace", "model", "tag", "one"),
|
||||||
|
filepath.Join("host", "namespace", "model", "tag", "another", "one"),
|
||||||
|
},
|
||||||
|
wantInvalidCount: 2,
|
||||||
|
},
|
||||||
|
"upper tag": {
|
||||||
|
ps: []string{
|
||||||
|
filepath.Join("host", "namespace", "model", "TAG"),
|
||||||
|
},
|
||||||
|
wantValidCount: 1,
|
||||||
|
},
|
||||||
|
"upper model": {
|
||||||
|
ps: []string{
|
||||||
|
filepath.Join("host", "namespace", "MODEL", "tag"),
|
||||||
|
},
|
||||||
|
wantValidCount: 1,
|
||||||
|
},
|
||||||
|
"upper namespace": {
|
||||||
|
ps: []string{
|
||||||
|
filepath.Join("host", "NAMESPACE", "model", "tag"),
|
||||||
|
},
|
||||||
|
wantValidCount: 1,
|
||||||
|
},
|
||||||
|
"upper host": {
|
||||||
|
ps: []string{
|
||||||
|
filepath.Join("HOST", "namespace", "model", "tag"),
|
||||||
|
},
|
||||||
|
wantValidCount: 1,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
for n, wants := range cases {
|
||||||
|
t.Run(n, func(t *testing.T) {
|
||||||
|
d := t.TempDir()
|
||||||
|
t.Setenv("OLLAMA_MODELS", d)
|
||||||
|
|
||||||
|
for _, p := range wants.ps {
|
||||||
|
createManifest(t, d, p)
|
||||||
|
}
|
||||||
|
|
||||||
|
ms, err := Manifests()
|
||||||
|
if err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
|
||||||
|
var ns []model.Name
|
||||||
|
for k := range ms {
|
||||||
|
ns = append(ns, k)
|
||||||
|
}
|
||||||
|
|
||||||
|
var gotValidCount, gotInvalidCount int
|
||||||
|
for _, p := range wants.ps {
|
||||||
|
n := model.ParseNameFromFilepath(p)
|
||||||
|
if n.IsValid() {
|
||||||
|
gotValidCount++
|
||||||
|
} else {
|
||||||
|
gotInvalidCount++
|
||||||
|
}
|
||||||
|
|
||||||
|
if !n.IsValid() && slices.Contains(ns, n) {
|
||||||
|
t.Errorf("unexpected invalid name: %s", p)
|
||||||
|
} else if n.IsValid() && !slices.Contains(ns, n) {
|
||||||
|
t.Errorf("missing valid name: %s", p)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if gotValidCount != wants.wantValidCount {
|
||||||
|
t.Errorf("got valid count %d, want %d", gotValidCount, wants.wantValidCount)
|
||||||
|
}
|
||||||
|
|
||||||
|
if gotInvalidCount != wants.wantInvalidCount {
|
||||||
|
t.Errorf("got invalid count %d, want %d", gotInvalidCount, wants.wantInvalidCount)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
|
@ -575,48 +575,31 @@ func (s *Server) CreateModelHandler(c *gin.Context) {
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *Server) DeleteModelHandler(c *gin.Context) {
|
func (s *Server) DeleteModelHandler(c *gin.Context) {
|
||||||
var req api.DeleteRequest
|
var r api.DeleteRequest
|
||||||
err := c.ShouldBindJSON(&req)
|
if err := c.ShouldBindJSON(&r); errors.Is(err, io.EOF) {
|
||||||
switch {
|
|
||||||
case errors.Is(err, io.EOF):
|
|
||||||
c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": "missing request body"})
|
c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": "missing request body"})
|
||||||
return
|
return
|
||||||
case err != nil:
|
} else if err != nil {
|
||||||
c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": err.Error()})
|
c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": err.Error()})
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
var model string
|
n := model.ParseName(cmp.Or(r.Model, r.Name))
|
||||||
if req.Model != "" {
|
if !n.IsValid() {
|
||||||
model = req.Model
|
c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": fmt.Sprintf("name %q is invalid", cmp.Or(r.Model, r.Name))})
|
||||||
} else if req.Name != "" {
|
|
||||||
model = req.Name
|
|
||||||
} else {
|
|
||||||
c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": "model is required"})
|
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
if err := DeleteModel(model); err != nil {
|
m, err := ParseNamedManifest(n)
|
||||||
if os.IsNotExist(err) {
|
|
||||||
c.JSON(http.StatusNotFound, gin.H{"error": fmt.Sprintf("model '%s' not found", model)})
|
|
||||||
} else {
|
|
||||||
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
|
|
||||||
}
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
manifestsPath, err := GetManifestPath()
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
|
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
if err := PruneDirectory(manifestsPath); err != nil {
|
if err := m.Remove(); err != nil {
|
||||||
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
|
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
c.JSON(http.StatusOK, nil)
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *Server) ShowModelHandler(c *gin.Context) {
|
func (s *Server) ShowModelHandler(c *gin.Context) {
|
||||||
|
@ -720,49 +703,25 @@ func GetModelInfo(req api.ShowRequest) (*api.ShowResponse, error) {
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *Server) ListModelsHandler(c *gin.Context) {
|
func (s *Server) ListModelsHandler(c *gin.Context) {
|
||||||
manifests, err := GetManifestPath()
|
ms, err := Manifests()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
|
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
models := []api.ModelResponse{}
|
models := []api.ModelResponse{}
|
||||||
if err := filepath.Walk(manifests, func(path string, info os.FileInfo, _ error) error {
|
for n, m := range ms {
|
||||||
if !info.IsDir() {
|
|
||||||
rel, err := filepath.Rel(manifests, path)
|
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
|
|
||||||
if hidden, err := filepath.Match(".*", filepath.Base(rel)); err != nil {
|
|
||||||
return err
|
|
||||||
} else if hidden {
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
n := model.ParseNameFromFilepath(rel)
|
|
||||||
if !n.IsValid() {
|
|
||||||
slog.Warn("bad manifest filepath", "path", rel)
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
m, err := ParseNamedManifest(n)
|
|
||||||
if err != nil {
|
|
||||||
slog.Warn("bad manifest", "name", n, "error", err)
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
f, err := m.Config.Open()
|
f, err := m.Config.Open()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
slog.Warn("bad manifest config filepath", "name", n, "error", err)
|
slog.Warn("bad manifest filepath", "name", n, "error", err)
|
||||||
return nil
|
continue
|
||||||
}
|
}
|
||||||
defer f.Close()
|
defer f.Close()
|
||||||
|
|
||||||
var c ConfigV2
|
var cf ConfigV2
|
||||||
if err := json.NewDecoder(f).Decode(&c); err != nil {
|
if err := json.NewDecoder(f).Decode(&cf); err != nil {
|
||||||
slog.Warn("bad manifest config", "name", n, "error", err)
|
slog.Warn("bad manifest config", "name", n, "error", err)
|
||||||
return nil
|
continue
|
||||||
}
|
}
|
||||||
|
|
||||||
// tag should never be masked
|
// tag should never be masked
|
||||||
|
@ -770,24 +729,18 @@ func (s *Server) ListModelsHandler(c *gin.Context) {
|
||||||
Model: n.DisplayShortest(),
|
Model: n.DisplayShortest(),
|
||||||
Name: n.DisplayShortest(),
|
Name: n.DisplayShortest(),
|
||||||
Size: m.Size(),
|
Size: m.Size(),
|
||||||
Digest: m.Digest,
|
Digest: m.digest,
|
||||||
ModifiedAt: info.ModTime(),
|
ModifiedAt: m.fi.ModTime(),
|
||||||
Details: api.ModelDetails{
|
Details: api.ModelDetails{
|
||||||
Format: c.ModelFormat,
|
Format: cf.ModelFormat,
|
||||||
Family: c.ModelFamily,
|
Family: cf.ModelFamily,
|
||||||
Families: c.ModelFamilies,
|
Families: cf.ModelFamilies,
|
||||||
ParameterSize: c.ModelType,
|
ParameterSize: cf.ModelType,
|
||||||
QuantizationLevel: c.FileType,
|
QuantizationLevel: cf.FileType,
|
||||||
},
|
},
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
return nil
|
|
||||||
}); err != nil {
|
|
||||||
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
slices.SortStableFunc(models, func(i, j api.ModelResponse) int {
|
slices.SortStableFunc(models, func(i, j api.ModelResponse) int {
|
||||||
// most recently modified first
|
// most recently modified first
|
||||||
return cmp.Compare(j.ModifiedAt.Unix(), i.ModifiedAt.Unix())
|
return cmp.Compare(j.ModifiedAt.Unix(), i.ModifiedAt.Unix())
|
||||||
|
|
160
server/routes_create_test.go
Normal file
160
server/routes_create_test.go
Normal file
|
@ -0,0 +1,160 @@
|
||||||
|
package server
|
||||||
|
|
||||||
|
import (
|
||||||
|
"bytes"
|
||||||
|
"encoding/binary"
|
||||||
|
"encoding/json"
|
||||||
|
"fmt"
|
||||||
|
"io"
|
||||||
|
"net/http"
|
||||||
|
"net/http/httptest"
|
||||||
|
"os"
|
||||||
|
"path/filepath"
|
||||||
|
"slices"
|
||||||
|
"testing"
|
||||||
|
|
||||||
|
"github.com/gin-gonic/gin"
|
||||||
|
"github.com/ollama/ollama/api"
|
||||||
|
)
|
||||||
|
|
||||||
|
var stream bool = false
|
||||||
|
|
||||||
|
func createBinFile(t *testing.T) string {
|
||||||
|
t.Helper()
|
||||||
|
|
||||||
|
f, err := os.CreateTemp(t.TempDir(), "")
|
||||||
|
if err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
defer f.Close()
|
||||||
|
|
||||||
|
if err := binary.Write(f, binary.LittleEndian, []byte("GGUF")); err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := binary.Write(f, binary.LittleEndian, uint32(3)); err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := binary.Write(f, binary.LittleEndian, uint64(0)); err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := binary.Write(f, binary.LittleEndian, uint64(0)); err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
|
||||||
|
return f.Name()
|
||||||
|
}
|
||||||
|
|
||||||
|
type responseRecorder struct {
|
||||||
|
*httptest.ResponseRecorder
|
||||||
|
http.CloseNotifier
|
||||||
|
}
|
||||||
|
|
||||||
|
func NewRecorder() *responseRecorder {
|
||||||
|
return &responseRecorder{
|
||||||
|
ResponseRecorder: httptest.NewRecorder(),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (t *responseRecorder) CloseNotify() <-chan bool {
|
||||||
|
return make(chan bool)
|
||||||
|
}
|
||||||
|
|
||||||
|
func createRequest(t *testing.T, fn func(*gin.Context), body any) *httptest.ResponseRecorder {
|
||||||
|
t.Helper()
|
||||||
|
|
||||||
|
w := NewRecorder()
|
||||||
|
c, _ := gin.CreateTestContext(w)
|
||||||
|
|
||||||
|
var b bytes.Buffer
|
||||||
|
if err := json.NewEncoder(&b).Encode(body); err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
|
||||||
|
c.Request = &http.Request{
|
||||||
|
Body: io.NopCloser(&b),
|
||||||
|
}
|
||||||
|
|
||||||
|
fn(c)
|
||||||
|
return w.ResponseRecorder
|
||||||
|
}
|
||||||
|
|
||||||
|
func checkFileExists(t *testing.T, p string, expect []string) {
|
||||||
|
t.Helper()
|
||||||
|
|
||||||
|
actual, err := filepath.Glob(p)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if !slices.Equal(actual, expect) {
|
||||||
|
t.Fatalf("expected slices to be equal %v", actual)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestCreateFromBin(t *testing.T) {
|
||||||
|
p := t.TempDir()
|
||||||
|
t.Setenv("OLLAMA_MODELS", p)
|
||||||
|
|
||||||
|
var s Server
|
||||||
|
w := createRequest(t, s.CreateModelHandler, api.CreateRequest{
|
||||||
|
Name: "test",
|
||||||
|
Modelfile: fmt.Sprintf("FROM %s", createBinFile(t)),
|
||||||
|
Stream: &stream,
|
||||||
|
})
|
||||||
|
|
||||||
|
if w.Code != http.StatusOK {
|
||||||
|
t.Fatalf("expected status code 200, actual %d", w.Code)
|
||||||
|
}
|
||||||
|
|
||||||
|
checkFileExists(t, filepath.Join(p, "manifests", "*", "*", "*", "*"), []string{
|
||||||
|
filepath.Join(p, "manifests", "registry.ollama.ai", "library", "test", "latest"),
|
||||||
|
})
|
||||||
|
|
||||||
|
checkFileExists(t, filepath.Join(p, "blobs", "*"), []string{
|
||||||
|
filepath.Join(p, "blobs", "sha256-a4e5e156ddec27e286f75328784d7106b60a4eb1d246e950a001a3f944fbda99"),
|
||||||
|
filepath.Join(p, "blobs", "sha256-ca239d7bd8ea90e4a5d2e6bf88f8d74a47b14336e73eb4e18bed4dd325018116"),
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestCreateFromModel(t *testing.T) {
|
||||||
|
p := t.TempDir()
|
||||||
|
t.Setenv("OLLAMA_MODELS", p)
|
||||||
|
var s Server
|
||||||
|
|
||||||
|
w := createRequest(t, s.CreateModelHandler, api.CreateRequest{
|
||||||
|
Name: "test",
|
||||||
|
Modelfile: fmt.Sprintf("FROM %s", createBinFile(t)),
|
||||||
|
Stream: &stream,
|
||||||
|
})
|
||||||
|
|
||||||
|
if w.Code != http.StatusOK {
|
||||||
|
t.Fatalf("expected status code 200, actual %d", w.Code)
|
||||||
|
}
|
||||||
|
|
||||||
|
checkFileExists(t, filepath.Join(p, "manifests", "*", "*", "*", "*"), []string{
|
||||||
|
filepath.Join(p, "manifests", "registry.ollama.ai", "library", "test", "latest"),
|
||||||
|
})
|
||||||
|
|
||||||
|
w = createRequest(t, s.CreateModelHandler, api.CreateRequest{
|
||||||
|
Name: "test2",
|
||||||
|
Modelfile: "FROM test",
|
||||||
|
Stream: &stream,
|
||||||
|
})
|
||||||
|
|
||||||
|
if w.Code != http.StatusOK {
|
||||||
|
t.Fatalf("expected status code 200, actual %d", w.Code)
|
||||||
|
}
|
||||||
|
|
||||||
|
checkFileExists(t, filepath.Join(p, "manifests", "*", "*", "*", "*"), []string{
|
||||||
|
filepath.Join(p, "manifests", "registry.ollama.ai", "library", "test", "latest"),
|
||||||
|
filepath.Join(p, "manifests", "registry.ollama.ai", "library", "test2", "latest"),
|
||||||
|
})
|
||||||
|
|
||||||
|
checkFileExists(t, filepath.Join(p, "blobs", "*"), []string{
|
||||||
|
filepath.Join(p, "blobs", "sha256-a4e5e156ddec27e286f75328784d7106b60a4eb1d246e950a001a3f944fbda99"),
|
||||||
|
filepath.Join(p, "blobs", "sha256-ca239d7bd8ea90e4a5d2e6bf88f8d74a47b14336e73eb4e18bed4dd325018116"),
|
||||||
|
})
|
||||||
|
}
|
71
server/routes_delete_test.go
Normal file
71
server/routes_delete_test.go
Normal file
|
@ -0,0 +1,71 @@
|
||||||
|
package server
|
||||||
|
|
||||||
|
import (
|
||||||
|
"fmt"
|
||||||
|
"net/http"
|
||||||
|
"path/filepath"
|
||||||
|
"testing"
|
||||||
|
|
||||||
|
"github.com/ollama/ollama/api"
|
||||||
|
)
|
||||||
|
|
||||||
|
func TestDelete(t *testing.T) {
|
||||||
|
p := t.TempDir()
|
||||||
|
t.Setenv("OLLAMA_MODELS", p)
|
||||||
|
var s Server
|
||||||
|
|
||||||
|
w := createRequest(t, s.CreateModelHandler, api.CreateRequest{
|
||||||
|
Name: "test",
|
||||||
|
Modelfile: fmt.Sprintf("FROM %s", createBinFile(t)),
|
||||||
|
})
|
||||||
|
|
||||||
|
if w.Code != http.StatusOK {
|
||||||
|
t.Fatalf("expected status code 200, actual %d", w.Code)
|
||||||
|
}
|
||||||
|
|
||||||
|
w = createRequest(t, s.CreateModelHandler, api.CreateRequest{
|
||||||
|
Name: "test2",
|
||||||
|
Modelfile: fmt.Sprintf("FROM %s\nTEMPLATE {{ .System }} {{ .Prompt }}", createBinFile(t)),
|
||||||
|
})
|
||||||
|
|
||||||
|
if w.Code != http.StatusOK {
|
||||||
|
t.Fatalf("expected status code 200, actual %d", w.Code)
|
||||||
|
}
|
||||||
|
|
||||||
|
checkFileExists(t, filepath.Join(p, "manifests", "*", "*", "*", "*"), []string{
|
||||||
|
filepath.Join(p, "manifests", "registry.ollama.ai", "library", "test", "latest"),
|
||||||
|
filepath.Join(p, "manifests", "registry.ollama.ai", "library", "test2", "latest"),
|
||||||
|
})
|
||||||
|
|
||||||
|
checkFileExists(t, filepath.Join(p, "blobs", "*"), []string{
|
||||||
|
filepath.Join(p, "blobs", "sha256-8f2c2167d789c6b2302dff965160fa5029f6a24096d262c1cbb469f21a045382"),
|
||||||
|
filepath.Join(p, "blobs", "sha256-a4e5e156ddec27e286f75328784d7106b60a4eb1d246e950a001a3f944fbda99"),
|
||||||
|
filepath.Join(p, "blobs", "sha256-ca239d7bd8ea90e4a5d2e6bf88f8d74a47b14336e73eb4e18bed4dd325018116"),
|
||||||
|
filepath.Join(p, "blobs", "sha256-fe7ac77b725cda2ccad03f88a880ecdfd7a33192d6cae08fce2c0ee1455991ed"),
|
||||||
|
})
|
||||||
|
|
||||||
|
w = createRequest(t, s.DeleteModelHandler, api.DeleteRequest{Name: "test"})
|
||||||
|
|
||||||
|
if w.Code != http.StatusOK {
|
||||||
|
t.Fatalf("expected status code 200, actual %d", w.Code)
|
||||||
|
}
|
||||||
|
|
||||||
|
checkFileExists(t, filepath.Join(p, "manifests", "*", "*", "*", "*"), []string{
|
||||||
|
filepath.Join(p, "manifests", "registry.ollama.ai", "library", "test2", "latest"),
|
||||||
|
})
|
||||||
|
|
||||||
|
checkFileExists(t, filepath.Join(p, "blobs", "*"), []string{
|
||||||
|
filepath.Join(p, "blobs", "sha256-8f2c2167d789c6b2302dff965160fa5029f6a24096d262c1cbb469f21a045382"),
|
||||||
|
filepath.Join(p, "blobs", "sha256-a4e5e156ddec27e286f75328784d7106b60a4eb1d246e950a001a3f944fbda99"),
|
||||||
|
filepath.Join(p, "blobs", "sha256-fe7ac77b725cda2ccad03f88a880ecdfd7a33192d6cae08fce2c0ee1455991ed"),
|
||||||
|
})
|
||||||
|
|
||||||
|
w = createRequest(t, s.DeleteModelHandler, api.DeleteRequest{Name: "test2"})
|
||||||
|
|
||||||
|
if w.Code != http.StatusOK {
|
||||||
|
t.Fatalf("expected status code 200, actual %d", w.Code)
|
||||||
|
}
|
||||||
|
|
||||||
|
checkFileExists(t, filepath.Join(p, "manifests", "*", "*", "*", "*"), []string{})
|
||||||
|
checkFileExists(t, filepath.Join(p, "blobs", "*"), []string{})
|
||||||
|
}
|
61
server/routes_list_test.go
Normal file
61
server/routes_list_test.go
Normal file
|
@ -0,0 +1,61 @@
|
||||||
|
package server
|
||||||
|
|
||||||
|
import (
|
||||||
|
"encoding/json"
|
||||||
|
"fmt"
|
||||||
|
"net/http"
|
||||||
|
"slices"
|
||||||
|
"testing"
|
||||||
|
|
||||||
|
"github.com/ollama/ollama/api"
|
||||||
|
)
|
||||||
|
|
||||||
|
func TestList(t *testing.T) {
|
||||||
|
t.Setenv("OLLAMA_MODELS", t.TempDir())
|
||||||
|
|
||||||
|
expectNames := []string{
|
||||||
|
"mistral:7b-instruct-q4_0",
|
||||||
|
"zephyr:7b-beta-q5_K_M",
|
||||||
|
"apple/OpenELM:latest",
|
||||||
|
"boreas:2b-code-v1.5-q6_K",
|
||||||
|
"notus:7b-v1-IQ2_S",
|
||||||
|
// TODO: host:port currently fails on windows (#4107)
|
||||||
|
// "localhost:5000/library/eurus:700b-v0.5-iq3_XXS",
|
||||||
|
"mynamespace/apeliotes:latest",
|
||||||
|
"myhost/mynamespace/lips:code",
|
||||||
|
}
|
||||||
|
|
||||||
|
var s Server
|
||||||
|
for _, n := range expectNames {
|
||||||
|
createRequest(t, s.CreateModelHandler, api.CreateRequest{
|
||||||
|
Name: n,
|
||||||
|
Modelfile: fmt.Sprintf("FROM %s", createBinFile(t)),
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
w := createRequest(t, s.ListModelsHandler, nil)
|
||||||
|
if w.Code != http.StatusOK {
|
||||||
|
t.Fatalf("expected status code 200, actual %d", w.Code)
|
||||||
|
}
|
||||||
|
|
||||||
|
var resp api.ListResponse
|
||||||
|
if err := json.NewDecoder(w.Body).Decode(&resp); err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if len(resp.Models) != len(expectNames) {
|
||||||
|
t.Fatalf("expected %d models, actual %d", len(expectNames), len(resp.Models))
|
||||||
|
}
|
||||||
|
|
||||||
|
actualNames := make([]string, len(resp.Models))
|
||||||
|
for i, m := range resp.Models {
|
||||||
|
actualNames[i] = m.Name
|
||||||
|
}
|
||||||
|
|
||||||
|
slices.Sort(actualNames)
|
||||||
|
slices.Sort(expectNames)
|
||||||
|
|
||||||
|
if !slices.Equal(actualNames, expectNames) {
|
||||||
|
t.Fatalf("expected slices to be equal %v", actualNames)
|
||||||
|
}
|
||||||
|
}
|
Loading…
Reference in a new issue