ollama/convert/llama.go

159 lines
4 KiB
Go
Raw Normal View History

package convert
import (
2024-05-17 19:11:49 +00:00
"cmp"
2024-05-15 18:53:14 +00:00
"errors"
"fmt"
"io"
2024-04-28 17:36:38 +00:00
"os"
"path/filepath"
"regexp"
"strings"
"github.com/pdevine/tensor"
"github.com/pdevine/tensor/native"
"github.com/ollama/ollama/llm"
)
type LlamaModel struct {
ModelData
}
func (m *LlamaModel) GetTensors() error {
t, err := m.Format.GetTensors(m.Path, m.Params)
if err != nil {
return err
}
pattern := `^blk\.[0-9]+\.attn_(?P<layer>q|k)\.weight$`
re, err := regexp.Compile(pattern)
if err != nil {
return err
}
for _, l := range t {
matches := re.FindAllStringSubmatch(l.Name, -1)
if len(matches) > 0 {
2024-04-28 17:36:38 +00:00
switch m.Format.(type) {
case *TorchFormat:
2024-04-25 01:32:01 +00:00
wt := l.WriterTo.(torchWriterTo)
2024-05-17 19:11:49 +00:00
wt.repacker = m.Repack
2024-04-25 01:32:01 +00:00
l.WriterTo = wt
2024-04-28 17:36:38 +00:00
case *SafetensorFormat:
2024-04-25 01:32:01 +00:00
wt := l.WriterTo.(safetensorWriterTo)
2024-05-17 19:11:49 +00:00
wt.repacker = m.Repack
2024-04-25 01:32:01 +00:00
l.WriterTo = wt
}
}
m.Tensors = append(m.Tensors, l)
}
return nil
}
2024-05-15 21:55:57 +00:00
func (m *LlamaModel) LoadVocab() (err error) {
pre, ts, merges, err := parseTokens(filepath.Join(m.Path, "tokenizer.json"))
2024-05-15 18:53:14 +00:00
if errors.Is(err, os.ErrNotExist) {
2024-05-15 21:55:57 +00:00
return nil
2024-05-15 18:53:14 +00:00
} else if err != nil {
return err
2024-04-28 17:36:38 +00:00
}
2024-05-15 18:53:14 +00:00
2024-05-15 21:55:57 +00:00
m.Vocab = &Vocab{}
for _, t := range ts {
m.Vocab.Tokens = append(m.Vocab.Tokens, t.Content)
m.Vocab.Types = append(m.Vocab.Types, t.Type())
}
2024-04-28 17:36:38 +00:00
2024-05-15 21:55:57 +00:00
m.Vocab.Merges = merges
m.Params.PreTokenizer = pre
return nil
}
func (m *LlamaModel) WriteGGUF(ws io.WriteSeeker) error {
kv := llm.KV{
"general.architecture": "llama",
"general.name": m.Name,
"llama.vocab_size": uint32(len(m.Vocab.Tokens)),
"llama.context_length": uint32(m.Params.ContextSize),
"llama.embedding_length": uint32(m.Params.HiddenSize),
"llama.block_count": uint32(m.Params.HiddenLayers),
"llama.feed_forward_length": uint32(m.Params.IntermediateSize),
2024-04-28 17:36:38 +00:00
"llama.rope.freq_base": float32(m.Params.RopeFrequencyBase),
"llama.rope.dimension_count": uint32(m.Params.HiddenSize / m.Params.AttentionHeads),
"llama.attention.head_count": uint32(m.Params.AttentionHeads),
"llama.attention.head_count_kv": uint32(m.Params.KeyValHeads),
"llama.attention.layer_norm_rms_epsilon": float32(m.Params.NormEPS),
2024-05-15 21:55:57 +00:00
"general.file_type": uint32(1),
2024-05-08 23:07:46 +00:00
"tokenizer.ggml.model": "gpt2",
2024-05-15 18:53:14 +00:00
"tokenizer.ggml.pre": m.Params.PreTokenizer,
"tokenizer.ggml.tokens": m.Vocab.Tokens,
"tokenizer.ggml.token_type": m.Vocab.Types,
"tokenizer.ggml.bos_token_id": uint32(m.Params.BoSTokenID),
"tokenizer.ggml.eos_token_id": uint32(m.Params.EoSTokenID),
"tokenizer.ggml.unknown_token_id": uint32(0),
2024-04-28 17:36:38 +00:00
}
if len(m.Vocab.Merges) > 0 {
kv["tokenizer.ggml.merges"] = m.Vocab.Merges
} else {
kv["tokenizer.ggml.scores"] = m.Vocab.Scores
}
2024-05-06 21:00:50 +00:00
return llm.NewGGUFV3(m.Params.ByteOrder).Encode(ws, kv, m.Tensors)
}
2024-05-17 19:11:49 +00:00
func (m *LlamaModel) Repack(name string, data []float32, shape []uint64) ([]float32, error) {
return llamaRepack(name, m.Params, data, shape)
}
func llamaRepack(name string, params *Params, data []float32, shape []uint64) ([]float32, error) {
var dims []int
for _, dim := range shape {
if dim != 0 {
dims = append(dims, int(dim))
}
}
var heads int
if strings.HasSuffix(name, "attn_q.weight") {
heads = params.AttentionHeads
} else if strings.HasSuffix(name, "attn_k.weight") {
heads = cmp.Or(params.KeyValHeads, params.AttentionHeads)
} else {
return nil, fmt.Errorf("unknown tensor name: %s", name)
}
n := tensor.New(tensor.WithShape(dims...), tensor.WithBacking(data))
if err := n.Reshape(append([]int{heads, 2, dims[0] / heads / 2}, dims[1:]...)...); err != nil {
return nil, err
}
if err := n.T(0, 2, 1, 3); err != nil {
return nil, err
}
if err := n.Reshape(dims...); err != nil {
return nil, err
}
if err := n.Transpose(); err != nil {
return nil, err
}
ts, err := native.SelectF32(n, 1)
if err != nil {
return nil, err
}
var f32s []float32
for _, t := range ts {
f32s = append(f32s, t...)
}
return f32s, nil
}