ollama/convert/gemma.go

103 lines
2.8 KiB
Go
Raw Normal View History

2024-04-01 16:14:53 -07:00
package convert
import (
"fmt"
"io"
"log/slog"
"strings"
"github.com/pdevine/tensor"
"github.com/pdevine/tensor/native"
"github.com/ollama/ollama/llm"
)
type GemmaModel struct {
ModelData
}
func addOnes(data []float32, vectorSize int) ([]float32, error) {
n := tensor.New(tensor.WithShape(vectorSize), tensor.WithBacking(data))
ones := tensor.Ones(tensor.Float32, vectorSize)
2024-05-17 12:11:49 -07:00
n, err := n.Add(ones)
2024-04-01 16:14:53 -07:00
if err != nil {
2024-05-17 12:11:49 -07:00
return nil, err
2024-04-01 16:14:53 -07:00
}
2024-05-17 12:11:49 -07:00
ts, err := native.SelectF32(n, 0)
2024-04-01 16:14:53 -07:00
if err != nil {
2024-05-17 12:11:49 -07:00
return nil, err
2024-04-01 16:14:53 -07:00
}
2024-05-17 12:11:49 -07:00
var f32s []float32
for _, t := range ts {
f32s = append(f32s, t...)
2024-04-01 16:14:53 -07:00
}
2024-05-17 12:11:49 -07:00
return f32s, nil
2024-04-01 16:14:53 -07:00
}
func (m *GemmaModel) GetTensors() error {
t, err := m.Format.GetTensors(m.Path, m.Params)
2024-04-01 16:14:53 -07:00
if err != nil {
return err
}
slog.Debug(fmt.Sprintf("Total tensors: %d", len(t)))
2024-04-01 16:14:53 -07:00
for _, l := range t {
if strings.HasSuffix(l.Name, "norm.weight") {
wt := l.WriterTo.(safetensorWriterTo)
2024-05-17 12:11:49 -07:00
wt.repacker = m.Repack
2024-04-01 16:14:53 -07:00
l.WriterTo = wt
}
m.Tensors = append(m.Tensors, l)
}
return nil
}
func (m *GemmaModel) LoadVocab() error {
v, err := LoadSentencePieceTokens(m.Path, m.Params)
2024-04-01 16:14:53 -07:00
if err != nil {
return err
}
m.Vocab = v
return nil
}
2024-05-17 12:11:49 -07:00
func (m *GemmaModel) Repack(_ string, data []float32, shape []uint64) ([]float32, error) {
return addOnes(data, int(shape[0]))
}
func (m *GemmaModel) WriteGGUF(ws io.WriteSeeker) error {
2024-04-01 16:14:53 -07:00
kv := llm.KV{
"general.architecture": "gemma",
"general.name": m.Name,
"gemma.context_length": uint32(m.Params.ContextSize),
"gemma.embedding_length": uint32(m.Params.HiddenSize),
"gemma.block_count": uint32(m.Params.HiddenLayers),
"gemma.feed_forward_length": uint32(m.Params.IntermediateSize),
"gemma.attention.head_count": uint32(m.Params.AttentionHeads),
"gemma.attention.head_count_kv": uint32(m.Params.KeyValHeads),
"gemma.attention.layer_norm_rms_epsilon": float32(m.Params.NormEPS),
"gemma.attention.key_length": uint32(m.Params.HeadDimension),
"gemma.attention.value_length": uint32(m.Params.HeadDimension),
"general.file_type": uint32(1),
"tokenizer.ggml.model": "llama",
"tokenizer.ggml.tokens": m.Vocab.Tokens,
"tokenizer.ggml.scores": m.Vocab.Scores,
"tokenizer.ggml.token_type": m.Vocab.Types,
"tokenizer.ggml.bos_token_id": uint32(m.Params.BoSTokenID),
"tokenizer.ggml.eos_token_id": uint32(m.Params.EoSTokenID),
"tokenizer.ggml.padding_token_id": uint32(m.Params.PaddingTokenID),
"tokenizer.ggml.unknown_token_id": uint32(3),
"tokenizer.ggml.add_bos_token": true,
"tokenizer.ggml.add_eos_token": false,
}
return llm.NewGGUFV3(m.Params.ByteOrder).Encode(ws, kv, m.Tensors)
2024-04-01 16:14:53 -07:00
}