2024-05-31 20:00:49 -07:00
|
|
|
package convert
|
|
|
|
|
|
|
|
import (
|
|
|
|
"cmp"
|
|
|
|
"encoding/json"
|
|
|
|
"errors"
|
|
|
|
"fmt"
|
2024-06-29 16:53:59 -07:00
|
|
|
"io/fs"
|
2024-05-31 20:00:49 -07:00
|
|
|
"os"
|
|
|
|
"slices"
|
|
|
|
|
|
|
|
"google.golang.org/protobuf/proto"
|
|
|
|
|
|
|
|
"github.com/ollama/ollama/convert/sentencepiece"
|
|
|
|
)
|
|
|
|
|
2024-06-29 16:53:59 -07:00
|
|
|
func parseSentencePiece(fsys fs.FS) (*Vocabulary, error) {
|
|
|
|
bts, err := fs.ReadFile(fsys, "tokenizer.model")
|
2024-05-31 20:00:49 -07:00
|
|
|
if err != nil {
|
|
|
|
return nil, err
|
|
|
|
}
|
|
|
|
|
|
|
|
var spm sentencepiece.ModelProto
|
|
|
|
if err := proto.Unmarshal(bts, &spm); err != nil {
|
|
|
|
return nil, err
|
|
|
|
}
|
|
|
|
|
|
|
|
v := Vocabulary{Model: "llama"}
|
|
|
|
for _, piece := range spm.GetPieces() {
|
|
|
|
v.Tokens = append(v.Tokens, piece.GetPiece())
|
|
|
|
v.Scores = append(v.Scores, piece.GetScore())
|
|
|
|
|
|
|
|
switch t := piece.GetType(); t {
|
|
|
|
case sentencepiece.ModelProto_SentencePiece_UNKNOWN,
|
|
|
|
sentencepiece.ModelProto_SentencePiece_CONTROL,
|
|
|
|
sentencepiece.ModelProto_SentencePiece_UNUSED,
|
|
|
|
sentencepiece.ModelProto_SentencePiece_BYTE:
|
|
|
|
v.Types = append(v.Types, int32(t))
|
|
|
|
default:
|
|
|
|
v.Types = append(v.Types, int32(sentencepiece.ModelProto_SentencePiece_NORMAL))
|
|
|
|
}
|
|
|
|
}
|
|
|
|
|
2024-06-29 16:53:59 -07:00
|
|
|
f, err := fsys.Open("added_tokens.json")
|
2024-05-31 20:00:49 -07:00
|
|
|
if errors.Is(err, os.ErrNotExist) {
|
|
|
|
return &v, nil
|
|
|
|
} else if err != nil {
|
|
|
|
return nil, err
|
|
|
|
}
|
|
|
|
defer f.Close()
|
|
|
|
|
|
|
|
var atm map[string]int
|
|
|
|
if err := json.NewDecoder(f).Decode(&atm); err != nil {
|
|
|
|
return nil, err
|
|
|
|
}
|
|
|
|
|
|
|
|
type t struct {
|
|
|
|
id int
|
|
|
|
content string
|
|
|
|
}
|
|
|
|
|
|
|
|
var ts []t
|
|
|
|
for content, id := range atm {
|
|
|
|
ts = append(ts, t{id, content})
|
|
|
|
}
|
|
|
|
|
|
|
|
slices.SortFunc(ts, func(i, j t) int {
|
|
|
|
return cmp.Compare(i.id, j.id)
|
|
|
|
})
|
|
|
|
|
|
|
|
n := len(v.Tokens)
|
|
|
|
for i, t := range ts {
|
|
|
|
if t.id != i+n {
|
|
|
|
return nil, fmt.Errorf("invalid token id: %d", t.id)
|
|
|
|
}
|
|
|
|
|
|
|
|
v.Tokens = append(v.Tokens, t.content)
|
|
|
|
v.Scores = append(v.Scores, -1000.0)
|
|
|
|
v.Types = append(v.Types, tokenTypeUserDefined)
|
|
|
|
}
|
|
|
|
|
|
|
|
return &v, nil
|
|
|
|
}
|