refactor create model

This commit is contained in:
Michael Yang 2023-11-14 12:30:34 -08:00
parent f61f340279
commit b0d14ed51c
2 changed files with 162 additions and 184 deletions

View file

@ -248,88 +248,122 @@ func filenameWithPath(path, f string) (string, error) {
return f, nil return f, nil
} }
func CreateModel(ctx context.Context, name string, path string, fn func(resp api.ProgressResponse)) error { func realpath(p string) string {
mp := ParseModelPath(name) abspath, err := filepath.Abs(p)
var manifest *ManifestV2
var err error
var noprune string
// build deleteMap to prune unused layers
deleteMap := make(map[string]bool)
if noprune = os.Getenv("OLLAMA_NOPRUNE"); noprune == "" {
manifest, _, err = GetManifest(mp)
if err != nil && !errors.Is(err, os.ErrNotExist) {
return err
}
if manifest != nil {
for _, l := range manifest.Layers {
deleteMap[l.Digest] = true
}
deleteMap[manifest.Config.Digest] = true
}
}
mf, err := os.Open(path)
if err != nil { if err != nil {
fn(api.ProgressResponse{Status: fmt.Sprintf("couldn't open modelfile '%s'", path)}) return p
return fmt.Errorf("failed to open file: %w", err)
} }
defer mf.Close()
fn(api.ProgressResponse{Status: "parsing modelfile"}) home, err := os.UserHomeDir()
commands, err := parser.Parse(mf)
if err != nil { if err != nil {
return err return abspath
} }
if p == "~" {
return home
} else if strings.HasPrefix(p, "~/") {
return filepath.Join(home, p[2:])
}
return abspath
}
func CreateModel(ctx context.Context, name string, commands []parser.Command, fn func(resp api.ProgressResponse)) error {
config := ConfigV2{ config := ConfigV2{
Architecture: "amd64",
OS: "linux", OS: "linux",
Architecture: "amd64",
} }
deleteMap := make(map[string]struct{})
var layers []*LayerReader var layers []*LayerReader
params := make(map[string][]string) params := make(map[string][]string)
var sourceParams map[string]any fromParams := make(map[string]any)
for _, c := range commands { for _, c := range commands {
log.Printf("[%s] - %s\n", c.Name, c.Args) log.Printf("[%s] - %s", c.Name, c.Args)
mediatype := fmt.Sprintf("application/vnd.ollama.image.%s", c.Name)
switch c.Name { switch c.Name {
case "model": case "model":
fn(api.ProgressResponse{Status: "looking for model"}) bin, err := os.Open(realpath(c.Args))
mp := ParseModelPath(c.Args)
mf, _, err := GetManifest(mp)
if err != nil { if err != nil {
modelFile, err := filenameWithPath(path, c.Args) // not a file on disk so must be a model reference
if err != nil { modelpath := ParseModelPath(c.Args)
return err manifest, _, err := GetManifest(modelpath)
} switch {
if _, err := os.Stat(modelFile); err != nil { case errors.Is(err, os.ErrNotExist):
// the model file does not exist, try pulling it fn(api.ProgressResponse{Status: "pulling model"})
if errors.Is(err, os.ErrNotExist) {
fn(api.ProgressResponse{Status: "pulling model file"})
if err := PullModel(ctx, c.Args, &RegistryOptions{}, fn); err != nil { if err := PullModel(ctx, c.Args, &RegistryOptions{}, fn); err != nil {
return err return err
} }
mf, _, err = GetManifest(mp)
manifest, _, err = GetManifest(modelpath)
if err != nil { if err != nil {
return fmt.Errorf("failed to open file after pull: %v", err)
}
} else {
return err return err
} }
} else { case err != nil:
// create a model from this specified file return err
fn(api.ProgressResponse{Status: "creating model layer"})
file, err := os.Open(modelFile)
if err != nil {
return fmt.Errorf("failed to open file: %v", err)
} }
defer file.Close()
ggml, err := llm.DecodeGGML(file) fn(api.ProgressResponse{Status: "reading model metadata"})
fromConfigPath, err := GetBlobsPath(manifest.Config.Digest)
if err != nil {
return err
}
fromConfigFile, err := os.Open(fromConfigPath)
if err != nil {
return err
}
defer fromConfigFile.Close()
var fromConfig ConfigV2
if err := json.NewDecoder(fromConfigFile).Decode(&fromConfig); err != nil {
return err
}
config.ModelFormat = fromConfig.ModelFormat
config.ModelFamily = fromConfig.ModelFamily
config.ModelType = fromConfig.ModelType
config.FileType = fromConfig.FileType
for _, layer := range manifest.Layers {
deleteMap[layer.Digest] = struct{}{}
if layer.MediaType == "application/vnd.ollama.image.params" {
fromParamsPath, err := GetBlobsPath(layer.Digest)
if err != nil {
return err
}
fromParamsFile, err := os.Open(fromParamsPath)
if err != nil {
return err
}
defer fromParamsFile.Close()
if err := json.NewDecoder(fromParamsFile).Decode(&fromParams); err != nil {
return err
}
}
layer, err := GetLayerWithBufferFromLayer(layer)
if err != nil {
return err
}
layer.From = modelpath.GetShortTagname()
layers = append(layers, layer)
}
deleteMap[manifest.Config.Digest] = struct{}{}
continue
}
defer bin.Close()
fn(api.ProgressResponse{Status: "creating model layer"})
ggml, err := llm.DecodeGGML(bin)
if err != nil { if err != nil {
return err return err
} }
@ -339,109 +373,47 @@ func CreateModel(ctx context.Context, name string, path string, fn func(resp api
config.ModelType = ggml.ModelType() config.ModelType = ggml.ModelType()
config.FileType = ggml.FileType() config.FileType = ggml.FileType()
// reset the file bin.Seek(0, io.SeekStart)
file.Seek(0, io.SeekStart) layer, err := CreateLayer(bin)
l, err := CreateLayer(file)
if err != nil {
return fmt.Errorf("failed to create layer: %v", err)
}
l.MediaType = "application/vnd.ollama.image.model"
layers = append(layers, l)
}
}
if mf != nil {
fn(api.ProgressResponse{Status: "reading model metadata"})
sourceBlobPath, err := GetBlobsPath(mf.Config.Digest)
if err != nil { if err != nil {
return err return err
} }
sourceBlob, err := os.Open(sourceBlobPath) layer.MediaType = mediatype
if err != nil { layers = append(layers, layer)
return err
}
defer sourceBlob.Close()
var source ConfigV2
if err := json.NewDecoder(sourceBlob).Decode(&source); err != nil {
return err
}
// copy the model metadata
config.ModelFamily = source.ModelFamily
config.ModelType = source.ModelType
config.ModelFormat = source.ModelFormat
config.FileType = source.FileType
for _, l := range mf.Layers {
if l.MediaType == "application/vnd.ollama.image.params" {
sourceParamsBlobPath, err := GetBlobsPath(l.Digest)
if err != nil {
return err
}
sourceParamsBlob, err := os.Open(sourceParamsBlobPath)
if err != nil {
return err
}
defer sourceParamsBlob.Close()
if err := json.NewDecoder(sourceParamsBlob).Decode(&sourceParams); err != nil {
return err
}
}
newLayer, err := GetLayerWithBufferFromLayer(l)
if err != nil {
return err
}
newLayer.From = mp.GetShortTagname()
layers = append(layers, newLayer)
}
}
case "adapter": case "adapter":
fn(api.ProgressResponse{Status: fmt.Sprintf("creating model %s layer", c.Name)}) fn(api.ProgressResponse{Status: "creating adapter layer"})
bin, err := os.Open(realpath(c.Args))
fp, err := filenameWithPath(path, c.Args)
if err != nil { if err != nil {
return err return err
} }
defer bin.Close()
// create a model from this specified file layer, err := CreateLayer(bin)
fn(api.ProgressResponse{Status: "creating model layer"})
file, err := os.Open(fp)
if err != nil {
return fmt.Errorf("failed to open file: %v", err)
}
defer file.Close()
l, err := CreateLayer(file)
if err != nil {
return fmt.Errorf("failed to create layer: %v", err)
}
l.MediaType = "application/vnd.ollama.image.adapter"
layers = append(layers, l)
case "license":
fn(api.ProgressResponse{Status: fmt.Sprintf("creating model %s layer", c.Name)})
mediaType := fmt.Sprintf("application/vnd.ollama.image.%s", c.Name)
layer, err := CreateLayer(strings.NewReader(c.Args))
if err != nil { if err != nil {
return err return err
} }
if layer.Size > 0 { if layer.Size > 0 {
layer.MediaType = mediaType layer.MediaType = mediatype
layers = append(layers, layer) layers = append(layers, layer)
} }
case "template", "system", "prompt": case "license":
fn(api.ProgressResponse{Status: fmt.Sprintf("creating model %s layer", c.Name)}) fn(api.ProgressResponse{Status: "creating license layer"})
// remove the layer if one exists layer, err := CreateLayer(strings.NewReader(c.Args))
mediaType := fmt.Sprintf("application/vnd.ollama.image.%s", c.Name) if err != nil {
layers = removeLayerFromLayers(layers, mediaType) return err
}
if layer.Size > 0 {
layer.MediaType = mediatype
layers = append(layers, layer)
}
case "template", "system":
fn(api.ProgressResponse{Status: fmt.Sprintf("creating %s layer", c.Name)})
// remove duplicates layers
layers = removeLayerFromLayers(layers, mediatype)
layer, err := CreateLayer(strings.NewReader(c.Args)) layer, err := CreateLayer(strings.NewReader(c.Args))
if err != nil { if err != nil {
@ -449,48 +421,47 @@ func CreateModel(ctx context.Context, name string, path string, fn func(resp api
} }
if layer.Size > 0 { if layer.Size > 0 {
layer.MediaType = mediaType layer.MediaType = mediatype
layers = append(layers, layer) layers = append(layers, layer)
} }
default: default:
// runtime parameters, build a list of args for each parameter to allow multiple values to be specified (ex: multiple stop sequences)
params[c.Name] = append(params[c.Name], c.Args) params[c.Name] = append(params[c.Name], c.Args)
} }
} }
// Create a single layer for the parameters
if len(params) > 0 { if len(params) > 0 {
fn(api.ProgressResponse{Status: "creating parameter layer"}) fn(api.ProgressResponse{Status: "creating parameters layer"})
layers = removeLayerFromLayers(layers, "application/vnd.ollama.image.params")
formattedParams, err := formatParams(params) formattedParams, err := formatParams(params)
if err != nil { if err != nil {
return fmt.Errorf("couldn't create params json: %v", err) return err
} }
for k, v := range sourceParams { for k, v := range fromParams {
if _, ok := formattedParams[k]; !ok { if _, ok := formattedParams[k]; !ok {
formattedParams[k] = v formattedParams[k] = v
} }
} }
if config.ModelType == "65B" { if config.ModelType == "65B" {
if numGQA, ok := formattedParams["num_gqa"].(int); ok && numGQA == 8 { if gqa, ok := formattedParams["gqa"].(int); ok && gqa == 8 {
config.ModelType = "70B" config.ModelType = "70B"
} }
} }
bts, err := json.Marshal(formattedParams) var b bytes.Buffer
if err := json.NewEncoder(&b).Encode(formattedParams); err != nil {
return err
}
fn(api.ProgressResponse{Status: "creating config layer"})
layer, err := CreateLayer(bytes.NewReader(b.Bytes()))
if err != nil { if err != nil {
return err return err
} }
l, err := CreateLayer(bytes.NewReader(bts)) layer.MediaType = "application/vnd.ollama.image.params"
if err != nil { layers = append(layers, layer)
return fmt.Errorf("failed to create layer: %v", err)
}
l.MediaType = "application/vnd.ollama.image.params"
layers = append(layers, l)
} }
digests, err := getLayerDigests(layers) digests, err := getLayerDigests(layers)
@ -498,36 +469,31 @@ func CreateModel(ctx context.Context, name string, path string, fn func(resp api
return err return err
} }
var manifestLayers []*Layer configLayer, err := createConfigLayer(config, digests)
for _, l := range layers {
manifestLayers = append(manifestLayers, &l.Layer)
delete(deleteMap, l.Layer.Digest)
}
// Create a layer for the config object
fn(api.ProgressResponse{Status: "creating config layer"})
cfg, err := createConfigLayer(config, digests)
if err != nil { if err != nil {
return err return err
} }
layers = append(layers, cfg)
delete(deleteMap, cfg.Layer.Digest) layers = append(layers, configLayer)
delete(deleteMap, configLayer.Digest)
if err := SaveLayers(layers, fn, false); err != nil { if err := SaveLayers(layers, fn, false); err != nil {
return err return err
} }
// Create the manifest var contentLayers []*Layer
for _, layer := range layers {
contentLayers = append(contentLayers, &layer.Layer)
delete(deleteMap, layer.Digest)
}
fn(api.ProgressResponse{Status: "writing manifest"}) fn(api.ProgressResponse{Status: "writing manifest"})
err = CreateManifest(name, cfg, manifestLayers) if err := CreateManifest(name, configLayer, contentLayers); err != nil {
if err != nil {
return err return err
} }
if noprune == "" { if noprune := os.Getenv("OLLAMA_NOPRUNE"); noprune == "" {
fn(api.ProgressResponse{Status: "removing any unused layers"}) if err := deleteUnusedLayers(nil, deleteMap, false); err != nil {
err = deleteUnusedLayers(nil, deleteMap, false)
if err != nil {
return err return err
} }
} }
@ -739,7 +705,7 @@ func CopyModel(src, dest string) error {
return nil return nil
} }
func deleteUnusedLayers(skipModelPath *ModelPath, deleteMap map[string]bool, dryRun bool) error { func deleteUnusedLayers(skipModelPath *ModelPath, deleteMap map[string]struct{}, dryRun bool) error {
fp, err := GetManifestPath() fp, err := GetManifestPath()
if err != nil { if err != nil {
return err return err
@ -779,8 +745,7 @@ func deleteUnusedLayers(skipModelPath *ModelPath, deleteMap map[string]bool, dry
} }
// only delete the files which are still in the deleteMap // only delete the files which are still in the deleteMap
for k, v := range deleteMap { for k := range deleteMap {
if v {
fp, err := GetBlobsPath(k) fp, err := GetBlobsPath(k)
if err != nil { if err != nil {
log.Printf("couldn't get file path for '%s': %v", k, err) log.Printf("couldn't get file path for '%s': %v", k, err)
@ -795,13 +760,12 @@ func deleteUnusedLayers(skipModelPath *ModelPath, deleteMap map[string]bool, dry
log.Printf("wanted to remove: %s", fp) log.Printf("wanted to remove: %s", fp)
} }
} }
}
return nil return nil
} }
func PruneLayers() error { func PruneLayers() error {
deleteMap := make(map[string]bool) deleteMap := make(map[string]struct{})
p, err := GetBlobsPath("") p, err := GetBlobsPath("")
if err != nil { if err != nil {
return err return err
@ -818,7 +782,7 @@ func PruneLayers() error {
if runtime.GOOS == "windows" { if runtime.GOOS == "windows" {
name = strings.ReplaceAll(name, "-", ":") name = strings.ReplaceAll(name, "-", ":")
} }
deleteMap[name] = true deleteMap[name] = struct{}{}
} }
log.Printf("total blobs: %d", len(deleteMap)) log.Printf("total blobs: %d", len(deleteMap))
@ -873,11 +837,11 @@ func DeleteModel(name string) error {
return err return err
} }
deleteMap := make(map[string]bool) deleteMap := make(map[string]struct{})
for _, layer := range manifest.Layers { for _, layer := range manifest.Layers {
deleteMap[layer.Digest] = true deleteMap[layer.Digest] = struct{}{}
} }
deleteMap[manifest.Config.Digest] = true deleteMap[manifest.Config.Digest] = struct{}{}
err = deleteUnusedLayers(&mp, deleteMap, false) err = deleteUnusedLayers(&mp, deleteMap, false)
if err != nil { if err != nil {
@ -1013,7 +977,7 @@ func PullModel(ctx context.Context, name string, regOpts *RegistryOptions, fn fu
var noprune string var noprune string
// build deleteMap to prune unused layers // build deleteMap to prune unused layers
deleteMap := make(map[string]bool) deleteMap := make(map[string]struct{})
if noprune = os.Getenv("OLLAMA_NOPRUNE"); noprune == "" { if noprune = os.Getenv("OLLAMA_NOPRUNE"); noprune == "" {
manifest, _, err = GetManifest(mp) manifest, _, err = GetManifest(mp)
@ -1023,9 +987,9 @@ func PullModel(ctx context.Context, name string, regOpts *RegistryOptions, fn fu
if manifest != nil { if manifest != nil {
for _, l := range manifest.Layers { for _, l := range manifest.Layers {
deleteMap[l.Digest] = true deleteMap[l.Digest] = struct{}{}
} }
deleteMap[manifest.Config.Digest] = true deleteMap[manifest.Config.Digest] = struct{}{}
} }
} }

View file

@ -26,6 +26,7 @@ import (
"github.com/jmorganca/ollama/api" "github.com/jmorganca/ollama/api"
"github.com/jmorganca/ollama/llm" "github.com/jmorganca/ollama/llm"
"github.com/jmorganca/ollama/parser"
"github.com/jmorganca/ollama/version" "github.com/jmorganca/ollama/version"
) )
@ -414,6 +415,19 @@ func CreateModelHandler(c *gin.Context) {
return return
} }
modelfile, err := os.Open(req.Path)
if err != nil {
c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": err.Error()})
return
}
defer modelfile.Close()
commands, err := parser.Parse(modelfile)
if err != nil {
c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": err.Error()})
return
}
ch := make(chan any) ch := make(chan any)
go func() { go func() {
defer close(ch) defer close(ch)
@ -424,7 +438,7 @@ func CreateModelHandler(c *gin.Context) {
ctx, cancel := context.WithCancel(c.Request.Context()) ctx, cancel := context.WithCancel(c.Request.Context())
defer cancel() defer cancel()
if err := CreateModel(ctx, req.Name, req.Path, fn); err != nil { if err := CreateModel(ctx, req.Name, commands, fn); err != nil {
ch <- gin.H{"error": err.Error()} ch <- gin.H{"error": err.Error()}
} }
}() }()