instead of static number of parameters for each model family, get the real number from the tensors (#1022)

* parse tensor info

* refactor decoder

* return actual parameter count

* explicit rounding

* s/Human/HumanNumber/
This commit is contained in:
Michael Yang 2023-11-08 19:55:46 -06:00 committed by GitHub
parent a49d6acc1e
commit c5e1bbabda
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
2 changed files with 72 additions and 18 deletions

25
format/format.go Normal file
View file

@ -0,0 +1,25 @@
package format
import (
"fmt"
"math"
)
const (
Thousand = 1000
Million = Thousand * 1000
Billion = Million * 1000
)
func HumanNumber(b uint64) string {
switch {
case b > Billion:
return fmt.Sprintf("%.0fB", math.Round(float64(b)/Billion))
case b > Million:
return fmt.Sprintf("%.0fM", math.Round(float64(b)/Million))
case b > Thousand:
return fmt.Sprintf("%.0fK", math.Round(float64(b)/Thousand))
default:
return fmt.Sprintf("%d", b)
}
}

View file

@ -5,6 +5,8 @@ import (
"encoding/binary" "encoding/binary"
"fmt" "fmt"
"io" "io"
"github.com/jmorganca/ollama/format"
) )
type containerGGUF struct { type containerGGUF struct {
@ -21,6 +23,8 @@ type containerGGUF struct {
NumTensor uint64 NumTensor uint64
NumKV uint64 NumKV uint64
} }
parameters uint64
} }
func (c *containerGGUF) Name() string { func (c *containerGGUF) Name() string {
@ -75,6 +79,14 @@ func newGGUFModel(container *containerGGUF) *ggufModel {
} }
} }
func (llm *ggufModel) NumTensor() uint64 {
if llm.Version == 1 {
return uint64(llm.V1.NumTensor)
}
return llm.V2.NumTensor
}
func (llm *ggufModel) NumKV() uint64 { func (llm *ggufModel) NumKV() uint64 {
if llm.Version == 1 { if llm.Version == 1 {
return uint64(llm.V1.NumKV) return uint64(llm.V1.NumKV)
@ -93,6 +105,10 @@ func (llm *ggufModel) ModelFamily() string {
} }
func (llm *ggufModel) ModelType() string { func (llm *ggufModel) ModelType() string {
if llm.parameters > 0 {
return format.HumanNumber(llm.parameters)
}
switch llm.ModelFamily() { switch llm.ModelFamily() {
case "llama": case "llama":
if blocks, ok := llm.kv["llama.block_count"].(uint32); ok { if blocks, ok := llm.kv["llama.block_count"].(uint32); ok {
@ -127,13 +143,9 @@ func (llm *ggufModel) FileType() string {
} }
func (llm *ggufModel) Decode(r io.Reader) error { func (llm *ggufModel) Decode(r io.Reader) error {
read := llm.readString // decode key-values
if llm.Version == 1 {
read = llm.readStringV1
}
for i := 0; uint64(i) < llm.NumKV(); i++ { for i := 0; uint64(i) < llm.NumKV(); i++ {
k, err := read(r) k, err := llm.readString(r)
if err != nil { if err != nil {
return err return err
} }
@ -165,24 +177,14 @@ func (llm *ggufModel) Decode(r io.Reader) error {
case ggufTypeBool: case ggufTypeBool:
v = llm.readBool(r) v = llm.readBool(r)
case ggufTypeString: case ggufTypeString:
fn := llm.readString s, err := llm.readString(r)
if llm.Version == 1 {
fn = llm.readStringV1
}
s, err := fn(r)
if err != nil { if err != nil {
return err return err
} }
v = s v = s
case ggufTypeArray: case ggufTypeArray:
fn := llm.readArray a, err := llm.readArray(r)
if llm.Version == 1 {
fn = llm.readArrayV1
}
a, err := fn(r)
if err != nil { if err != nil {
return err return err
} }
@ -195,6 +197,25 @@ func (llm *ggufModel) Decode(r io.Reader) error {
llm.kv[k] = v llm.kv[k] = v
} }
// decode tensors
for i := 0; uint64(i) < llm.NumTensor(); i++ {
if _, err := llm.readString(r); err != nil {
return err
}
dimensions := llm.readU32(r)
var elements uint64 = 1
for i := 0; uint32(i) < dimensions; i++ {
elements *= llm.readU64(r)
}
llm.readU32(r) // type
llm.readU64(r) // offset
llm.parameters += elements
}
return nil return nil
} }
@ -290,6 +311,10 @@ func (llm ggufModel) readStringV1(r io.Reader) (string, error) {
} }
func (llm ggufModel) readString(r io.Reader) (string, error) { func (llm ggufModel) readString(r io.Reader) (string, error) {
if llm.Version == 1 {
return llm.readStringV1(r)
}
var nameLength uint64 var nameLength uint64
binary.Read(r, llm.bo, &nameLength) binary.Read(r, llm.bo, &nameLength)
@ -339,6 +364,10 @@ func (llm *ggufModel) readArrayV1(r io.Reader) (arr []any, err error) {
} }
func (llm *ggufModel) readArray(r io.Reader) (arr []any, err error) { func (llm *ggufModel) readArray(r io.Reader) (arr []any, err error) {
if llm.Version == 1 {
return llm.readArrayV1(r)
}
atype := llm.readU32(r) atype := llm.readU32(r)
n := llm.readU64(r) n := llm.readU64(r)