throw an error when encountering unsupport tensor sizes (#6538)
This commit is contained in:
parent
93ea9240ae
commit
6c1c1ad6a9
2 changed files with 106 additions and 0 deletions
|
@ -140,6 +140,107 @@ func TestConvertFull(t *testing.T) {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestConvertInvalidDatatype(t *testing.T) {
|
||||||
|
f, err := os.CreateTemp(t.TempDir(), "testmodel")
|
||||||
|
if err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
defer f.Close()
|
||||||
|
|
||||||
|
tempDir := t.TempDir()
|
||||||
|
generateSafetensorTestData(t, tempDir)
|
||||||
|
|
||||||
|
err = ConvertModel(os.DirFS(tempDir), f)
|
||||||
|
if err == nil || err.Error() != "unsupported safetensors model" {
|
||||||
|
t.Errorf("expected error but didn't get one")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func generateSafetensorTestData(t *testing.T, tempDir string) {
|
||||||
|
type tensorData struct {
|
||||||
|
Offsets []int `json:"data_offsets"`
|
||||||
|
Type string `json:"dtype"`
|
||||||
|
Shape []int `json:"shape"`
|
||||||
|
}
|
||||||
|
offset := 4096 * 14336
|
||||||
|
|
||||||
|
td := map[string]*tensorData{}
|
||||||
|
td["model.layers.0.mlp.down_proj.weight"] = &tensorData{
|
||||||
|
Offsets: []int{0, offset},
|
||||||
|
Type: "I8",
|
||||||
|
Shape: []int{4096, 14336},
|
||||||
|
}
|
||||||
|
td["model.layers.0.mlp.down_proj.weight_format"] = &tensorData{
|
||||||
|
Offsets: []int{offset, offset},
|
||||||
|
Type: "U8",
|
||||||
|
Shape: []int{},
|
||||||
|
}
|
||||||
|
|
||||||
|
data, err := json.Marshal(td)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
|
||||||
|
var buf bytes.Buffer
|
||||||
|
|
||||||
|
l := int64(len(data))
|
||||||
|
err = binary.Write(&buf, binary.LittleEndian, l)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
|
||||||
|
_, err = buf.Write(data)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
|
||||||
|
fdata, err := os.Create(filepath.Join(tempDir, "model-00001-of-00001.safetensors"))
|
||||||
|
if err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
defer fdata.Close()
|
||||||
|
|
||||||
|
_, err = fdata.Write(buf.Bytes())
|
||||||
|
if err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
|
||||||
|
configData := `
|
||||||
|
{
|
||||||
|
"architectures": [
|
||||||
|
"LlamaForCausalLM"
|
||||||
|
]
|
||||||
|
}
|
||||||
|
`
|
||||||
|
|
||||||
|
f, err := os.Create(filepath.Join(tempDir, "config.json"))
|
||||||
|
if err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
defer f.Close()
|
||||||
|
|
||||||
|
_, err = f.WriteString(configData)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
|
||||||
|
tokenizerData := `
|
||||||
|
{
|
||||||
|
}
|
||||||
|
`
|
||||||
|
|
||||||
|
f, err = os.Create(filepath.Join(tempDir, "tokenizer.json"))
|
||||||
|
if err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
defer f.Close()
|
||||||
|
|
||||||
|
_, err = f.WriteString(tokenizerData)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
func TestConvertAdapter(t *testing.T) {
|
func TestConvertAdapter(t *testing.T) {
|
||||||
type AdapterCase struct {
|
type AdapterCase struct {
|
||||||
Name string
|
Name string
|
||||||
|
|
|
@ -4,6 +4,7 @@ import (
|
||||||
"bytes"
|
"bytes"
|
||||||
"encoding/binary"
|
"encoding/binary"
|
||||||
"encoding/json"
|
"encoding/json"
|
||||||
|
"errors"
|
||||||
"fmt"
|
"fmt"
|
||||||
"io"
|
"io"
|
||||||
"io/fs"
|
"io/fs"
|
||||||
|
@ -50,6 +51,10 @@ func parseSafetensors(fsys fs.FS, replacer *strings.Replacer, ps ...string) ([]T
|
||||||
|
|
||||||
for _, key := range keys {
|
for _, key := range keys {
|
||||||
if value := headers[key]; value.Type != "" {
|
if value := headers[key]; value.Type != "" {
|
||||||
|
// bitsandbytes quantized models are unsupported
|
||||||
|
if len(value.Shape) == 0 {
|
||||||
|
return nil, errors.New("unsupported safetensors model")
|
||||||
|
}
|
||||||
ts = append(ts, safetensor{
|
ts = append(ts, safetensor{
|
||||||
fs: fsys,
|
fs: fsys,
|
||||||
path: p,
|
path: p,
|
||||||
|
|
Loading…
Reference in a new issue