Merge pull request #167 from jmorganca/decode-ggml

partial decode ggml bin for more info
This commit is contained in:
Michael Yang 2023-08-10 17:22:40 -07:00 committed by GitHub
commit 6a6828bddf
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
26 changed files with 336 additions and 69 deletions

180
llm/ggml.go Normal file
View file

@ -0,0 +1,180 @@
package llm
import (
"encoding/binary"
"errors"
"fmt"
"io"
)
type ModelFamily string
const ModelFamilyLlama ModelFamily = "llama"
type ModelType uint32
const (
ModelType3B ModelType = 26
ModelType7B ModelType = 32
ModelType13B ModelType = 40
ModelType30B ModelType = 60
ModelType65B ModelType = 80
)
type FileType uint32
const (
FileTypeF32 FileType = iota
FileTypeF16
FileTypeQ4_0
FileTypeQ4_1
FileTypeQ4_1_F16
FileTypeQ8_0 = iota + 3
FileTypeQ5_0
FileTypeQ5_1
FileTypeQ2_K
FileTypeQ3_K
FileTypeQ4_K
FileTypeQ5_K
FileTypeQ6_K
FileTypeUnknown = -1
)
type GGML struct {
ModelFamily
ModelType
magic uint32
container
llamaHyperparameters
}
type container interface {
Name() string
Decode(io.Reader) error
}
type containerGGML struct {
}
func (c *containerGGML) Name() string {
return "ggml"
}
func (c *containerGGML) Decode(r io.Reader) error {
return nil
}
type containerGGMF struct {
version uint32
}
func (c *containerGGMF) Name() string {
return "ggmf"
}
func (c *containerGGMF) Decode(r io.Reader) error {
var version uint32
binary.Read(r, binary.LittleEndian, &version)
switch version {
case 1:
default:
return errors.New("invalid version")
}
c.version = version
return nil
}
type containerGGJT struct {
version uint32
}
func (c *containerGGJT) Name() string {
return "ggjt"
}
func (c *containerGGJT) Decode(r io.Reader) error {
var version uint32
binary.Read(r, binary.LittleEndian, &version)
switch version {
case 1, 2, 3:
default:
return errors.New("invalid version")
}
c.version = version
return nil
}
type containerLORA struct {
version uint32
}
func (c *containerLORA) Name() string {
return "ggla"
}
func (c *containerLORA) Decode(r io.Reader) error {
var version uint32
binary.Read(r, binary.LittleEndian, &version)
switch version {
case 1:
default:
return errors.New("invalid version")
}
c.version = version
return nil
}
const (
// / Magic constant for `ggml` files (unversioned).
FILE_MAGIC_GGML = 0x67676d6c
// / Magic constant for `ggml` files (versioned, ggmf).
FILE_MAGIC_GGMF = 0x67676d66
// / Magic constant for `ggml` files (versioned, ggjt).
FILE_MAGIC_GGJT = 0x67676a74
// / Magic constant for `ggla` files (LoRA adapter).
FILE_MAGIC_GGLA = 0x67676C61
)
func DecodeGGML(r io.ReadSeeker, hint ModelFamily) (*GGML, error) {
var ggml GGML
binary.Read(r, binary.LittleEndian, &ggml.magic)
switch ggml.magic {
case FILE_MAGIC_GGML:
ggml.container = &containerGGML{}
case FILE_MAGIC_GGMF:
ggml.container = &containerGGMF{}
case FILE_MAGIC_GGJT:
ggml.container = &containerGGJT{}
case FILE_MAGIC_GGLA:
ggml.container = &containerLORA{}
default:
return nil, errors.New("invalid file magic")
}
if err := ggml.Decode(r); err != nil {
return nil, err
}
// different model types may have different layouts for hyperparameters
switch hint {
case ModelFamilyLlama:
binary.Read(r, binary.LittleEndian, &ggml.llamaHyperparameters)
// TODO: sanity check hyperparameters
default:
return nil, fmt.Errorf("unsupported model type: %s", hint)
}
// final model type
ggml.ModelFamily = hint
ggml.ModelType = ModelType(ggml.NumLayer)
return &ggml, nil
}

View file

@ -1,4 +1,4 @@
package llama
package llm
/*
#cgo CPPFLAGS: -O3 -Wall -Wextra -Wno-unused-function -Wno-unused-variable -DNDEBUG -DGGML_USE_K_QUANTS
@ -105,7 +105,7 @@ import (
//go:embed ggml-metal.metal
var fs embed.FS
type LLM struct {
type llama struct {
params *C.struct_llama_context_params
model *C.struct_llama_model
ctx *C.struct_llama_context
@ -120,12 +120,28 @@ type LLM struct {
api.Options
}
func New(model string, opts api.Options) (*LLM, error) {
type llamaHyperparameters struct {
// NumVocab is the size of the model's vocabulary.
NumVocab uint32
// NumEmbd is the size of the model's embedding layer.
NumEmbd uint32
NumMult uint32
NumHead uint32
// NumLayer is the number of layers in the model.
NumLayer uint32
NumRot uint32
// FileType describes the quantization level of the model, e.g. Q4_0, Q5_K, etc.
FileType
}
func newLlama(model string, opts api.Options) (*llama, error) {
if _, err := os.Stat(model); err != nil {
return nil, err
}
llm := LLM{Options: opts}
llm := llama{Options: opts}
C.llama_backend_init(C.bool(llm.UseNUMA))
@ -168,7 +184,7 @@ func New(model string, opts api.Options) (*LLM, error) {
return &llm, nil
}
func (llm *LLM) Close() {
func (llm *llama) Close() {
llm.gc = true
llm.mu.Lock()
@ -180,17 +196,16 @@ func (llm *LLM) Close() {
C.llama_print_timings(llm.ctx)
}
var errNeedMoreData = errors.New("need more data")
func (llm *LLM) Predict(ctx []int, prompt string, fn func(api.GenerateResponse)) error {
C.llama_reset_timings(llm.ctx)
tokens := make([]C.llama_token, len(ctx))
for i := range tokens {
tokens[i] = C.llama_token(ctx[i])
func (llm *llama) SetOptions(opts api.Options) {
llm.Options = opts
}
llm.marshalPrompt(tokens, prompt)
var errNeedMoreData = errors.New("need more data")
func (llm *llama) Predict(ctx []int, prompt string, fn func(api.GenerateResponse)) error {
C.llama_reset_timings(llm.ctx)
llm.marshalPrompt(ctx, prompt)
C.llama_set_rng_seed(llm.ctx, C.uint(llm.Seed))
@ -205,7 +220,7 @@ func (llm *LLM) Predict(ctx []int, prompt string, fn func(api.GenerateResponse))
return err
}
b.WriteString(llm.Decode(token))
b.WriteString(llm.Decode(int(token)))
if err := llm.checkStopConditions(b); err != nil {
if errors.Is(err, io.EOF) {
@ -243,7 +258,7 @@ func (llm *LLM) Predict(ctx []int, prompt string, fn func(api.GenerateResponse))
return nil
}
func (llm *LLM) checkStopConditions(b bytes.Buffer) error {
func (llm *llama) checkStopConditions(b bytes.Buffer) error {
for _, stopCondition := range llm.Stop {
if stopCondition == strings.TrimSpace(b.String()) {
return io.EOF
@ -255,12 +270,17 @@ func (llm *LLM) checkStopConditions(b bytes.Buffer) error {
return nil
}
func (llm *LLM) marshalPrompt(ctx []C.llama_token, prompt string) []C.llama_token {
func (llm *llama) marshalPrompt(ctx []int, prompt string) []C.llama_token {
tokens := append(ctx, llm.Encode(prompt)...)
if llm.NumKeep < 0 {
llm.NumKeep = len(tokens)
}
cTokens := make([]C.llama_token, len(tokens))
for i := range tokens {
cTokens[i] = C.llama_token(tokens[i])
}
// min(llm.NumCtx - 4, llm.NumKeep)
if llm.NumCtx-4 < llm.NumKeep {
llm.NumKeep = llm.NumCtx - 4
@ -269,25 +289,25 @@ func (llm *LLM) marshalPrompt(ctx []C.llama_token, prompt string) []C.llama_toke
if len(tokens) >= llm.NumCtx {
// truncate input
numLeft := (llm.NumCtx - llm.NumKeep) / 2
truncated := tokens[:llm.NumKeep]
erasedBlocks := (len(tokens) - llm.NumKeep - numLeft - 1) / numLeft
truncated = append(truncated, tokens[llm.NumKeep+erasedBlocks*numLeft:]...)
copy(llm.last, tokens[len(tokens)-llm.NumCtx:])
truncated := cTokens[:llm.NumKeep]
erasedBlocks := (len(cTokens) - llm.NumKeep - numLeft - 1) / numLeft
truncated = append(truncated, cTokens[llm.NumKeep+erasedBlocks*numLeft:]...)
copy(llm.last, cTokens[len(cTokens)-llm.NumCtx:])
tokens = truncated
cTokens = truncated
log.Printf("input truncated: num_ctx=%d num_keep=%d num_left=%d num_tokens=%d", llm.NumCtx, llm.NumKeep, numLeft, len(truncated))
} else {
llm.last = make([]C.llama_token, llm.NumCtx-len(tokens))
llm.last = append(llm.last, tokens...)
llm.last = make([]C.llama_token, llm.NumCtx-len(cTokens))
llm.last = append(llm.last, cTokens...)
}
var i int
for i = 0; i < len(llm.embd) && i < len(tokens) && llm.embd[i] == tokens[i]; i++ {
for i = 0; i < len(llm.embd) && i < len(cTokens) && llm.embd[i] == cTokens[i]; i++ {
// noop
}
llm.embd = tokens
if i == len(tokens) {
llm.embd = cTokens
if i == len(cTokens) {
// evaluate at least one token to generate logits
i--
}
@ -295,31 +315,36 @@ func (llm *LLM) marshalPrompt(ctx []C.llama_token, prompt string) []C.llama_toke
llm.cursor = i
log.Printf("prompt: num_past=%d cached=%v eval=%v", i, len(llm.embd[:i]), len(llm.embd[i:]))
return tokens
return cTokens
}
func (llm *LLM) Encode(prompt string) []C.llama_token {
func (llm *llama) Encode(prompt string) []int {
cPrompt := C.CString(prompt)
defer C.free(unsafe.Pointer(cPrompt))
tokens := make([]C.llama_token, len(prompt)+1)
if n := C.llama_tokenize(llm.ctx, cPrompt, unsafe.SliceData(tokens), C.int(len(tokens)), true); n > 0 {
return tokens[:n]
cTokens := make([]C.llama_token, len(prompt)+1)
if n := C.llama_tokenize(llm.ctx, cPrompt, unsafe.SliceData(cTokens), C.int(len(cTokens)), true); n > 0 {
tokens := make([]int, n)
for i := range cTokens[:n] {
tokens[i] = int(cTokens[i])
}
return tokens
}
return nil
}
func (llm *LLM) Decode(tokens ...C.llama_token) string {
func (llm *llama) Decode(tokens ...int) string {
var sb strings.Builder
for _, token := range tokens {
sb.WriteString(C.GoString(C.llama_token_to_str(llm.ctx, token)))
sb.WriteString(C.GoString(C.llama_token_to_str(llm.ctx, C.llama_token(token))))
}
return sb.String()
}
func (llm *LLM) next() (C.llama_token, error) {
func (llm *llama) next() (C.llama_token, error) {
llm.mu.Lock()
defer llm.mu.Unlock()
@ -410,7 +435,7 @@ func (llm *LLM) next() (C.llama_token, error) {
return token, nil
}
func (llm *LLM) Embedding(input string) ([]float64, error) {
func (llm *llama) Embedding(input string) ([]float64, error) {
if !llm.EmbeddingOnly {
return nil, errors.New("llama: embedding not enabled")
}
@ -420,7 +445,12 @@ func (llm *LLM) Embedding(input string) ([]float64, error) {
return nil, errors.New("llama: tokenize embedding")
}
retval := C.llama_eval(llm.ctx, unsafe.SliceData(tokens), C.int(len(tokens)), 0, C.int(llm.NumThread))
cTokens := make([]C.llama_token, len(tokens))
for i := range tokens {
cTokens[i] = C.llama_token(tokens[i])
}
retval := C.llama_eval(llm.ctx, unsafe.SliceData(cTokens), C.int(len(tokens)), 0, C.int(llm.NumThread))
if retval != 0 {
return nil, errors.New("llama: eval")
}

View file

@ -1,4 +1,4 @@
package llama
package llm
import (
"bytes"

40
llm/llm.go Normal file
View file

@ -0,0 +1,40 @@
package llm
import (
"fmt"
"os"
"github.com/jmorganca/ollama/api"
)
type LLM interface {
Predict([]int, string, func(api.GenerateResponse)) error
Embedding(string) ([]float64, error)
Encode(string) []int
Decode(...int) string
SetOptions(api.Options)
Close()
}
func New(model string, opts api.Options) (LLM, error) {
if _, err := os.Stat(model); err != nil {
return nil, err
}
f, err := os.Open(model)
if err != nil {
return nil, err
}
ggml, err := DecodeGGML(f, ModelFamilyLlama)
if err != nil {
return nil, err
}
switch ggml.ModelFamily {
case ModelFamilyLlama:
return newLlama(model, opts)
default:
return nil, fmt.Errorf("unknown ggml type: %s", ggml.ModelFamily)
}
}

View file

@ -1,4 +1,4 @@
package llama
package llm
import (
"fmt"

View file

@ -19,7 +19,7 @@ import (
"strings"
"github.com/jmorganca/ollama/api"
"github.com/jmorganca/ollama/llama"
"github.com/jmorganca/ollama/llm"
"github.com/jmorganca/ollama/parser"
"github.com/jmorganca/ollama/vector"
)
@ -99,9 +99,14 @@ type LayerReader struct {
}
type ConfigV2 struct {
ModelFamily llm.ModelFamily `json:"model_family"`
ModelType llm.ModelType `json:"model_type"`
FileType llm.FileType `json:"file_type"`
RootFS RootFS `json:"rootfs"`
// required by spec
Architecture string `json:"architecture"`
OS string `json:"os"`
RootFS RootFS `json:"rootfs"`
}
type RootFS struct {
@ -246,6 +251,11 @@ func CreateModel(ctx context.Context, name string, path string, fn func(resp api
return err
}
config := ConfigV2{
Architecture: "amd64",
OS: "linux",
}
var layers []*LayerReader
params := make(map[string][]string)
embed := EmbeddingParams{fn: fn, opts: api.DefaultOptions()}
@ -284,6 +294,18 @@ func CreateModel(ctx context.Context, name string, path string, fn func(resp api
}
defer file.Close()
ggml, err := llm.DecodeGGML(file, llm.ModelFamilyLlama)
if err != nil {
return err
}
config.ModelFamily = ggml.ModelFamily
config.ModelType = ggml.ModelType
config.FileType = ggml.FileType
// reset the file
file.Seek(0, io.SeekStart)
l, err := CreateLayer(file)
if err != nil {
return fmt.Errorf("failed to create layer: %v", err)
@ -292,6 +314,7 @@ func CreateModel(ctx context.Context, name string, path string, fn func(resp api
layers = append(layers, l)
}
}
if mf != nil {
log.Printf("manifest = %#v", mf)
for _, l := range mf.Layers {
@ -321,7 +344,7 @@ func CreateModel(ctx context.Context, name string, path string, fn func(resp api
layers = append(layers, layer)
case "template", "system", "prompt":
fn(api.ProgressResponse{Status: fmt.Sprintf("creating model %s layer", c.Name)})
// remove the prompt layer if one exists
// remove the layer if one exists
mediaType := fmt.Sprintf("application/vnd.ollama.image.%s", c.Name)
layers = removeLayerFromLayers(layers, mediaType)
@ -383,7 +406,7 @@ func CreateModel(ctx context.Context, name string, path string, fn func(resp api
// Create a layer for the config object
fn(api.ProgressResponse{Status: "creating config layer"})
cfg, err := createConfigLayer(digests)
cfg, err := createConfigLayer(config, digests)
if err != nil {
return err
}
@ -430,13 +453,13 @@ func embeddingLayers(e EmbeddingParams) ([]*LayerReader, error) {
}
e.opts.EmbeddingOnly = true
llm, err := llama.New(e.model, e.opts)
llmModel, err := llm.New(e.model, e.opts)
if err != nil {
return nil, fmt.Errorf("load model to generate embeddings: %v", err)
}
defer func() {
if llm != nil {
llm.Close()
if llmModel != nil {
llmModel.Close()
}
}()
@ -480,7 +503,7 @@ func embeddingLayers(e EmbeddingParams) ([]*LayerReader, error) {
Total: len(data) - 1,
Completed: i,
})
embed, err := llm.Embedding(d)
embed, err := llmModel.Embedding(d)
if err != nil {
log.Printf("failed to generate embedding for '%s' line %d: %v", filePath, i+1, err)
continue
@ -676,7 +699,7 @@ func getLayerDigests(layers []*LayerReader) ([]string, error) {
// CreateLayer creates a Layer object from a given file
func CreateLayer(f io.ReadSeeker) (*LayerReader, error) {
digest, size := GetSHA256Digest(f)
f.Seek(0, 0)
f.Seek(0, io.SeekStart)
layer := &LayerReader{
Layer: Layer{
@ -768,10 +791,6 @@ func DeleteModel(name string) error {
return err
}
if err != nil {
return err
}
// only delete the files which are still in the deleteMap
for k, v := range deleteMap {
if v {
@ -970,15 +989,10 @@ func pullModelManifest(mp ModelPath, regOpts *RegistryOptions) (*ManifestV2, err
return m, err
}
func createConfigLayer(layers []string) (*LayerReader, error) {
// TODO change architecture and OS
config := ConfigV2{
Architecture: "arm64",
OS: "linux",
RootFS: RootFS{
func createConfigLayer(config ConfigV2, layers []string) (*LayerReader, error) {
config.RootFS = RootFS{
Type: "layers",
DiffIDs: layers,
},
}
configJSON, err := json.Marshal(config)

View file

@ -21,14 +21,14 @@ import (
"gonum.org/v1/gonum/mat"
"github.com/jmorganca/ollama/api"
"github.com/jmorganca/ollama/llama"
"github.com/jmorganca/ollama/llm"
"github.com/jmorganca/ollama/vector"
)
var loaded struct {
mu sync.Mutex
llm *llama.LLM
llm llm.LLM
Embeddings []vector.Embedding
expireAt time.Time
@ -63,11 +63,16 @@ func load(model *Model, reqOpts map[string]interface{}, sessionDuration time.Dur
loaded.Embeddings = model.Embeddings
}
llm, err := llama.New(model.ModelPath, opts)
llmModel, err := llm.New(model.ModelPath, opts)
if err != nil {
return err
}
// set cache values before modifying opts
loaded.llm = llmModel
loaded.digest = model.Digest
loaded.options = opts
if opts.NumKeep < 0 {
promptWithSystem, err := model.Prompt(api.GenerateRequest{}, "")
if err != nil {
@ -79,15 +84,13 @@ func load(model *Model, reqOpts map[string]interface{}, sessionDuration time.Dur
return err
}
tokensWithSystem := llm.Encode(promptWithSystem)
tokensNoSystem := llm.Encode(promptNoSystem)
tokensWithSystem := llmModel.Encode(promptWithSystem)
tokensNoSystem := llmModel.Encode(promptNoSystem)
llm.NumKeep = len(tokensWithSystem) - len(tokensNoSystem) + 1
opts.NumKeep = len(tokensWithSystem) - len(tokensNoSystem) + 1
llmModel.SetOptions(opts)
}
loaded.llm = llm
loaded.digest = model.Digest
loaded.options = opts
}
loaded.expireAt = time.Now().Add(sessionDuration)