137 lines
3.3 KiB
Go
137 lines
3.3 KiB
Go
|
package convert
|
||
|
|
||
|
import (
|
||
|
"encoding/binary"
|
||
|
"fmt"
|
||
|
"io"
|
||
|
"log/slog"
|
||
|
"os"
|
||
|
"strings"
|
||
|
|
||
|
"github.com/d4l3k/go-bfloat16"
|
||
|
"github.com/pdevine/tensor"
|
||
|
"github.com/pdevine/tensor/native"
|
||
|
|
||
|
"github.com/ollama/ollama/llm"
|
||
|
)
|
||
|
|
||
|
type GemmaModel struct {
|
||
|
ModelData
|
||
|
}
|
||
|
|
||
|
func gemmaLayerHandler(w io.Writer, r safetensorWriterTo, f *os.File) error {
|
||
|
slog.Debug(fmt.Sprintf("converting '%s'", r.t.Name))
|
||
|
|
||
|
data := make([]byte, r.end-r.start)
|
||
|
if err := binary.Read(f, r.bo, data); err != nil {
|
||
|
return err
|
||
|
}
|
||
|
|
||
|
tDataF32 := bfloat16.DecodeFloat32(data)
|
||
|
|
||
|
var err error
|
||
|
tDataF32, err = addOnes(tDataF32, int(r.t.Shape[0]))
|
||
|
if err != nil {
|
||
|
return err
|
||
|
}
|
||
|
|
||
|
if err := binary.Write(w, r.bo, tDataF32); err != nil {
|
||
|
return err
|
||
|
}
|
||
|
return nil
|
||
|
}
|
||
|
|
||
|
func addOnes(data []float32, vectorSize int) ([]float32, error) {
|
||
|
n := tensor.New(tensor.WithShape(vectorSize), tensor.WithBacking(data))
|
||
|
ones := tensor.Ones(tensor.Float32, vectorSize)
|
||
|
|
||
|
var err error
|
||
|
n, err = n.Add(ones)
|
||
|
if err != nil {
|
||
|
return []float32{}, err
|
||
|
}
|
||
|
|
||
|
newN, err := native.SelectF32(n, 0)
|
||
|
if err != nil {
|
||
|
return []float32{}, err
|
||
|
}
|
||
|
|
||
|
var fullTensor []float32
|
||
|
for _, v := range newN {
|
||
|
fullTensor = append(fullTensor, v...)
|
||
|
}
|
||
|
|
||
|
return fullTensor, nil
|
||
|
}
|
||
|
|
||
|
func (m *GemmaModel) GetTensors() error {
|
||
|
t, err := GetSafeTensors(m.Path, m.Params)
|
||
|
if err != nil {
|
||
|
return err
|
||
|
}
|
||
|
|
||
|
m.Tensors = []llm.Tensor{}
|
||
|
|
||
|
for _, l := range t {
|
||
|
if strings.HasSuffix(l.Name, "norm.weight") {
|
||
|
wt := l.WriterTo.(safetensorWriterTo)
|
||
|
wt.handler = gemmaLayerHandler
|
||
|
l.WriterTo = wt
|
||
|
}
|
||
|
m.Tensors = append(m.Tensors, l)
|
||
|
}
|
||
|
|
||
|
return nil
|
||
|
}
|
||
|
|
||
|
func (m *GemmaModel) LoadVocab() error {
|
||
|
v, err := LoadSentencePieceTokens(m.Path, m.Params.VocabSize)
|
||
|
if err != nil {
|
||
|
return err
|
||
|
}
|
||
|
m.Vocab = v
|
||
|
return nil
|
||
|
}
|
||
|
|
||
|
func (m *GemmaModel) WriteGGUF() (string, error) {
|
||
|
kv := llm.KV{
|
||
|
"general.architecture": "gemma",
|
||
|
"general.name": m.Name,
|
||
|
"gemma.context_length": uint32(m.Params.ContextSize),
|
||
|
"gemma.embedding_length": uint32(m.Params.HiddenSize),
|
||
|
"gemma.block_count": uint32(m.Params.HiddenLayers),
|
||
|
"gemma.feed_forward_length": uint32(m.Params.IntermediateSize),
|
||
|
"gemma.attention.head_count": uint32(m.Params.AttentionHeads),
|
||
|
"gemma.attention.head_count_kv": uint32(m.Params.KeyValHeads),
|
||
|
"gemma.attention.layer_norm_rms_epsilon": float32(m.Params.NormEPS),
|
||
|
"gemma.attention.key_length": uint32(m.Params.HeadDimension),
|
||
|
"gemma.attention.value_length": uint32(m.Params.HeadDimension),
|
||
|
"general.file_type": uint32(1),
|
||
|
"tokenizer.ggml.model": "llama",
|
||
|
|
||
|
"tokenizer.ggml.tokens": m.Vocab.Tokens,
|
||
|
"tokenizer.ggml.scores": m.Vocab.Scores,
|
||
|
"tokenizer.ggml.token_type": m.Vocab.Types,
|
||
|
|
||
|
"tokenizer.ggml.bos_token_id": uint32(m.Params.BoSTokenID),
|
||
|
"tokenizer.ggml.eos_token_id": uint32(m.Params.EoSTokenID),
|
||
|
"tokenizer.ggml.padding_token_id": uint32(m.Params.PaddingTokenID),
|
||
|
"tokenizer.ggml.unknown_token_id": uint32(3),
|
||
|
"tokenizer.ggml.add_bos_token": true,
|
||
|
"tokenizer.ggml.add_eos_token": false,
|
||
|
}
|
||
|
|
||
|
f, err := os.CreateTemp("", "ollama-gguf")
|
||
|
if err != nil {
|
||
|
return "", err
|
||
|
}
|
||
|
defer f.Close()
|
||
|
|
||
|
mod := llm.NewGGUFV3(m.Params.ByteOrder)
|
||
|
if err := mod.Encode(f, kv, m.Tensors); err != nil {
|
||
|
return "", err
|
||
|
}
|
||
|
|
||
|
return f.Name(), nil
|
||
|
}
|