ollama/convert/tokenizer_spm.go

84 lines
1.7 KiB
Go
Raw Normal View History

2024-05-31 20:00:49 -07:00
package convert
import (
"cmp"
"encoding/json"
"errors"
"fmt"
2024-06-29 16:53:59 -07:00
"io/fs"
2024-05-31 20:00:49 -07:00
"os"
"slices"
"google.golang.org/protobuf/proto"
"github.com/ollama/ollama/convert/sentencepiece"
)
2024-06-29 16:53:59 -07:00
func parseSentencePiece(fsys fs.FS) (*Vocabulary, error) {
bts, err := fs.ReadFile(fsys, "tokenizer.model")
2024-05-31 20:00:49 -07:00
if err != nil {
return nil, err
}
var spm sentencepiece.ModelProto
if err := proto.Unmarshal(bts, &spm); err != nil {
return nil, err
}
v := Vocabulary{Model: "llama"}
for _, piece := range spm.GetPieces() {
v.Tokens = append(v.Tokens, piece.GetPiece())
v.Scores = append(v.Scores, piece.GetScore())
switch t := piece.GetType(); t {
case sentencepiece.ModelProto_SentencePiece_UNKNOWN,
sentencepiece.ModelProto_SentencePiece_CONTROL,
sentencepiece.ModelProto_SentencePiece_UNUSED,
sentencepiece.ModelProto_SentencePiece_BYTE:
v.Types = append(v.Types, int32(t))
default:
v.Types = append(v.Types, int32(sentencepiece.ModelProto_SentencePiece_NORMAL))
}
}
2024-06-29 16:53:59 -07:00
f, err := fsys.Open("added_tokens.json")
2024-05-31 20:00:49 -07:00
if errors.Is(err, os.ErrNotExist) {
return &v, nil
} else if err != nil {
return nil, err
}
defer f.Close()
var atm map[string]int
if err := json.NewDecoder(f).Decode(&atm); err != nil {
return nil, err
}
type t struct {
id int
content string
}
var ts []t
for content, id := range atm {
ts = append(ts, t{id, content})
}
slices.SortFunc(ts, func(i, j t) int {
return cmp.Compare(i.id, j.id)
})
n := len(v.Tokens)
for i, t := range ts {
if t.id != i+n {
return nil, fmt.Errorf("invalid token id: %d", t.id)
}
v.Tokens = append(v.Tokens, t.content)
v.Scores = append(v.Scores, -1000.0)
v.Types = append(v.Types, tokenTypeUserDefined)
}
return &v, nil
}