Merge pull request #3833 from ollama/mxyng/fix-from

fix: from blob
This commit is contained in:
Michael Yang 2024-04-24 15:13:47 -07:00 committed by GitHub
commit 2010cbc5fa
No known key found for this signature in database
GPG key ID: B5690EEEBB952194

View file

@ -17,6 +17,7 @@ import (
"os" "os"
"os/signal" "os/signal"
"path/filepath" "path/filepath"
"regexp"
"runtime" "runtime"
"strings" "strings"
"syscall" "syscall"
@ -53,8 +54,6 @@ func CreateHandler(cmd *cobra.Command, args []string) error {
p := progress.NewProgress(os.Stderr) p := progress.NewProgress(os.Stderr)
defer p.Stop() defer p.Stop()
bars := make(map[string]*progress.Bar)
modelfile, err := os.ReadFile(filename) modelfile, err := os.ReadFile(filename)
if err != nil { if err != nil {
return err return err
@ -95,95 +94,16 @@ func CreateHandler(cmd *cobra.Command, args []string) error {
return err return err
} }
// TODO make this work w/ adapters
if fi.IsDir() { if fi.IsDir() {
tf, err := os.CreateTemp("", "ollama-tf") // this is likely a safetensors or pytorch directory
// TODO make this work w/ adapters
tempfile, err := tempZipFiles(path)
if err != nil { if err != nil {
return err return err
} }
defer os.RemoveAll(tf.Name()) defer os.RemoveAll(tempfile)
zf := zip.NewWriter(tf) path = tempfile
files := []string{}
tfiles, err := filepath.Glob(filepath.Join(path, "pytorch_model-*.bin"))
if err != nil {
return err
} else if len(tfiles) == 0 {
tfiles, err = filepath.Glob(filepath.Join(path, "model-*.safetensors"))
if err != nil {
return err
}
}
files = append(files, tfiles...)
if len(files) == 0 {
return fmt.Errorf("no models were found in '%s'", path)
}
// add the safetensor/torch config file + tokenizer
files = append(files, filepath.Join(path, "config.json"))
files = append(files, filepath.Join(path, "params.json"))
files = append(files, filepath.Join(path, "added_tokens.json"))
files = append(files, filepath.Join(path, "tokenizer.model"))
for _, fn := range files {
f, err := os.Open(fn)
// just skip whatever files aren't there
if os.IsNotExist(err) {
if strings.HasSuffix(fn, "tokenizer.model") {
// try the parent dir before giving up
parentDir := filepath.Dir(path)
newFn := filepath.Join(parentDir, "tokenizer.model")
f, err = os.Open(newFn)
if os.IsNotExist(err) {
continue
} else if err != nil {
return err
}
} else {
continue
}
} else if err != nil {
return err
}
fi, err := f.Stat()
if err != nil {
return err
}
h, err := zip.FileInfoHeader(fi)
if err != nil {
return err
}
h.Name = filepath.Base(fn)
h.Method = zip.Store
w, err := zf.CreateHeader(h)
if err != nil {
return err
}
_, err = io.Copy(w, f)
if err != nil {
return err
}
}
if err := zf.Close(); err != nil {
return err
}
if err := tf.Close(); err != nil {
return err
}
path = tf.Name()
} }
digest, err := createBlob(cmd, client, path) digest, err := createBlob(cmd, client, path)
@ -191,10 +111,17 @@ func CreateHandler(cmd *cobra.Command, args []string) error {
return err return err
} }
modelfile = bytes.ReplaceAll(modelfile, []byte(c.Args), []byte("@"+digest)) name := c.Name
if c.Name == "model" {
name = "from"
}
re := regexp.MustCompile(fmt.Sprintf(`(?im)^(%s)\s+%s\s*$`, name, c.Args))
modelfile = re.ReplaceAll(modelfile, []byte("$1 @"+digest))
} }
} }
bars := make(map[string]*progress.Bar)
fn := func(resp api.ProgressResponse) error { fn := func(resp api.ProgressResponse) error {
if resp.Digest != "" { if resp.Digest != "" {
spinner.Stop() spinner.Stop()
@ -228,6 +155,88 @@ func CreateHandler(cmd *cobra.Command, args []string) error {
return nil return nil
} }
func tempZipFiles(path string) (string, error) {
tempfile, err := os.CreateTemp("", "ollama-tf")
if err != nil {
return "", err
}
defer tempfile.Close()
zipfile := zip.NewWriter(tempfile)
defer zipfile.Close()
tfiles, err := filepath.Glob(filepath.Join(path, "pytorch_model-*.bin"))
if err != nil {
return "", err
} else if len(tfiles) == 0 {
tfiles, err = filepath.Glob(filepath.Join(path, "model-*.safetensors"))
if err != nil {
return "", err
}
}
files := []string{}
files = append(files, tfiles...)
if len(files) == 0 {
return "", fmt.Errorf("no models were found in '%s'", path)
}
// add the safetensor/torch config file + tokenizer
files = append(files, filepath.Join(path, "config.json"))
files = append(files, filepath.Join(path, "params.json"))
files = append(files, filepath.Join(path, "added_tokens.json"))
files = append(files, filepath.Join(path, "tokenizer.model"))
for _, fn := range files {
f, err := os.Open(fn)
// just skip whatever files aren't there
if os.IsNotExist(err) {
if strings.HasSuffix(fn, "tokenizer.model") {
// try the parent dir before giving up
parentDir := filepath.Dir(path)
newFn := filepath.Join(parentDir, "tokenizer.model")
f, err = os.Open(newFn)
if os.IsNotExist(err) {
continue
} else if err != nil {
return "", err
}
} else {
continue
}
} else if err != nil {
return "", err
}
fi, err := f.Stat()
if err != nil {
return "", err
}
h, err := zip.FileInfoHeader(fi)
if err != nil {
return "", err
}
h.Name = filepath.Base(fn)
h.Method = zip.Store
w, err := zipfile.CreateHeader(h)
if err != nil {
return "", err
}
_, err = io.Copy(w, f)
if err != nil {
return "", err
}
}
return tempfile.Name(), nil
}
func createBlob(cmd *cobra.Command, client *api.Client, path string) (string, error) { func createBlob(cmd *cobra.Command, client *api.Client, path string) (string, error) {
bin, err := os.Open(path) bin, err := os.Open(path)
if err != nil { if err != nil {