diff --git a/server/images.go b/server/images.go index 4d4b47c4..5da47b79 100644 --- a/server/images.go +++ b/server/images.go @@ -30,7 +30,6 @@ import ( "github.com/ollama/ollama/server/envconfig" "github.com/ollama/ollama/types/errtypes" "github.com/ollama/ollama/types/model" - "github.com/ollama/ollama/types/ordered" "github.com/ollama/ollama/version" ) @@ -344,7 +343,7 @@ func CreateModel(ctx context.Context, name, modelFileDir, quantization string, m switch c.Name { case "model", "adapter": - var baseLayers *ordered.Map[*Layer, *llm.GGML] + var baseLayers []*layerWithGGML if name := model.ParseName(c.Args); name.IsValid() { baseLayers, err = parseFromModel(ctx, name, fn) if err != nil { @@ -377,70 +376,51 @@ func CreateModel(ctx context.Context, name, modelFileDir, quantization string, m return fmt.Errorf("invalid model reference: %s", c.Args) } - var err2 error - var tempfiles []*os.File - - // TODO(mxyng): replace with rangefunc - baseLayers.Items()(func(layer *Layer, ggml *llm.GGML) bool { - if quantization != "" && ggml != nil && ggml.Name() == "gguf" { + for _, baseLayer := range baseLayers { + if quantization != "" && baseLayer.GGML != nil && baseLayer.GGML.Name() == "gguf" { ftype, err := llm.ParseFileType(quantization) if err != nil { - err2 = err - return false + return err } - filetype := ggml.KV().FileType() + filetype := baseLayer.GGML.KV().FileType() if !slices.Contains([]string{"F16", "F32"}, filetype) { - err2 = errors.New("quantization is only supported for F16 and F32 models") - return false + return errors.New("quantization is only supported for F16 and F32 models") } fn(api.ProgressResponse{Status: fmt.Sprintf("quantizing %s model to %s", filetype, quantization)}) - blob, err := GetBlobsPath(layer.Digest) + blob, err := GetBlobsPath(baseLayer.Digest) if err != nil { - err2 = err - return false + return err } temp, err := os.CreateTemp(filepath.Dir(blob), quantization) if err != nil { - err2 = err - return false + return err } - tempfiles = append(tempfiles, temp) + defer temp.Close() + defer os.Remove(temp.Name()) if err := llm.Quantize(blob, temp.Name(), ftype); err != nil { - err2 = err - return false + return err } - layer, err = NewLayer(temp, layer.MediaType) + baseLayer.Layer, err = NewLayer(temp, baseLayer.Layer.MediaType) if err != nil { - err2 = err - return false + return err } } - if ggml != nil { - config.ModelFormat = cmp.Or(config.ModelFormat, ggml.Name()) - config.ModelFamily = cmp.Or(config.ModelFamily, ggml.KV().Architecture()) - config.ModelType = cmp.Or(config.ModelType, format.HumanNumber(ggml.KV().ParameterCount())) - config.FileType = cmp.Or(config.FileType, ggml.KV().FileType()) - config.ModelFamilies = append(config.ModelFamilies, ggml.KV().Architecture()) + if baseLayer.GGML != nil { + config.ModelFormat = cmp.Or(config.ModelFormat, baseLayer.GGML.Name()) + config.ModelFamily = cmp.Or(config.ModelFamily, baseLayer.GGML.KV().Architecture()) + config.ModelType = cmp.Or(config.ModelType, format.HumanNumber(baseLayer.GGML.KV().ParameterCount())) + config.FileType = cmp.Or(config.FileType, baseLayer.GGML.KV().FileType()) + config.ModelFamilies = append(config.ModelFamilies, baseLayer.GGML.KV().Architecture()) } - layers = append(layers, layer) - return true - }) - - for _, tempfile := range tempfiles { - defer tempfile.Close() - defer os.Remove(tempfile.Name()) - } - - if err2 != nil { - return err2 + layers = append(layers, baseLayer.Layer) } case "license", "template", "system": blob := strings.NewReader(c.Args) diff --git a/server/model.go b/server/model.go index cf036052..b27c7083 100644 --- a/server/model.go +++ b/server/model.go @@ -15,10 +15,14 @@ import ( "github.com/ollama/ollama/convert" "github.com/ollama/ollama/llm" "github.com/ollama/ollama/types/model" - "github.com/ollama/ollama/types/ordered" ) -func parseFromModel(ctx context.Context, name model.Name, fn func(api.ProgressResponse)) (*ordered.Map[*Layer, *llm.GGML], error) { +type layerWithGGML struct { + *Layer + *llm.GGML +} + +func parseFromModel(ctx context.Context, name model.Name, fn func(api.ProgressResponse)) (layers []*layerWithGGML, err error) { modelpath := ParseModelPath(name.DisplayLongest()) manifest, _, err := GetManifest(modelpath) switch { @@ -36,7 +40,6 @@ func parseFromModel(ctx context.Context, name model.Name, fn func(api.ProgressRe return nil, err } - layers := ordered.NewMap[*Layer, *llm.GGML]() for _, layer := range manifest.Layers { layer, err := NewLayerFromLayer(layer.Digest, layer.MediaType, modelpath.GetShortTagname()) if err != nil { @@ -62,9 +65,10 @@ func parseFromModel(ctx context.Context, name model.Name, fn func(api.ProgressRe if err != nil { return nil, err } - layers.Add(layer, ggml) + + layers = append(layers, &layerWithGGML{layer, ggml}) default: - layers.Add(layer, nil) + layers = append(layers, &layerWithGGML{layer, nil}) } } @@ -72,7 +76,7 @@ func parseFromModel(ctx context.Context, name model.Name, fn func(api.ProgressRe return layers, nil } -func parseFromZipFile(_ context.Context, file *os.File, fn func(api.ProgressResponse)) (*ordered.Map[*Layer, *llm.GGML], error) { +func parseFromZipFile(_ context.Context, file *os.File, fn func(api.ProgressResponse)) (layers []*layerWithGGML, err error) { stat, err := file.Stat() if err != nil { return nil, err @@ -184,12 +188,11 @@ func parseFromZipFile(_ context.Context, file *os.File, fn func(api.ProgressResp return nil, err } - layers := ordered.NewMap[*Layer, *llm.GGML]() - layers.Add(layer, ggml) + layers = append(layers, &layerWithGGML{layer, ggml}) return layers, nil } -func parseFromFile(ctx context.Context, file *os.File, fn func(api.ProgressResponse)) (*ordered.Map[*Layer, *llm.GGML], error) { +func parseFromFile(ctx context.Context, file *os.File, fn func(api.ProgressResponse)) (layers []*layerWithGGML, err error) { sr := io.NewSectionReader(file, 0, 512) contentType, err := detectContentType(sr) if err != nil { @@ -205,8 +208,6 @@ func parseFromFile(ctx context.Context, file *os.File, fn func(api.ProgressRespo return nil, fmt.Errorf("unsupported content type: %s", contentType) } - layers := ordered.NewMap[*Layer, *llm.GGML]() - stat, err := file.Stat() if err != nil { return nil, err @@ -233,7 +234,7 @@ func parseFromFile(ctx context.Context, file *os.File, fn func(api.ProgressRespo return nil, err } - layers.Add(layer, ggml) + layers = append(layers, &layerWithGGML{layer, ggml}) offset = n } diff --git a/types/ordered/map.go b/types/ordered/map.go deleted file mode 100644 index 076d657d..00000000 --- a/types/ordered/map.go +++ /dev/null @@ -1,32 +0,0 @@ -package ordered - -type Map[K comparable, V any] struct { - s []K - m map[K]V -} - -func NewMap[K comparable, V any]() *Map[K, V] { - return &Map[K, V]{ - s: make([]K, 0), - m: make(map[K]V), - } -} - -type iter_Seq2[K, V any] func(func(K, V) bool) - -func (m *Map[K, V]) Items() iter_Seq2[K, V] { - return func(yield func(K, V) bool) { - for _, k := range m.s { - if !yield(k, m.m[k]) { - return - } - } - } -} - -func (m *Map[K, V]) Add(k K, v V) { - if _, ok := m.m[k]; !ok { - m.s = append(m.s, k) - m.m[k] = v - } -}