check file type before zip
This commit is contained in:
parent
7fea1ecdf6
commit
41e03ede95
1 changed files with 70 additions and 44 deletions
114
cmd/cmd.go
114
cmd/cmd.go
|
@ -165,71 +165,97 @@ func tempZipFiles(path string) (string, error) {
|
||||||
zipfile := zip.NewWriter(tempfile)
|
zipfile := zip.NewWriter(tempfile)
|
||||||
defer zipfile.Close()
|
defer zipfile.Close()
|
||||||
|
|
||||||
tfiles, err := filepath.Glob(filepath.Join(path, "pytorch_model-*.bin"))
|
detectContentType := func(path string) (string, error) {
|
||||||
if err != nil {
|
f, err := os.Open(path)
|
||||||
return "", err
|
|
||||||
} else if len(tfiles) == 0 {
|
|
||||||
tfiles, err = filepath.Glob(filepath.Join(path, "model-*.safetensors"))
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return "", err
|
return "", err
|
||||||
}
|
}
|
||||||
}
|
defer f.Close()
|
||||||
|
|
||||||
files := []string{}
|
var b bytes.Buffer
|
||||||
files = append(files, tfiles...)
|
b.Grow(512)
|
||||||
|
|
||||||
if len(files) == 0 {
|
if _, err := io.CopyN(&b, f, 512); err != nil && !errors.Is(err, io.EOF) {
|
||||||
return "", fmt.Errorf("no models were found in '%s'", path)
|
|
||||||
}
|
|
||||||
|
|
||||||
// add the safetensor/torch config file + tokenizer
|
|
||||||
files = append(files, filepath.Join(path, "config.json"))
|
|
||||||
files = append(files, filepath.Join(path, "params.json"))
|
|
||||||
files = append(files, filepath.Join(path, "added_tokens.json"))
|
|
||||||
files = append(files, filepath.Join(path, "tokenizer.model"))
|
|
||||||
|
|
||||||
for _, fn := range files {
|
|
||||||
f, err := os.Open(fn)
|
|
||||||
|
|
||||||
// just skip whatever files aren't there
|
|
||||||
if os.IsNotExist(err) {
|
|
||||||
if strings.HasSuffix(fn, "tokenizer.model") {
|
|
||||||
// try the parent dir before giving up
|
|
||||||
parentDir := filepath.Dir(path)
|
|
||||||
newFn := filepath.Join(parentDir, "tokenizer.model")
|
|
||||||
f, err = os.Open(newFn)
|
|
||||||
if os.IsNotExist(err) {
|
|
||||||
continue
|
|
||||||
} else if err != nil {
|
|
||||||
return "", err
|
|
||||||
}
|
|
||||||
} else {
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
} else if err != nil {
|
|
||||||
return "", err
|
return "", err
|
||||||
}
|
}
|
||||||
|
|
||||||
|
contentType, _, _ := strings.Cut(http.DetectContentType(b.Bytes()), ";")
|
||||||
|
return contentType, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
glob := func(pattern, contentType string) ([]string, error) {
|
||||||
|
matches, err := filepath.Glob(pattern)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, safetensor := range matches {
|
||||||
|
if ct, err := detectContentType(safetensor); err != nil {
|
||||||
|
return nil, err
|
||||||
|
} else if ct != contentType {
|
||||||
|
return nil, fmt.Errorf("invalid content type: expected %s for %s", ct, safetensor)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return matches, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
var files []string
|
||||||
|
if st, _ := glob(filepath.Join(path, "model*.safetensors"), "application/octet-stream"); len(st) > 0 {
|
||||||
|
// safetensors files might be unresolved git lfs references; skip if they are
|
||||||
|
// covers model-x-of-y.safetensors, model.fp32-x-of-y.safetensors, model.safetensors
|
||||||
|
files = append(files, st...)
|
||||||
|
} else if pt, _ := glob(filepath.Join(path, "pytorch_model*.bin"), "application/zip"); len(pt) > 0 {
|
||||||
|
// pytorch files might also be unresolved git lfs references; skip if they are
|
||||||
|
// covers pytorch_model-x-of-y.bin, pytorch_model.fp32-x-of-y.bin, pytorch_model.bin
|
||||||
|
files = append(files, pt...)
|
||||||
|
} else if pt, _ := glob(filepath.Join(path, "consolidated*.pth"), "application/octet-stream"); len(pt) > 0 {
|
||||||
|
// pytorch files might also be unresolved git lfs references; skip if they are
|
||||||
|
// covers consolidated.x.pth, consolidated.pth
|
||||||
|
files = append(files, pt...)
|
||||||
|
} else {
|
||||||
|
return "", errors.New("no safetensors or torch files found")
|
||||||
|
}
|
||||||
|
|
||||||
|
// add configuration files, json files are detected as text/plain
|
||||||
|
js, err := glob(filepath.Join(path, "*.json"), "text/plain")
|
||||||
|
if err != nil {
|
||||||
|
return "", err
|
||||||
|
}
|
||||||
|
files = append(files, js...)
|
||||||
|
|
||||||
|
if tks, _ := glob(filepath.Join(path, "tokenizer.model"), "application/octet-stream"); len(tks) > 0 {
|
||||||
|
// add tokenizer.model if it exists, tokenizer.json is automatically picked up by the previous glob
|
||||||
|
// tokenizer.model might be a unresolved git lfs reference; error if it is
|
||||||
|
files = append(files, tks...)
|
||||||
|
} else if tks, _ := glob(filepath.Join(path, "**/tokenizer.model"), "text/plain"); len(tks) > 0 {
|
||||||
|
// some times tokenizer.model is in a subdirectory (e.g. meta-llama/Meta-Llama-3-8B)
|
||||||
|
files = append(files, tks...)
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, file := range files {
|
||||||
|
f, err := os.Open(file)
|
||||||
|
if err != nil {
|
||||||
|
return "", err
|
||||||
|
}
|
||||||
|
defer f.Close()
|
||||||
|
|
||||||
fi, err := f.Stat()
|
fi, err := f.Stat()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return "", err
|
return "", err
|
||||||
}
|
}
|
||||||
|
|
||||||
h, err := zip.FileInfoHeader(fi)
|
zfi, err := zip.FileInfoHeader(fi)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return "", err
|
return "", err
|
||||||
}
|
}
|
||||||
|
|
||||||
h.Name = filepath.Base(fn)
|
zf, err := zipfile.CreateHeader(zfi)
|
||||||
h.Method = zip.Store
|
|
||||||
|
|
||||||
w, err := zipfile.CreateHeader(h)
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return "", err
|
return "", err
|
||||||
}
|
}
|
||||||
|
|
||||||
_, err = io.Copy(w, f)
|
if _, err := io.Copy(zf, f); err != nil {
|
||||||
if err != nil {
|
|
||||||
return "", err
|
return "", err
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
Loading…
Add table
Reference in a new issue