diff --git a/convert/convert.go b/convert/convert.go index 30c5a53f..b9461e4f 100644 --- a/convert/convert.go +++ b/convert/convert.go @@ -5,9 +5,8 @@ import ( "errors" "fmt" "io" + "io/fs" "log/slog" - "os" - "path/filepath" "github.com/ollama/ollama/llm" ) @@ -67,8 +66,8 @@ type Converter interface { // and files it finds in the input path. // Supported input model formats include safetensors. // Supported input tokenizers files include tokenizer.json (preferred) and tokenizer.model. -func Convert(path string, ws io.WriteSeeker) error { - bts, err := os.ReadFile(filepath.Join(path, "config.json")) +func Convert(fsys fs.FS, ws io.WriteSeeker) error { + bts, err := fs.ReadFile(fsys, "config.json") if err != nil { return err } @@ -98,7 +97,7 @@ func Convert(path string, ws io.WriteSeeker) error { return err } - t, err := parseTokenizer(path, conv.specialTokenTypes()) + t, err := parseTokenizer(fsys, conv.specialTokenTypes()) if err != nil { return err } @@ -114,7 +113,7 @@ func Convert(path string, ws io.WriteSeeker) error { slog.Debug("vocabulary", "size", len(t.Vocabulary.Tokens)) } - ts, err := parseTensors(path) + ts, err := parseTensors(fsys) if err != nil { return err } diff --git a/convert/convert_test.go b/convert/convert_test.go index 0fbd436f..67a2fcfe 100644 --- a/convert/convert_test.go +++ b/convert/convert_test.go @@ -6,6 +6,7 @@ import ( "flag" "fmt" "io" + "io/fs" "log/slog" "math" "os" @@ -17,7 +18,7 @@ import ( "golang.org/x/exp/maps" ) -func convertFull(t *testing.T, d string) (*os.File, llm.KV, llm.Tensors) { +func convertFull(t *testing.T, fsys fs.FS) (*os.File, llm.KV, llm.Tensors) { t.Helper() f, err := os.CreateTemp(t.TempDir(), "f16") @@ -26,7 +27,7 @@ func convertFull(t *testing.T, d string) (*os.File, llm.KV, llm.Tensors) { } defer f.Close() - if err := Convert(d, f); err != nil { + if err := Convert(fsys, f); err != nil { t.Fatal(err) } @@ -76,7 +77,7 @@ func TestConvertFull(t *testing.T) { t.Skipf("%s not found", p) } - f, kv, tensors := convertFull(t, p) + f, kv, tensors := convertFull(t, os.DirFS(p)) actual := make(map[string]string) for k, v := range kv { if s, ok := v.(json.Marshaler); !ok { diff --git a/convert/fs.go b/convert/fs.go new file mode 100644 index 00000000..bf6da6c2 --- /dev/null +++ b/convert/fs.go @@ -0,0 +1,58 @@ +package convert + +import ( + "archive/zip" + "errors" + "io" + "io/fs" + "os" + "path/filepath" +) + +type ZipReader struct { + r *zip.Reader + p string + + // limit is the maximum size of a file that can be read directly + // from the zip archive. Files larger than this size will be extracted + limit int64 +} + +func NewZipReader(r *zip.Reader, p string, limit int64) fs.FS { + return &ZipReader{r, p, limit} +} + +func (z *ZipReader) Open(name string) (fs.File, error) { + r, err := z.r.Open(name) + if err != nil { + return nil, err + } + defer r.Close() + + if fi, err := r.Stat(); err != nil { + return nil, err + } else if fi.Size() < z.limit { + return r, nil + } + + if !filepath.IsLocal(name) { + return nil, zip.ErrInsecurePath + } + + n := filepath.Join(z.p, name) + if _, err := os.Stat(n); errors.Is(err, os.ErrNotExist) { + w, err := os.Create(n) + if err != nil { + return nil, err + } + defer w.Close() + + if _, err := io.Copy(w, r); err != nil { + return nil, err + } + } else if err != nil { + return nil, err + } + + return os.Open(n) +} diff --git a/convert/reader.go b/convert/reader.go index 11ccaa81..56a8ae89 100644 --- a/convert/reader.go +++ b/convert/reader.go @@ -3,7 +3,7 @@ package convert import ( "errors" "io" - "path/filepath" + "io/fs" "strings" ) @@ -55,8 +55,8 @@ func (t *tensorBase) SetRepacker(fn repacker) { type repacker func(string, []float32, []uint64) ([]float32, error) -func parseTensors(d string) ([]Tensor, error) { - patterns := map[string]func(...string) ([]Tensor, error){ +func parseTensors(fsys fs.FS) ([]Tensor, error) { + patterns := map[string]func(fs.FS, ...string) ([]Tensor, error){ "model-*-of-*.safetensors": parseSafetensors, "model.safetensors": parseSafetensors, "pytorch_model-*-of-*.bin": parseTorch, @@ -65,13 +65,13 @@ func parseTensors(d string) ([]Tensor, error) { } for pattern, parseFn := range patterns { - matches, err := filepath.Glob(filepath.Join(d, pattern)) + matches, err := fs.Glob(fsys, pattern) if err != nil { return nil, err } if len(matches) > 0 { - return parseFn(matches...) + return parseFn(fsys, matches...) } } diff --git a/convert/reader_safetensors.go b/convert/reader_safetensors.go index c5fe663c..1c169504 100644 --- a/convert/reader_safetensors.go +++ b/convert/reader_safetensors.go @@ -6,7 +6,7 @@ import ( "encoding/json" "fmt" "io" - "os" + "io/fs" "slices" "github.com/d4l3k/go-bfloat16" @@ -20,10 +20,10 @@ type safetensorMetadata struct { Offsets []int64 `json:"data_offsets"` } -func parseSafetensors(ps ...string) ([]Tensor, error) { +func parseSafetensors(fsys fs.FS, ps ...string) ([]Tensor, error) { var ts []Tensor for _, p := range ps { - f, err := os.Open(p) + f, err := fsys.Open(p) if err != nil { return nil, err } @@ -50,6 +50,7 @@ func parseSafetensors(ps ...string) ([]Tensor, error) { for _, key := range keys { if value := headers[key]; value.Type != "" { ts = append(ts, safetensor{ + fs: fsys, path: p, dtype: value.Type, offset: safetensorsPad(n, value.Offsets[0]), @@ -72,6 +73,7 @@ func safetensorsPad(n, offset int64) int64 { } type safetensor struct { + fs fs.FS path string dtype string offset int64 @@ -80,14 +82,20 @@ type safetensor struct { } func (st safetensor) WriteTo(w io.Writer) (int64, error) { - f, err := os.Open(st.path) + f, err := st.fs.Open(st.path) if err != nil { return 0, err } defer f.Close() - if _, err = f.Seek(st.offset, io.SeekStart); err != nil { - return 0, err + if seeker, ok := f.(io.Seeker); ok { + if _, err := seeker.Seek(st.offset, io.SeekStart); err != nil { + return 0, err + } + } else { + if _, err := io.CopyN(io.Discard, f, st.offset); err != nil { + return 0, err + } } var f32s []float32 diff --git a/convert/reader_torch.go b/convert/reader_torch.go index 1428706e..531996bf 100644 --- a/convert/reader_torch.go +++ b/convert/reader_torch.go @@ -2,12 +2,13 @@ package convert import ( "io" + "io/fs" "github.com/nlpodyssey/gopickle/pytorch" "github.com/nlpodyssey/gopickle/types" ) -func parseTorch(ps ...string) ([]Tensor, error) { +func parseTorch(fsys fs.FS, ps ...string) ([]Tensor, error) { var ts []Tensor for _, p := range ps { pt, err := pytorch.Load(p) diff --git a/convert/tokenizer.go b/convert/tokenizer.go index 43d8c14e..cca40eb0 100644 --- a/convert/tokenizer.go +++ b/convert/tokenizer.go @@ -7,9 +7,9 @@ import ( "encoding/json" "errors" "fmt" + "io/fs" "log/slog" "os" - "path/filepath" "slices" ) @@ -32,8 +32,8 @@ type Tokenizer struct { Template string } -func parseTokenizer(d string, specialTokenTypes []string) (*Tokenizer, error) { - v, err := parseVocabulary(d) +func parseTokenizer(fsys fs.FS, specialTokenTypes []string) (*Tokenizer, error) { + v, err := parseVocabulary(fsys) if err != nil { return nil, err } @@ -44,7 +44,7 @@ func parseTokenizer(d string, specialTokenTypes []string) (*Tokenizer, error) { } addedTokens := make(map[string]token) - if f, err := os.Open(filepath.Join(d, "tokenizer.json")); errors.Is(err, os.ErrNotExist) { + if f, err := fsys.Open("tokenizer.json"); errors.Is(err, os.ErrNotExist) { } else if err != nil { return nil, err } else { @@ -87,7 +87,7 @@ func parseTokenizer(d string, specialTokenTypes []string) (*Tokenizer, error) { } } - if f, err := os.Open(filepath.Join(d, "tokenizer_config.json")); errors.Is(err, os.ErrNotExist) { + if f, err := fsys.Open("tokenizer_config.json"); errors.Is(err, os.ErrNotExist) { } else if err != nil { return nil, err } else { @@ -172,8 +172,8 @@ type Vocabulary struct { Types []int32 } -func parseVocabularyFromTokenizer(p string) (*Vocabulary, error) { - f, err := os.Open(filepath.Join(p, "tokenizer.json")) +func parseVocabularyFromTokenizer(fsys fs.FS) (*Vocabulary, error) { + f, err := fsys.Open("tokenizer.json") if err != nil { return nil, err } @@ -219,20 +219,20 @@ func parseVocabularyFromTokenizer(p string) (*Vocabulary, error) { return &v, nil } -func parseVocabulary(d string) (*Vocabulary, error) { - patterns := map[string]func(string) (*Vocabulary, error){ +func parseVocabulary(fsys fs.FS) (*Vocabulary, error) { + patterns := map[string]func(fs.FS) (*Vocabulary, error){ "tokenizer.model": parseSentencePiece, "tokenizer.json": parseVocabularyFromTokenizer, } for pattern, parseFn := range patterns { - if _, err := os.Stat(filepath.Join(d, pattern)); errors.Is(err, os.ErrNotExist) { + if _, err := fs.Stat(fsys, pattern); errors.Is(err, os.ErrNotExist) { continue } else if err != nil { return nil, err } - return parseFn(d) + return parseFn(fsys) } return nil, errors.New("unknown tensor format") diff --git a/convert/tokenizer_spm.go b/convert/tokenizer_spm.go index 75d9fe26..babf702c 100644 --- a/convert/tokenizer_spm.go +++ b/convert/tokenizer_spm.go @@ -5,8 +5,8 @@ import ( "encoding/json" "errors" "fmt" + "io/fs" "os" - "path/filepath" "slices" "google.golang.org/protobuf/proto" @@ -14,8 +14,8 @@ import ( "github.com/ollama/ollama/convert/sentencepiece" ) -func parseSentencePiece(d string) (*Vocabulary, error) { - bts, err := os.ReadFile(filepath.Join(d, "tokenizer.model")) +func parseSentencePiece(fsys fs.FS) (*Vocabulary, error) { + bts, err := fs.ReadFile(fsys, "tokenizer.model") if err != nil { return nil, err } @@ -41,7 +41,7 @@ func parseSentencePiece(d string) (*Vocabulary, error) { } } - f, err := os.Open(filepath.Join(d, "added_tokens.json")) + f, err := fsys.Open("added_tokens.json") if errors.Is(err, os.ErrNotExist) { return &v, nil } else if err != nil { diff --git a/server/model.go b/server/model.go index 81272a34..f2946a0b 100644 --- a/server/model.go +++ b/server/model.go @@ -81,88 +81,43 @@ func parseFromModel(ctx context.Context, name model.Name, fn func(api.ProgressRe return layers, nil } -func extractFromZipFile(p string, file *os.File, fn func(api.ProgressResponse)) error { - stat, err := file.Stat() - if err != nil { - return err - } - - r, err := zip.NewReader(file, stat.Size()) - if err != nil { - return err - } - - fn(api.ProgressResponse{Status: "unpacking model metadata"}) - for _, f := range r.File { - if !filepath.IsLocal(f.Name) { - return fmt.Errorf("%w: %s", zip.ErrInsecurePath, f.Name) - } - - n := filepath.Join(p, f.Name) - if err := os.MkdirAll(filepath.Dir(n), 0o750); err != nil { - return err - } - - // TODO(mxyng): this should not write out all files to disk - outfile, err := os.Create(n) - if err != nil { - return err - } - defer outfile.Close() - - infile, err := f.Open() - if err != nil { - return err - } - defer infile.Close() - - if _, err = io.Copy(outfile, infile); err != nil { - return err - } - - if err := outfile.Close(); err != nil { - return err - } - - if err := infile.Close(); err != nil { - return err - } - } - - return nil -} - -func parseFromZipFile(_ context.Context, file *os.File, digest string, fn func(api.ProgressResponse)) (layers []*layerGGML, err error) { - tempDir, err := os.MkdirTemp(filepath.Dir(file.Name()), "") +func parseFromZipFile(_ context.Context, f *os.File, digest string, fn func(api.ProgressResponse)) (layers []*layerGGML, err error) { + fi, err := f.Stat() if err != nil { return nil, err } - defer os.RemoveAll(tempDir) - if err := extractFromZipFile(tempDir, file, fn); err != nil { + r, err := zip.NewReader(f, fi.Size()) + if err != nil { return nil, err } + p, err := os.MkdirTemp(filepath.Dir(f.Name()), "") + if err != nil { + return nil, err + } + defer os.RemoveAll(p) + fn(api.ProgressResponse{Status: "converting model"}) - // TODO(mxyng): this should write directly into a layer // e.g. NewLayer(arch.Reader(), "application/vnd.ollama.image.model") - temp, err := os.CreateTemp(tempDir, "fp16") + t, err := os.CreateTemp(p, "fp16") if err != nil { return nil, err } - defer temp.Close() - defer os.Remove(temp.Name()) + defer t.Close() + defer os.Remove(t.Name()) - if err := convert.Convert(tempDir, temp); err != nil { + fn(api.ProgressResponse{Status: "converting model"}) + if err := convert.Convert(convert.NewZipReader(r, p, 32<<20), t); err != nil { return nil, err } - if _, err := temp.Seek(0, io.SeekStart); err != nil { + if _, err := t.Seek(0, io.SeekStart); err != nil { return nil, err } - layer, err := NewLayer(temp, "application/vnd.ollama.image.model") + layer, err := NewLayer(t, "application/vnd.ollama.image.model") if err != nil { return nil, err } diff --git a/server/model_test.go b/server/model_test.go index 5829adfc..0a2225d5 100644 --- a/server/model_test.go +++ b/server/model_test.go @@ -1,16 +1,11 @@ package server import ( - "archive/zip" "bytes" "encoding/json" - "errors" "fmt" - "io" "os" "path/filepath" - "slices" - "strings" "testing" "github.com/google/go-cmp/cmp" @@ -18,103 +13,6 @@ import ( "github.com/ollama/ollama/template" ) -func createZipFile(t *testing.T, name string) *os.File { - t.Helper() - - f, err := os.CreateTemp(t.TempDir(), "") - if err != nil { - t.Fatal(err) - } - - zf := zip.NewWriter(f) - defer zf.Close() - - zh, err := zf.CreateHeader(&zip.FileHeader{Name: name}) - if err != nil { - t.Fatal(err) - } - - if _, err := io.Copy(zh, bytes.NewReader([]byte(""))); err != nil { - t.Fatal(err) - } - - return f -} - -func TestExtractFromZipFile(t *testing.T) { - cases := []struct { - name string - expect []string - err error - }{ - { - name: "good", - expect: []string{"good"}, - }, - { - name: strings.Join([]string{"path", "..", "to", "good"}, string(os.PathSeparator)), - expect: []string{filepath.Join("to", "good")}, - }, - { - name: strings.Join([]string{"path", "..", "to", "..", "good"}, string(os.PathSeparator)), - expect: []string{"good"}, - }, - { - name: strings.Join([]string{"path", "to", "..", "..", "good"}, string(os.PathSeparator)), - expect: []string{"good"}, - }, - { - name: strings.Join([]string{"..", "..", "..", "..", "..", "..", "..", "..", "..", "..", "..", "..", "..", "..", "..", "..", "bad"}, string(os.PathSeparator)), - err: zip.ErrInsecurePath, - }, - { - name: strings.Join([]string{"path", "..", "..", "to", "bad"}, string(os.PathSeparator)), - err: zip.ErrInsecurePath, - }, - } - - for _, tt := range cases { - t.Run(tt.name, func(t *testing.T) { - f := createZipFile(t, tt.name) - defer f.Close() - - tempDir := t.TempDir() - if err := extractFromZipFile(tempDir, f, func(api.ProgressResponse) {}); !errors.Is(err, tt.err) { - t.Fatal(err) - } - - var matches []string - if err := filepath.Walk(tempDir, func(p string, fi os.FileInfo, err error) error { - if err != nil { - return err - } - - if !fi.IsDir() { - matches = append(matches, p) - } - - return nil - }); err != nil { - t.Fatal(err) - } - - var actual []string - for _, match := range matches { - rel, err := filepath.Rel(tempDir, match) - if err != nil { - t.Error(err) - } - - actual = append(actual, rel) - } - - if !slices.Equal(actual, tt.expect) { - t.Fatalf("expected %d files, got %d", len(tt.expect), len(matches)) - } - }) - } -} - func readFile(t *testing.T, base, name string) *bytes.Buffer { t.Helper()