refactor create model

This commit is contained in:
Michael Yang 2023-11-14 12:30:34 -08:00
parent f61f340279
commit b0d14ed51c
2 changed files with 162 additions and 184 deletions

View file

@ -248,200 +248,172 @@ func filenameWithPath(path, f string) (string, error) {
return f, nil return f, nil
} }
func CreateModel(ctx context.Context, name string, path string, fn func(resp api.ProgressResponse)) error { func realpath(p string) string {
mp := ParseModelPath(name) abspath, err := filepath.Abs(p)
var manifest *ManifestV2
var err error
var noprune string
// build deleteMap to prune unused layers
deleteMap := make(map[string]bool)
if noprune = os.Getenv("OLLAMA_NOPRUNE"); noprune == "" {
manifest, _, err = GetManifest(mp)
if err != nil && !errors.Is(err, os.ErrNotExist) {
return err
}
if manifest != nil {
for _, l := range manifest.Layers {
deleteMap[l.Digest] = true
}
deleteMap[manifest.Config.Digest] = true
}
}
mf, err := os.Open(path)
if err != nil { if err != nil {
fn(api.ProgressResponse{Status: fmt.Sprintf("couldn't open modelfile '%s'", path)}) return p
return fmt.Errorf("failed to open file: %w", err)
} }
defer mf.Close()
fn(api.ProgressResponse{Status: "parsing modelfile"}) home, err := os.UserHomeDir()
commands, err := parser.Parse(mf)
if err != nil { if err != nil {
return err return abspath
} }
if p == "~" {
return home
} else if strings.HasPrefix(p, "~/") {
return filepath.Join(home, p[2:])
}
return abspath
}
func CreateModel(ctx context.Context, name string, commands []parser.Command, fn func(resp api.ProgressResponse)) error {
config := ConfigV2{ config := ConfigV2{
Architecture: "amd64",
OS: "linux", OS: "linux",
Architecture: "amd64",
} }
deleteMap := make(map[string]struct{})
var layers []*LayerReader var layers []*LayerReader
params := make(map[string][]string) params := make(map[string][]string)
var sourceParams map[string]any fromParams := make(map[string]any)
for _, c := range commands { for _, c := range commands {
log.Printf("[%s] - %s\n", c.Name, c.Args) log.Printf("[%s] - %s", c.Name, c.Args)
mediatype := fmt.Sprintf("application/vnd.ollama.image.%s", c.Name)
switch c.Name { switch c.Name {
case "model": case "model":
fn(api.ProgressResponse{Status: "looking for model"}) bin, err := os.Open(realpath(c.Args))
mp := ParseModelPath(c.Args)
mf, _, err := GetManifest(mp)
if err != nil { if err != nil {
modelFile, err := filenameWithPath(path, c.Args) // not a file on disk so must be a model reference
if err != nil { modelpath := ParseModelPath(c.Args)
manifest, _, err := GetManifest(modelpath)
switch {
case errors.Is(err, os.ErrNotExist):
fn(api.ProgressResponse{Status: "pulling model"})
if err := PullModel(ctx, c.Args, &RegistryOptions{}, fn); err != nil {
return err
}
manifest, _, err = GetManifest(modelpath)
if err != nil {
return err
}
case err != nil:
return err return err
} }
if _, err := os.Stat(modelFile); err != nil {
// the model file does not exist, try pulling it
if errors.Is(err, os.ErrNotExist) {
fn(api.ProgressResponse{Status: "pulling model file"})
if err := PullModel(ctx, c.Args, &RegistryOptions{}, fn); err != nil {
return err
}
mf, _, err = GetManifest(mp)
if err != nil {
return fmt.Errorf("failed to open file after pull: %v", err)
}
} else {
return err
}
} else {
// create a model from this specified file
fn(api.ProgressResponse{Status: "creating model layer"})
file, err := os.Open(modelFile)
if err != nil {
return fmt.Errorf("failed to open file: %v", err)
}
defer file.Close()
ggml, err := llm.DecodeGGML(file)
if err != nil {
return err
}
config.ModelFormat = ggml.Name()
config.ModelFamily = ggml.ModelFamily()
config.ModelType = ggml.ModelType()
config.FileType = ggml.FileType()
// reset the file
file.Seek(0, io.SeekStart)
l, err := CreateLayer(file)
if err != nil {
return fmt.Errorf("failed to create layer: %v", err)
}
l.MediaType = "application/vnd.ollama.image.model"
layers = append(layers, l)
}
}
if mf != nil {
fn(api.ProgressResponse{Status: "reading model metadata"}) fn(api.ProgressResponse{Status: "reading model metadata"})
sourceBlobPath, err := GetBlobsPath(mf.Config.Digest) fromConfigPath, err := GetBlobsPath(manifest.Config.Digest)
if err != nil { if err != nil {
return err return err
} }
sourceBlob, err := os.Open(sourceBlobPath) fromConfigFile, err := os.Open(fromConfigPath)
if err != nil { if err != nil {
return err return err
} }
defer sourceBlob.Close() defer fromConfigFile.Close()
var source ConfigV2 var fromConfig ConfigV2
if err := json.NewDecoder(sourceBlob).Decode(&source); err != nil { if err := json.NewDecoder(fromConfigFile).Decode(&fromConfig); err != nil {
return err return err
} }
// copy the model metadata config.ModelFormat = fromConfig.ModelFormat
config.ModelFamily = source.ModelFamily config.ModelFamily = fromConfig.ModelFamily
config.ModelType = source.ModelType config.ModelType = fromConfig.ModelType
config.ModelFormat = source.ModelFormat config.FileType = fromConfig.FileType
config.FileType = source.FileType
for _, l := range mf.Layers { for _, layer := range manifest.Layers {
if l.MediaType == "application/vnd.ollama.image.params" { deleteMap[layer.Digest] = struct{}{}
sourceParamsBlobPath, err := GetBlobsPath(l.Digest) if layer.MediaType == "application/vnd.ollama.image.params" {
fromParamsPath, err := GetBlobsPath(layer.Digest)
if err != nil { if err != nil {
return err return err
} }
sourceParamsBlob, err := os.Open(sourceParamsBlobPath) fromParamsFile, err := os.Open(fromParamsPath)
if err != nil { if err != nil {
return err return err
} }
defer sourceParamsBlob.Close() defer fromParamsFile.Close()
if err := json.NewDecoder(sourceParamsBlob).Decode(&sourceParams); err != nil { if err := json.NewDecoder(fromParamsFile).Decode(&fromParams); err != nil {
return err return err
} }
} }
newLayer, err := GetLayerWithBufferFromLayer(l) layer, err := GetLayerWithBufferFromLayer(layer)
if err != nil { if err != nil {
return err return err
} }
newLayer.From = mp.GetShortTagname()
layers = append(layers, newLayer)
}
}
case "adapter":
fn(api.ProgressResponse{Status: fmt.Sprintf("creating model %s layer", c.Name)})
fp, err := filenameWithPath(path, c.Args) layer.From = modelpath.GetShortTagname()
layers = append(layers, layer)
}
deleteMap[manifest.Config.Digest] = struct{}{}
continue
}
defer bin.Close()
fn(api.ProgressResponse{Status: "creating model layer"})
ggml, err := llm.DecodeGGML(bin)
if err != nil { if err != nil {
return err return err
} }
// create a model from this specified file config.ModelFormat = ggml.Name()
fn(api.ProgressResponse{Status: "creating model layer"}) config.ModelFamily = ggml.ModelFamily()
config.ModelType = ggml.ModelType()
config.FileType = ggml.FileType()
file, err := os.Open(fp) bin.Seek(0, io.SeekStart)
layer, err := CreateLayer(bin)
if err != nil { if err != nil {
return fmt.Errorf("failed to open file: %v", err) return err
} }
defer file.Close()
l, err := CreateLayer(file) layer.MediaType = mediatype
layers = append(layers, layer)
case "adapter":
fn(api.ProgressResponse{Status: "creating adapter layer"})
bin, err := os.Open(realpath(c.Args))
if err != nil { if err != nil {
return fmt.Errorf("failed to create layer: %v", err) return err
} }
l.MediaType = "application/vnd.ollama.image.adapter" defer bin.Close()
layers = append(layers, l)
case "license":
fn(api.ProgressResponse{Status: fmt.Sprintf("creating model %s layer", c.Name)})
mediaType := fmt.Sprintf("application/vnd.ollama.image.%s", c.Name)
layer, err := CreateLayer(strings.NewReader(c.Args)) layer, err := CreateLayer(bin)
if err != nil { if err != nil {
return err return err
} }
if layer.Size > 0 { if layer.Size > 0 {
layer.MediaType = mediaType layer.MediaType = mediatype
layers = append(layers, layer) layers = append(layers, layer)
} }
case "template", "system", "prompt": case "license":
fn(api.ProgressResponse{Status: fmt.Sprintf("creating model %s layer", c.Name)}) fn(api.ProgressResponse{Status: "creating license layer"})
// remove the layer if one exists layer, err := CreateLayer(strings.NewReader(c.Args))
mediaType := fmt.Sprintf("application/vnd.ollama.image.%s", c.Name) if err != nil {
layers = removeLayerFromLayers(layers, mediaType) return err
}
if layer.Size > 0 {
layer.MediaType = mediatype
layers = append(layers, layer)
}
case "template", "system":
fn(api.ProgressResponse{Status: fmt.Sprintf("creating %s layer", c.Name)})
// remove duplicates layers
layers = removeLayerFromLayers(layers, mediatype)
layer, err := CreateLayer(strings.NewReader(c.Args)) layer, err := CreateLayer(strings.NewReader(c.Args))
if err != nil { if err != nil {
@ -449,48 +421,47 @@ func CreateModel(ctx context.Context, name string, path string, fn func(resp api
} }
if layer.Size > 0 { if layer.Size > 0 {
layer.MediaType = mediaType layer.MediaType = mediatype
layers = append(layers, layer) layers = append(layers, layer)
} }
default: default:
// runtime parameters, build a list of args for each parameter to allow multiple values to be specified (ex: multiple stop sequences)
params[c.Name] = append(params[c.Name], c.Args) params[c.Name] = append(params[c.Name], c.Args)
} }
} }
// Create a single layer for the parameters
if len(params) > 0 { if len(params) > 0 {
fn(api.ProgressResponse{Status: "creating parameter layer"}) fn(api.ProgressResponse{Status: "creating parameters layer"})
layers = removeLayerFromLayers(layers, "application/vnd.ollama.image.params")
formattedParams, err := formatParams(params) formattedParams, err := formatParams(params)
if err != nil { if err != nil {
return fmt.Errorf("couldn't create params json: %v", err) return err
} }
for k, v := range sourceParams { for k, v := range fromParams {
if _, ok := formattedParams[k]; !ok { if _, ok := formattedParams[k]; !ok {
formattedParams[k] = v formattedParams[k] = v
} }
} }
if config.ModelType == "65B" { if config.ModelType == "65B" {
if numGQA, ok := formattedParams["num_gqa"].(int); ok && numGQA == 8 { if gqa, ok := formattedParams["gqa"].(int); ok && gqa == 8 {
config.ModelType = "70B" config.ModelType = "70B"
} }
} }
bts, err := json.Marshal(formattedParams) var b bytes.Buffer
if err := json.NewEncoder(&b).Encode(formattedParams); err != nil {
return err
}
fn(api.ProgressResponse{Status: "creating config layer"})
layer, err := CreateLayer(bytes.NewReader(b.Bytes()))
if err != nil { if err != nil {
return err return err
} }
l, err := CreateLayer(bytes.NewReader(bts)) layer.MediaType = "application/vnd.ollama.image.params"
if err != nil { layers = append(layers, layer)
return fmt.Errorf("failed to create layer: %v", err)
}
l.MediaType = "application/vnd.ollama.image.params"
layers = append(layers, l)
} }
digests, err := getLayerDigests(layers) digests, err := getLayerDigests(layers)
@ -498,36 +469,31 @@ func CreateModel(ctx context.Context, name string, path string, fn func(resp api
return err return err
} }
var manifestLayers []*Layer configLayer, err := createConfigLayer(config, digests)
for _, l := range layers {
manifestLayers = append(manifestLayers, &l.Layer)
delete(deleteMap, l.Layer.Digest)
}
// Create a layer for the config object
fn(api.ProgressResponse{Status: "creating config layer"})
cfg, err := createConfigLayer(config, digests)
if err != nil { if err != nil {
return err return err
} }
layers = append(layers, cfg)
delete(deleteMap, cfg.Layer.Digest) layers = append(layers, configLayer)
delete(deleteMap, configLayer.Digest)
if err := SaveLayers(layers, fn, false); err != nil { if err := SaveLayers(layers, fn, false); err != nil {
return err return err
} }
// Create the manifest var contentLayers []*Layer
for _, layer := range layers {
contentLayers = append(contentLayers, &layer.Layer)
delete(deleteMap, layer.Digest)
}
fn(api.ProgressResponse{Status: "writing manifest"}) fn(api.ProgressResponse{Status: "writing manifest"})
err = CreateManifest(name, cfg, manifestLayers) if err := CreateManifest(name, configLayer, contentLayers); err != nil {
if err != nil {
return err return err
} }
if noprune == "" { if noprune := os.Getenv("OLLAMA_NOPRUNE"); noprune == "" {
fn(api.ProgressResponse{Status: "removing any unused layers"}) if err := deleteUnusedLayers(nil, deleteMap, false); err != nil {
err = deleteUnusedLayers(nil, deleteMap, false)
if err != nil {
return err return err
} }
} }
@ -739,7 +705,7 @@ func CopyModel(src, dest string) error {
return nil return nil
} }
func deleteUnusedLayers(skipModelPath *ModelPath, deleteMap map[string]bool, dryRun bool) error { func deleteUnusedLayers(skipModelPath *ModelPath, deleteMap map[string]struct{}, dryRun bool) error {
fp, err := GetManifestPath() fp, err := GetManifestPath()
if err != nil { if err != nil {
return err return err
@ -779,21 +745,19 @@ func deleteUnusedLayers(skipModelPath *ModelPath, deleteMap map[string]bool, dry
} }
// only delete the files which are still in the deleteMap // only delete the files which are still in the deleteMap
for k, v := range deleteMap { for k := range deleteMap {
if v { fp, err := GetBlobsPath(k)
fp, err := GetBlobsPath(k) if err != nil {
if err != nil { log.Printf("couldn't get file path for '%s': %v", k, err)
log.Printf("couldn't get file path for '%s': %v", k, err) continue
}
if !dryRun {
if err := os.Remove(fp); err != nil {
log.Printf("couldn't remove file '%s': %v", fp, err)
continue continue
} }
if !dryRun { } else {
if err := os.Remove(fp); err != nil { log.Printf("wanted to remove: %s", fp)
log.Printf("couldn't remove file '%s': %v", fp, err)
continue
}
} else {
log.Printf("wanted to remove: %s", fp)
}
} }
} }
@ -801,7 +765,7 @@ func deleteUnusedLayers(skipModelPath *ModelPath, deleteMap map[string]bool, dry
} }
func PruneLayers() error { func PruneLayers() error {
deleteMap := make(map[string]bool) deleteMap := make(map[string]struct{})
p, err := GetBlobsPath("") p, err := GetBlobsPath("")
if err != nil { if err != nil {
return err return err
@ -818,7 +782,7 @@ func PruneLayers() error {
if runtime.GOOS == "windows" { if runtime.GOOS == "windows" {
name = strings.ReplaceAll(name, "-", ":") name = strings.ReplaceAll(name, "-", ":")
} }
deleteMap[name] = true deleteMap[name] = struct{}{}
} }
log.Printf("total blobs: %d", len(deleteMap)) log.Printf("total blobs: %d", len(deleteMap))
@ -873,11 +837,11 @@ func DeleteModel(name string) error {
return err return err
} }
deleteMap := make(map[string]bool) deleteMap := make(map[string]struct{})
for _, layer := range manifest.Layers { for _, layer := range manifest.Layers {
deleteMap[layer.Digest] = true deleteMap[layer.Digest] = struct{}{}
} }
deleteMap[manifest.Config.Digest] = true deleteMap[manifest.Config.Digest] = struct{}{}
err = deleteUnusedLayers(&mp, deleteMap, false) err = deleteUnusedLayers(&mp, deleteMap, false)
if err != nil { if err != nil {
@ -1013,7 +977,7 @@ func PullModel(ctx context.Context, name string, regOpts *RegistryOptions, fn fu
var noprune string var noprune string
// build deleteMap to prune unused layers // build deleteMap to prune unused layers
deleteMap := make(map[string]bool) deleteMap := make(map[string]struct{})
if noprune = os.Getenv("OLLAMA_NOPRUNE"); noprune == "" { if noprune = os.Getenv("OLLAMA_NOPRUNE"); noprune == "" {
manifest, _, err = GetManifest(mp) manifest, _, err = GetManifest(mp)
@ -1023,9 +987,9 @@ func PullModel(ctx context.Context, name string, regOpts *RegistryOptions, fn fu
if manifest != nil { if manifest != nil {
for _, l := range manifest.Layers { for _, l := range manifest.Layers {
deleteMap[l.Digest] = true deleteMap[l.Digest] = struct{}{}
} }
deleteMap[manifest.Config.Digest] = true deleteMap[manifest.Config.Digest] = struct{}{}
} }
} }

View file

@ -26,6 +26,7 @@ import (
"github.com/jmorganca/ollama/api" "github.com/jmorganca/ollama/api"
"github.com/jmorganca/ollama/llm" "github.com/jmorganca/ollama/llm"
"github.com/jmorganca/ollama/parser"
"github.com/jmorganca/ollama/version" "github.com/jmorganca/ollama/version"
) )
@ -414,6 +415,19 @@ func CreateModelHandler(c *gin.Context) {
return return
} }
modelfile, err := os.Open(req.Path)
if err != nil {
c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": err.Error()})
return
}
defer modelfile.Close()
commands, err := parser.Parse(modelfile)
if err != nil {
c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": err.Error()})
return
}
ch := make(chan any) ch := make(chan any)
go func() { go func() {
defer close(ch) defer close(ch)
@ -424,7 +438,7 @@ func CreateModelHandler(c *gin.Context) {
ctx, cancel := context.WithCancel(c.Request.Context()) ctx, cancel := context.WithCancel(c.Request.Context())
defer cancel() defer cancel()
if err := CreateModel(ctx, req.Name, req.Path, fn); err != nil { if err := CreateModel(ctx, req.Name, commands, fn); err != nil {
ch <- gin.H{"error": err.Error()} ch <- gin.H{"error": err.Error()}
} }
}() }()