Merge pull request #3682 from ollama/mxyng/quantize-all-the-things

quantize any fp16/fp32 model
This commit is contained in:
Michael Yang 2024-05-07 15:20:49 -07:00 committed by GitHub
commit 1e0a669f75
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
14 changed files with 641 additions and 606 deletions

View file

@ -5,6 +5,7 @@ import (
"encoding/binary" "encoding/binary"
"encoding/json" "encoding/json"
"fmt" "fmt"
"io"
"log/slog" "log/slog"
"os" "os"
"path/filepath" "path/filepath"
@ -47,7 +48,7 @@ type ByteOrder interface {
type ModelArch interface { type ModelArch interface {
GetTensors() error GetTensors() error
LoadVocab() error LoadVocab() error
WriteGGUF() (string, error) WriteGGUF(io.WriteSeeker) error
} }
type ModelFormat interface { type ModelFormat interface {

View file

@ -94,7 +94,7 @@ func (m *GemmaModel) LoadVocab() error {
return nil return nil
} }
func (m *GemmaModel) WriteGGUF() (string, error) { func (m *GemmaModel) WriteGGUF(ws io.WriteSeeker) error {
kv := llm.KV{ kv := llm.KV{
"general.architecture": "gemma", "general.architecture": "gemma",
"general.name": m.Name, "general.name": m.Name,
@ -122,16 +122,5 @@ func (m *GemmaModel) WriteGGUF() (string, error) {
"tokenizer.ggml.add_eos_token": false, "tokenizer.ggml.add_eos_token": false,
} }
f, err := os.CreateTemp("", "ollama-gguf") return llm.NewGGUFV3(m.Params.ByteOrder).Encode(ws, kv, m.Tensors)
if err != nil {
return "", err
}
defer f.Close()
mod := llm.NewGGUFV3(m.Params.ByteOrder)
if err := mod.Encode(f, kv, m.Tensors); err != nil {
return "", err
}
return f.Name(), nil
} }

View file

@ -5,7 +5,6 @@ import (
"fmt" "fmt"
"io" "io"
"log/slog" "log/slog"
"os"
"regexp" "regexp"
"strings" "strings"
@ -132,7 +131,7 @@ func (m *LlamaModel) LoadVocab() error {
return nil return nil
} }
func (m *LlamaModel) WriteGGUF() (string, error) { func (m *LlamaModel) WriteGGUF(ws io.WriteSeeker) error {
kv := llm.KV{ kv := llm.KV{
"general.architecture": "llama", "general.architecture": "llama",
"general.name": m.Name, "general.name": m.Name,
@ -159,18 +158,5 @@ func (m *LlamaModel) WriteGGUF() (string, error) {
"tokenizer.ggml.add_eos_token": false, "tokenizer.ggml.add_eos_token": false,
} }
f, err := os.CreateTemp("", "ollama-gguf") return llm.NewGGUFV3(m.Params.ByteOrder).Encode(ws, kv, m.Tensors)
if err != nil {
return "", err
}
defer f.Close()
mod := llm.NewGGUFV3(m.Params.ByteOrder)
if err := mod.Encode(f, kv, m.Tensors); err != nil {
return "", err
}
slog.Debug(fmt.Sprintf("gguf file = %s", f.Name()))
return f.Name(), nil
} }

View file

@ -132,7 +132,7 @@ func (m *MistralModel) LoadVocab() error {
return nil return nil
} }
func (m *MistralModel) WriteGGUF() (string, error) { func (m *MistralModel) WriteGGUF(ws io.WriteSeeker) error {
kv := llm.KV{ kv := llm.KV{
"general.architecture": "llama", "general.architecture": "llama",
"general.name": m.Name, "general.name": m.Name,
@ -158,16 +158,5 @@ func (m *MistralModel) WriteGGUF() (string, error) {
"tokenizer.ggml.unknown_token_id": uint32(0), "tokenizer.ggml.unknown_token_id": uint32(0),
} }
f, err := os.CreateTemp("", "ollama-gguf") return llm.NewGGUFV3(m.Params.ByteOrder).Encode(ws, kv, m.Tensors)
if err != nil {
return "", err
}
defer f.Close()
mod := llm.NewGGUFV3(m.Params.ByteOrder)
if err := mod.Encode(f, kv, m.Tensors); err != nil {
return "", err
}
return f.Name(), nil
} }

View file

@ -1,7 +1,7 @@
package convert package convert
import ( import (
"os" "io"
"regexp" "regexp"
"github.com/ollama/ollama/llm" "github.com/ollama/ollama/llm"
@ -47,7 +47,7 @@ func (m *MixtralModel) LoadVocab() error {
return nil return nil
} }
func (m *MixtralModel) WriteGGUF() (string, error) { func (m *MixtralModel) WriteGGUF(ws io.WriteSeeker) error {
kv := llm.KV{ kv := llm.KV{
"general.architecture": "llama", "general.architecture": "llama",
"general.name": m.Name, "general.name": m.Name,
@ -81,16 +81,5 @@ func (m *MixtralModel) WriteGGUF() (string, error) {
"tokenizer.ggml.add_eos_token": false, "tokenizer.ggml.add_eos_token": false,
} }
f, err := os.CreateTemp("", "ollama-gguf") return llm.NewGGUFV3(m.Params.ByteOrder).Encode(ws, kv, m.Tensors)
if err != nil {
return "", err
}
defer f.Close()
mod := llm.NewGGUFV3(m.Params.ByteOrder)
if err := mod.Encode(f, kv, m.Tensors); err != nil {
return "", err
}
return f.Name(), nil
} }

View file

@ -107,7 +107,7 @@ func startServer(ctx context.Context, ollamaHost string) error {
if tmp := os.Getenv("OLLAMA_HOST"); tmp != ollamaHost { if tmp := os.Getenv("OLLAMA_HOST"); tmp != ollamaHost {
slog.Info("setting env", "OLLAMA_HOST", ollamaHost) slog.Info("setting env", "OLLAMA_HOST", ollamaHost)
os.Setenv("OLLAMA_HOST", ollamaHost) t.Setenv("OLLAMA_HOST", ollamaHost)
} }
slog.Info("starting server", "url", ollamaHost) slog.Info("starting server", "url", ollamaHost)

140
llm/filetype.go Normal file
View file

@ -0,0 +1,140 @@
package llm
import "fmt"
type fileType uint32
const (
fileTypeF32 fileType = iota
fileTypeF16
fileTypeQ4_0
fileTypeQ4_1
fileTypeQ4_1_F16
fileTypeQ4_2 // unused
fileTypeQ4_3 // unused
fileTypeQ8_0
fileTypeQ5_0
fileTypeQ5_1
fileTypeQ2_K
fileTypeQ3_K_S
fileTypeQ3_K_M
fileTypeQ3_K_L
fileTypeQ4_K_S
fileTypeQ4_K_M
fileTypeQ5_K_S
fileTypeQ5_K_M
fileTypeQ6_K
fileTypeIQ2_XXS
fileTypeIQ2_XS
fileTypeQ2_K_S
fileTypeQ3_K_XS
fileTypeIQ3_XXS
fileTypeUnknown
)
func ParseFileType(s string) (fileType, error) {
switch s {
case "F32":
return fileTypeF32, nil
case "F16":
return fileTypeF16, nil
case "Q4_0":
return fileTypeQ4_0, nil
case "Q4_1":
return fileTypeQ4_1, nil
case "Q4_1_F16":
return fileTypeQ4_1_F16, nil
case "Q8_0":
return fileTypeQ8_0, nil
case "Q5_0":
return fileTypeQ5_0, nil
case "Q5_1":
return fileTypeQ5_1, nil
case "Q2_K":
return fileTypeQ2_K, nil
case "Q3_K_S":
return fileTypeQ3_K_S, nil
case "Q3_K_M":
return fileTypeQ3_K_M, nil
case "Q3_K_L":
return fileTypeQ3_K_L, nil
case "Q4_K_S":
return fileTypeQ4_K_S, nil
case "Q4_K_M":
return fileTypeQ4_K_M, nil
case "Q5_K_S":
return fileTypeQ5_K_S, nil
case "Q5_K_M":
return fileTypeQ5_K_M, nil
case "Q6_K":
return fileTypeQ6_K, nil
case "IQ2_XXS":
return fileTypeIQ2_XXS, nil
case "IQ2_XS":
return fileTypeIQ2_XS, nil
case "Q2_K_S":
return fileTypeQ2_K_S, nil
case "Q3_K_XS":
return fileTypeQ3_K_XS, nil
case "IQ3_XXS":
return fileTypeIQ3_XXS, nil
default:
return fileTypeUnknown, fmt.Errorf("unknown fileType: %s", s)
}
}
func (t fileType) String() string {
switch t {
case fileTypeF32:
return "F32"
case fileTypeF16:
return "F16"
case fileTypeQ4_0:
return "Q4_0"
case fileTypeQ4_1:
return "Q4_1"
case fileTypeQ4_1_F16:
return "Q4_1_F16"
case fileTypeQ8_0:
return "Q8_0"
case fileTypeQ5_0:
return "Q5_0"
case fileTypeQ5_1:
return "Q5_1"
case fileTypeQ2_K:
return "Q2_K"
case fileTypeQ3_K_S:
return "Q3_K_S"
case fileTypeQ3_K_M:
return "Q3_K_M"
case fileTypeQ3_K_L:
return "Q3_K_L"
case fileTypeQ4_K_S:
return "Q4_K_S"
case fileTypeQ4_K_M:
return "Q4_K_M"
case fileTypeQ5_K_S:
return "Q5_K_S"
case fileTypeQ5_K_M:
return "Q5_K_M"
case fileTypeQ6_K:
return "Q6_K"
case fileTypeIQ2_XXS:
return "IQ2_XXS"
case fileTypeIQ2_XS:
return "IQ2_XS"
case fileTypeQ2_K_S:
return "Q2_K_S"
case fileTypeQ3_K_XS:
return "Q3_K_XS"
case fileTypeIQ3_XXS:
return "IQ3_XXS"
default:
return "unknown"
}
}
func (t fileType) Value() uint32 {
return uint32(t)
}

View file

@ -13,82 +13,6 @@ type GGML struct {
model model
} }
const (
fileTypeF32 uint32 = iota
fileTypeF16
fileTypeQ4_0
fileTypeQ4_1
fileTypeQ4_1_F16
fileTypeQ8_0 uint32 = iota + 2
fileTypeQ5_0
fileTypeQ5_1
fileTypeQ2_K
fileTypeQ3_K_S
fileTypeQ3_K_M
fileTypeQ3_K_L
fileTypeQ4_K_S
fileTypeQ4_K_M
fileTypeQ5_K_S
fileTypeQ5_K_M
fileTypeQ6_K
fileTypeIQ2_XXS
fileTypeIQ2_XS
fileTypeQ2_K_S
fileTypeQ3_K_XS
fileTypeIQ3_XXS
)
func fileType(fileType uint32) string {
switch fileType {
case fileTypeF32:
return "F32"
case fileTypeF16:
return "F16"
case fileTypeQ4_0:
return "Q4_0"
case fileTypeQ4_1:
return "Q4_1"
case fileTypeQ4_1_F16:
return "Q4_1_F16"
case fileTypeQ8_0:
return "Q8_0"
case fileTypeQ5_0:
return "Q5_0"
case fileTypeQ5_1:
return "Q5_1"
case fileTypeQ2_K:
return "Q2_K"
case fileTypeQ3_K_S:
return "Q3_K_S"
case fileTypeQ3_K_M:
return "Q3_K_M"
case fileTypeQ3_K_L:
return "Q3_K_L"
case fileTypeQ4_K_S:
return "Q4_K_S"
case fileTypeQ4_K_M:
return "Q4_K_M"
case fileTypeQ5_K_S:
return "Q5_K_S"
case fileTypeQ5_K_M:
return "Q5_K_M"
case fileTypeQ6_K:
return "Q6_K"
case fileTypeIQ2_XXS:
return "IQ2_XXS"
case fileTypeIQ2_XS:
return "IQ2_XS"
case fileTypeQ2_K_S:
return "Q2_K_S"
case fileTypeQ3_K_XS:
return "Q3_K_XS"
case fileTypeIQ3_XXS:
return "IQ3_XXS"
default:
return "unknown"
}
}
type model interface { type model interface {
KV() KV KV() KV
Tensors() Tensors Tensors() Tensors
@ -123,7 +47,7 @@ func (kv KV) ParameterCount() uint64 {
func (kv KV) FileType() string { func (kv KV) FileType() string {
if u64 := kv.u64("general.file_type"); u64 > 0 { if u64 := kv.u64("general.file_type"); u64 > 0 {
return fileType(uint32(u64)) return fileType(uint32(u64)).String()
} }
return "unknown" return "unknown"
@ -286,6 +210,23 @@ const (
var ErrUnsupportedFormat = errors.New("unsupported model format") var ErrUnsupportedFormat = errors.New("unsupported model format")
func DetectGGMLType(b []byte) string {
switch binary.LittleEndian.Uint32(b[:4]) {
case FILE_MAGIC_GGML:
return "ggml"
case FILE_MAGIC_GGMF:
return "ggmf"
case FILE_MAGIC_GGJT:
return "ggjt"
case FILE_MAGIC_GGLA:
return "ggla"
case FILE_MAGIC_GGUF_LE, FILE_MAGIC_GGUF_BE:
return "gguf"
default:
return ""
}
}
func DecodeGGML(rs io.ReadSeeker) (*GGML, int64, error) { func DecodeGGML(rs io.ReadSeeker) (*GGML, int64, error) {
var magic uint32 var magic uint32
if err := binary.Read(rs, binary.LittleEndian, &magic); err != nil { if err := binary.Read(rs, binary.LittleEndian, &magic); err != nil {

View file

@ -20,7 +20,7 @@ func SystemInfo() string {
return C.GoString(C.llama_print_system_info()) return C.GoString(C.llama_print_system_info())
} }
func Quantize(infile, outfile, filetype string) error { func Quantize(infile, outfile string, ftype fileType) error {
cinfile := C.CString(infile) cinfile := C.CString(infile)
defer C.free(unsafe.Pointer(cinfile)) defer C.free(unsafe.Pointer(cinfile))
@ -29,58 +29,10 @@ func Quantize(infile, outfile, filetype string) error {
params := C.llama_model_quantize_default_params() params := C.llama_model_quantize_default_params()
params.nthread = -1 params.nthread = -1
params.ftype = ftype.Value()
switch filetype { if rc := C.llama_model_quantize(cinfile, coutfile, &params); rc != 0 {
case "F32": return fmt.Errorf("llama_model_quantize: %d", rc)
params.ftype = fileTypeF32
case "F16":
params.ftype = fileTypeF16
case "Q4_0":
params.ftype = fileTypeQ4_0
case "Q4_1":
params.ftype = fileTypeQ4_1
case "Q4_1_F16":
params.ftype = fileTypeQ4_1_F16
case "Q8_0":
params.ftype = fileTypeQ8_0
case "Q5_0":
params.ftype = fileTypeQ5_0
case "Q5_1":
params.ftype = fileTypeQ5_1
case "Q2_K":
params.ftype = fileTypeQ2_K
case "Q3_K_S":
params.ftype = fileTypeQ3_K_S
case "Q3_K_M":
params.ftype = fileTypeQ3_K_M
case "Q3_K_L":
params.ftype = fileTypeQ3_K_L
case "Q4_K_S":
params.ftype = fileTypeQ4_K_S
case "Q4_K_M":
params.ftype = fileTypeQ4_K_M
case "Q5_K_S":
params.ftype = fileTypeQ5_K_S
case "Q5_K_M":
params.ftype = fileTypeQ5_K_M
case "Q6_K":
params.ftype = fileTypeQ6_K
case "IQ2_XXS":
params.ftype = fileTypeIQ2_XXS
case "IQ2_XS":
params.ftype = fileTypeIQ2_XS
case "Q2_K_S":
params.ftype = fileTypeQ2_K_S
case "Q3_K_XS":
params.ftype = fileTypeQ3_K_XS
case "IQ3_XXS":
params.ftype = fileTypeIQ3_XXS
default:
return fmt.Errorf("unknown filetype: %s", filetype)
}
if retval := C.llama_model_quantize(cinfile, coutfile, &params); retval != 0 {
return fmt.Errorf("llama_model_quantize: %d", retval)
} }
return nil return nil

View file

@ -1,8 +1,8 @@
package server package server
import ( import (
"archive/zip"
"bytes" "bytes"
"cmp"
"context" "context"
"crypto/sha256" "crypto/sha256"
"encoding/base64" "encoding/base64"
@ -11,7 +11,6 @@ import (
"errors" "errors"
"fmt" "fmt"
"io" "io"
"io/fs"
"log" "log"
"log/slog" "log/slog"
"net/http" "net/http"
@ -26,7 +25,6 @@ import (
"github.com/ollama/ollama/api" "github.com/ollama/ollama/api"
"github.com/ollama/ollama/auth" "github.com/ollama/ollama/auth"
"github.com/ollama/ollama/convert"
"github.com/ollama/ollama/format" "github.com/ollama/ollama/format"
"github.com/ollama/ollama/llm" "github.com/ollama/ollama/llm"
"github.com/ollama/ollama/server/envconfig" "github.com/ollama/ollama/server/envconfig"
@ -158,36 +156,6 @@ type ConfigV2 struct {
RootFS RootFS `json:"rootfs"` RootFS RootFS `json:"rootfs"`
} }
func (c *ConfigV2) SetModelFormat(format string) {
if c.ModelFormat == "" {
c.ModelFormat = format
}
}
func (c *ConfigV2) SetModelFamily(families ...string) {
for _, family := range families {
if c.ModelFamily == "" {
c.ModelFamily = family
}
if !slices.Contains(c.ModelFamilies, family) {
c.ModelFamilies = append(c.ModelFamilies, family)
}
}
}
func (c *ConfigV2) SetModelType(modelType string) {
if c.ModelType == "" {
c.ModelType = modelType
}
}
func (c *ConfigV2) SetFileType(fileType string) {
if c.FileType == "" {
c.FileType = fileType
}
}
type RootFS struct { type RootFS struct {
Type string `json:"type"` Type string `json:"type"`
DiffIDs []string `json:"diff_ids"` DiffIDs []string `json:"diff_ids"`
@ -332,7 +300,7 @@ func GetModel(name string) (*Model, error) {
return model, nil return model, nil
} }
func realpath(mfDir, from string) string { func realpath(rel, from string) string {
abspath, err := filepath.Abs(from) abspath, err := filepath.Abs(from)
if err != nil { if err != nil {
return from return from
@ -349,22 +317,15 @@ func realpath(mfDir, from string) string {
return filepath.Join(home, from[2:]) return filepath.Join(home, from[2:])
} }
if _, err := os.Stat(filepath.Join(mfDir, from)); err == nil { if _, err := os.Stat(filepath.Join(rel, from)); err == nil {
// this is a file relative to the Modelfile // this is a file relative to the Modelfile
return filepath.Join(mfDir, from) return filepath.Join(rel, from)
} }
return abspath return abspath
} }
func CreateModel(ctx context.Context, name, modelFileDir, quantization string, modelfile *model.File, fn func(resp api.ProgressResponse)) error { func CreateModel(ctx context.Context, name, modelFileDir, quantization string, modelfile *model.File, fn func(resp api.ProgressResponse)) (err error) {
deleteMap := make(map[string]struct{})
if manifest, _, err := GetManifest(ParseModelPath(name)); err == nil {
for _, layer := range append(manifest.Layers, manifest.Config) {
deleteMap[layer.Digest] = struct{}{}
}
}
config := ConfigV2{ config := ConfigV2{
OS: "linux", OS: "linux",
Architecture: "amd64", Architecture: "amd64",
@ -373,250 +334,181 @@ func CreateModel(ctx context.Context, name, modelFileDir, quantization string, m
}, },
} }
var layers Layers var messages []*api.Message
messages := []string{} parameters := make(map[string]any)
params := make(map[string][]string)
fromParams := make(map[string]any)
var layers []*Layer
for _, c := range modelfile.Commands { for _, c := range modelfile.Commands {
mediatype := fmt.Sprintf("application/vnd.ollama.image.%s", c.Name) mediatype := fmt.Sprintf("application/vnd.ollama.image.%s", c.Name)
switch c.Name { switch c.Name {
case "model": case "model", "adapter":
if strings.HasPrefix(c.Args, "@") { var baseLayers []*layerWithGGML
blobPath, err := GetBlobsPath(strings.TrimPrefix(c.Args, "@")) if name := model.ParseName(c.Args); name.IsValid() {
baseLayers, err = parseFromModel(ctx, name, fn)
if err != nil {
return err
}
} else if strings.HasPrefix(c.Args, "@") {
blobpath, err := GetBlobsPath(strings.TrimPrefix(c.Args, "@"))
if err != nil { if err != nil {
return err return err
} }
c.Args = blobPath blob, err := os.Open(blobpath)
}
pathName := realpath(modelFileDir, c.Args)
ggufName, err := convertModel(name, pathName, fn)
if err != nil {
var pathErr *fs.PathError
switch {
case errors.Is(err, zip.ErrFormat):
// it's not a safetensor archive
case errors.As(err, &pathErr):
// it's not a file on disk, could be a model reference
default:
return err
}
}
if ggufName != "" {
pathName = ggufName
defer os.RemoveAll(ggufName)
if quantization != "" {
quantization = strings.ToUpper(quantization)
fn(api.ProgressResponse{Status: fmt.Sprintf("quantizing %s model to %s", "F16", quantization)})
tempfile, err := os.CreateTemp(filepath.Dir(ggufName), quantization)
if err != nil {
return err
}
defer os.RemoveAll(tempfile.Name())
if err := llm.Quantize(ggufName, tempfile.Name(), quantization); err != nil {
return err
}
if err := tempfile.Close(); err != nil {
return err
}
pathName = tempfile.Name()
}
}
bin, err := os.Open(pathName)
if err != nil {
// not a file on disk so must be a model reference
modelpath := ParseModelPath(c.Args)
manifest, _, err := GetManifest(modelpath)
switch {
case errors.Is(err, os.ErrNotExist):
fn(api.ProgressResponse{Status: "pulling model"})
if err := PullModel(ctx, c.Args, &registryOptions{}, fn); err != nil {
return err
}
manifest, _, err = GetManifest(modelpath)
if err != nil {
return err
}
case err != nil:
return err
}
fn(api.ProgressResponse{Status: "reading model metadata"})
fromConfigPath, err := GetBlobsPath(manifest.Config.Digest)
if err != nil { if err != nil {
return err return err
} }
defer blob.Close()
fromConfigFile, err := os.Open(fromConfigPath) baseLayers, err = parseFromFile(ctx, blob, fn)
if err != nil { if err != nil {
return err return err
} }
defer fromConfigFile.Close() } else if file, err := os.Open(realpath(modelFileDir, c.Args)); err == nil {
defer file.Close()
var fromConfig ConfigV2 baseLayers, err = parseFromFile(ctx, file, fn)
if err := json.NewDecoder(fromConfigFile).Decode(&fromConfig); err != nil { if err != nil {
return err return err
} }
} else {
return fmt.Errorf("invalid model reference: %s", c.Args)
}
// if the model is still not in gguf format, error out for _, baseLayer := range baseLayers {
if fromConfig.ModelFormat != "gguf" { if quantization != "" &&
return fmt.Errorf("%s is not in gguf format, this base model is not compatible with this version of ollama", c.Args) baseLayer.MediaType == "application/vnd.ollama.image.model" &&
} baseLayer.GGML != nil &&
baseLayer.GGML.Name() == "gguf" {
config.SetModelFormat(fromConfig.ModelFormat) ftype, err := llm.ParseFileType(quantization)
config.SetModelFamily(append(fromConfig.ModelFamilies, fromConfig.ModelFamily)...)
config.SetModelType(fromConfig.ModelType)
config.SetFileType(fromConfig.FileType)
for _, layer := range manifest.Layers {
deleteMap[layer.Digest] = struct{}{}
if layer.MediaType == "application/vnd.ollama.image.params" {
fromParamsPath, err := GetBlobsPath(layer.Digest)
if err != nil {
return err
}
fromParamsFile, err := os.Open(fromParamsPath)
if err != nil {
return err
}
defer fromParamsFile.Close()
if err := json.NewDecoder(fromParamsFile).Decode(&fromParams); err != nil {
return err
}
}
layer, err := NewLayerFromLayer(layer.Digest, layer.MediaType, modelpath.GetShortTagname())
if err != nil { if err != nil {
return err return err
} }
layers.Add(layer) filetype := baseLayer.GGML.KV().FileType()
if !slices.Contains([]string{"F16", "F32"}, filetype) {
return errors.New("quantization is only supported for F16 and F32 models")
}
fn(api.ProgressResponse{Status: fmt.Sprintf("quantizing %s model to %s", filetype, quantization)})
blob, err := GetBlobsPath(baseLayer.Digest)
if err != nil {
return err
}
temp, err := os.CreateTemp(filepath.Dir(blob), quantization)
if err != nil {
return err
}
defer temp.Close()
defer os.Remove(temp.Name())
if err := llm.Quantize(blob, temp.Name(), ftype); err != nil {
return err
}
baseLayer.Layer, err = NewLayer(temp, baseLayer.Layer.MediaType)
if err != nil {
return err
}
} }
deleteMap[manifest.Config.Digest] = struct{}{} if baseLayer.GGML != nil {
continue config.ModelFormat = cmp.Or(config.ModelFormat, baseLayer.GGML.Name())
config.ModelFamily = cmp.Or(config.ModelFamily, baseLayer.GGML.KV().Architecture())
config.ModelType = cmp.Or(config.ModelType, format.HumanNumber(baseLayer.GGML.KV().ParameterCount()))
config.FileType = cmp.Or(config.FileType, baseLayer.GGML.KV().FileType())
config.ModelFamilies = append(config.ModelFamilies, baseLayer.GGML.KV().Architecture())
}
layers = append(layers, baseLayer.Layer)
} }
defer bin.Close() case "license", "template", "system":
blob := strings.NewReader(c.Args)
var offset int64 layer, err := NewLayer(blob, mediatype)
for {
fn(api.ProgressResponse{Status: "creating model layer"})
if _, err := bin.Seek(offset, io.SeekStart); err != nil {
return err
}
ggml, size, err := llm.DecodeGGML(bin)
if errors.Is(err, io.EOF) {
break
} else if errors.Is(err, llm.ErrUnsupportedFormat) {
return fmt.Errorf("model binary specified in FROM field is not a valid gguf format model, %w", err)
} else if err != nil {
return err
}
config.SetModelFormat(ggml.Name())
config.SetModelFamily(ggml.KV().Architecture())
config.SetModelType(format.HumanNumber(ggml.KV().ParameterCount()))
config.SetFileType(ggml.KV().FileType())
mediatype := mediatype
if ggml.KV().Architecture() == "clip" {
mediatype = "application/vnd.ollama.image.projector"
}
sr := io.NewSectionReader(bin, offset, size)
layer, err := NewLayer(sr, mediatype)
if err != nil {
return err
}
layers.Add(layer)
offset += size
}
case "adapter":
if strings.HasPrefix(c.Args, "@") {
blobPath, err := GetBlobsPath(strings.TrimPrefix(c.Args, "@"))
if err != nil {
return err
}
c.Args = blobPath
}
fn(api.ProgressResponse{Status: "creating adapter layer"})
bin, err := os.Open(realpath(modelFileDir, c.Args))
if err != nil {
return err
}
defer bin.Close()
_, size, err := llm.DecodeGGML(bin)
if err != nil { if err != nil {
return err return err
} }
sr := io.NewSectionReader(bin, 0, size) if c.Name != "license" {
layer, err := NewLayer(sr, mediatype) // replace
if err != nil { layers = slices.DeleteFunc(layers, func(layer *Layer) bool {
return err return layer.MediaType == mediatype
})
} }
layers.Add(layer) layers = append(layers, layer)
case "license":
fn(api.ProgressResponse{Status: "creating license layer"})
bin := strings.NewReader(c.Args)
layer, err := NewLayer(bin, mediatype)
if err != nil {
return err
}
layers.Add(layer)
case "template", "system":
fn(api.ProgressResponse{Status: fmt.Sprintf("creating %s layer", c.Name)})
bin := strings.NewReader(c.Args)
layer, err := NewLayer(bin, mediatype)
if err != nil {
return err
}
layers.Replace(layer)
case "message": case "message":
messages = append(messages, c.Args) role, content, ok := strings.Cut(c.Args, ": ")
if !ok {
return fmt.Errorf("invalid message: %s", c.Args)
}
messages = append(messages, &api.Message{Role: role, Content: content})
default: default:
params[c.Name] = append(params[c.Name], c.Args) ps, err := api.FormatParams(map[string][]string{c.Name: {c.Args}})
if err != nil {
return err
}
for k, v := range ps {
if ks, ok := parameters[k].([]string); ok {
parameters[k] = append(ks, v.([]string)...)
} else if vs, ok := v.([]string); ok {
parameters[k] = vs
} else {
parameters[k] = v
}
}
} }
} }
if len(messages) > 0 { var err2 error
fn(api.ProgressResponse{Status: "creating parameters layer"}) layers = slices.DeleteFunc(layers, func(layer *Layer) bool {
switch layer.MediaType {
case "application/vnd.ollama.image.message":
// if there are new messages, remove the inherited ones
if len(messages) > 0 {
return true
}
msgs := make([]api.Message, 0) return false
case "application/vnd.ollama.image.params":
// merge inherited parameters with new ones
r, err := layer.Open()
if err != nil {
err2 = err
return false
}
defer r.Close()
for _, m := range messages { var ps map[string]any
// todo: handle images if err := json.NewDecoder(r).Decode(&ps); err != nil {
msg := strings.SplitN(m, ": ", 2) err2 = err
msgs = append(msgs, api.Message{Role: msg[0], Content: msg[1]}) return false
}
for k, v := range ps {
if _, ok := parameters[k]; !ok {
parameters[k] = v
}
}
return true
default:
return false
} }
})
if err2 != nil {
return err2
}
if len(messages) > 0 {
var b bytes.Buffer var b bytes.Buffer
if err := json.NewEncoder(&b).Encode(msgs); err != nil { if err := json.NewEncoder(&b).Encode(messages); err != nil {
return err return err
} }
@ -625,39 +517,25 @@ func CreateModel(ctx context.Context, name, modelFileDir, quantization string, m
return err return err
} }
layers.Replace(layer) layers = append(layers, layer)
} }
if len(params) > 0 { if len(parameters) > 0 {
fn(api.ProgressResponse{Status: "creating parameters layer"})
formattedParams, err := api.FormatParams(params)
if err != nil {
return err
}
for k, v := range fromParams {
if _, ok := formattedParams[k]; !ok {
formattedParams[k] = v
}
}
var b bytes.Buffer var b bytes.Buffer
if err := json.NewEncoder(&b).Encode(formattedParams); err != nil { if err := json.NewEncoder(&b).Encode(parameters); err != nil {
return err return err
} }
fn(api.ProgressResponse{Status: "creating config layer"})
layer, err := NewLayer(&b, "application/vnd.ollama.image.params") layer, err := NewLayer(&b, "application/vnd.ollama.image.params")
if err != nil { if err != nil {
return err return err
} }
layers.Replace(layer) layers = append(layers, layer)
} }
digests := make([]string, len(layers.items)) digests := make([]string, len(layers))
for i, layer := range layers.items { for i, layer := range layers {
digests[i] = layer.Digest digests[i] = layer.Digest
} }
@ -668,36 +546,37 @@ func CreateModel(ctx context.Context, name, modelFileDir, quantization string, m
return err return err
} }
configLayer, err := NewLayer(&b, "application/vnd.docker.container.image.v1+json") layer, err := NewLayer(&b, "application/vnd.docker.container.image.v1+json")
if err != nil { if err != nil {
return err return err
} }
delete(deleteMap, configLayer.Digest) for _, layer := range append(layers, layer) {
if layer.status != "" {
fn(api.ProgressResponse{Status: layer.status})
}
}
for _, layer := range append(layers.items, configLayer) { unref := make(map[string]struct{})
committed, err := layer.Commit() if manifest, _, err := GetManifest(ParseModelPath(name)); err == nil {
if err != nil { for _, layer := range manifest.Layers {
return err if !slices.Contains(digests, layer.Digest) {
unref[layer.Digest] = struct{}{}
}
} }
status := "writing layer" if manifest.Config.Digest != layer.Digest {
if !committed { unref[manifest.Config.Digest] = struct{}{}
status = "using already created layer"
} }
fn(api.ProgressResponse{Status: fmt.Sprintf("%s %s", status, layer.Digest)})
delete(deleteMap, layer.Digest)
} }
fn(api.ProgressResponse{Status: "writing manifest"}) fn(api.ProgressResponse{Status: "writing manifest"})
if err := WriteManifest(name, configLayer, layers.items); err != nil { if err := WriteManifest(name, layer, layers); err != nil {
return err return err
} }
if !envconfig.NoPrune { if !envconfig.NoPrune {
if err := deleteUnusedLayers(nil, deleteMap, false); err != nil { if err := deleteUnusedLayers(nil, unref, false); err != nil {
return err return err
} }
} }
@ -706,74 +585,6 @@ func CreateModel(ctx context.Context, name, modelFileDir, quantization string, m
return nil return nil
} }
func convertModel(name, path string, fn func(resp api.ProgressResponse)) (string, error) {
r, err := zip.OpenReader(path)
if err != nil {
return "", err
}
defer r.Close()
tempDir, err := os.MkdirTemp("", "ollama-convert")
if err != nil {
return "", err
}
defer os.RemoveAll(tempDir)
fn(api.ProgressResponse{Status: "unpacking model metadata"})
for _, f := range r.File {
fpath := filepath.Join(tempDir, f.Name)
outFile, err := os.OpenFile(fpath, os.O_WRONLY|os.O_CREATE|os.O_TRUNC, f.Mode())
if err != nil {
return "", err
}
rc, err := f.Open()
if err != nil {
return "", err
}
_, err = io.Copy(outFile, rc)
if err != nil {
return "", err
}
outFile.Close()
rc.Close()
}
mf, err := convert.GetModelFormat(tempDir)
if err != nil {
return "", err
}
params, err := mf.GetParams(tempDir)
if err != nil {
return "", err
}
mArch, err := mf.GetModelArch(name, tempDir, params)
if err != nil {
return "", err
}
fn(api.ProgressResponse{Status: "processing tensors"})
if err := mArch.GetTensors(); err != nil {
return "", err
}
if err := mArch.LoadVocab(); err != nil {
return "", err
}
fn(api.ProgressResponse{Status: "converting model"})
path, err = mArch.WriteGGUF()
if err != nil {
return "", err
}
return path, nil
}
func CopyModel(src, dst model.Name) error { func CopyModel(src, dst model.Name) error {
if !dst.IsFullyQualified() { if !dst.IsFullyQualified() {
return model.Unqualified(dst) return model.Unqualified(dst)

View file

@ -5,39 +5,14 @@ import (
"fmt" "fmt"
"io" "io"
"os" "os"
"strings"
"golang.org/x/exp/slices"
) )
type Layers struct {
items []*Layer
}
func (ls *Layers) Add(layer *Layer) {
if layer.Size > 0 {
ls.items = append(ls.items, layer)
}
}
func (ls *Layers) Replace(layer *Layer) {
if layer.Size > 0 {
mediatype := layer.MediaType
layers := slices.DeleteFunc(ls.items, func(l *Layer) bool {
return l.MediaType == mediatype
})
ls.items = append(layers, layer)
}
}
type Layer struct { type Layer struct {
MediaType string `json:"mediaType"` MediaType string `json:"mediaType"`
Digest string `json:"digest"` Digest string `json:"digest"`
Size int64 `json:"size"` Size int64 `json:"size"`
From string `json:"from,omitempty"` From string `json:"from,omitempty"`
status string
tempFileName string
} }
func NewLayer(r io.Reader, mediatype string) (*Layer, error) { func NewLayer(r io.Reader, mediatype string) (*Layer, error) {
@ -46,14 +21,12 @@ func NewLayer(r io.Reader, mediatype string) (*Layer, error) {
return nil, err return nil, err
} }
const delimiter = "-" temp, err := os.CreateTemp(blobs, "sha256-")
pattern := strings.Join([]string{"sha256", "*-partial"}, delimiter)
temp, err := os.CreateTemp(blobs, pattern)
if err != nil { if err != nil {
return nil, err return nil, err
} }
defer temp.Close() defer temp.Close()
defer os.Remove(temp.Name())
sha256sum := sha256.New() sha256sum := sha256.New()
n, err := io.Copy(io.MultiWriter(temp, sha256sum), r) n, err := io.Copy(io.MultiWriter(temp, sha256sum), r)
@ -61,11 +34,29 @@ func NewLayer(r io.Reader, mediatype string) (*Layer, error) {
return nil, err return nil, err
} }
if err := temp.Close(); err != nil {
return nil, err
}
digest := fmt.Sprintf("sha256:%x", sha256sum.Sum(nil))
blob, err := GetBlobsPath(digest)
if err != nil {
return nil, err
}
status := "using existing layer"
if _, err := os.Stat(blob); err != nil {
status = "creating new layer"
if err := os.Rename(temp.Name(), blob); err != nil {
return nil, err
}
}
return &Layer{ return &Layer{
MediaType: mediatype, MediaType: mediatype,
Digest: fmt.Sprintf("sha256:%x", sha256sum.Sum(nil)), Digest: digest,
Size: n, Size: n,
tempFileName: temp.Name(), status: fmt.Sprintf("%s %s", status, digest),
}, nil }, nil
} }
@ -85,21 +76,15 @@ func NewLayerFromLayer(digest, mediatype, from string) (*Layer, error) {
Digest: digest, Digest: digest,
Size: fi.Size(), Size: fi.Size(),
From: from, From: from,
status: fmt.Sprintf("using existing layer %s", digest),
}, nil }, nil
} }
func (l *Layer) Commit() (bool, error) { func (l *Layer) Open() (io.ReadCloser, error) {
// always remove temp
defer os.Remove(l.tempFileName)
blob, err := GetBlobsPath(l.Digest) blob, err := GetBlobsPath(l.Digest)
if err != nil { if err != nil {
return false, err return nil, err
} }
if _, err := os.Stat(blob); err != nil { return os.Open(blob)
return true, os.Rename(l.tempFileName, blob)
}
return false, nil
} }

261
server/model.go Normal file
View file

@ -0,0 +1,261 @@
package server
import (
"archive/zip"
"bytes"
"context"
"errors"
"fmt"
"io"
"net/http"
"os"
"path/filepath"
"github.com/ollama/ollama/api"
"github.com/ollama/ollama/convert"
"github.com/ollama/ollama/llm"
"github.com/ollama/ollama/types/model"
)
type layerWithGGML struct {
*Layer
*llm.GGML
}
func parseFromModel(ctx context.Context, name model.Name, fn func(api.ProgressResponse)) (layers []*layerWithGGML, err error) {
modelpath := ParseModelPath(name.String())
manifest, _, err := GetManifest(modelpath)
switch {
case errors.Is(err, os.ErrNotExist):
if err := PullModel(ctx, name.String(), &registryOptions{}, fn); err != nil {
return nil, err
}
modelpath = ParseModelPath(name.String())
manifest, _, err = GetManifest(modelpath)
if err != nil {
return nil, err
}
case err != nil:
return nil, err
}
for _, layer := range manifest.Layers {
layer, err := NewLayerFromLayer(layer.Digest, layer.MediaType, modelpath.GetShortTagname())
if err != nil {
return nil, err
}
switch layer.MediaType {
case "application/vnd.ollama.image.model",
"application/vnd.ollama.image.projector",
"application/vnd.ollama.image.adapter":
blobpath, err := GetBlobsPath(layer.Digest)
if err != nil {
return nil, err
}
blob, err := os.Open(blobpath)
if err != nil {
return nil, err
}
defer blob.Close()
ggml, _, err := llm.DecodeGGML(blob)
if err != nil {
return nil, err
}
layers = append(layers, &layerWithGGML{layer, ggml})
default:
layers = append(layers, &layerWithGGML{layer, nil})
}
}
return layers, nil
}
func parseFromZipFile(_ context.Context, file *os.File, fn func(api.ProgressResponse)) (layers []*layerWithGGML, err error) {
stat, err := file.Stat()
if err != nil {
return nil, err
}
r, err := zip.NewReader(file, stat.Size())
if err != nil {
return nil, err
}
tempdir, err := os.MkdirTemp(filepath.Dir(file.Name()), "")
if err != nil {
return nil, err
}
defer os.RemoveAll(tempdir)
fn(api.ProgressResponse{Status: "unpacking model metadata"})
for _, f := range r.File {
// TODO(mxyng): this should not write out all files to disk
outfile, err := os.Create(filepath.Join(tempdir, f.Name))
if err != nil {
return nil, err
}
defer outfile.Close()
infile, err := f.Open()
if err != nil {
return nil, err
}
defer infile.Close()
if _, err = io.Copy(outfile, infile); err != nil {
return nil, err
}
if err := outfile.Close(); err != nil {
return nil, err
}
if err := infile.Close(); err != nil {
return nil, err
}
}
mf, err := convert.GetModelFormat(tempdir)
if err != nil {
return nil, err
}
params, err := mf.GetParams(tempdir)
if err != nil {
return nil, err
}
mArch, err := mf.GetModelArch("", tempdir, params)
if err != nil {
return nil, err
}
fn(api.ProgressResponse{Status: "processing tensors"})
if err := mArch.GetTensors(); err != nil {
return nil, err
}
if err := mArch.LoadVocab(); err != nil {
return nil, err
}
fn(api.ProgressResponse{Status: "converting model"})
// TODO(mxyng): this should write directly into a layer
// e.g. NewLayer(arch.Reader(), "application/vnd.ollama.image.model")
temp, err := os.CreateTemp(tempdir, "fp16")
if err != nil {
return nil, err
}
defer temp.Close()
defer os.Remove(temp.Name())
if err = mArch.WriteGGUF(temp); err != nil {
return nil, err
}
if _, err := temp.Seek(0, io.SeekStart); err != nil {
return nil, err
}
layer, err := NewLayer(temp, "application/vnd.ollama.image.model")
if err != nil {
return nil, fmt.Errorf("aaa: %w", err)
}
blobpath, err := GetBlobsPath(layer.Digest)
if err != nil {
return nil, err
}
bin, err := os.Open(blobpath)
if err != nil {
return nil, err
}
defer bin.Close()
ggml, _, err := llm.DecodeGGML(bin)
if err != nil {
return nil, err
}
layer, err = NewLayerFromLayer(layer.Digest, layer.MediaType, "")
if err != nil {
return nil, err
}
layers = append(layers, &layerWithGGML{layer, ggml})
return layers, nil
}
func parseFromFile(ctx context.Context, file *os.File, fn func(api.ProgressResponse)) (layers []*layerWithGGML, err error) {
sr := io.NewSectionReader(file, 0, 512)
contentType, err := detectContentType(sr)
if err != nil {
return nil, err
}
switch contentType {
case "gguf", "ggla":
// noop
case "application/zip":
return parseFromZipFile(ctx, file, fn)
default:
return nil, fmt.Errorf("unsupported content type: %s", contentType)
}
stat, err := file.Stat()
if err != nil {
return nil, err
}
var offset int64
for offset < stat.Size() {
ggml, n, err := llm.DecodeGGML(file)
if errors.Is(err, io.EOF) {
break
} else if err != nil {
return nil, err
}
mediatype := "application/vnd.ollama.image.model"
if ggml.Name() == "ggla" {
mediatype = "application/vnd.ollama.image.adapter"
} else if ggml.KV().Architecture() == "clip" {
mediatype = "application/vnd.ollama.image.projector"
}
layer, err := NewLayer(io.NewSectionReader(file, offset, n), mediatype)
if err != nil {
return nil, err
}
layers = append(layers, &layerWithGGML{layer, ggml})
offset = n
}
return layers, nil
}
func detectContentType(r io.Reader) (string, error) {
var b bytes.Buffer
if _, err := io.Copy(&b, r); err != nil {
return "", err
}
if contentType := llm.DetectGGMLType(b.Bytes()); contentType != "" {
return contentType, nil
}
if contentType := http.DetectContentType(b.Bytes()); contentType != "application/octet-stream" {
return contentType, nil
}
return "unknown", nil
}

View file

@ -560,7 +560,7 @@ func (s *Server) CreateModelHandler(c *gin.Context) {
ctx, cancel := context.WithCancel(c.Request.Context()) ctx, cancel := context.WithCancel(c.Request.Context())
defer cancel() defer cancel()
if err := CreateModel(ctx, name.String(), filepath.Dir(req.Path), req.Quantization, modelfile, fn); err != nil { if err := CreateModel(ctx, name.String(), filepath.Dir(req.Path), strings.ToUpper(req.Quantization), modelfile, fn); err != nil {
ch <- gin.H{"error": err.Error()} ch <- gin.H{"error": err.Error()}
} }
}() }()
@ -852,11 +852,6 @@ func (s *Server) CreateBlobHandler(c *gin.Context) {
return return
} }
if _, err := layer.Commit(); err != nil {
c.AbortWithStatusJSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
return
}
c.Status(http.StatusCreated) c.Status(http.StatusCreated)
} }

View file

@ -124,14 +124,12 @@ func Test_Routes(t *testing.T) {
Method: http.MethodPost, Method: http.MethodPost,
Path: "/api/create", Path: "/api/create",
Setup: func(t *testing.T, req *http.Request) { Setup: func(t *testing.T, req *http.Request) {
f, err := os.CreateTemp(t.TempDir(), "ollama-model") fname := createTestFile(t, "ollama-model")
assert.Nil(t, err)
defer f.Close()
stream := false stream := false
createReq := api.CreateRequest{ createReq := api.CreateRequest{
Name: "t-bone", Name: "t-bone",
Modelfile: fmt.Sprintf("FROM %s", f.Name()), Modelfile: fmt.Sprintf("FROM %s", fname),
Stream: &stream, Stream: &stream,
} }
jsonData, err := json.Marshal(createReq) jsonData, err := json.Marshal(createReq)
@ -216,27 +214,25 @@ func Test_Routes(t *testing.T) {
httpSrv := httptest.NewServer(router) httpSrv := httptest.NewServer(router)
t.Cleanup(httpSrv.Close) t.Cleanup(httpSrv.Close)
workDir, err := os.MkdirTemp("", "ollama-test") t.Setenv("OLLAMA_MODELS", t.TempDir())
assert.Nil(t, err)
defer os.RemoveAll(workDir)
os.Setenv("OLLAMA_MODELS", workDir)
for _, tc := range testCases { for _, tc := range testCases {
t.Logf("Running Test: [%s]", tc.Name) t.Run(tc.Name, func(t *testing.T) {
u := httpSrv.URL + tc.Path u := httpSrv.URL + tc.Path
req, err := http.NewRequestWithContext(context.TODO(), tc.Method, u, nil) req, err := http.NewRequestWithContext(context.TODO(), tc.Method, u, nil)
assert.Nil(t, err) assert.Nil(t, err)
if tc.Setup != nil { if tc.Setup != nil {
tc.Setup(t, req) tc.Setup(t, req)
} }
resp, err := httpSrv.Client().Do(req) resp, err := httpSrv.Client().Do(req)
assert.Nil(t, err) assert.Nil(t, err)
defer resp.Body.Close() defer resp.Body.Close()
if tc.Expected != nil { if tc.Expected != nil {
tc.Expected(t, resp) tc.Expected(t, resp)
} }
})
} }
} }