fix parsing big endian gguf

This commit is contained in:
Michael Yang 2024-06-08 12:32:02 -07:00
parent cddc63381c
commit 620d5c569e
2 changed files with 19 additions and 9 deletions

View file

@ -231,8 +231,7 @@ const (
// Magic constant for `ggla` files (LoRA adapter). // Magic constant for `ggla` files (LoRA adapter).
FILE_MAGIC_GGLA = 0x67676C61 FILE_MAGIC_GGLA = 0x67676C61
// Magic constant for `gguf` files (versioned, gguf) // Magic constant for `gguf` files (versioned, gguf)
FILE_MAGIC_GGUF_LE = 0x46554747 FILE_MAGIC_GGUF = 0x46554747
FILE_MAGIC_GGUF_BE = 0x47475546
) )
var ErrUnsupportedFormat = errors.New("unsupported model format") var ErrUnsupportedFormat = errors.New("unsupported model format")
@ -247,7 +246,7 @@ func DetectGGMLType(b []byte) string {
return "ggjt" return "ggjt"
case FILE_MAGIC_GGLA: case FILE_MAGIC_GGLA:
return "ggla" return "ggla"
case FILE_MAGIC_GGUF_LE, FILE_MAGIC_GGUF_BE: case FILE_MAGIC_GGUF:
return "gguf" return "gguf"
default: default:
return "" return ""
@ -255,21 +254,19 @@ func DetectGGMLType(b []byte) string {
} }
func DecodeGGML(rs io.ReadSeeker) (*GGML, int64, error) { func DecodeGGML(rs io.ReadSeeker) (*GGML, int64, error) {
var magic uint32 var magic [4]byte
if err := binary.Read(rs, binary.LittleEndian, &magic); err != nil { if err := binary.Read(rs, binary.LittleEndian, &magic); err != nil {
return nil, 0, err return nil, 0, err
} }
var c container var c container
switch magic { switch binary.LittleEndian.Uint32(magic[:]) {
case FILE_MAGIC_GGML, FILE_MAGIC_GGMF, FILE_MAGIC_GGJT: case FILE_MAGIC_GGML, FILE_MAGIC_GGMF, FILE_MAGIC_GGJT:
return nil, 0, ErrUnsupportedFormat return nil, 0, ErrUnsupportedFormat
case FILE_MAGIC_GGLA: case FILE_MAGIC_GGLA:
c = &containerGGLA{} c = &containerGGLA{}
case FILE_MAGIC_GGUF_LE: case FILE_MAGIC_GGUF:
c = &containerGGUF{ByteOrder: binary.LittleEndian} c = &containerGGUF{ByteOrder: binary.LittleEndian}
case FILE_MAGIC_GGUF_BE:
c = &containerGGUF{ByteOrder: binary.BigEndian}
default: default:
return nil, 0, errors.New("invalid file magic") return nil, 0, errors.New("invalid file magic")
} }

View file

@ -36,10 +36,23 @@ func (c *containerGGUF) Name() string {
} }
func (c *containerGGUF) Decode(rs io.ReadSeeker) (model, error) { func (c *containerGGUF) Decode(rs io.ReadSeeker) (model, error) {
if err := binary.Read(rs, c.ByteOrder, &c.Version); err != nil { var version [4]byte
if err := binary.Read(rs, c.ByteOrder, &version); err != nil {
return nil, err return nil, err
} }
// if the lower 16 bits are 0, the byte order is probably wrong
if c.ByteOrder.Uint32(version[:])&1<<4 == 0 {
switch c.ByteOrder {
case binary.LittleEndian:
c.ByteOrder = binary.BigEndian
case binary.BigEndian:
c.ByteOrder = binary.LittleEndian
}
}
c.Version = c.ByteOrder.Uint32(version[:])
var err error var err error
switch c.Version { switch c.Version {
case 1: case 1: