diff --git a/convert/convert.go b/convert/convert.go index 42de080c..f4210e50 100644 --- a/convert/convert.go +++ b/convert/convert.go @@ -5,6 +5,7 @@ import ( "encoding/binary" "encoding/json" "fmt" + "io" "log/slog" "os" "path/filepath" @@ -47,7 +48,7 @@ type ByteOrder interface { type ModelArch interface { GetTensors() error LoadVocab() error - WriteGGUF() (string, error) + WriteGGUF(io.WriteSeeker) error } type ModelFormat interface { diff --git a/convert/gemma.go b/convert/gemma.go index 648a4ad9..88abe646 100644 --- a/convert/gemma.go +++ b/convert/gemma.go @@ -94,7 +94,7 @@ func (m *GemmaModel) LoadVocab() error { return nil } -func (m *GemmaModel) WriteGGUF() (string, error) { +func (m *GemmaModel) WriteGGUF(ws io.WriteSeeker) error { kv := llm.KV{ "general.architecture": "gemma", "general.name": m.Name, @@ -122,16 +122,5 @@ func (m *GemmaModel) WriteGGUF() (string, error) { "tokenizer.ggml.add_eos_token": false, } - f, err := os.CreateTemp("", "ollama-gguf") - if err != nil { - return "", err - } - defer f.Close() - - mod := llm.NewGGUFV3(m.Params.ByteOrder) - if err := mod.Encode(f, kv, m.Tensors); err != nil { - return "", err - } - - return f.Name(), nil + return llm.NewGGUFV3(m.Params.ByteOrder).Encode(ws, kv, m.Tensors) } diff --git a/convert/llama.go b/convert/llama.go index c7f7b290..fb576e2e 100644 --- a/convert/llama.go +++ b/convert/llama.go @@ -5,7 +5,6 @@ import ( "fmt" "io" "log/slog" - "os" "regexp" "strings" @@ -132,7 +131,7 @@ func (m *LlamaModel) LoadVocab() error { return nil } -func (m *LlamaModel) WriteGGUF() (string, error) { +func (m *LlamaModel) WriteGGUF(ws io.WriteSeeker) error { kv := llm.KV{ "general.architecture": "llama", "general.name": m.Name, @@ -159,18 +158,5 @@ func (m *LlamaModel) WriteGGUF() (string, error) { "tokenizer.ggml.add_eos_token": false, } - f, err := os.CreateTemp("", "ollama-gguf") - if err != nil { - return "", err - } - defer f.Close() - - mod := llm.NewGGUFV3(m.Params.ByteOrder) - if err := mod.Encode(f, kv, m.Tensors); err != nil { - return "", err - } - - slog.Debug(fmt.Sprintf("gguf file = %s", f.Name())) - - return f.Name(), nil + return llm.NewGGUFV3(m.Params.ByteOrder).Encode(ws, kv, m.Tensors) } diff --git a/convert/mistral.go b/convert/mistral.go index 70c92edd..f88de12b 100644 --- a/convert/mistral.go +++ b/convert/mistral.go @@ -132,7 +132,7 @@ func (m *MistralModel) LoadVocab() error { return nil } -func (m *MistralModel) WriteGGUF() (string, error) { +func (m *MistralModel) WriteGGUF(ws io.WriteSeeker) error { kv := llm.KV{ "general.architecture": "llama", "general.name": m.Name, @@ -158,16 +158,5 @@ func (m *MistralModel) WriteGGUF() (string, error) { "tokenizer.ggml.unknown_token_id": uint32(0), } - f, err := os.CreateTemp("", "ollama-gguf") - if err != nil { - return "", err - } - defer f.Close() - - mod := llm.NewGGUFV3(m.Params.ByteOrder) - if err := mod.Encode(f, kv, m.Tensors); err != nil { - return "", err - } - - return f.Name(), nil + return llm.NewGGUFV3(m.Params.ByteOrder).Encode(ws, kv, m.Tensors) } diff --git a/convert/mixtral.go b/convert/mixtral.go index e31e84af..940df55d 100644 --- a/convert/mixtral.go +++ b/convert/mixtral.go @@ -1,7 +1,7 @@ package convert import ( - "os" + "io" "regexp" "github.com/ollama/ollama/llm" @@ -47,7 +47,7 @@ func (m *MixtralModel) LoadVocab() error { return nil } -func (m *MixtralModel) WriteGGUF() (string, error) { +func (m *MixtralModel) WriteGGUF(ws io.WriteSeeker) error { kv := llm.KV{ "general.architecture": "llama", "general.name": m.Name, @@ -81,16 +81,5 @@ func (m *MixtralModel) WriteGGUF() (string, error) { "tokenizer.ggml.add_eos_token": false, } - f, err := os.CreateTemp("", "ollama-gguf") - if err != nil { - return "", err - } - defer f.Close() - - mod := llm.NewGGUFV3(m.Params.ByteOrder) - if err := mod.Encode(f, kv, m.Tensors); err != nil { - return "", err - } - - return f.Name(), nil + return llm.NewGGUFV3(m.Params.ByteOrder).Encode(ws, kv, m.Tensors) } diff --git a/integration/utils_test.go b/integration/utils_test.go index 3e91187a..e133e76d 100644 --- a/integration/utils_test.go +++ b/integration/utils_test.go @@ -107,7 +107,7 @@ func startServer(ctx context.Context, ollamaHost string) error { if tmp := os.Getenv("OLLAMA_HOST"); tmp != ollamaHost { slog.Info("setting env", "OLLAMA_HOST", ollamaHost) - os.Setenv("OLLAMA_HOST", ollamaHost) + t.Setenv("OLLAMA_HOST", ollamaHost) } slog.Info("starting server", "url", ollamaHost) diff --git a/llm/filetype.go b/llm/filetype.go new file mode 100644 index 00000000..e5e9410d --- /dev/null +++ b/llm/filetype.go @@ -0,0 +1,140 @@ +package llm + +import "fmt" + +type fileType uint32 + +const ( + fileTypeF32 fileType = iota + fileTypeF16 + fileTypeQ4_0 + fileTypeQ4_1 + fileTypeQ4_1_F16 + fileTypeQ4_2 // unused + fileTypeQ4_3 // unused + fileTypeQ8_0 + fileTypeQ5_0 + fileTypeQ5_1 + fileTypeQ2_K + fileTypeQ3_K_S + fileTypeQ3_K_M + fileTypeQ3_K_L + fileTypeQ4_K_S + fileTypeQ4_K_M + fileTypeQ5_K_S + fileTypeQ5_K_M + fileTypeQ6_K + fileTypeIQ2_XXS + fileTypeIQ2_XS + fileTypeQ2_K_S + fileTypeQ3_K_XS + fileTypeIQ3_XXS + + fileTypeUnknown +) + +func ParseFileType(s string) (fileType, error) { + switch s { + case "F32": + return fileTypeF32, nil + case "F16": + return fileTypeF16, nil + case "Q4_0": + return fileTypeQ4_0, nil + case "Q4_1": + return fileTypeQ4_1, nil + case "Q4_1_F16": + return fileTypeQ4_1_F16, nil + case "Q8_0": + return fileTypeQ8_0, nil + case "Q5_0": + return fileTypeQ5_0, nil + case "Q5_1": + return fileTypeQ5_1, nil + case "Q2_K": + return fileTypeQ2_K, nil + case "Q3_K_S": + return fileTypeQ3_K_S, nil + case "Q3_K_M": + return fileTypeQ3_K_M, nil + case "Q3_K_L": + return fileTypeQ3_K_L, nil + case "Q4_K_S": + return fileTypeQ4_K_S, nil + case "Q4_K_M": + return fileTypeQ4_K_M, nil + case "Q5_K_S": + return fileTypeQ5_K_S, nil + case "Q5_K_M": + return fileTypeQ5_K_M, nil + case "Q6_K": + return fileTypeQ6_K, nil + case "IQ2_XXS": + return fileTypeIQ2_XXS, nil + case "IQ2_XS": + return fileTypeIQ2_XS, nil + case "Q2_K_S": + return fileTypeQ2_K_S, nil + case "Q3_K_XS": + return fileTypeQ3_K_XS, nil + case "IQ3_XXS": + return fileTypeIQ3_XXS, nil + default: + return fileTypeUnknown, fmt.Errorf("unknown fileType: %s", s) + } +} + +func (t fileType) String() string { + switch t { + case fileTypeF32: + return "F32" + case fileTypeF16: + return "F16" + case fileTypeQ4_0: + return "Q4_0" + case fileTypeQ4_1: + return "Q4_1" + case fileTypeQ4_1_F16: + return "Q4_1_F16" + case fileTypeQ8_0: + return "Q8_0" + case fileTypeQ5_0: + return "Q5_0" + case fileTypeQ5_1: + return "Q5_1" + case fileTypeQ2_K: + return "Q2_K" + case fileTypeQ3_K_S: + return "Q3_K_S" + case fileTypeQ3_K_M: + return "Q3_K_M" + case fileTypeQ3_K_L: + return "Q3_K_L" + case fileTypeQ4_K_S: + return "Q4_K_S" + case fileTypeQ4_K_M: + return "Q4_K_M" + case fileTypeQ5_K_S: + return "Q5_K_S" + case fileTypeQ5_K_M: + return "Q5_K_M" + case fileTypeQ6_K: + return "Q6_K" + case fileTypeIQ2_XXS: + return "IQ2_XXS" + case fileTypeIQ2_XS: + return "IQ2_XS" + case fileTypeQ2_K_S: + return "Q2_K_S" + case fileTypeQ3_K_XS: + return "Q3_K_XS" + case fileTypeIQ3_XXS: + return "IQ3_XXS" + default: + return "unknown" + } +} + +func (t fileType) Value() uint32 { + return uint32(t) +} diff --git a/llm/ggml.go b/llm/ggml.go index 1b094027..a83bba8f 100644 --- a/llm/ggml.go +++ b/llm/ggml.go @@ -13,82 +13,6 @@ type GGML struct { model } -const ( - fileTypeF32 uint32 = iota - fileTypeF16 - fileTypeQ4_0 - fileTypeQ4_1 - fileTypeQ4_1_F16 - fileTypeQ8_0 uint32 = iota + 2 - fileTypeQ5_0 - fileTypeQ5_1 - fileTypeQ2_K - fileTypeQ3_K_S - fileTypeQ3_K_M - fileTypeQ3_K_L - fileTypeQ4_K_S - fileTypeQ4_K_M - fileTypeQ5_K_S - fileTypeQ5_K_M - fileTypeQ6_K - fileTypeIQ2_XXS - fileTypeIQ2_XS - fileTypeQ2_K_S - fileTypeQ3_K_XS - fileTypeIQ3_XXS -) - -func fileType(fileType uint32) string { - switch fileType { - case fileTypeF32: - return "F32" - case fileTypeF16: - return "F16" - case fileTypeQ4_0: - return "Q4_0" - case fileTypeQ4_1: - return "Q4_1" - case fileTypeQ4_1_F16: - return "Q4_1_F16" - case fileTypeQ8_0: - return "Q8_0" - case fileTypeQ5_0: - return "Q5_0" - case fileTypeQ5_1: - return "Q5_1" - case fileTypeQ2_K: - return "Q2_K" - case fileTypeQ3_K_S: - return "Q3_K_S" - case fileTypeQ3_K_M: - return "Q3_K_M" - case fileTypeQ3_K_L: - return "Q3_K_L" - case fileTypeQ4_K_S: - return "Q4_K_S" - case fileTypeQ4_K_M: - return "Q4_K_M" - case fileTypeQ5_K_S: - return "Q5_K_S" - case fileTypeQ5_K_M: - return "Q5_K_M" - case fileTypeQ6_K: - return "Q6_K" - case fileTypeIQ2_XXS: - return "IQ2_XXS" - case fileTypeIQ2_XS: - return "IQ2_XS" - case fileTypeQ2_K_S: - return "Q2_K_S" - case fileTypeQ3_K_XS: - return "Q3_K_XS" - case fileTypeIQ3_XXS: - return "IQ3_XXS" - default: - return "unknown" - } -} - type model interface { KV() KV Tensors() Tensors @@ -123,7 +47,7 @@ func (kv KV) ParameterCount() uint64 { func (kv KV) FileType() string { if u64 := kv.u64("general.file_type"); u64 > 0 { - return fileType(uint32(u64)) + return fileType(uint32(u64)).String() } return "unknown" @@ -286,6 +210,23 @@ const ( var ErrUnsupportedFormat = errors.New("unsupported model format") +func DetectGGMLType(b []byte) string { + switch binary.LittleEndian.Uint32(b[:4]) { + case FILE_MAGIC_GGML: + return "ggml" + case FILE_MAGIC_GGMF: + return "ggmf" + case FILE_MAGIC_GGJT: + return "ggjt" + case FILE_MAGIC_GGLA: + return "ggla" + case FILE_MAGIC_GGUF_LE, FILE_MAGIC_GGUF_BE: + return "gguf" + default: + return "" + } +} + func DecodeGGML(rs io.ReadSeeker) (*GGML, int64, error) { var magic uint32 if err := binary.Read(rs, binary.LittleEndian, &magic); err != nil { diff --git a/llm/llm.go b/llm/llm.go index c81e2edf..2a0c4b91 100644 --- a/llm/llm.go +++ b/llm/llm.go @@ -20,7 +20,7 @@ func SystemInfo() string { return C.GoString(C.llama_print_system_info()) } -func Quantize(infile, outfile, filetype string) error { +func Quantize(infile, outfile string, ftype fileType) error { cinfile := C.CString(infile) defer C.free(unsafe.Pointer(cinfile)) @@ -29,58 +29,10 @@ func Quantize(infile, outfile, filetype string) error { params := C.llama_model_quantize_default_params() params.nthread = -1 + params.ftype = ftype.Value() - switch filetype { - case "F32": - params.ftype = fileTypeF32 - case "F16": - params.ftype = fileTypeF16 - case "Q4_0": - params.ftype = fileTypeQ4_0 - case "Q4_1": - params.ftype = fileTypeQ4_1 - case "Q4_1_F16": - params.ftype = fileTypeQ4_1_F16 - case "Q8_0": - params.ftype = fileTypeQ8_0 - case "Q5_0": - params.ftype = fileTypeQ5_0 - case "Q5_1": - params.ftype = fileTypeQ5_1 - case "Q2_K": - params.ftype = fileTypeQ2_K - case "Q3_K_S": - params.ftype = fileTypeQ3_K_S - case "Q3_K_M": - params.ftype = fileTypeQ3_K_M - case "Q3_K_L": - params.ftype = fileTypeQ3_K_L - case "Q4_K_S": - params.ftype = fileTypeQ4_K_S - case "Q4_K_M": - params.ftype = fileTypeQ4_K_M - case "Q5_K_S": - params.ftype = fileTypeQ5_K_S - case "Q5_K_M": - params.ftype = fileTypeQ5_K_M - case "Q6_K": - params.ftype = fileTypeQ6_K - case "IQ2_XXS": - params.ftype = fileTypeIQ2_XXS - case "IQ2_XS": - params.ftype = fileTypeIQ2_XS - case "Q2_K_S": - params.ftype = fileTypeQ2_K_S - case "Q3_K_XS": - params.ftype = fileTypeQ3_K_XS - case "IQ3_XXS": - params.ftype = fileTypeIQ3_XXS - default: - return fmt.Errorf("unknown filetype: %s", filetype) - } - - if retval := C.llama_model_quantize(cinfile, coutfile, ¶ms); retval != 0 { - return fmt.Errorf("llama_model_quantize: %d", retval) + if rc := C.llama_model_quantize(cinfile, coutfile, ¶ms); rc != 0 { + return fmt.Errorf("llama_model_quantize: %d", rc) } return nil diff --git a/server/images.go b/server/images.go index 76205392..2817b1d3 100644 --- a/server/images.go +++ b/server/images.go @@ -1,8 +1,8 @@ package server import ( - "archive/zip" "bytes" + "cmp" "context" "crypto/sha256" "encoding/base64" @@ -11,7 +11,6 @@ import ( "errors" "fmt" "io" - "io/fs" "log" "log/slog" "net/http" @@ -26,7 +25,6 @@ import ( "github.com/ollama/ollama/api" "github.com/ollama/ollama/auth" - "github.com/ollama/ollama/convert" "github.com/ollama/ollama/format" "github.com/ollama/ollama/llm" "github.com/ollama/ollama/server/envconfig" @@ -158,36 +156,6 @@ type ConfigV2 struct { RootFS RootFS `json:"rootfs"` } -func (c *ConfigV2) SetModelFormat(format string) { - if c.ModelFormat == "" { - c.ModelFormat = format - } -} - -func (c *ConfigV2) SetModelFamily(families ...string) { - for _, family := range families { - if c.ModelFamily == "" { - c.ModelFamily = family - } - - if !slices.Contains(c.ModelFamilies, family) { - c.ModelFamilies = append(c.ModelFamilies, family) - } - } -} - -func (c *ConfigV2) SetModelType(modelType string) { - if c.ModelType == "" { - c.ModelType = modelType - } -} - -func (c *ConfigV2) SetFileType(fileType string) { - if c.FileType == "" { - c.FileType = fileType - } -} - type RootFS struct { Type string `json:"type"` DiffIDs []string `json:"diff_ids"` @@ -332,7 +300,7 @@ func GetModel(name string) (*Model, error) { return model, nil } -func realpath(mfDir, from string) string { +func realpath(rel, from string) string { abspath, err := filepath.Abs(from) if err != nil { return from @@ -349,22 +317,15 @@ func realpath(mfDir, from string) string { return filepath.Join(home, from[2:]) } - if _, err := os.Stat(filepath.Join(mfDir, from)); err == nil { + if _, err := os.Stat(filepath.Join(rel, from)); err == nil { // this is a file relative to the Modelfile - return filepath.Join(mfDir, from) + return filepath.Join(rel, from) } return abspath } -func CreateModel(ctx context.Context, name, modelFileDir, quantization string, modelfile *model.File, fn func(resp api.ProgressResponse)) error { - deleteMap := make(map[string]struct{}) - if manifest, _, err := GetManifest(ParseModelPath(name)); err == nil { - for _, layer := range append(manifest.Layers, manifest.Config) { - deleteMap[layer.Digest] = struct{}{} - } - } - +func CreateModel(ctx context.Context, name, modelFileDir, quantization string, modelfile *model.File, fn func(resp api.ProgressResponse)) (err error) { config := ConfigV2{ OS: "linux", Architecture: "amd64", @@ -373,250 +334,181 @@ func CreateModel(ctx context.Context, name, modelFileDir, quantization string, m }, } - var layers Layers - messages := []string{} - - params := make(map[string][]string) - fromParams := make(map[string]any) + var messages []*api.Message + parameters := make(map[string]any) + var layers []*Layer for _, c := range modelfile.Commands { mediatype := fmt.Sprintf("application/vnd.ollama.image.%s", c.Name) switch c.Name { - case "model": - if strings.HasPrefix(c.Args, "@") { - blobPath, err := GetBlobsPath(strings.TrimPrefix(c.Args, "@")) + case "model", "adapter": + var baseLayers []*layerWithGGML + if name := model.ParseName(c.Args); name.IsValid() { + baseLayers, err = parseFromModel(ctx, name, fn) + if err != nil { + return err + } + } else if strings.HasPrefix(c.Args, "@") { + blobpath, err := GetBlobsPath(strings.TrimPrefix(c.Args, "@")) if err != nil { return err } - c.Args = blobPath - } - - pathName := realpath(modelFileDir, c.Args) - - ggufName, err := convertModel(name, pathName, fn) - if err != nil { - var pathErr *fs.PathError - switch { - case errors.Is(err, zip.ErrFormat): - // it's not a safetensor archive - case errors.As(err, &pathErr): - // it's not a file on disk, could be a model reference - default: - return err - } - } - - if ggufName != "" { - pathName = ggufName - defer os.RemoveAll(ggufName) - - if quantization != "" { - quantization = strings.ToUpper(quantization) - fn(api.ProgressResponse{Status: fmt.Sprintf("quantizing %s model to %s", "F16", quantization)}) - tempfile, err := os.CreateTemp(filepath.Dir(ggufName), quantization) - if err != nil { - return err - } - defer os.RemoveAll(tempfile.Name()) - - if err := llm.Quantize(ggufName, tempfile.Name(), quantization); err != nil { - return err - } - - if err := tempfile.Close(); err != nil { - return err - } - - pathName = tempfile.Name() - } - } - - bin, err := os.Open(pathName) - if err != nil { - // not a file on disk so must be a model reference - modelpath := ParseModelPath(c.Args) - manifest, _, err := GetManifest(modelpath) - switch { - case errors.Is(err, os.ErrNotExist): - fn(api.ProgressResponse{Status: "pulling model"}) - if err := PullModel(ctx, c.Args, ®istryOptions{}, fn); err != nil { - return err - } - - manifest, _, err = GetManifest(modelpath) - if err != nil { - return err - } - case err != nil: - return err - } - - fn(api.ProgressResponse{Status: "reading model metadata"}) - fromConfigPath, err := GetBlobsPath(manifest.Config.Digest) + blob, err := os.Open(blobpath) if err != nil { return err } + defer blob.Close() - fromConfigFile, err := os.Open(fromConfigPath) + baseLayers, err = parseFromFile(ctx, blob, fn) if err != nil { return err } - defer fromConfigFile.Close() + } else if file, err := os.Open(realpath(modelFileDir, c.Args)); err == nil { + defer file.Close() - var fromConfig ConfigV2 - if err := json.NewDecoder(fromConfigFile).Decode(&fromConfig); err != nil { + baseLayers, err = parseFromFile(ctx, file, fn) + if err != nil { return err } + } else { + return fmt.Errorf("invalid model reference: %s", c.Args) + } - // if the model is still not in gguf format, error out - if fromConfig.ModelFormat != "gguf" { - return fmt.Errorf("%s is not in gguf format, this base model is not compatible with this version of ollama", c.Args) - } - - config.SetModelFormat(fromConfig.ModelFormat) - config.SetModelFamily(append(fromConfig.ModelFamilies, fromConfig.ModelFamily)...) - config.SetModelType(fromConfig.ModelType) - config.SetFileType(fromConfig.FileType) - - for _, layer := range manifest.Layers { - deleteMap[layer.Digest] = struct{}{} - if layer.MediaType == "application/vnd.ollama.image.params" { - fromParamsPath, err := GetBlobsPath(layer.Digest) - if err != nil { - return err - } - - fromParamsFile, err := os.Open(fromParamsPath) - if err != nil { - return err - } - defer fromParamsFile.Close() - - if err := json.NewDecoder(fromParamsFile).Decode(&fromParams); err != nil { - return err - } - } - - layer, err := NewLayerFromLayer(layer.Digest, layer.MediaType, modelpath.GetShortTagname()) + for _, baseLayer := range baseLayers { + if quantization != "" && + baseLayer.MediaType == "application/vnd.ollama.image.model" && + baseLayer.GGML != nil && + baseLayer.GGML.Name() == "gguf" { + ftype, err := llm.ParseFileType(quantization) if err != nil { return err } - layers.Add(layer) + filetype := baseLayer.GGML.KV().FileType() + if !slices.Contains([]string{"F16", "F32"}, filetype) { + return errors.New("quantization is only supported for F16 and F32 models") + } + + fn(api.ProgressResponse{Status: fmt.Sprintf("quantizing %s model to %s", filetype, quantization)}) + + blob, err := GetBlobsPath(baseLayer.Digest) + if err != nil { + return err + } + + temp, err := os.CreateTemp(filepath.Dir(blob), quantization) + if err != nil { + return err + } + defer temp.Close() + defer os.Remove(temp.Name()) + + if err := llm.Quantize(blob, temp.Name(), ftype); err != nil { + return err + } + + baseLayer.Layer, err = NewLayer(temp, baseLayer.Layer.MediaType) + if err != nil { + return err + } } - deleteMap[manifest.Config.Digest] = struct{}{} - continue + if baseLayer.GGML != nil { + config.ModelFormat = cmp.Or(config.ModelFormat, baseLayer.GGML.Name()) + config.ModelFamily = cmp.Or(config.ModelFamily, baseLayer.GGML.KV().Architecture()) + config.ModelType = cmp.Or(config.ModelType, format.HumanNumber(baseLayer.GGML.KV().ParameterCount())) + config.FileType = cmp.Or(config.FileType, baseLayer.GGML.KV().FileType()) + config.ModelFamilies = append(config.ModelFamilies, baseLayer.GGML.KV().Architecture()) + } + + layers = append(layers, baseLayer.Layer) } - defer bin.Close() - - var offset int64 - for { - fn(api.ProgressResponse{Status: "creating model layer"}) - if _, err := bin.Seek(offset, io.SeekStart); err != nil { - return err - } - - ggml, size, err := llm.DecodeGGML(bin) - if errors.Is(err, io.EOF) { - break - } else if errors.Is(err, llm.ErrUnsupportedFormat) { - return fmt.Errorf("model binary specified in FROM field is not a valid gguf format model, %w", err) - } else if err != nil { - return err - } - - config.SetModelFormat(ggml.Name()) - config.SetModelFamily(ggml.KV().Architecture()) - config.SetModelType(format.HumanNumber(ggml.KV().ParameterCount())) - config.SetFileType(ggml.KV().FileType()) - - mediatype := mediatype - if ggml.KV().Architecture() == "clip" { - mediatype = "application/vnd.ollama.image.projector" - } - - sr := io.NewSectionReader(bin, offset, size) - layer, err := NewLayer(sr, mediatype) - if err != nil { - return err - } - - layers.Add(layer) - - offset += size - } - case "adapter": - if strings.HasPrefix(c.Args, "@") { - blobPath, err := GetBlobsPath(strings.TrimPrefix(c.Args, "@")) - if err != nil { - return err - } - - c.Args = blobPath - } - - fn(api.ProgressResponse{Status: "creating adapter layer"}) - bin, err := os.Open(realpath(modelFileDir, c.Args)) - if err != nil { - return err - } - defer bin.Close() - - _, size, err := llm.DecodeGGML(bin) + case "license", "template", "system": + blob := strings.NewReader(c.Args) + layer, err := NewLayer(blob, mediatype) if err != nil { return err } - sr := io.NewSectionReader(bin, 0, size) - layer, err := NewLayer(sr, mediatype) - if err != nil { - return err + if c.Name != "license" { + // replace + layers = slices.DeleteFunc(layers, func(layer *Layer) bool { + return layer.MediaType == mediatype + }) } - layers.Add(layer) - case "license": - fn(api.ProgressResponse{Status: "creating license layer"}) - - bin := strings.NewReader(c.Args) - layer, err := NewLayer(bin, mediatype) - if err != nil { - return err - } - - layers.Add(layer) - case "template", "system": - fn(api.ProgressResponse{Status: fmt.Sprintf("creating %s layer", c.Name)}) - - bin := strings.NewReader(c.Args) - layer, err := NewLayer(bin, mediatype) - if err != nil { - return err - } - - layers.Replace(layer) + layers = append(layers, layer) case "message": - messages = append(messages, c.Args) + role, content, ok := strings.Cut(c.Args, ": ") + if !ok { + return fmt.Errorf("invalid message: %s", c.Args) + } + + messages = append(messages, &api.Message{Role: role, Content: content}) default: - params[c.Name] = append(params[c.Name], c.Args) + ps, err := api.FormatParams(map[string][]string{c.Name: {c.Args}}) + if err != nil { + return err + } + + for k, v := range ps { + if ks, ok := parameters[k].([]string); ok { + parameters[k] = append(ks, v.([]string)...) + } else if vs, ok := v.([]string); ok { + parameters[k] = vs + } else { + parameters[k] = v + } + } } } - if len(messages) > 0 { - fn(api.ProgressResponse{Status: "creating parameters layer"}) + var err2 error + layers = slices.DeleteFunc(layers, func(layer *Layer) bool { + switch layer.MediaType { + case "application/vnd.ollama.image.message": + // if there are new messages, remove the inherited ones + if len(messages) > 0 { + return true + } - msgs := make([]api.Message, 0) + return false + case "application/vnd.ollama.image.params": + // merge inherited parameters with new ones + r, err := layer.Open() + if err != nil { + err2 = err + return false + } + defer r.Close() - for _, m := range messages { - // todo: handle images - msg := strings.SplitN(m, ": ", 2) - msgs = append(msgs, api.Message{Role: msg[0], Content: msg[1]}) + var ps map[string]any + if err := json.NewDecoder(r).Decode(&ps); err != nil { + err2 = err + return false + } + + for k, v := range ps { + if _, ok := parameters[k]; !ok { + parameters[k] = v + } + } + + return true + default: + return false } + }) + if err2 != nil { + return err2 + } + + if len(messages) > 0 { var b bytes.Buffer - if err := json.NewEncoder(&b).Encode(msgs); err != nil { + if err := json.NewEncoder(&b).Encode(messages); err != nil { return err } @@ -625,39 +517,25 @@ func CreateModel(ctx context.Context, name, modelFileDir, quantization string, m return err } - layers.Replace(layer) + layers = append(layers, layer) } - if len(params) > 0 { - fn(api.ProgressResponse{Status: "creating parameters layer"}) - - formattedParams, err := api.FormatParams(params) - if err != nil { - return err - } - - for k, v := range fromParams { - if _, ok := formattedParams[k]; !ok { - formattedParams[k] = v - } - } - + if len(parameters) > 0 { var b bytes.Buffer - if err := json.NewEncoder(&b).Encode(formattedParams); err != nil { + if err := json.NewEncoder(&b).Encode(parameters); err != nil { return err } - fn(api.ProgressResponse{Status: "creating config layer"}) layer, err := NewLayer(&b, "application/vnd.ollama.image.params") if err != nil { return err } - layers.Replace(layer) + layers = append(layers, layer) } - digests := make([]string, len(layers.items)) - for i, layer := range layers.items { + digests := make([]string, len(layers)) + for i, layer := range layers { digests[i] = layer.Digest } @@ -668,36 +546,37 @@ func CreateModel(ctx context.Context, name, modelFileDir, quantization string, m return err } - configLayer, err := NewLayer(&b, "application/vnd.docker.container.image.v1+json") + layer, err := NewLayer(&b, "application/vnd.docker.container.image.v1+json") if err != nil { return err } - delete(deleteMap, configLayer.Digest) + for _, layer := range append(layers, layer) { + if layer.status != "" { + fn(api.ProgressResponse{Status: layer.status}) + } + } - for _, layer := range append(layers.items, configLayer) { - committed, err := layer.Commit() - if err != nil { - return err + unref := make(map[string]struct{}) + if manifest, _, err := GetManifest(ParseModelPath(name)); err == nil { + for _, layer := range manifest.Layers { + if !slices.Contains(digests, layer.Digest) { + unref[layer.Digest] = struct{}{} + } } - status := "writing layer" - if !committed { - status = "using already created layer" + if manifest.Config.Digest != layer.Digest { + unref[manifest.Config.Digest] = struct{}{} } - - fn(api.ProgressResponse{Status: fmt.Sprintf("%s %s", status, layer.Digest)}) - - delete(deleteMap, layer.Digest) } fn(api.ProgressResponse{Status: "writing manifest"}) - if err := WriteManifest(name, configLayer, layers.items); err != nil { + if err := WriteManifest(name, layer, layers); err != nil { return err } if !envconfig.NoPrune { - if err := deleteUnusedLayers(nil, deleteMap, false); err != nil { + if err := deleteUnusedLayers(nil, unref, false); err != nil { return err } } @@ -706,74 +585,6 @@ func CreateModel(ctx context.Context, name, modelFileDir, quantization string, m return nil } -func convertModel(name, path string, fn func(resp api.ProgressResponse)) (string, error) { - r, err := zip.OpenReader(path) - if err != nil { - return "", err - } - defer r.Close() - - tempDir, err := os.MkdirTemp("", "ollama-convert") - if err != nil { - return "", err - } - defer os.RemoveAll(tempDir) - - fn(api.ProgressResponse{Status: "unpacking model metadata"}) - for _, f := range r.File { - fpath := filepath.Join(tempDir, f.Name) - outFile, err := os.OpenFile(fpath, os.O_WRONLY|os.O_CREATE|os.O_TRUNC, f.Mode()) - if err != nil { - return "", err - } - - rc, err := f.Open() - if err != nil { - return "", err - } - - _, err = io.Copy(outFile, rc) - if err != nil { - return "", err - } - - outFile.Close() - rc.Close() - } - - mf, err := convert.GetModelFormat(tempDir) - if err != nil { - return "", err - } - - params, err := mf.GetParams(tempDir) - if err != nil { - return "", err - } - - mArch, err := mf.GetModelArch(name, tempDir, params) - if err != nil { - return "", err - } - - fn(api.ProgressResponse{Status: "processing tensors"}) - if err := mArch.GetTensors(); err != nil { - return "", err - } - - if err := mArch.LoadVocab(); err != nil { - return "", err - } - - fn(api.ProgressResponse{Status: "converting model"}) - path, err = mArch.WriteGGUF() - if err != nil { - return "", err - } - - return path, nil -} - func CopyModel(src, dst model.Name) error { if !dst.IsFullyQualified() { return model.Unqualified(dst) diff --git a/server/layers.go b/server/layer.go similarity index 53% rename from server/layers.go rename to server/layer.go index 07787406..dcca3854 100644 --- a/server/layers.go +++ b/server/layer.go @@ -5,39 +5,14 @@ import ( "fmt" "io" "os" - "strings" - - "golang.org/x/exp/slices" ) -type Layers struct { - items []*Layer -} - -func (ls *Layers) Add(layer *Layer) { - if layer.Size > 0 { - ls.items = append(ls.items, layer) - } -} - -func (ls *Layers) Replace(layer *Layer) { - if layer.Size > 0 { - mediatype := layer.MediaType - layers := slices.DeleteFunc(ls.items, func(l *Layer) bool { - return l.MediaType == mediatype - }) - - ls.items = append(layers, layer) - } -} - type Layer struct { MediaType string `json:"mediaType"` Digest string `json:"digest"` Size int64 `json:"size"` From string `json:"from,omitempty"` - - tempFileName string + status string } func NewLayer(r io.Reader, mediatype string) (*Layer, error) { @@ -46,14 +21,12 @@ func NewLayer(r io.Reader, mediatype string) (*Layer, error) { return nil, err } - const delimiter = "-" - - pattern := strings.Join([]string{"sha256", "*-partial"}, delimiter) - temp, err := os.CreateTemp(blobs, pattern) + temp, err := os.CreateTemp(blobs, "sha256-") if err != nil { return nil, err } defer temp.Close() + defer os.Remove(temp.Name()) sha256sum := sha256.New() n, err := io.Copy(io.MultiWriter(temp, sha256sum), r) @@ -61,11 +34,29 @@ func NewLayer(r io.Reader, mediatype string) (*Layer, error) { return nil, err } + if err := temp.Close(); err != nil { + return nil, err + } + + digest := fmt.Sprintf("sha256:%x", sha256sum.Sum(nil)) + blob, err := GetBlobsPath(digest) + if err != nil { + return nil, err + } + + status := "using existing layer" + if _, err := os.Stat(blob); err != nil { + status = "creating new layer" + if err := os.Rename(temp.Name(), blob); err != nil { + return nil, err + } + } + return &Layer{ - MediaType: mediatype, - Digest: fmt.Sprintf("sha256:%x", sha256sum.Sum(nil)), - Size: n, - tempFileName: temp.Name(), + MediaType: mediatype, + Digest: digest, + Size: n, + status: fmt.Sprintf("%s %s", status, digest), }, nil } @@ -85,21 +76,15 @@ func NewLayerFromLayer(digest, mediatype, from string) (*Layer, error) { Digest: digest, Size: fi.Size(), From: from, + status: fmt.Sprintf("using existing layer %s", digest), }, nil } -func (l *Layer) Commit() (bool, error) { - // always remove temp - defer os.Remove(l.tempFileName) - +func (l *Layer) Open() (io.ReadCloser, error) { blob, err := GetBlobsPath(l.Digest) if err != nil { - return false, err + return nil, err } - if _, err := os.Stat(blob); err != nil { - return true, os.Rename(l.tempFileName, blob) - } - - return false, nil + return os.Open(blob) } diff --git a/server/model.go b/server/model.go new file mode 100644 index 00000000..eea5d13a --- /dev/null +++ b/server/model.go @@ -0,0 +1,261 @@ +package server + +import ( + "archive/zip" + "bytes" + "context" + "errors" + "fmt" + "io" + "net/http" + "os" + "path/filepath" + + "github.com/ollama/ollama/api" + "github.com/ollama/ollama/convert" + "github.com/ollama/ollama/llm" + "github.com/ollama/ollama/types/model" +) + +type layerWithGGML struct { + *Layer + *llm.GGML +} + +func parseFromModel(ctx context.Context, name model.Name, fn func(api.ProgressResponse)) (layers []*layerWithGGML, err error) { + modelpath := ParseModelPath(name.String()) + manifest, _, err := GetManifest(modelpath) + switch { + case errors.Is(err, os.ErrNotExist): + if err := PullModel(ctx, name.String(), ®istryOptions{}, fn); err != nil { + return nil, err + } + + modelpath = ParseModelPath(name.String()) + manifest, _, err = GetManifest(modelpath) + if err != nil { + return nil, err + } + case err != nil: + return nil, err + } + + for _, layer := range manifest.Layers { + layer, err := NewLayerFromLayer(layer.Digest, layer.MediaType, modelpath.GetShortTagname()) + if err != nil { + return nil, err + } + + switch layer.MediaType { + case "application/vnd.ollama.image.model", + "application/vnd.ollama.image.projector", + "application/vnd.ollama.image.adapter": + blobpath, err := GetBlobsPath(layer.Digest) + if err != nil { + return nil, err + } + + blob, err := os.Open(blobpath) + if err != nil { + return nil, err + } + defer blob.Close() + + ggml, _, err := llm.DecodeGGML(blob) + if err != nil { + return nil, err + } + + layers = append(layers, &layerWithGGML{layer, ggml}) + default: + layers = append(layers, &layerWithGGML{layer, nil}) + } + + } + + return layers, nil +} + +func parseFromZipFile(_ context.Context, file *os.File, fn func(api.ProgressResponse)) (layers []*layerWithGGML, err error) { + stat, err := file.Stat() + if err != nil { + return nil, err + } + + r, err := zip.NewReader(file, stat.Size()) + if err != nil { + return nil, err + } + + tempdir, err := os.MkdirTemp(filepath.Dir(file.Name()), "") + if err != nil { + return nil, err + } + defer os.RemoveAll(tempdir) + + fn(api.ProgressResponse{Status: "unpacking model metadata"}) + for _, f := range r.File { + // TODO(mxyng): this should not write out all files to disk + outfile, err := os.Create(filepath.Join(tempdir, f.Name)) + if err != nil { + return nil, err + } + defer outfile.Close() + + infile, err := f.Open() + if err != nil { + return nil, err + } + defer infile.Close() + + if _, err = io.Copy(outfile, infile); err != nil { + return nil, err + } + + if err := outfile.Close(); err != nil { + return nil, err + } + + if err := infile.Close(); err != nil { + return nil, err + } + } + + mf, err := convert.GetModelFormat(tempdir) + if err != nil { + return nil, err + } + + params, err := mf.GetParams(tempdir) + if err != nil { + return nil, err + } + + mArch, err := mf.GetModelArch("", tempdir, params) + if err != nil { + return nil, err + } + + fn(api.ProgressResponse{Status: "processing tensors"}) + if err := mArch.GetTensors(); err != nil { + return nil, err + } + + if err := mArch.LoadVocab(); err != nil { + return nil, err + } + + fn(api.ProgressResponse{Status: "converting model"}) + + // TODO(mxyng): this should write directly into a layer + // e.g. NewLayer(arch.Reader(), "application/vnd.ollama.image.model") + temp, err := os.CreateTemp(tempdir, "fp16") + if err != nil { + return nil, err + } + defer temp.Close() + defer os.Remove(temp.Name()) + + if err = mArch.WriteGGUF(temp); err != nil { + return nil, err + } + + if _, err := temp.Seek(0, io.SeekStart); err != nil { + return nil, err + } + + layer, err := NewLayer(temp, "application/vnd.ollama.image.model") + if err != nil { + return nil, fmt.Errorf("aaa: %w", err) + } + + blobpath, err := GetBlobsPath(layer.Digest) + if err != nil { + return nil, err + } + + bin, err := os.Open(blobpath) + if err != nil { + return nil, err + } + defer bin.Close() + + ggml, _, err := llm.DecodeGGML(bin) + if err != nil { + return nil, err + } + + layer, err = NewLayerFromLayer(layer.Digest, layer.MediaType, "") + if err != nil { + return nil, err + } + + layers = append(layers, &layerWithGGML{layer, ggml}) + return layers, nil +} + +func parseFromFile(ctx context.Context, file *os.File, fn func(api.ProgressResponse)) (layers []*layerWithGGML, err error) { + sr := io.NewSectionReader(file, 0, 512) + contentType, err := detectContentType(sr) + if err != nil { + return nil, err + } + + switch contentType { + case "gguf", "ggla": + // noop + case "application/zip": + return parseFromZipFile(ctx, file, fn) + default: + return nil, fmt.Errorf("unsupported content type: %s", contentType) + } + + stat, err := file.Stat() + if err != nil { + return nil, err + } + + var offset int64 + for offset < stat.Size() { + ggml, n, err := llm.DecodeGGML(file) + if errors.Is(err, io.EOF) { + break + } else if err != nil { + return nil, err + } + + mediatype := "application/vnd.ollama.image.model" + if ggml.Name() == "ggla" { + mediatype = "application/vnd.ollama.image.adapter" + } else if ggml.KV().Architecture() == "clip" { + mediatype = "application/vnd.ollama.image.projector" + } + + layer, err := NewLayer(io.NewSectionReader(file, offset, n), mediatype) + if err != nil { + return nil, err + } + + layers = append(layers, &layerWithGGML{layer, ggml}) + offset = n + } + + return layers, nil +} + +func detectContentType(r io.Reader) (string, error) { + var b bytes.Buffer + if _, err := io.Copy(&b, r); err != nil { + return "", err + } + + if contentType := llm.DetectGGMLType(b.Bytes()); contentType != "" { + return contentType, nil + } + + if contentType := http.DetectContentType(b.Bytes()); contentType != "application/octet-stream" { + return contentType, nil + } + + return "unknown", nil +} diff --git a/server/routes.go b/server/routes.go index da51fbbe..e9b7d3b0 100644 --- a/server/routes.go +++ b/server/routes.go @@ -560,7 +560,7 @@ func (s *Server) CreateModelHandler(c *gin.Context) { ctx, cancel := context.WithCancel(c.Request.Context()) defer cancel() - if err := CreateModel(ctx, name.String(), filepath.Dir(req.Path), req.Quantization, modelfile, fn); err != nil { + if err := CreateModel(ctx, name.String(), filepath.Dir(req.Path), strings.ToUpper(req.Quantization), modelfile, fn); err != nil { ch <- gin.H{"error": err.Error()} } }() @@ -852,11 +852,6 @@ func (s *Server) CreateBlobHandler(c *gin.Context) { return } - if _, err := layer.Commit(); err != nil { - c.AbortWithStatusJSON(http.StatusInternalServerError, gin.H{"error": err.Error()}) - return - } - c.Status(http.StatusCreated) } diff --git a/server/routes_test.go b/server/routes_test.go index 27e53cbd..896dc27b 100644 --- a/server/routes_test.go +++ b/server/routes_test.go @@ -124,14 +124,12 @@ func Test_Routes(t *testing.T) { Method: http.MethodPost, Path: "/api/create", Setup: func(t *testing.T, req *http.Request) { - f, err := os.CreateTemp(t.TempDir(), "ollama-model") - assert.Nil(t, err) - defer f.Close() + fname := createTestFile(t, "ollama-model") stream := false createReq := api.CreateRequest{ Name: "t-bone", - Modelfile: fmt.Sprintf("FROM %s", f.Name()), + Modelfile: fmt.Sprintf("FROM %s", fname), Stream: &stream, } jsonData, err := json.Marshal(createReq) @@ -216,27 +214,25 @@ func Test_Routes(t *testing.T) { httpSrv := httptest.NewServer(router) t.Cleanup(httpSrv.Close) - workDir, err := os.MkdirTemp("", "ollama-test") - assert.Nil(t, err) - defer os.RemoveAll(workDir) - os.Setenv("OLLAMA_MODELS", workDir) + t.Setenv("OLLAMA_MODELS", t.TempDir()) for _, tc := range testCases { - t.Logf("Running Test: [%s]", tc.Name) - u := httpSrv.URL + tc.Path - req, err := http.NewRequestWithContext(context.TODO(), tc.Method, u, nil) - assert.Nil(t, err) + t.Run(tc.Name, func(t *testing.T) { + u := httpSrv.URL + tc.Path + req, err := http.NewRequestWithContext(context.TODO(), tc.Method, u, nil) + assert.Nil(t, err) - if tc.Setup != nil { - tc.Setup(t, req) - } + if tc.Setup != nil { + tc.Setup(t, req) + } - resp, err := httpSrv.Client().Do(req) - assert.Nil(t, err) - defer resp.Body.Close() + resp, err := httpSrv.Client().Do(req) + assert.Nil(t, err) + defer resp.Body.Close() - if tc.Expected != nil { - tc.Expected(t, resp) - } + if tc.Expected != nil { + tc.Expected(t, resp) + } + }) } }