2024-05-31 20:00:49 -07:00
|
|
|
package convert
|
|
|
|
|
|
|
|
import (
|
|
|
|
"errors"
|
|
|
|
"io"
|
2024-06-29 16:53:59 -07:00
|
|
|
"io/fs"
|
2024-05-31 20:00:49 -07:00
|
|
|
"strings"
|
|
|
|
)
|
|
|
|
|
|
|
|
type Tensor interface {
|
|
|
|
Name() string
|
|
|
|
Shape() []uint64
|
|
|
|
Kind() uint32
|
|
|
|
SetRepacker(repacker)
|
|
|
|
WriteTo(io.Writer) (int64, error)
|
|
|
|
}
|
|
|
|
|
|
|
|
type tensorBase struct {
|
|
|
|
name string
|
|
|
|
shape []uint64
|
|
|
|
repacker
|
|
|
|
}
|
|
|
|
|
|
|
|
func (t tensorBase) Name() string {
|
|
|
|
return t.name
|
|
|
|
}
|
|
|
|
|
|
|
|
func (t tensorBase) Shape() []uint64 {
|
|
|
|
return t.shape
|
|
|
|
}
|
|
|
|
|
2024-07-08 16:59:48 -07:00
|
|
|
const (
|
|
|
|
tensorKindF32 uint32 = iota
|
|
|
|
tensorKindF16
|
|
|
|
)
|
|
|
|
|
2024-05-31 20:00:49 -07:00
|
|
|
func (t tensorBase) Kind() uint32 {
|
2024-06-28 13:27:05 -07:00
|
|
|
if strings.HasSuffix(t.name, ".ffn_gate_inp.weight") ||
|
|
|
|
t.name == "token_types.weight" {
|
|
|
|
// these tensors are always F32
|
2024-06-06 08:59:04 -07:00
|
|
|
return 0
|
2024-05-31 20:00:49 -07:00
|
|
|
}
|
|
|
|
|
|
|
|
switch len(t.shape) {
|
|
|
|
case 0:
|
|
|
|
panic("invalid tensor shape")
|
|
|
|
case 1:
|
2024-07-08 16:59:48 -07:00
|
|
|
return tensorKindF32
|
2024-05-31 20:00:49 -07:00
|
|
|
default:
|
2024-07-08 16:59:48 -07:00
|
|
|
return tensorKindF16
|
2024-05-31 20:00:49 -07:00
|
|
|
}
|
|
|
|
}
|
|
|
|
|
|
|
|
func (t *tensorBase) SetRepacker(fn repacker) {
|
|
|
|
t.repacker = fn
|
|
|
|
}
|
|
|
|
|
|
|
|
type repacker func(string, []float32, []uint64) ([]float32, error)
|
|
|
|
|
2024-06-28 13:27:05 -07:00
|
|
|
func parseTensors(fsys fs.FS, replacer *strings.Replacer) ([]Tensor, error) {
|
2024-07-31 15:39:11 -07:00
|
|
|
patterns := []struct {
|
|
|
|
Pattern string
|
2024-06-28 13:27:05 -07:00
|
|
|
Func func(fs.FS, *strings.Replacer, ...string) ([]Tensor, error)
|
2024-07-31 15:39:11 -07:00
|
|
|
}{
|
|
|
|
{"model-*-of-*.safetensors", parseSafetensors},
|
|
|
|
{"model.safetensors", parseSafetensors},
|
2024-08-23 11:29:56 -07:00
|
|
|
{"adapters.safetensors", parseSafetensors},
|
|
|
|
{"adapter_model.safetensors", parseSafetensors},
|
2024-07-31 15:39:11 -07:00
|
|
|
{"pytorch_model-*-of-*.bin", parseTorch},
|
|
|
|
{"pytorch_model.bin", parseTorch},
|
|
|
|
{"consolidated.*.pth", parseTorch},
|
2024-05-31 20:00:49 -07:00
|
|
|
}
|
|
|
|
|
2024-07-31 15:39:11 -07:00
|
|
|
for _, pattern := range patterns {
|
|
|
|
matches, err := fs.Glob(fsys, pattern.Pattern)
|
2024-05-31 20:00:49 -07:00
|
|
|
if err != nil {
|
|
|
|
return nil, err
|
|
|
|
}
|
|
|
|
|
|
|
|
if len(matches) > 0 {
|
2024-06-28 13:27:05 -07:00
|
|
|
return pattern.Func(fsys, replacer, matches...)
|
2024-05-31 20:00:49 -07:00
|
|
|
}
|
|
|
|
}
|
|
|
|
|
|
|
|
return nil, errors.New("unknown tensor format")
|
|
|
|
}
|