174 lines
4.3 KiB
Go
174 lines
4.3 KiB
Go
package convert
|
|
|
|
import (
|
|
"cmp"
|
|
"encoding/json"
|
|
"io/fs"
|
|
"path/filepath"
|
|
"slices"
|
|
"strings"
|
|
|
|
"github.com/ollama/ollama/llm"
|
|
)
|
|
|
|
type bertModel struct {
|
|
ModelParameters
|
|
NLayers uint32 `json:"n_layers"`
|
|
NumHiddenLayers uint32 `json:"num_hidden_layers"`
|
|
NLayer uint32 `json:"n_layer"`
|
|
MaxPositionEmbeddings uint32 `json:"max_position_embeddings"`
|
|
NCtx uint32 `json:"n_ctx"`
|
|
HiddenSize uint32 `json:"hidden_size"`
|
|
NEmbd uint32 `json:"n_embd"`
|
|
IntermediateSize uint32 `json:"intermediate_size"`
|
|
NInner uint32 `json:"n_inner"`
|
|
NumAttentionHeads uint32 `json:"num_attention_heads"`
|
|
NHead uint32 `json:"n_head"`
|
|
NumKeyValueHeads uint32 `json:"num_key_value_heads"`
|
|
LayerNormEPS float32 `json:"layer_norm_eps"`
|
|
LayerNormEpsilon float32 `json:"layer_norm_epsilon"`
|
|
NormEpsilon float32 `json:"norm_epsilon"`
|
|
|
|
PoolingType uint32
|
|
}
|
|
|
|
var (
|
|
_ ModelConverter = (*bertModel)(nil)
|
|
_ moreParser = (*bertModel)(nil)
|
|
)
|
|
|
|
func (p *bertModel) parseMore(fsys fs.FS) error {
|
|
bts, err := fs.ReadFile(fsys, "modules.json")
|
|
if err != nil {
|
|
return err
|
|
}
|
|
|
|
var modules []struct {
|
|
Type string `json:"type"`
|
|
Path string `json:"path"`
|
|
}
|
|
|
|
if err := json.Unmarshal(bts, &modules); err != nil {
|
|
return err
|
|
}
|
|
|
|
var pooling string
|
|
for _, m := range modules {
|
|
if m.Type == "sentence_transformers.models.Pooling" {
|
|
pooling = m.Path
|
|
break
|
|
}
|
|
}
|
|
|
|
if pooling != "" {
|
|
bts, err := fs.ReadFile(fsys, filepath.Join(pooling, "config.json"))
|
|
if err != nil {
|
|
return err
|
|
}
|
|
|
|
var pc struct {
|
|
PoolingModeCLSToken bool `json:"pooling_mode_cls_token"`
|
|
PoolingModeMeanTokens bool `json:"pooling_mode_mean_tokens"`
|
|
}
|
|
|
|
if err := json.Unmarshal(bts, &pc); err != nil {
|
|
return err
|
|
}
|
|
|
|
if pc.PoolingModeMeanTokens {
|
|
p.PoolingType = 1
|
|
} else if pc.PoolingModeCLSToken {
|
|
p.PoolingType = 2
|
|
}
|
|
}
|
|
|
|
return nil
|
|
}
|
|
|
|
func (p *bertModel) KV(t *Tokenizer) llm.KV {
|
|
kv := p.ModelParameters.KV(t)
|
|
kv["general.architecture"] = "bert"
|
|
kv["bert.attention.causal"] = false
|
|
kv["bert.pooling_type"] = p.PoolingType
|
|
|
|
kv["bert.block_count"] = cmp.Or(p.NLayers, p.NumHiddenLayers, p.NLayer)
|
|
|
|
if contextLength := cmp.Or(p.MaxPositionEmbeddings, p.NCtx); contextLength > 0 {
|
|
kv["bert.context_length"] = contextLength
|
|
}
|
|
|
|
if embeddingLength := cmp.Or(p.HiddenSize, p.NEmbd); embeddingLength > 0 {
|
|
kv["bert.embedding_length"] = cmp.Or(p.HiddenSize, p.NEmbd)
|
|
}
|
|
|
|
if feedForwardLength := cmp.Or(p.IntermediateSize, p.NInner); feedForwardLength > 0 {
|
|
kv["bert.feed_forward_length"] = cmp.Or(p.IntermediateSize, p.NInner)
|
|
}
|
|
|
|
if headCount := cmp.Or(p.NumAttentionHeads, p.NHead); headCount > 0 {
|
|
kv["bert.attention.head_count"] = cmp.Or(p.NumAttentionHeads, p.NHead)
|
|
}
|
|
|
|
if layerNormEpsilon := cmp.Or(p.LayerNormEPS, p.LayerNormEpsilon, p.NormEpsilon); layerNormEpsilon > 0 {
|
|
kv["bert.attention.layer_norm_epsilon"] = layerNormEpsilon
|
|
}
|
|
|
|
kv["tokenizer.ggml.model"] = "bert"
|
|
kv["tokenizer.ggml.token_type_count"] = uint32(2)
|
|
|
|
// convert to phantom space tokens
|
|
for i, e := range t.Tokens {
|
|
if strings.HasPrefix(e, "[") && strings.HasSuffix(e, "]") {
|
|
// noop
|
|
} else if strings.HasPrefix(e, "##") {
|
|
t.Tokens[i] = e[2:]
|
|
} else {
|
|
t.Tokens[i] = "\u2581" + e
|
|
}
|
|
}
|
|
|
|
kv["tokenizer.ggml.tokens"] = t.Tokens
|
|
|
|
return kv
|
|
}
|
|
|
|
func (p *bertModel) Tensors(ts []Tensor) []llm.Tensor {
|
|
var out []llm.Tensor
|
|
for _, t := range ts {
|
|
if slices.Contains([]string{
|
|
"embeddings.position_ids",
|
|
"pooler.dense.weight",
|
|
"pooler.dense.bias",
|
|
}, t.Name()) {
|
|
continue
|
|
}
|
|
|
|
out = append(out, llm.Tensor{
|
|
Name: t.Name(),
|
|
Kind: t.Kind(),
|
|
Shape: t.Shape(),
|
|
WriterTo: t,
|
|
})
|
|
}
|
|
|
|
return out
|
|
}
|
|
|
|
func (bertModel) Replacements() []string {
|
|
return []string{
|
|
"encoder.layer", "blk",
|
|
"encoder.layers", "blk",
|
|
"embeddings.word_embeddings", "token_embd",
|
|
"embeddings.token_type_embeddings", "token_types",
|
|
"embeddings.LayerNorm", "token_embd_norm",
|
|
"embeddings.position_embeddings", "position_embd",
|
|
"attention.self.query", "attn_q",
|
|
"attention.self.key", "attn_k",
|
|
"attention.self.value", "attn_v",
|
|
"attention.output.dense", "attn_output",
|
|
"attention.output.LayerNorm", "attn_output_norm",
|
|
"intermediate.dense", "ffn_up",
|
|
"output.dense", "ffn_down",
|
|
"output.LayerNorm", "layer_output_norm",
|
|
}
|
|
}
|