Merge pull request #167 from jmorganca/decode-ggml

partial decode ggml bin for more info
This commit is contained in:
Michael Yang 2023-08-10 17:22:40 -07:00 committed by GitHub
commit 6a6828bddf
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
26 changed files with 336 additions and 69 deletions

180
llm/ggml.go Normal file
View file

@ -0,0 +1,180 @@
package llm
import (
"encoding/binary"
"errors"
"fmt"
"io"
)
type ModelFamily string
const ModelFamilyLlama ModelFamily = "llama"
type ModelType uint32
const (
ModelType3B ModelType = 26
ModelType7B ModelType = 32
ModelType13B ModelType = 40
ModelType30B ModelType = 60
ModelType65B ModelType = 80
)
type FileType uint32
const (
FileTypeF32 FileType = iota
FileTypeF16
FileTypeQ4_0
FileTypeQ4_1
FileTypeQ4_1_F16
FileTypeQ8_0 = iota + 3
FileTypeQ5_0
FileTypeQ5_1
FileTypeQ2_K
FileTypeQ3_K
FileTypeQ4_K
FileTypeQ5_K
FileTypeQ6_K
FileTypeUnknown = -1
)
type GGML struct {
ModelFamily
ModelType
magic uint32
container
llamaHyperparameters
}
type container interface {
Name() string
Decode(io.Reader) error
}
type containerGGML struct {
}
func (c *containerGGML) Name() string {
return "ggml"
}
func (c *containerGGML) Decode(r io.Reader) error {
return nil
}
type containerGGMF struct {
version uint32
}
func (c *containerGGMF) Name() string {
return "ggmf"
}
func (c *containerGGMF) Decode(r io.Reader) error {
var version uint32
binary.Read(r, binary.LittleEndian, &version)
switch version {
case 1:
default:
return errors.New("invalid version")
}
c.version = version
return nil
}
type containerGGJT struct {
version uint32
}
func (c *containerGGJT) Name() string {
return "ggjt"
}
func (c *containerGGJT) Decode(r io.Reader) error {
var version uint32
binary.Read(r, binary.LittleEndian, &version)
switch version {
case 1, 2, 3:
default:
return errors.New("invalid version")
}
c.version = version
return nil
}
type containerLORA struct {
version uint32
}
func (c *containerLORA) Name() string {
return "ggla"
}
func (c *containerLORA) Decode(r io.Reader) error {
var version uint32
binary.Read(r, binary.LittleEndian, &version)
switch version {
case 1:
default:
return errors.New("invalid version")
}
c.version = version
return nil
}
const (
// / Magic constant for `ggml` files (unversioned).
FILE_MAGIC_GGML = 0x67676d6c
// / Magic constant for `ggml` files (versioned, ggmf).
FILE_MAGIC_GGMF = 0x67676d66
// / Magic constant for `ggml` files (versioned, ggjt).
FILE_MAGIC_GGJT = 0x67676a74
// / Magic constant for `ggla` files (LoRA adapter).
FILE_MAGIC_GGLA = 0x67676C61
)
func DecodeGGML(r io.ReadSeeker, hint ModelFamily) (*GGML, error) {
var ggml GGML
binary.Read(r, binary.LittleEndian, &ggml.magic)
switch ggml.magic {
case FILE_MAGIC_GGML:
ggml.container = &containerGGML{}
case FILE_MAGIC_GGMF:
ggml.container = &containerGGMF{}
case FILE_MAGIC_GGJT:
ggml.container = &containerGGJT{}
case FILE_MAGIC_GGLA:
ggml.container = &containerLORA{}
default:
return nil, errors.New("invalid file magic")
}
if err := ggml.Decode(r); err != nil {
return nil, err
}
// different model types may have different layouts for hyperparameters
switch hint {
case ModelFamilyLlama:
binary.Read(r, binary.LittleEndian, &ggml.llamaHyperparameters)
// TODO: sanity check hyperparameters
default:
return nil, fmt.Errorf("unsupported model type: %s", hint)
}
// final model type
ggml.ModelFamily = hint
ggml.ModelType = ModelType(ggml.NumLayer)
return &ggml, nil
}

View file

@ -1,4 +1,4 @@
package llama package llm
/* /*
#cgo CPPFLAGS: -O3 -Wall -Wextra -Wno-unused-function -Wno-unused-variable -DNDEBUG -DGGML_USE_K_QUANTS #cgo CPPFLAGS: -O3 -Wall -Wextra -Wno-unused-function -Wno-unused-variable -DNDEBUG -DGGML_USE_K_QUANTS
@ -105,7 +105,7 @@ import (
//go:embed ggml-metal.metal //go:embed ggml-metal.metal
var fs embed.FS var fs embed.FS
type LLM struct { type llama struct {
params *C.struct_llama_context_params params *C.struct_llama_context_params
model *C.struct_llama_model model *C.struct_llama_model
ctx *C.struct_llama_context ctx *C.struct_llama_context
@ -120,12 +120,28 @@ type LLM struct {
api.Options api.Options
} }
func New(model string, opts api.Options) (*LLM, error) { type llamaHyperparameters struct {
// NumVocab is the size of the model's vocabulary.
NumVocab uint32
// NumEmbd is the size of the model's embedding layer.
NumEmbd uint32
NumMult uint32
NumHead uint32
// NumLayer is the number of layers in the model.
NumLayer uint32
NumRot uint32
// FileType describes the quantization level of the model, e.g. Q4_0, Q5_K, etc.
FileType
}
func newLlama(model string, opts api.Options) (*llama, error) {
if _, err := os.Stat(model); err != nil { if _, err := os.Stat(model); err != nil {
return nil, err return nil, err
} }
llm := LLM{Options: opts} llm := llama{Options: opts}
C.llama_backend_init(C.bool(llm.UseNUMA)) C.llama_backend_init(C.bool(llm.UseNUMA))
@ -168,7 +184,7 @@ func New(model string, opts api.Options) (*LLM, error) {
return &llm, nil return &llm, nil
} }
func (llm *LLM) Close() { func (llm *llama) Close() {
llm.gc = true llm.gc = true
llm.mu.Lock() llm.mu.Lock()
@ -180,17 +196,16 @@ func (llm *LLM) Close() {
C.llama_print_timings(llm.ctx) C.llama_print_timings(llm.ctx)
} }
func (llm *llama) SetOptions(opts api.Options) {
llm.Options = opts
}
var errNeedMoreData = errors.New("need more data") var errNeedMoreData = errors.New("need more data")
func (llm *LLM) Predict(ctx []int, prompt string, fn func(api.GenerateResponse)) error { func (llm *llama) Predict(ctx []int, prompt string, fn func(api.GenerateResponse)) error {
C.llama_reset_timings(llm.ctx) C.llama_reset_timings(llm.ctx)
tokens := make([]C.llama_token, len(ctx)) llm.marshalPrompt(ctx, prompt)
for i := range tokens {
tokens[i] = C.llama_token(ctx[i])
}
llm.marshalPrompt(tokens, prompt)
C.llama_set_rng_seed(llm.ctx, C.uint(llm.Seed)) C.llama_set_rng_seed(llm.ctx, C.uint(llm.Seed))
@ -205,7 +220,7 @@ func (llm *LLM) Predict(ctx []int, prompt string, fn func(api.GenerateResponse))
return err return err
} }
b.WriteString(llm.Decode(token)) b.WriteString(llm.Decode(int(token)))
if err := llm.checkStopConditions(b); err != nil { if err := llm.checkStopConditions(b); err != nil {
if errors.Is(err, io.EOF) { if errors.Is(err, io.EOF) {
@ -243,7 +258,7 @@ func (llm *LLM) Predict(ctx []int, prompt string, fn func(api.GenerateResponse))
return nil return nil
} }
func (llm *LLM) checkStopConditions(b bytes.Buffer) error { func (llm *llama) checkStopConditions(b bytes.Buffer) error {
for _, stopCondition := range llm.Stop { for _, stopCondition := range llm.Stop {
if stopCondition == strings.TrimSpace(b.String()) { if stopCondition == strings.TrimSpace(b.String()) {
return io.EOF return io.EOF
@ -255,12 +270,17 @@ func (llm *LLM) checkStopConditions(b bytes.Buffer) error {
return nil return nil
} }
func (llm *LLM) marshalPrompt(ctx []C.llama_token, prompt string) []C.llama_token { func (llm *llama) marshalPrompt(ctx []int, prompt string) []C.llama_token {
tokens := append(ctx, llm.Encode(prompt)...) tokens := append(ctx, llm.Encode(prompt)...)
if llm.NumKeep < 0 { if llm.NumKeep < 0 {
llm.NumKeep = len(tokens) llm.NumKeep = len(tokens)
} }
cTokens := make([]C.llama_token, len(tokens))
for i := range tokens {
cTokens[i] = C.llama_token(tokens[i])
}
// min(llm.NumCtx - 4, llm.NumKeep) // min(llm.NumCtx - 4, llm.NumKeep)
if llm.NumCtx-4 < llm.NumKeep { if llm.NumCtx-4 < llm.NumKeep {
llm.NumKeep = llm.NumCtx - 4 llm.NumKeep = llm.NumCtx - 4
@ -269,25 +289,25 @@ func (llm *LLM) marshalPrompt(ctx []C.llama_token, prompt string) []C.llama_toke
if len(tokens) >= llm.NumCtx { if len(tokens) >= llm.NumCtx {
// truncate input // truncate input
numLeft := (llm.NumCtx - llm.NumKeep) / 2 numLeft := (llm.NumCtx - llm.NumKeep) / 2
truncated := tokens[:llm.NumKeep] truncated := cTokens[:llm.NumKeep]
erasedBlocks := (len(tokens) - llm.NumKeep - numLeft - 1) / numLeft erasedBlocks := (len(cTokens) - llm.NumKeep - numLeft - 1) / numLeft
truncated = append(truncated, tokens[llm.NumKeep+erasedBlocks*numLeft:]...) truncated = append(truncated, cTokens[llm.NumKeep+erasedBlocks*numLeft:]...)
copy(llm.last, tokens[len(tokens)-llm.NumCtx:]) copy(llm.last, cTokens[len(cTokens)-llm.NumCtx:])
tokens = truncated cTokens = truncated
log.Printf("input truncated: num_ctx=%d num_keep=%d num_left=%d num_tokens=%d", llm.NumCtx, llm.NumKeep, numLeft, len(truncated)) log.Printf("input truncated: num_ctx=%d num_keep=%d num_left=%d num_tokens=%d", llm.NumCtx, llm.NumKeep, numLeft, len(truncated))
} else { } else {
llm.last = make([]C.llama_token, llm.NumCtx-len(tokens)) llm.last = make([]C.llama_token, llm.NumCtx-len(cTokens))
llm.last = append(llm.last, tokens...) llm.last = append(llm.last, cTokens...)
} }
var i int var i int
for i = 0; i < len(llm.embd) && i < len(tokens) && llm.embd[i] == tokens[i]; i++ { for i = 0; i < len(llm.embd) && i < len(cTokens) && llm.embd[i] == cTokens[i]; i++ {
// noop // noop
} }
llm.embd = tokens llm.embd = cTokens
if i == len(tokens) { if i == len(cTokens) {
// evaluate at least one token to generate logits // evaluate at least one token to generate logits
i-- i--
} }
@ -295,31 +315,36 @@ func (llm *LLM) marshalPrompt(ctx []C.llama_token, prompt string) []C.llama_toke
llm.cursor = i llm.cursor = i
log.Printf("prompt: num_past=%d cached=%v eval=%v", i, len(llm.embd[:i]), len(llm.embd[i:])) log.Printf("prompt: num_past=%d cached=%v eval=%v", i, len(llm.embd[:i]), len(llm.embd[i:]))
return tokens return cTokens
} }
func (llm *LLM) Encode(prompt string) []C.llama_token { func (llm *llama) Encode(prompt string) []int {
cPrompt := C.CString(prompt) cPrompt := C.CString(prompt)
defer C.free(unsafe.Pointer(cPrompt)) defer C.free(unsafe.Pointer(cPrompt))
tokens := make([]C.llama_token, len(prompt)+1) cTokens := make([]C.llama_token, len(prompt)+1)
if n := C.llama_tokenize(llm.ctx, cPrompt, unsafe.SliceData(tokens), C.int(len(tokens)), true); n > 0 { if n := C.llama_tokenize(llm.ctx, cPrompt, unsafe.SliceData(cTokens), C.int(len(cTokens)), true); n > 0 {
return tokens[:n] tokens := make([]int, n)
for i := range cTokens[:n] {
tokens[i] = int(cTokens[i])
}
return tokens
} }
return nil return nil
} }
func (llm *LLM) Decode(tokens ...C.llama_token) string { func (llm *llama) Decode(tokens ...int) string {
var sb strings.Builder var sb strings.Builder
for _, token := range tokens { for _, token := range tokens {
sb.WriteString(C.GoString(C.llama_token_to_str(llm.ctx, token))) sb.WriteString(C.GoString(C.llama_token_to_str(llm.ctx, C.llama_token(token))))
} }
return sb.String() return sb.String()
} }
func (llm *LLM) next() (C.llama_token, error) { func (llm *llama) next() (C.llama_token, error) {
llm.mu.Lock() llm.mu.Lock()
defer llm.mu.Unlock() defer llm.mu.Unlock()
@ -410,7 +435,7 @@ func (llm *LLM) next() (C.llama_token, error) {
return token, nil return token, nil
} }
func (llm *LLM) Embedding(input string) ([]float64, error) { func (llm *llama) Embedding(input string) ([]float64, error) {
if !llm.EmbeddingOnly { if !llm.EmbeddingOnly {
return nil, errors.New("llama: embedding not enabled") return nil, errors.New("llama: embedding not enabled")
} }
@ -420,7 +445,12 @@ func (llm *LLM) Embedding(input string) ([]float64, error) {
return nil, errors.New("llama: tokenize embedding") return nil, errors.New("llama: tokenize embedding")
} }
retval := C.llama_eval(llm.ctx, unsafe.SliceData(tokens), C.int(len(tokens)), 0, C.int(llm.NumThread)) cTokens := make([]C.llama_token, len(tokens))
for i := range tokens {
cTokens[i] = C.llama_token(tokens[i])
}
retval := C.llama_eval(llm.ctx, unsafe.SliceData(cTokens), C.int(len(tokens)), 0, C.int(llm.NumThread))
if retval != 0 { if retval != 0 {
return nil, errors.New("llama: eval") return nil, errors.New("llama: eval")
} }

View file

@ -1,4 +1,4 @@
package llama package llm
import ( import (
"bytes" "bytes"

40
llm/llm.go Normal file
View file

@ -0,0 +1,40 @@
package llm
import (
"fmt"
"os"
"github.com/jmorganca/ollama/api"
)
type LLM interface {
Predict([]int, string, func(api.GenerateResponse)) error
Embedding(string) ([]float64, error)
Encode(string) []int
Decode(...int) string
SetOptions(api.Options)
Close()
}
func New(model string, opts api.Options) (LLM, error) {
if _, err := os.Stat(model); err != nil {
return nil, err
}
f, err := os.Open(model)
if err != nil {
return nil, err
}
ggml, err := DecodeGGML(f, ModelFamilyLlama)
if err != nil {
return nil, err
}
switch ggml.ModelFamily {
case ModelFamilyLlama:
return newLlama(model, opts)
default:
return nil, fmt.Errorf("unknown ggml type: %s", ggml.ModelFamily)
}
}

View file

@ -1,4 +1,4 @@
package llama package llm
import ( import (
"fmt" "fmt"

View file

@ -19,7 +19,7 @@ import (
"strings" "strings"
"github.com/jmorganca/ollama/api" "github.com/jmorganca/ollama/api"
"github.com/jmorganca/ollama/llama" "github.com/jmorganca/ollama/llm"
"github.com/jmorganca/ollama/parser" "github.com/jmorganca/ollama/parser"
"github.com/jmorganca/ollama/vector" "github.com/jmorganca/ollama/vector"
) )
@ -99,9 +99,14 @@ type LayerReader struct {
} }
type ConfigV2 struct { type ConfigV2 struct {
ModelFamily llm.ModelFamily `json:"model_family"`
ModelType llm.ModelType `json:"model_type"`
FileType llm.FileType `json:"file_type"`
RootFS RootFS `json:"rootfs"`
// required by spec
Architecture string `json:"architecture"` Architecture string `json:"architecture"`
OS string `json:"os"` OS string `json:"os"`
RootFS RootFS `json:"rootfs"`
} }
type RootFS struct { type RootFS struct {
@ -246,6 +251,11 @@ func CreateModel(ctx context.Context, name string, path string, fn func(resp api
return err return err
} }
config := ConfigV2{
Architecture: "amd64",
OS: "linux",
}
var layers []*LayerReader var layers []*LayerReader
params := make(map[string][]string) params := make(map[string][]string)
embed := EmbeddingParams{fn: fn, opts: api.DefaultOptions()} embed := EmbeddingParams{fn: fn, opts: api.DefaultOptions()}
@ -284,6 +294,18 @@ func CreateModel(ctx context.Context, name string, path string, fn func(resp api
} }
defer file.Close() defer file.Close()
ggml, err := llm.DecodeGGML(file, llm.ModelFamilyLlama)
if err != nil {
return err
}
config.ModelFamily = ggml.ModelFamily
config.ModelType = ggml.ModelType
config.FileType = ggml.FileType
// reset the file
file.Seek(0, io.SeekStart)
l, err := CreateLayer(file) l, err := CreateLayer(file)
if err != nil { if err != nil {
return fmt.Errorf("failed to create layer: %v", err) return fmt.Errorf("failed to create layer: %v", err)
@ -292,6 +314,7 @@ func CreateModel(ctx context.Context, name string, path string, fn func(resp api
layers = append(layers, l) layers = append(layers, l)
} }
} }
if mf != nil { if mf != nil {
log.Printf("manifest = %#v", mf) log.Printf("manifest = %#v", mf)
for _, l := range mf.Layers { for _, l := range mf.Layers {
@ -321,7 +344,7 @@ func CreateModel(ctx context.Context, name string, path string, fn func(resp api
layers = append(layers, layer) layers = append(layers, layer)
case "template", "system", "prompt": case "template", "system", "prompt":
fn(api.ProgressResponse{Status: fmt.Sprintf("creating model %s layer", c.Name)}) fn(api.ProgressResponse{Status: fmt.Sprintf("creating model %s layer", c.Name)})
// remove the prompt layer if one exists // remove the layer if one exists
mediaType := fmt.Sprintf("application/vnd.ollama.image.%s", c.Name) mediaType := fmt.Sprintf("application/vnd.ollama.image.%s", c.Name)
layers = removeLayerFromLayers(layers, mediaType) layers = removeLayerFromLayers(layers, mediaType)
@ -383,7 +406,7 @@ func CreateModel(ctx context.Context, name string, path string, fn func(resp api
// Create a layer for the config object // Create a layer for the config object
fn(api.ProgressResponse{Status: "creating config layer"}) fn(api.ProgressResponse{Status: "creating config layer"})
cfg, err := createConfigLayer(digests) cfg, err := createConfigLayer(config, digests)
if err != nil { if err != nil {
return err return err
} }
@ -430,13 +453,13 @@ func embeddingLayers(e EmbeddingParams) ([]*LayerReader, error) {
} }
e.opts.EmbeddingOnly = true e.opts.EmbeddingOnly = true
llm, err := llama.New(e.model, e.opts) llmModel, err := llm.New(e.model, e.opts)
if err != nil { if err != nil {
return nil, fmt.Errorf("load model to generate embeddings: %v", err) return nil, fmt.Errorf("load model to generate embeddings: %v", err)
} }
defer func() { defer func() {
if llm != nil { if llmModel != nil {
llm.Close() llmModel.Close()
} }
}() }()
@ -480,7 +503,7 @@ func embeddingLayers(e EmbeddingParams) ([]*LayerReader, error) {
Total: len(data) - 1, Total: len(data) - 1,
Completed: i, Completed: i,
}) })
embed, err := llm.Embedding(d) embed, err := llmModel.Embedding(d)
if err != nil { if err != nil {
log.Printf("failed to generate embedding for '%s' line %d: %v", filePath, i+1, err) log.Printf("failed to generate embedding for '%s' line %d: %v", filePath, i+1, err)
continue continue
@ -676,7 +699,7 @@ func getLayerDigests(layers []*LayerReader) ([]string, error) {
// CreateLayer creates a Layer object from a given file // CreateLayer creates a Layer object from a given file
func CreateLayer(f io.ReadSeeker) (*LayerReader, error) { func CreateLayer(f io.ReadSeeker) (*LayerReader, error) {
digest, size := GetSHA256Digest(f) digest, size := GetSHA256Digest(f)
f.Seek(0, 0) f.Seek(0, io.SeekStart)
layer := &LayerReader{ layer := &LayerReader{
Layer: Layer{ Layer: Layer{
@ -768,10 +791,6 @@ func DeleteModel(name string) error {
return err return err
} }
if err != nil {
return err
}
// only delete the files which are still in the deleteMap // only delete the files which are still in the deleteMap
for k, v := range deleteMap { for k, v := range deleteMap {
if v { if v {
@ -970,15 +989,10 @@ func pullModelManifest(mp ModelPath, regOpts *RegistryOptions) (*ManifestV2, err
return m, err return m, err
} }
func createConfigLayer(layers []string) (*LayerReader, error) { func createConfigLayer(config ConfigV2, layers []string) (*LayerReader, error) {
// TODO change architecture and OS config.RootFS = RootFS{
config := ConfigV2{ Type: "layers",
Architecture: "arm64", DiffIDs: layers,
OS: "linux",
RootFS: RootFS{
Type: "layers",
DiffIDs: layers,
},
} }
configJSON, err := json.Marshal(config) configJSON, err := json.Marshal(config)

View file

@ -21,14 +21,14 @@ import (
"gonum.org/v1/gonum/mat" "gonum.org/v1/gonum/mat"
"github.com/jmorganca/ollama/api" "github.com/jmorganca/ollama/api"
"github.com/jmorganca/ollama/llama" "github.com/jmorganca/ollama/llm"
"github.com/jmorganca/ollama/vector" "github.com/jmorganca/ollama/vector"
) )
var loaded struct { var loaded struct {
mu sync.Mutex mu sync.Mutex
llm *llama.LLM llm llm.LLM
Embeddings []vector.Embedding Embeddings []vector.Embedding
expireAt time.Time expireAt time.Time
@ -63,11 +63,16 @@ func load(model *Model, reqOpts map[string]interface{}, sessionDuration time.Dur
loaded.Embeddings = model.Embeddings loaded.Embeddings = model.Embeddings
} }
llm, err := llama.New(model.ModelPath, opts) llmModel, err := llm.New(model.ModelPath, opts)
if err != nil { if err != nil {
return err return err
} }
// set cache values before modifying opts
loaded.llm = llmModel
loaded.digest = model.Digest
loaded.options = opts
if opts.NumKeep < 0 { if opts.NumKeep < 0 {
promptWithSystem, err := model.Prompt(api.GenerateRequest{}, "") promptWithSystem, err := model.Prompt(api.GenerateRequest{}, "")
if err != nil { if err != nil {
@ -79,15 +84,13 @@ func load(model *Model, reqOpts map[string]interface{}, sessionDuration time.Dur
return err return err
} }
tokensWithSystem := llm.Encode(promptWithSystem) tokensWithSystem := llmModel.Encode(promptWithSystem)
tokensNoSystem := llm.Encode(promptNoSystem) tokensNoSystem := llmModel.Encode(promptNoSystem)
llm.NumKeep = len(tokensWithSystem) - len(tokensNoSystem) + 1 opts.NumKeep = len(tokensWithSystem) - len(tokensNoSystem) + 1
llmModel.SetOptions(opts)
} }
loaded.llm = llm
loaded.digest = model.Digest
loaded.options = opts
} }
loaded.expireAt = time.Now().Add(sessionDuration) loaded.expireAt = time.Now().Add(sessionDuration)