84 lines
1.5 KiB
Go
84 lines
1.5 KiB
Go
package convert
|
|
|
|
import (
|
|
"errors"
|
|
"io"
|
|
"io/fs"
|
|
"strings"
|
|
)
|
|
|
|
type Tensor interface {
|
|
Name() string
|
|
Shape() []uint64
|
|
Kind() uint32
|
|
SetRepacker(repacker)
|
|
WriteTo(io.Writer) (int64, error)
|
|
}
|
|
|
|
type tensorBase struct {
|
|
name string
|
|
shape []uint64
|
|
repacker
|
|
}
|
|
|
|
func (t tensorBase) Name() string {
|
|
return t.name
|
|
}
|
|
|
|
func (t tensorBase) Shape() []uint64 {
|
|
return t.shape
|
|
}
|
|
|
|
const (
|
|
tensorKindF32 uint32 = iota
|
|
tensorKindF16
|
|
)
|
|
|
|
func (t tensorBase) Kind() uint32 {
|
|
if strings.HasSuffix(t.name, ".block_sparse_moe.gate.weight") {
|
|
return 0
|
|
} else if t.name == "embeddings.token_type_embeddings.weight" {
|
|
return 0
|
|
}
|
|
|
|
switch len(t.shape) {
|
|
case 0:
|
|
panic("invalid tensor shape")
|
|
case 1:
|
|
return tensorKindF32
|
|
default:
|
|
return tensorKindF16
|
|
}
|
|
}
|
|
|
|
func (t *tensorBase) SetRepacker(fn repacker) {
|
|
t.repacker = fn
|
|
}
|
|
|
|
type repacker func(string, []float32, []uint64) ([]float32, error)
|
|
|
|
func parseTensors(fsys fs.FS) ([]Tensor, error) {
|
|
patterns := []struct {
|
|
Pattern string
|
|
Func func(fs.FS, ...string) ([]Tensor, error)
|
|
}{
|
|
{"model-*-of-*.safetensors", parseSafetensors},
|
|
{"model.safetensors", parseSafetensors},
|
|
{"pytorch_model-*-of-*.bin", parseTorch},
|
|
{"pytorch_model.bin", parseTorch},
|
|
{"consolidated.*.pth", parseTorch},
|
|
}
|
|
|
|
for _, pattern := range patterns {
|
|
matches, err := fs.Glob(fsys, pattern.Pattern)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
if len(matches) > 0 {
|
|
return pattern.Func(fsys, matches...)
|
|
}
|
|
}
|
|
|
|
return nil, errors.New("unknown tensor format")
|
|
}
|