From 41e03ede95d81278a24424132a0ed584554d022f Mon Sep 17 00:00:00 2001 From: Michael Yang Date: Thu, 25 Apr 2024 14:41:30 -0700 Subject: [PATCH] check file type before zip --- cmd/cmd.go | 114 ++++++++++++++++++++++++++++++++--------------------- 1 file changed, 70 insertions(+), 44 deletions(-) diff --git a/cmd/cmd.go b/cmd/cmd.go index 4585533f..0a1dc7ed 100644 --- a/cmd/cmd.go +++ b/cmd/cmd.go @@ -165,71 +165,97 @@ func tempZipFiles(path string) (string, error) { zipfile := zip.NewWriter(tempfile) defer zipfile.Close() - tfiles, err := filepath.Glob(filepath.Join(path, "pytorch_model-*.bin")) - if err != nil { - return "", err - } else if len(tfiles) == 0 { - tfiles, err = filepath.Glob(filepath.Join(path, "model-*.safetensors")) + detectContentType := func(path string) (string, error) { + f, err := os.Open(path) if err != nil { return "", err } - } + defer f.Close() - files := []string{} - files = append(files, tfiles...) + var b bytes.Buffer + b.Grow(512) - if len(files) == 0 { - return "", fmt.Errorf("no models were found in '%s'", path) - } - - // add the safetensor/torch config file + tokenizer - files = append(files, filepath.Join(path, "config.json")) - files = append(files, filepath.Join(path, "params.json")) - files = append(files, filepath.Join(path, "added_tokens.json")) - files = append(files, filepath.Join(path, "tokenizer.model")) - - for _, fn := range files { - f, err := os.Open(fn) - - // just skip whatever files aren't there - if os.IsNotExist(err) { - if strings.HasSuffix(fn, "tokenizer.model") { - // try the parent dir before giving up - parentDir := filepath.Dir(path) - newFn := filepath.Join(parentDir, "tokenizer.model") - f, err = os.Open(newFn) - if os.IsNotExist(err) { - continue - } else if err != nil { - return "", err - } - } else { - continue - } - } else if err != nil { + if _, err := io.CopyN(&b, f, 512); err != nil && !errors.Is(err, io.EOF) { return "", err } + contentType, _, _ := strings.Cut(http.DetectContentType(b.Bytes()), ";") + return contentType, nil + } + + glob := func(pattern, contentType string) ([]string, error) { + matches, err := filepath.Glob(pattern) + if err != nil { + return nil, err + } + + for _, safetensor := range matches { + if ct, err := detectContentType(safetensor); err != nil { + return nil, err + } else if ct != contentType { + return nil, fmt.Errorf("invalid content type: expected %s for %s", ct, safetensor) + } + } + + return matches, nil + } + + var files []string + if st, _ := glob(filepath.Join(path, "model*.safetensors"), "application/octet-stream"); len(st) > 0 { + // safetensors files might be unresolved git lfs references; skip if they are + // covers model-x-of-y.safetensors, model.fp32-x-of-y.safetensors, model.safetensors + files = append(files, st...) + } else if pt, _ := glob(filepath.Join(path, "pytorch_model*.bin"), "application/zip"); len(pt) > 0 { + // pytorch files might also be unresolved git lfs references; skip if they are + // covers pytorch_model-x-of-y.bin, pytorch_model.fp32-x-of-y.bin, pytorch_model.bin + files = append(files, pt...) + } else if pt, _ := glob(filepath.Join(path, "consolidated*.pth"), "application/octet-stream"); len(pt) > 0 { + // pytorch files might also be unresolved git lfs references; skip if they are + // covers consolidated.x.pth, consolidated.pth + files = append(files, pt...) + } else { + return "", errors.New("no safetensors or torch files found") + } + + // add configuration files, json files are detected as text/plain + js, err := glob(filepath.Join(path, "*.json"), "text/plain") + if err != nil { + return "", err + } + files = append(files, js...) + + if tks, _ := glob(filepath.Join(path, "tokenizer.model"), "application/octet-stream"); len(tks) > 0 { + // add tokenizer.model if it exists, tokenizer.json is automatically picked up by the previous glob + // tokenizer.model might be a unresolved git lfs reference; error if it is + files = append(files, tks...) + } else if tks, _ := glob(filepath.Join(path, "**/tokenizer.model"), "text/plain"); len(tks) > 0 { + // some times tokenizer.model is in a subdirectory (e.g. meta-llama/Meta-Llama-3-8B) + files = append(files, tks...) + } + + for _, file := range files { + f, err := os.Open(file) + if err != nil { + return "", err + } + defer f.Close() + fi, err := f.Stat() if err != nil { return "", err } - h, err := zip.FileInfoHeader(fi) + zfi, err := zip.FileInfoHeader(fi) if err != nil { return "", err } - h.Name = filepath.Base(fn) - h.Method = zip.Store - - w, err := zipfile.CreateHeader(h) + zf, err := zipfile.CreateHeader(zfi) if err != nil { return "", err } - _, err = io.Copy(w, f) - if err != nil { + if _, err := io.Copy(zf, f); err != nil { return "", err } }