convert: only extract large files
This commit is contained in:
parent
781fc2d576
commit
eafc607abb
10 changed files with 120 additions and 200 deletions
|
@ -5,9 +5,8 @@ import (
|
||||||
"errors"
|
"errors"
|
||||||
"fmt"
|
"fmt"
|
||||||
"io"
|
"io"
|
||||||
|
"io/fs"
|
||||||
"log/slog"
|
"log/slog"
|
||||||
"os"
|
|
||||||
"path/filepath"
|
|
||||||
|
|
||||||
"github.com/ollama/ollama/llm"
|
"github.com/ollama/ollama/llm"
|
||||||
)
|
)
|
||||||
|
@ -67,8 +66,8 @@ type Converter interface {
|
||||||
// and files it finds in the input path.
|
// and files it finds in the input path.
|
||||||
// Supported input model formats include safetensors.
|
// Supported input model formats include safetensors.
|
||||||
// Supported input tokenizers files include tokenizer.json (preferred) and tokenizer.model.
|
// Supported input tokenizers files include tokenizer.json (preferred) and tokenizer.model.
|
||||||
func Convert(path string, ws io.WriteSeeker) error {
|
func Convert(fsys fs.FS, ws io.WriteSeeker) error {
|
||||||
bts, err := os.ReadFile(filepath.Join(path, "config.json"))
|
bts, err := fs.ReadFile(fsys, "config.json")
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
@ -98,7 +97,7 @@ func Convert(path string, ws io.WriteSeeker) error {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
t, err := parseTokenizer(path, conv.specialTokenTypes())
|
t, err := parseTokenizer(fsys, conv.specialTokenTypes())
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
@ -114,7 +113,7 @@ func Convert(path string, ws io.WriteSeeker) error {
|
||||||
slog.Debug("vocabulary", "size", len(t.Vocabulary.Tokens))
|
slog.Debug("vocabulary", "size", len(t.Vocabulary.Tokens))
|
||||||
}
|
}
|
||||||
|
|
||||||
ts, err := parseTensors(path)
|
ts, err := parseTensors(fsys)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
|
@ -6,6 +6,7 @@ import (
|
||||||
"flag"
|
"flag"
|
||||||
"fmt"
|
"fmt"
|
||||||
"io"
|
"io"
|
||||||
|
"io/fs"
|
||||||
"log/slog"
|
"log/slog"
|
||||||
"math"
|
"math"
|
||||||
"os"
|
"os"
|
||||||
|
@ -17,7 +18,7 @@ import (
|
||||||
"golang.org/x/exp/maps"
|
"golang.org/x/exp/maps"
|
||||||
)
|
)
|
||||||
|
|
||||||
func convertFull(t *testing.T, d string) (*os.File, llm.KV, llm.Tensors) {
|
func convertFull(t *testing.T, fsys fs.FS) (*os.File, llm.KV, llm.Tensors) {
|
||||||
t.Helper()
|
t.Helper()
|
||||||
|
|
||||||
f, err := os.CreateTemp(t.TempDir(), "f16")
|
f, err := os.CreateTemp(t.TempDir(), "f16")
|
||||||
|
@ -26,7 +27,7 @@ func convertFull(t *testing.T, d string) (*os.File, llm.KV, llm.Tensors) {
|
||||||
}
|
}
|
||||||
defer f.Close()
|
defer f.Close()
|
||||||
|
|
||||||
if err := Convert(d, f); err != nil {
|
if err := Convert(fsys, f); err != nil {
|
||||||
t.Fatal(err)
|
t.Fatal(err)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -76,7 +77,7 @@ func TestConvertFull(t *testing.T) {
|
||||||
t.Skipf("%s not found", p)
|
t.Skipf("%s not found", p)
|
||||||
}
|
}
|
||||||
|
|
||||||
f, kv, tensors := convertFull(t, p)
|
f, kv, tensors := convertFull(t, os.DirFS(p))
|
||||||
actual := make(map[string]string)
|
actual := make(map[string]string)
|
||||||
for k, v := range kv {
|
for k, v := range kv {
|
||||||
if s, ok := v.(json.Marshaler); !ok {
|
if s, ok := v.(json.Marshaler); !ok {
|
||||||
|
|
58
convert/fs.go
Normal file
58
convert/fs.go
Normal file
|
@ -0,0 +1,58 @@
|
||||||
|
package convert
|
||||||
|
|
||||||
|
import (
|
||||||
|
"archive/zip"
|
||||||
|
"errors"
|
||||||
|
"io"
|
||||||
|
"io/fs"
|
||||||
|
"os"
|
||||||
|
"path/filepath"
|
||||||
|
)
|
||||||
|
|
||||||
|
type ZipReader struct {
|
||||||
|
r *zip.Reader
|
||||||
|
p string
|
||||||
|
|
||||||
|
// limit is the maximum size of a file that can be read directly
|
||||||
|
// from the zip archive. Files larger than this size will be extracted
|
||||||
|
limit int64
|
||||||
|
}
|
||||||
|
|
||||||
|
func NewZipReader(r *zip.Reader, p string, limit int64) fs.FS {
|
||||||
|
return &ZipReader{r, p, limit}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (z *ZipReader) Open(name string) (fs.File, error) {
|
||||||
|
r, err := z.r.Open(name)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
defer r.Close()
|
||||||
|
|
||||||
|
if fi, err := r.Stat(); err != nil {
|
||||||
|
return nil, err
|
||||||
|
} else if fi.Size() < z.limit {
|
||||||
|
return r, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
if !filepath.IsLocal(name) {
|
||||||
|
return nil, zip.ErrInsecurePath
|
||||||
|
}
|
||||||
|
|
||||||
|
n := filepath.Join(z.p, name)
|
||||||
|
if _, err := os.Stat(n); errors.Is(err, os.ErrNotExist) {
|
||||||
|
w, err := os.Create(n)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
defer w.Close()
|
||||||
|
|
||||||
|
if _, err := io.Copy(w, r); err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
} else if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
return os.Open(n)
|
||||||
|
}
|
|
@ -3,7 +3,7 @@ package convert
|
||||||
import (
|
import (
|
||||||
"errors"
|
"errors"
|
||||||
"io"
|
"io"
|
||||||
"path/filepath"
|
"io/fs"
|
||||||
"strings"
|
"strings"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@ -55,8 +55,8 @@ func (t *tensorBase) SetRepacker(fn repacker) {
|
||||||
|
|
||||||
type repacker func(string, []float32, []uint64) ([]float32, error)
|
type repacker func(string, []float32, []uint64) ([]float32, error)
|
||||||
|
|
||||||
func parseTensors(d string) ([]Tensor, error) {
|
func parseTensors(fsys fs.FS) ([]Tensor, error) {
|
||||||
patterns := map[string]func(...string) ([]Tensor, error){
|
patterns := map[string]func(fs.FS, ...string) ([]Tensor, error){
|
||||||
"model-*-of-*.safetensors": parseSafetensors,
|
"model-*-of-*.safetensors": parseSafetensors,
|
||||||
"model.safetensors": parseSafetensors,
|
"model.safetensors": parseSafetensors,
|
||||||
"pytorch_model-*-of-*.bin": parseTorch,
|
"pytorch_model-*-of-*.bin": parseTorch,
|
||||||
|
@ -65,13 +65,13 @@ func parseTensors(d string) ([]Tensor, error) {
|
||||||
}
|
}
|
||||||
|
|
||||||
for pattern, parseFn := range patterns {
|
for pattern, parseFn := range patterns {
|
||||||
matches, err := filepath.Glob(filepath.Join(d, pattern))
|
matches, err := fs.Glob(fsys, pattern)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
if len(matches) > 0 {
|
if len(matches) > 0 {
|
||||||
return parseFn(matches...)
|
return parseFn(fsys, matches...)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -6,7 +6,7 @@ import (
|
||||||
"encoding/json"
|
"encoding/json"
|
||||||
"fmt"
|
"fmt"
|
||||||
"io"
|
"io"
|
||||||
"os"
|
"io/fs"
|
||||||
"slices"
|
"slices"
|
||||||
|
|
||||||
"github.com/d4l3k/go-bfloat16"
|
"github.com/d4l3k/go-bfloat16"
|
||||||
|
@ -20,10 +20,10 @@ type safetensorMetadata struct {
|
||||||
Offsets []int64 `json:"data_offsets"`
|
Offsets []int64 `json:"data_offsets"`
|
||||||
}
|
}
|
||||||
|
|
||||||
func parseSafetensors(ps ...string) ([]Tensor, error) {
|
func parseSafetensors(fsys fs.FS, ps ...string) ([]Tensor, error) {
|
||||||
var ts []Tensor
|
var ts []Tensor
|
||||||
for _, p := range ps {
|
for _, p := range ps {
|
||||||
f, err := os.Open(p)
|
f, err := fsys.Open(p)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
@ -50,6 +50,7 @@ func parseSafetensors(ps ...string) ([]Tensor, error) {
|
||||||
for _, key := range keys {
|
for _, key := range keys {
|
||||||
if value := headers[key]; value.Type != "" {
|
if value := headers[key]; value.Type != "" {
|
||||||
ts = append(ts, safetensor{
|
ts = append(ts, safetensor{
|
||||||
|
fs: fsys,
|
||||||
path: p,
|
path: p,
|
||||||
dtype: value.Type,
|
dtype: value.Type,
|
||||||
offset: safetensorsPad(n, value.Offsets[0]),
|
offset: safetensorsPad(n, value.Offsets[0]),
|
||||||
|
@ -72,6 +73,7 @@ func safetensorsPad(n, offset int64) int64 {
|
||||||
}
|
}
|
||||||
|
|
||||||
type safetensor struct {
|
type safetensor struct {
|
||||||
|
fs fs.FS
|
||||||
path string
|
path string
|
||||||
dtype string
|
dtype string
|
||||||
offset int64
|
offset int64
|
||||||
|
@ -80,15 +82,21 @@ type safetensor struct {
|
||||||
}
|
}
|
||||||
|
|
||||||
func (st safetensor) WriteTo(w io.Writer) (int64, error) {
|
func (st safetensor) WriteTo(w io.Writer) (int64, error) {
|
||||||
f, err := os.Open(st.path)
|
f, err := st.fs.Open(st.path)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return 0, err
|
return 0, err
|
||||||
}
|
}
|
||||||
defer f.Close()
|
defer f.Close()
|
||||||
|
|
||||||
if _, err = f.Seek(st.offset, io.SeekStart); err != nil {
|
if seeker, ok := f.(io.Seeker); ok {
|
||||||
|
if _, err := seeker.Seek(st.offset, io.SeekStart); err != nil {
|
||||||
return 0, err
|
return 0, err
|
||||||
}
|
}
|
||||||
|
} else {
|
||||||
|
if _, err := io.CopyN(io.Discard, f, st.offset); err != nil {
|
||||||
|
return 0, err
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
var f32s []float32
|
var f32s []float32
|
||||||
switch st.dtype {
|
switch st.dtype {
|
||||||
|
|
|
@ -2,12 +2,13 @@ package convert
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"io"
|
"io"
|
||||||
|
"io/fs"
|
||||||
|
|
||||||
"github.com/nlpodyssey/gopickle/pytorch"
|
"github.com/nlpodyssey/gopickle/pytorch"
|
||||||
"github.com/nlpodyssey/gopickle/types"
|
"github.com/nlpodyssey/gopickle/types"
|
||||||
)
|
)
|
||||||
|
|
||||||
func parseTorch(ps ...string) ([]Tensor, error) {
|
func parseTorch(fsys fs.FS, ps ...string) ([]Tensor, error) {
|
||||||
var ts []Tensor
|
var ts []Tensor
|
||||||
for _, p := range ps {
|
for _, p := range ps {
|
||||||
pt, err := pytorch.Load(p)
|
pt, err := pytorch.Load(p)
|
||||||
|
|
|
@ -7,9 +7,9 @@ import (
|
||||||
"encoding/json"
|
"encoding/json"
|
||||||
"errors"
|
"errors"
|
||||||
"fmt"
|
"fmt"
|
||||||
|
"io/fs"
|
||||||
"log/slog"
|
"log/slog"
|
||||||
"os"
|
"os"
|
||||||
"path/filepath"
|
|
||||||
"slices"
|
"slices"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@ -32,8 +32,8 @@ type Tokenizer struct {
|
||||||
Template string
|
Template string
|
||||||
}
|
}
|
||||||
|
|
||||||
func parseTokenizer(d string, specialTokenTypes []string) (*Tokenizer, error) {
|
func parseTokenizer(fsys fs.FS, specialTokenTypes []string) (*Tokenizer, error) {
|
||||||
v, err := parseVocabulary(d)
|
v, err := parseVocabulary(fsys)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
@ -44,7 +44,7 @@ func parseTokenizer(d string, specialTokenTypes []string) (*Tokenizer, error) {
|
||||||
}
|
}
|
||||||
|
|
||||||
addedTokens := make(map[string]token)
|
addedTokens := make(map[string]token)
|
||||||
if f, err := os.Open(filepath.Join(d, "tokenizer.json")); errors.Is(err, os.ErrNotExist) {
|
if f, err := fsys.Open("tokenizer.json"); errors.Is(err, os.ErrNotExist) {
|
||||||
} else if err != nil {
|
} else if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
} else {
|
} else {
|
||||||
|
@ -87,7 +87,7 @@ func parseTokenizer(d string, specialTokenTypes []string) (*Tokenizer, error) {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
if f, err := os.Open(filepath.Join(d, "tokenizer_config.json")); errors.Is(err, os.ErrNotExist) {
|
if f, err := fsys.Open("tokenizer_config.json"); errors.Is(err, os.ErrNotExist) {
|
||||||
} else if err != nil {
|
} else if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
} else {
|
} else {
|
||||||
|
@ -172,8 +172,8 @@ type Vocabulary struct {
|
||||||
Types []int32
|
Types []int32
|
||||||
}
|
}
|
||||||
|
|
||||||
func parseVocabularyFromTokenizer(p string) (*Vocabulary, error) {
|
func parseVocabularyFromTokenizer(fsys fs.FS) (*Vocabulary, error) {
|
||||||
f, err := os.Open(filepath.Join(p, "tokenizer.json"))
|
f, err := fsys.Open("tokenizer.json")
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
@ -219,20 +219,20 @@ func parseVocabularyFromTokenizer(p string) (*Vocabulary, error) {
|
||||||
return &v, nil
|
return &v, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func parseVocabulary(d string) (*Vocabulary, error) {
|
func parseVocabulary(fsys fs.FS) (*Vocabulary, error) {
|
||||||
patterns := map[string]func(string) (*Vocabulary, error){
|
patterns := map[string]func(fs.FS) (*Vocabulary, error){
|
||||||
"tokenizer.model": parseSentencePiece,
|
"tokenizer.model": parseSentencePiece,
|
||||||
"tokenizer.json": parseVocabularyFromTokenizer,
|
"tokenizer.json": parseVocabularyFromTokenizer,
|
||||||
}
|
}
|
||||||
|
|
||||||
for pattern, parseFn := range patterns {
|
for pattern, parseFn := range patterns {
|
||||||
if _, err := os.Stat(filepath.Join(d, pattern)); errors.Is(err, os.ErrNotExist) {
|
if _, err := fs.Stat(fsys, pattern); errors.Is(err, os.ErrNotExist) {
|
||||||
continue
|
continue
|
||||||
} else if err != nil {
|
} else if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
return parseFn(d)
|
return parseFn(fsys)
|
||||||
}
|
}
|
||||||
|
|
||||||
return nil, errors.New("unknown tensor format")
|
return nil, errors.New("unknown tensor format")
|
||||||
|
|
|
@ -5,8 +5,8 @@ import (
|
||||||
"encoding/json"
|
"encoding/json"
|
||||||
"errors"
|
"errors"
|
||||||
"fmt"
|
"fmt"
|
||||||
|
"io/fs"
|
||||||
"os"
|
"os"
|
||||||
"path/filepath"
|
|
||||||
"slices"
|
"slices"
|
||||||
|
|
||||||
"google.golang.org/protobuf/proto"
|
"google.golang.org/protobuf/proto"
|
||||||
|
@ -14,8 +14,8 @@ import (
|
||||||
"github.com/ollama/ollama/convert/sentencepiece"
|
"github.com/ollama/ollama/convert/sentencepiece"
|
||||||
)
|
)
|
||||||
|
|
||||||
func parseSentencePiece(d string) (*Vocabulary, error) {
|
func parseSentencePiece(fsys fs.FS) (*Vocabulary, error) {
|
||||||
bts, err := os.ReadFile(filepath.Join(d, "tokenizer.model"))
|
bts, err := fs.ReadFile(fsys, "tokenizer.model")
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
@ -41,7 +41,7 @@ func parseSentencePiece(d string) (*Vocabulary, error) {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
f, err := os.Open(filepath.Join(d, "added_tokens.json"))
|
f, err := fsys.Open("added_tokens.json")
|
||||||
if errors.Is(err, os.ErrNotExist) {
|
if errors.Is(err, os.ErrNotExist) {
|
||||||
return &v, nil
|
return &v, nil
|
||||||
} else if err != nil {
|
} else if err != nil {
|
||||||
|
|
|
@ -81,88 +81,43 @@ func parseFromModel(ctx context.Context, name model.Name, fn func(api.ProgressRe
|
||||||
return layers, nil
|
return layers, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func extractFromZipFile(p string, file *os.File, fn func(api.ProgressResponse)) error {
|
func parseFromZipFile(_ context.Context, f *os.File, digest string, fn func(api.ProgressResponse)) (layers []*layerGGML, err error) {
|
||||||
stat, err := file.Stat()
|
fi, err := f.Stat()
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
|
|
||||||
r, err := zip.NewReader(file, stat.Size())
|
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
|
|
||||||
fn(api.ProgressResponse{Status: "unpacking model metadata"})
|
|
||||||
for _, f := range r.File {
|
|
||||||
if !filepath.IsLocal(f.Name) {
|
|
||||||
return fmt.Errorf("%w: %s", zip.ErrInsecurePath, f.Name)
|
|
||||||
}
|
|
||||||
|
|
||||||
n := filepath.Join(p, f.Name)
|
|
||||||
if err := os.MkdirAll(filepath.Dir(n), 0o750); err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
|
|
||||||
// TODO(mxyng): this should not write out all files to disk
|
|
||||||
outfile, err := os.Create(n)
|
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
defer outfile.Close()
|
|
||||||
|
|
||||||
infile, err := f.Open()
|
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
defer infile.Close()
|
|
||||||
|
|
||||||
if _, err = io.Copy(outfile, infile); err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
|
|
||||||
if err := outfile.Close(); err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
|
|
||||||
if err := infile.Close(); err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func parseFromZipFile(_ context.Context, file *os.File, digest string, fn func(api.ProgressResponse)) (layers []*layerGGML, err error) {
|
|
||||||
tempDir, err := os.MkdirTemp(filepath.Dir(file.Name()), "")
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
defer os.RemoveAll(tempDir)
|
|
||||||
|
|
||||||
if err := extractFromZipFile(tempDir, file, fn); err != nil {
|
r, err := zip.NewReader(f, fi.Size())
|
||||||
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
|
p, err := os.MkdirTemp(filepath.Dir(f.Name()), "")
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
defer os.RemoveAll(p)
|
||||||
|
|
||||||
fn(api.ProgressResponse{Status: "converting model"})
|
fn(api.ProgressResponse{Status: "converting model"})
|
||||||
|
|
||||||
// TODO(mxyng): this should write directly into a layer
|
// TODO(mxyng): this should write directly into a layer
|
||||||
// e.g. NewLayer(arch.Reader(), "application/vnd.ollama.image.model")
|
// e.g. NewLayer(arch.Reader(), "application/vnd.ollama.image.model")
|
||||||
temp, err := os.CreateTemp(tempDir, "fp16")
|
t, err := os.CreateTemp(p, "fp16")
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
defer temp.Close()
|
defer t.Close()
|
||||||
defer os.Remove(temp.Name())
|
defer os.Remove(t.Name())
|
||||||
|
|
||||||
if err := convert.Convert(tempDir, temp); err != nil {
|
fn(api.ProgressResponse{Status: "converting model"})
|
||||||
|
if err := convert.Convert(convert.NewZipReader(r, p, 32<<20), t); err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
if _, err := temp.Seek(0, io.SeekStart); err != nil {
|
if _, err := t.Seek(0, io.SeekStart); err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
layer, err := NewLayer(temp, "application/vnd.ollama.image.model")
|
layer, err := NewLayer(t, "application/vnd.ollama.image.model")
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
|
@ -1,16 +1,11 @@
|
||||||
package server
|
package server
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"archive/zip"
|
|
||||||
"bytes"
|
"bytes"
|
||||||
"encoding/json"
|
"encoding/json"
|
||||||
"errors"
|
|
||||||
"fmt"
|
"fmt"
|
||||||
"io"
|
|
||||||
"os"
|
"os"
|
||||||
"path/filepath"
|
"path/filepath"
|
||||||
"slices"
|
|
||||||
"strings"
|
|
||||||
"testing"
|
"testing"
|
||||||
|
|
||||||
"github.com/google/go-cmp/cmp"
|
"github.com/google/go-cmp/cmp"
|
||||||
|
@ -18,103 +13,6 @@ import (
|
||||||
"github.com/ollama/ollama/template"
|
"github.com/ollama/ollama/template"
|
||||||
)
|
)
|
||||||
|
|
||||||
func createZipFile(t *testing.T, name string) *os.File {
|
|
||||||
t.Helper()
|
|
||||||
|
|
||||||
f, err := os.CreateTemp(t.TempDir(), "")
|
|
||||||
if err != nil {
|
|
||||||
t.Fatal(err)
|
|
||||||
}
|
|
||||||
|
|
||||||
zf := zip.NewWriter(f)
|
|
||||||
defer zf.Close()
|
|
||||||
|
|
||||||
zh, err := zf.CreateHeader(&zip.FileHeader{Name: name})
|
|
||||||
if err != nil {
|
|
||||||
t.Fatal(err)
|
|
||||||
}
|
|
||||||
|
|
||||||
if _, err := io.Copy(zh, bytes.NewReader([]byte(""))); err != nil {
|
|
||||||
t.Fatal(err)
|
|
||||||
}
|
|
||||||
|
|
||||||
return f
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestExtractFromZipFile(t *testing.T) {
|
|
||||||
cases := []struct {
|
|
||||||
name string
|
|
||||||
expect []string
|
|
||||||
err error
|
|
||||||
}{
|
|
||||||
{
|
|
||||||
name: "good",
|
|
||||||
expect: []string{"good"},
|
|
||||||
},
|
|
||||||
{
|
|
||||||
name: strings.Join([]string{"path", "..", "to", "good"}, string(os.PathSeparator)),
|
|
||||||
expect: []string{filepath.Join("to", "good")},
|
|
||||||
},
|
|
||||||
{
|
|
||||||
name: strings.Join([]string{"path", "..", "to", "..", "good"}, string(os.PathSeparator)),
|
|
||||||
expect: []string{"good"},
|
|
||||||
},
|
|
||||||
{
|
|
||||||
name: strings.Join([]string{"path", "to", "..", "..", "good"}, string(os.PathSeparator)),
|
|
||||||
expect: []string{"good"},
|
|
||||||
},
|
|
||||||
{
|
|
||||||
name: strings.Join([]string{"..", "..", "..", "..", "..", "..", "..", "..", "..", "..", "..", "..", "..", "..", "..", "..", "bad"}, string(os.PathSeparator)),
|
|
||||||
err: zip.ErrInsecurePath,
|
|
||||||
},
|
|
||||||
{
|
|
||||||
name: strings.Join([]string{"path", "..", "..", "to", "bad"}, string(os.PathSeparator)),
|
|
||||||
err: zip.ErrInsecurePath,
|
|
||||||
},
|
|
||||||
}
|
|
||||||
|
|
||||||
for _, tt := range cases {
|
|
||||||
t.Run(tt.name, func(t *testing.T) {
|
|
||||||
f := createZipFile(t, tt.name)
|
|
||||||
defer f.Close()
|
|
||||||
|
|
||||||
tempDir := t.TempDir()
|
|
||||||
if err := extractFromZipFile(tempDir, f, func(api.ProgressResponse) {}); !errors.Is(err, tt.err) {
|
|
||||||
t.Fatal(err)
|
|
||||||
}
|
|
||||||
|
|
||||||
var matches []string
|
|
||||||
if err := filepath.Walk(tempDir, func(p string, fi os.FileInfo, err error) error {
|
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
|
|
||||||
if !fi.IsDir() {
|
|
||||||
matches = append(matches, p)
|
|
||||||
}
|
|
||||||
|
|
||||||
return nil
|
|
||||||
}); err != nil {
|
|
||||||
t.Fatal(err)
|
|
||||||
}
|
|
||||||
|
|
||||||
var actual []string
|
|
||||||
for _, match := range matches {
|
|
||||||
rel, err := filepath.Rel(tempDir, match)
|
|
||||||
if err != nil {
|
|
||||||
t.Error(err)
|
|
||||||
}
|
|
||||||
|
|
||||||
actual = append(actual, rel)
|
|
||||||
}
|
|
||||||
|
|
||||||
if !slices.Equal(actual, tt.expect) {
|
|
||||||
t.Fatalf("expected %d files, got %d", len(tt.expect), len(matches))
|
|
||||||
}
|
|
||||||
})
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func readFile(t *testing.T, base, name string) *bytes.Buffer {
|
func readFile(t *testing.T, base, name string) *bytes.Buffer {
|
||||||
t.Helper()
|
t.Helper()
|
||||||
|
|
||||||
|
|
Loading…
Reference in a new issue