Merge pull request #4917 from ollama/mxyng/convert-bert
convert bert model from safetensors
This commit is contained in:
commit
107f695929
7 changed files with 344 additions and 15 deletions
13
cmd/cmd.go
13
cmd/cmd.go
|
@ -223,6 +223,14 @@ func tempZipFiles(path string) (string, error) {
|
|||
}
|
||||
files = append(files, js...)
|
||||
|
||||
// bert models require a nested config.json
|
||||
// TODO(mxyng): merge this with the glob above
|
||||
js, err = glob(filepath.Join(path, "**/*.json"), "text/plain")
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
files = append(files, js...)
|
||||
|
||||
if tks, _ := glob(filepath.Join(path, "tokenizer.model"), "application/octet-stream"); len(tks) > 0 {
|
||||
// add tokenizer.model if it exists, tokenizer.json is automatically picked up by the previous glob
|
||||
// tokenizer.model might be a unresolved git lfs reference; error if it is
|
||||
|
@ -252,6 +260,11 @@ func tempZipFiles(path string) (string, error) {
|
|||
return "", err
|
||||
}
|
||||
|
||||
zfi.Name, err = filepath.Rel(path, file)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
|
||||
zf, err := zipfile.CreateHeader(zfi)
|
||||
if err != nil {
|
||||
return "", err
|
||||
|
|
|
@ -66,6 +66,10 @@ type Converter interface {
|
|||
writeFile(io.WriteSeeker, llm.KV, []llm.Tensor) error
|
||||
}
|
||||
|
||||
type moreParser interface {
|
||||
parseMore(fs.FS) error
|
||||
}
|
||||
|
||||
// Convert writes an Ollama compatible model to the provided io.WriteSeeker based on configurations
|
||||
// and files it finds in the input path.
|
||||
// Supported input model formats include safetensors.
|
||||
|
@ -95,6 +99,8 @@ func Convert(fsys fs.FS, ws io.WriteSeeker) error {
|
|||
conv = &gemma{}
|
||||
case "Phi3ForCausalLM":
|
||||
conv = &phi3{}
|
||||
case "BertModel":
|
||||
conv = &bert{}
|
||||
default:
|
||||
return errors.New("unsupported architecture")
|
||||
}
|
||||
|
@ -103,6 +109,12 @@ func Convert(fsys fs.FS, ws io.WriteSeeker) error {
|
|||
return err
|
||||
}
|
||||
|
||||
if t, ok := conv.(moreParser); ok {
|
||||
if err := t.parseMore(fsys); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
t, err := parseTokenizer(fsys, conv.specialTokenTypes())
|
||||
if err != nil {
|
||||
return err
|
||||
|
|
176
convert/convert_bert.go
Normal file
176
convert/convert_bert.go
Normal file
|
@ -0,0 +1,176 @@
|
|||
package convert
|
||||
|
||||
import (
|
||||
"cmp"
|
||||
"encoding/json"
|
||||
"io/fs"
|
||||
"path/filepath"
|
||||
"slices"
|
||||
"strings"
|
||||
|
||||
"github.com/ollama/ollama/llm"
|
||||
)
|
||||
|
||||
type bert struct {
|
||||
Parameters
|
||||
NLayers uint32 `json:"n_layers"`
|
||||
NumHiddenLayers uint32 `json:"num_hidden_layers"`
|
||||
NLayer uint32 `json:"n_layer"`
|
||||
MaxPositionEmbeddings uint32 `json:"max_position_embeddings"`
|
||||
NCtx uint32 `json:"n_ctx"`
|
||||
HiddenSize uint32 `json:"hidden_size"`
|
||||
NEmbd uint32 `json:"n_embd"`
|
||||
IntermediateSize uint32 `json:"intermediate_size"`
|
||||
NInner uint32 `json:"n_inner"`
|
||||
NumAttentionHeads uint32 `json:"num_attention_heads"`
|
||||
NHead uint32 `json:"n_head"`
|
||||
NumKeyValueHeads uint32 `json:"num_key_value_heads"`
|
||||
LayerNormEPS float32 `json:"layer_norm_eps"`
|
||||
LayerNormEpsilon float32 `json:"layer_norm_epsilon"`
|
||||
NormEpsilon float32 `json:"norm_epsilon"`
|
||||
|
||||
PoolingType uint32
|
||||
}
|
||||
|
||||
var (
|
||||
_ Converter = (*bert)(nil)
|
||||
_ moreParser = (*bert)(nil)
|
||||
)
|
||||
|
||||
func (p *bert) parseMore(fsys fs.FS) error {
|
||||
bts, err := fs.ReadFile(fsys, "modules.json")
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
var modules []struct {
|
||||
Type string `json:"type"`
|
||||
Path string `json:"path"`
|
||||
}
|
||||
|
||||
if err := json.Unmarshal(bts, &modules); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
var pooling string
|
||||
for _, m := range modules {
|
||||
if m.Type == "sentence_transformers.models.Pooling" {
|
||||
pooling = m.Path
|
||||
break
|
||||
}
|
||||
}
|
||||
|
||||
if pooling != "" {
|
||||
bts, err := fs.ReadFile(fsys, filepath.Join(pooling, "config.json"))
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
var pc struct {
|
||||
PoolingModeCLSToken bool `json:"pooling_mode_cls_token"`
|
||||
PoolingModeMeanTokens bool `json:"pooling_mode_mean_tokens"`
|
||||
}
|
||||
|
||||
if err := json.Unmarshal(bts, &pc); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if pc.PoolingModeMeanTokens {
|
||||
p.PoolingType = 1
|
||||
} else if pc.PoolingModeCLSToken {
|
||||
p.PoolingType = 2
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (p *bert) KV(t *Tokenizer) llm.KV {
|
||||
kv := p.Parameters.KV(t)
|
||||
kv["general.architecture"] = "bert"
|
||||
kv["general.name"] = "bert"
|
||||
kv["bert.attention.causal"] = false
|
||||
kv["bert.pooling_type"] = p.PoolingType
|
||||
|
||||
kv["bert.block_count"] = cmp.Or(p.NLayers, p.NumHiddenLayers, p.NLayer)
|
||||
|
||||
if contextLength := cmp.Or(p.MaxPositionEmbeddings, p.NCtx); contextLength > 0 {
|
||||
kv["bert.context_length"] = contextLength
|
||||
}
|
||||
|
||||
if embeddingLength := cmp.Or(p.HiddenSize, p.NEmbd); embeddingLength > 0 {
|
||||
kv["bert.embedding_length"] = cmp.Or(p.HiddenSize, p.NEmbd)
|
||||
}
|
||||
|
||||
if feedForwardLength := cmp.Or(p.IntermediateSize, p.NInner); feedForwardLength > 0 {
|
||||
kv["bert.feed_forward_length"] = cmp.Or(p.IntermediateSize, p.NInner)
|
||||
}
|
||||
|
||||
if headCount := cmp.Or(p.NumAttentionHeads, p.NHead); headCount > 0 {
|
||||
kv["bert.attention.head_count"] = cmp.Or(p.NumAttentionHeads, p.NHead)
|
||||
}
|
||||
|
||||
if layerNormEpsilon := cmp.Or(p.LayerNormEPS, p.LayerNormEpsilon, p.NormEpsilon); layerNormEpsilon > 0 {
|
||||
kv["bert.attention.layer_norm_epsilon"] = layerNormEpsilon
|
||||
}
|
||||
|
||||
kv["tokenizer.ggml.model"] = "bert"
|
||||
kv["tokenizer.ggml.token_type_count"] = uint32(2)
|
||||
|
||||
// convert to phantom space tokens
|
||||
for i, e := range t.Tokens {
|
||||
if strings.HasPrefix(e, "[") && strings.HasSuffix(e, "]") {
|
||||
// noop
|
||||
} else if strings.HasPrefix(e, "##") {
|
||||
t.Tokens[i] = e[2:]
|
||||
} else {
|
||||
t.Tokens[i] = "\u2581" + e
|
||||
}
|
||||
}
|
||||
|
||||
kv["tokenizer.ggml.tokens"] = t.Tokens
|
||||
|
||||
return kv
|
||||
}
|
||||
|
||||
func (p *bert) Tensors(ts []Tensor) []llm.Tensor {
|
||||
var out []llm.Tensor
|
||||
for _, t := range ts {
|
||||
if slices.Contains([]string{
|
||||
"embeddings.position_ids",
|
||||
"pooler.dense.weight",
|
||||
"pooler.dense.bias",
|
||||
}, t.Name()) {
|
||||
continue
|
||||
}
|
||||
|
||||
name := p.tensorName(t.Name())
|
||||
out = append(out, llm.Tensor{
|
||||
Name: name,
|
||||
Kind: t.Kind(),
|
||||
Shape: t.Shape(),
|
||||
WriterTo: t,
|
||||
})
|
||||
}
|
||||
|
||||
return out
|
||||
}
|
||||
|
||||
func (bert) tensorName(n string) string {
|
||||
return strings.NewReplacer(
|
||||
"encoder.layer", "blk",
|
||||
"encoder.layers", "blk",
|
||||
"embeddings.word_embeddings", "token_embd",
|
||||
"embeddings.token_type_embeddings", "token_types",
|
||||
"embeddings.LayerNorm", "token_embd_norm",
|
||||
"embeddings.position_embeddings", "position_embd",
|
||||
"attention.self.query", "attn_q",
|
||||
"attention.self.key", "attn_k",
|
||||
"attention.self.value", "attn_v",
|
||||
"attention.output.dense", "attn_output",
|
||||
"attention.output.LayerNorm", "attn_output_norm",
|
||||
"intermediate.dense", "ffn_up",
|
||||
"output.dense", "ffn_down",
|
||||
"output.LayerNorm", "layer_output_norm",
|
||||
).Replace(n)
|
||||
}
|
|
@ -67,6 +67,7 @@ func TestConvertFull(t *testing.T) {
|
|||
"gemma-2b-it",
|
||||
// microsoft/Phi-3-mini-128-instruct@d548c233192db00165d842bf8edff054bb3212f8
|
||||
"Phi-3-mini-128k-instruct",
|
||||
"all-MiniLM-L6-v2",
|
||||
}
|
||||
|
||||
for i := range cases {
|
||||
|
|
|
@ -37,6 +37,8 @@ const (
|
|||
func (t tensorBase) Kind() uint32 {
|
||||
if strings.HasSuffix(t.name, ".block_sparse_moe.gate.weight") {
|
||||
return 0
|
||||
} else if t.name == "embeddings.token_type_embeddings.weight" {
|
||||
return 0
|
||||
}
|
||||
|
||||
switch len(t.shape) {
|
||||
|
|
124
convert/testdata/all-MiniLM-L6-v2.json
vendored
Normal file
124
convert/testdata/all-MiniLM-L6-v2.json
vendored
Normal file
|
@ -0,0 +1,124 @@
|
|||
{
|
||||
"general.architecture": "bert",
|
||||
"general.file_type": "1",
|
||||
"general.quantization_version": "2",
|
||||
"bert.attention.causal": "false",
|
||||
"bert.attention.head_count": "12",
|
||||
"bert.attention.layer_norm_epsilon": "1e-12",
|
||||
"bert.block_count": "6",
|
||||
"bert.context_length": "512",
|
||||
"bert.embedding_length": "384",
|
||||
"bert.feed_forward_length": "1536",
|
||||
"bert.pooling_type": "1",
|
||||
"tokenizer.ggml.model": "bert",
|
||||
"tokenizer.ggml.padding_token_id": "0",
|
||||
"tokenizer.ggml.unknown_token_id": "100",
|
||||
"tokenizer.ggml.cls_token_id": "101",
|
||||
"tokenizer.ggml.seperator_token_id": "102",
|
||||
"tokenizer.ggml.mask_token_id": "103",
|
||||
"tokenizer.ggml.token_type_count": "2",
|
||||
"tokenizer.ggml.scores": "6db964fe67338aca57790481a390121ff3dd643eebe49f7dd308029ad99abb6f",
|
||||
"tokenizer.ggml.token_type": "98d247c5404b6b18f05f133b92dd56edf6efefefac326794b00d7b351f6c5aa1",
|
||||
"tokenizer.ggml.tokens": "9efe405e229a45ff9916f54c475d151d2200cd2ab0006f347abfb069cf096c86",
|
||||
"token_embd.weight": "8c1ee80a9ea4f65aa385ba30112010068af3d209bebc6e149d3d4589c2cd0a5a",
|
||||
"position_embd.weight": "6c516f0b1c4e2388ab90394dd80ad69e4e4509b890982fc3408108ae66210eb6",
|
||||
"token_types.weight": "f879f8e422ed211948f28b560d3c5e17aae7993f063b51196a28cf5c0fb3da21",
|
||||
"token_embd_norm.weight": "75076e095d717aab96f8b6beeee503c27940d9a76f2b891a0e3de72f8a6043e4",
|
||||
"token_embd_norm.bias": "298735285ffe944e1bf03e5d35c7280326b85cf121bde9874f1af5dc51ab939d",
|
||||
"blk.0.attn_q.weight": "ab0923ce4c1549175112dcdfcc860fe30137f991e03ea6857fb5993670adaf6c",
|
||||
"blk.0.attn_q.bias": "a3ec29551dabf976e1d34256b8ab5ab7b758f3ed9742c3cafdbd984d5441df62",
|
||||
"blk.0.attn_k.weight": "4c1038a6d035c3e9ffed7fa672b614627814752503755fbad0cfb76a41ad71ba",
|
||||
"blk.0.attn_k.bias": "e0363930eb588d91816aa3d230bb03b6e2551c165117b80b8d60397413819ef9",
|
||||
"blk.0.attn_v.weight": "425e2e53e3f00ce98d29c3e6a161eb55d3e6ae0d96fdb9f6242d1c4fd6eef4b3",
|
||||
"blk.0.attn_v.bias": "6579173a1e65ee124fbd0bd53cbdca4225515b4f2c5f18fb1bfd000f5978f9bb",
|
||||
"blk.0.attn_output.weight": "a6d70a08cd7164de5d12af65d86d657c3db35aaecde778b2b3fda9193c4c9802",
|
||||
"blk.0.attn_output.bias": "2b8d12c4f9a9c5bfaa29c597839568f6e0525cb41eeaf64ddeb6bd84dfeb9701",
|
||||
"blk.0.attn_output_norm.weight": "bbe6e502a473228b525aeed26cc31b7db123ad63bdc5a6eebac6ea70b8b51d62",
|
||||
"blk.0.attn_output_norm.bias": "36eaacaf0007c5c62daea97aab0115390c0682914f78482e37eb76885f4b7a50",
|
||||
"blk.0.ffn_up.weight": "24654561c76ce387d125759ba843f06b904ef721fcceaeff6ccc62180a48e874",
|
||||
"blk.0.ffn_up.bias": "fd3f0126aa1d95768fa60eb6f4ab8a2763cfcb7e5405f35b92353031d86f4d34",
|
||||
"blk.0.ffn_down.weight": "97a829763a6a5bf3329ceb4d39c424ba4787d61653a5b0bbd1f84782e4d4e0ca",
|
||||
"blk.0.ffn_down.bias": "7aa980c30ae8b4ee7f69df28808dbf5c431f56ccc4a80340f644a0419f16c054",
|
||||
"blk.0.layer_output_norm.weight": "ef30dad4c2a083ae1ff5039a2a6cda60ecc89bf1e486a6f8c0d15f50589603f8",
|
||||
"blk.0.layer_output_norm.bias": "8b1b77e67568b1bce43fc476de1b177c53ff688d66beb66995e8eb3dc290da8a",
|
||||
"blk.1.attn_q.weight": "284331622a1f6f9b87ccee4f652bd66a394ca493c4d93be4d1844e4f6159ad10",
|
||||
"blk.1.attn_q.bias": "e24ebd4860330e08f6bfdd077a82db0bee33f4c8846cf1db26327a34754c7069",
|
||||
"blk.1.attn_k.weight": "729dd0d555544b5bd0f7580b3c8b384256b974605f0e7487b95f295aa032997d",
|
||||
"blk.1.attn_k.bias": "2aa51a828a858f35473f54477583fea54ce2ccc34ea60fbd1d228fbe9bca827f",
|
||||
"blk.1.attn_v.weight": "6be304671cc311d5ca5c103f2b51467ee800c589bc5b8101e09ff5aed1f68c21",
|
||||
"blk.1.attn_v.bias": "43bcbab78a8819e07f723bc9e5b737b71e87a7594f15234e882b63e327a64199",
|
||||
"blk.1.attn_output.weight": "15ec8a1a12b26c9976445308a09f748ab0e4bef0f583d13ab08c3129f8738d73",
|
||||
"blk.1.attn_output.bias": "dac2146f4baa6ed16f6c0dc7443831fb7ec79bedcceafd80d1a4b628a1bb072d",
|
||||
"blk.1.attn_output_norm.weight": "d2151eb33bffac536787a4c9a5d2b31c7a80b17c4611877842a3cce2cd6e98d8",
|
||||
"blk.1.attn_output_norm.bias": "31e1b779716dafb855d2cf5631ee168a0ccf372eb9c6ea6091f66fa97a9b9d2d",
|
||||
"blk.1.ffn_up.weight": "a57547fc3fc3b77406f5cdcb0c87af9bc184701f175c39c1f35297826fce3cc7",
|
||||
"blk.1.ffn_up.bias": "123be6d541d086202913c75d878c54d59a749f3af7b58f7ef9eb9e7c62a24c9a",
|
||||
"blk.1.ffn_down.weight": "cfdb79788377e5cbded8790cd41b9e66c397ecab75474071fcd7cf32d30f9613",
|
||||
"blk.1.ffn_down.bias": "bcb58315519a573097960891c9ae41cf4c685ab78c3e0e77471471758a7eae88",
|
||||
"blk.1.layer_output_norm.weight": "819b554271452bfb1d84c2603b90377b2e41a0ac1e3aa8b417ccf9dce63375bd",
|
||||
"blk.1.layer_output_norm.bias": "47a3433ac27f5ce8947fb38dd491f3706df4ef6adb0ddf74612bf0f54b19e164",
|
||||
"blk.2.attn_q.weight": "1557a9ea852b1880551f7290e00aded4f35e6c4180fdcbed1b0039bf805f639e",
|
||||
"blk.2.attn_q.bias": "c3bfe5f3066f655fd36b055530997b59ff33ef013563aaeb3cb8ff07dabd59a9",
|
||||
"blk.2.attn_k.weight": "cfd08eb69c61ae2f9f14f9b7ff5c5394ca264b1a9f3d48156677f90dd1766289",
|
||||
"blk.2.attn_k.bias": "9b839bc0e79974a0b3f5d1895972bc6f5c9a1bc16052e1af786e6a530758152d",
|
||||
"blk.2.attn_v.weight": "02b26b1208480eaeeb00e7b4cf8b690006ca14759357fc44ed4a2a8924ead993",
|
||||
"blk.2.attn_v.bias": "e7e6f0089fded1659a867ab736c220d9653ea7da6b1b94baf5c8d30a748b63ab",
|
||||
"blk.2.attn_output.weight": "a1db121c7d33806b349cadd050300a57db49fdc91224fd07c9ac43bf4299dc79",
|
||||
"blk.2.attn_output.bias": "7675128b6a92555cd955c820311e91e9417d31f48848f45d047b4100c62148b3",
|
||||
"blk.2.attn_output_norm.weight": "5b4595e0fbcba67a700c4331adf746d2fba3546364a4db5607ae241947bb1a21",
|
||||
"blk.2.attn_output_norm.bias": "7b8e16826ea30e5a2ba0b02e0095a901775981a296e98819625320e983060d08",
|
||||
"blk.2.ffn_up.weight": "a0d815d946ac07a65095c4ae4df77b818845e6d97795c7d82f55e689d944db59",
|
||||
"blk.2.ffn_up.bias": "ce37c0a4174d6bf773ded7bd016ede627ad3bdb8bc99b9992a18dc8e8898f252",
|
||||
"blk.2.ffn_down.weight": "f6231d2a25426fbd45b9f1160aa484220eb227ceef0348c4a6a6de890606e5ef",
|
||||
"blk.2.ffn_down.bias": "429e00556e8dc63a785238b309b9d83738500c1ef6d736fe6526ad88ea496d27",
|
||||
"blk.2.layer_output_norm.weight": "651457a573adf3f7dd9ee5dfe1c8e89389e94443993aab77ec6a0b05aa621e35",
|
||||
"blk.2.layer_output_norm.bias": "41fbbeda7fd89b0cef5f945ae44011c316982390401d6f75ba8c6d365e185247",
|
||||
"blk.3.attn_q.weight": "95a43f32949d2cb8d22815bb27a44abfc6665ba96221af817dfe058cb6ca72c6",
|
||||
"blk.3.attn_q.bias": "f4e34385e75d8108b6b3bd336106e2133a8c9be0cc343dfe5dc48c32a823c7cb",
|
||||
"blk.3.attn_k.weight": "6b892da6a17d4d3265265a15f695864a31813ee8c8e710ae9bc9e1adbc6c9a18",
|
||||
"blk.3.attn_k.bias": "40b8067b641a56014cee42548240aa8930820958b1933004892b5f04fbaef39e",
|
||||
"blk.3.attn_v.weight": "9fcd5922319dd2a461082a5ce040c1dfe65d87d70ca6547dd0b46eeecc3eeb2b",
|
||||
"blk.3.attn_v.bias": "b528c56212e66931fdbe267ac327a9c2f87cd03baff3ea719e30afe681da15f1",
|
||||
"blk.3.attn_output.weight": "e3b178c1b03981e75510e0d277af23ea59cc404b5394e61bd32291825719b502",
|
||||
"blk.3.attn_output.bias": "712c84d39a6a5a9c06a09da8fd9939ba0d5525524a4bba61ea4de09b48f45cae",
|
||||
"blk.3.attn_output_norm.weight": "d1ffac88e675592ff72f8a617be32b4a381d443b2f8f2645dbe44a1e5745aac0",
|
||||
"blk.3.attn_output_norm.bias": "ea31a1c73146234c50e0e43f485c458413714867b8e2703af66482f7db2d6c40",
|
||||
"blk.3.ffn_up.weight": "4ef4f3b9a1ea6ab2ef2eb6e8b008e06a44790d099d97482a05a51e39a29afac0",
|
||||
"blk.3.ffn_up.bias": "06a4296dda16f452675c51f108079fe7722552d6521c737d97734943818b9a2b",
|
||||
"blk.3.ffn_down.weight": "f114b2bebe392c7d80433bb880c6730293aa4561b0b0370dcdaf7472daebd847",
|
||||
"blk.3.ffn_down.bias": "2c8e67831d28a3bf613fc7912ae3259b63d72abcaf4d30efd8800758400158de",
|
||||
"blk.3.layer_output_norm.weight": "a1dfeb7b5a51dd56447312ca41e2ad2f361a3ea12ddc355127f5f4219fb0a482",
|
||||
"blk.3.layer_output_norm.bias": "1ed630021b25c6c6fc93fd32988b9907df966d4982a93081f639aac3044618ab",
|
||||
"blk.4.attn_q.weight": "b5fae4c1f9a5f33a2a2e816ac0c01c25f422e4efdd59ef1ed93da2610e5370fc",
|
||||
"blk.4.attn_q.bias": "c2e376524ea98ac3b10d9eee19ecb1b1e261fa5149efe0232844c923dfb428fb",
|
||||
"blk.4.attn_k.weight": "a4632f5ebf9321d9d08f9112a4e5dda2efe5671df4a4e67fee24845f5b14af16",
|
||||
"blk.4.attn_k.bias": "a9a02ffb8b8b4f6dfe487a7e0341f1d5318c9d2b793a688f34cb1b22fc66ef60",
|
||||
"blk.4.attn_v.weight": "10ad8deb81d9fa093b1e5c0f24ea82aa7df43e6aca49e260fcbea56eab8cc86a",
|
||||
"blk.4.attn_v.bias": "7326813e181e021130bd33ac136293fcffccce2d1d8cb59041e5b13a8cceacf6",
|
||||
"blk.4.attn_output.weight": "c92573088c7437c2b3cda51490e152c27fb19e5468df591eabba5a49d5398d44",
|
||||
"blk.4.attn_output.bias": "14e10b419e5859af1eb685af5c330aee67048cd704dcead9217840c6f5393222",
|
||||
"blk.4.attn_output_norm.weight": "02b6831c0e0fb0edbc579a92812a1dd972cb15d14fcd382d4427c5a7b300ac44",
|
||||
"blk.4.attn_output_norm.bias": "7eed5cd503bb6bb6ceb1bc8b07cc077903a4f14fb8b9d6cdf39644815ecf1374",
|
||||
"blk.4.ffn_up.weight": "8d0c91d62e74d6431321116a37cf3339e630bd50ba164d3304fc4fe8dd831223",
|
||||
"blk.4.ffn_up.bias": "d325f07f73c005a273c484c7be8e7abb4d6e8a5c4fd093f5869133b97629d017",
|
||||
"blk.4.ffn_down.weight": "7ba7bd81143f40537b84f938e403e19f30e4928625eb371de052b9025beb4d21",
|
||||
"blk.4.ffn_down.bias": "2853d9c2a75288214a4bf4907dc19d04d01926f4913d302b1aa7bdbfcce0f7a1",
|
||||
"blk.4.layer_output_norm.weight": "a4ed1885fa77b90fed5300c355ef0aa0c876a8c747151d9d790939d464d57d4f",
|
||||
"blk.4.layer_output_norm.bias": "62142a81e813a9e636333b2b805d6bc3b17c5e7cd4b15adce1ada6bc9a32563c",
|
||||
"blk.5.attn_q.weight": "afc1dff080a72c3daad01384b1448d476aaf789871017c8ff8e144788887995d",
|
||||
"blk.5.attn_q.bias": "748a820371c1d4f872c84545b36358d239c35bf6c99e2812c237d88c3292763b",
|
||||
"blk.5.attn_k.weight": "59e30c1ed8acd2cbb01de5f62e7804015b9ecf98ba157d98cab016344639eda5",
|
||||
"blk.5.attn_k.bias": "f839520078f9e589496e982e86d0126c7aa14196047339abffcf49a696229f77",
|
||||
"blk.5.attn_v.weight": "3e21fb874e21b90308e1f46af034a3c32d3eba1628d62ae5f2246d6af5818923",
|
||||
"blk.5.attn_v.bias": "5cd4852bf95c1444d10d756750f6bf49f842c0b39e9953c7f408bb67c325ac8c",
|
||||
"blk.5.attn_output.weight": "636ce6a7752895f204b9d01ba0aedd9a294f908b42f372c22a16d9dd590d7471",
|
||||
"blk.5.attn_output.bias": "82d924d4b0d2b94f2bbff91619216d6967a3541ce9b1531a6a60457a67b5d219",
|
||||
"blk.5.attn_output_norm.weight": "5e7bd0a8d3396080f3360d7c4700bf094a06216431bd014c4479eef72ecf4271",
|
||||
"blk.5.attn_output_norm.bias": "66c6de5edda5466d029c6753780be81ccd4218bf8bc00680000e0f06856ab712",
|
||||
"blk.5.ffn_up.weight": "5bbf6e7ea380e216e33f8bee06d25f2265359d3876a300e92bc6e41d48e33430",
|
||||
"blk.5.ffn_up.bias": "9d795388bb36fb33ad3a37fea3ccb4937838e02800a608fb47d363cd06b47370",
|
||||
"blk.5.ffn_down.weight": "2fd628974e7f075479dd227b46fbd48ae8d3ca34d735b36f391ac06410730368",
|
||||
"blk.5.ffn_down.bias": "cd213ba9eaa75fa541648097fbe9c96e58077e6c3ad6ad2fb1f21f8350f44291",
|
||||
"blk.5.layer_output_norm.weight": "159a9df41d15b7022d136f86a2a2631c4635f9816e957472217077b522bcf52a",
|
||||
"blk.5.layer_output_norm.bias": "24c1f27ffd1eb4e5be7e3a2909943e6f0980635d761fa1efdd0c19645da23766"
|
||||
}
|
|
@ -1,7 +1,6 @@
|
|||
package convert
|
||||
|
||||
import (
|
||||
"cmp"
|
||||
"crypto/sha256"
|
||||
"encoding/hex"
|
||||
"encoding/json"
|
||||
|
@ -11,6 +10,8 @@ import (
|
|||
"log/slog"
|
||||
"os"
|
||||
"slices"
|
||||
|
||||
"golang.org/x/exp/maps"
|
||||
)
|
||||
|
||||
const (
|
||||
|
@ -184,32 +185,32 @@ func parseVocabularyFromTokenizer(fsys fs.FS) (*Vocabulary, error) {
|
|||
return nil, err
|
||||
}
|
||||
|
||||
var tokens []token
|
||||
tokens := make(map[int]token, len(t.Model.Vocab))
|
||||
for k, v := range t.Model.Vocab {
|
||||
tokens = append(tokens, token{
|
||||
tokens[v] = token{
|
||||
ID: v,
|
||||
Content: k,
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
for _, t := range t.AddedTokens {
|
||||
t.UserDefined = true
|
||||
tokens = append(tokens, t)
|
||||
for _, token := range t.AddedTokens {
|
||||
token.UserDefined = true
|
||||
tokens[token.ID] = token
|
||||
}
|
||||
|
||||
slices.SortFunc(tokens, func(i, j token) int {
|
||||
return cmp.Compare(i.ID, j.ID)
|
||||
})
|
||||
keys := maps.Keys(tokens)
|
||||
slices.Sort(keys)
|
||||
|
||||
v := Vocabulary{Model: "gpt2"}
|
||||
for _, t := range tokens {
|
||||
v.Tokens = append(v.Tokens, t.Content)
|
||||
v.Scores = append(v.Scores, float32(t.ID))
|
||||
for _, k := range keys {
|
||||
token := tokens[k]
|
||||
v.Tokens = append(v.Tokens, token.Content)
|
||||
v.Scores = append(v.Scores, float32(token.ID))
|
||||
|
||||
switch {
|
||||
case t.Special:
|
||||
case token.Special:
|
||||
v.Types = append(v.Types, tokenTypeControl)
|
||||
case t.UserDefined:
|
||||
case token.UserDefined:
|
||||
v.Types = append(v.Types, tokenTypeUserDefined)
|
||||
default:
|
||||
v.Types = append(v.Types, tokenTypeNormal)
|
||||
|
|
Loading…
Reference in a new issue