2024-03-07 05:01:51 +00:00
|
|
|
package convert
|
|
|
|
|
|
|
|
import (
|
|
|
|
"encoding/json"
|
2024-06-01 03:00:49 +00:00
|
|
|
"errors"
|
2024-03-07 05:01:51 +00:00
|
|
|
"fmt"
|
2024-04-12 20:55:12 +00:00
|
|
|
"io"
|
2024-06-29 23:53:59 +00:00
|
|
|
"io/fs"
|
2024-03-07 05:01:51 +00:00
|
|
|
"log/slog"
|
2024-06-28 20:27:05 +00:00
|
|
|
"strings"
|
2024-03-07 05:01:51 +00:00
|
|
|
|
2024-03-26 20:04:17 +00:00
|
|
|
"github.com/ollama/ollama/llm"
|
2024-03-07 05:01:51 +00:00
|
|
|
)
|
|
|
|
|
2024-06-01 03:00:49 +00:00
|
|
|
type Parameters struct {
|
|
|
|
Architectures []string `json:"architectures"`
|
|
|
|
VocabSize uint32 `json:"vocab_size"`
|
2024-03-29 01:54:01 +00:00
|
|
|
}
|
|
|
|
|
2024-06-01 03:00:49 +00:00
|
|
|
func (Parameters) KV(t *Tokenizer) llm.KV {
|
|
|
|
kv := llm.KV{
|
|
|
|
"general.file_type": uint32(1),
|
|
|
|
"general.quantization_version": uint32(2),
|
|
|
|
"tokenizer.ggml.pre": t.Pre,
|
|
|
|
"tokenizer.ggml.model": t.Vocabulary.Model,
|
|
|
|
"tokenizer.ggml.tokens": t.Vocabulary.Tokens,
|
|
|
|
"tokenizer.ggml.scores": t.Vocabulary.Scores,
|
|
|
|
"tokenizer.ggml.token_type": t.Vocabulary.Types,
|
|
|
|
}
|
2024-03-07 05:01:51 +00:00
|
|
|
|
2024-06-03 22:53:58 +00:00
|
|
|
if len(t.Merges) > 0 {
|
|
|
|
kv["tokenizer.ggml.merges"] = t.Merges
|
|
|
|
}
|
|
|
|
|
2024-06-01 03:00:49 +00:00
|
|
|
if t.Template != "" {
|
|
|
|
kv["tokenizer.chat_template"] = t.Template
|
|
|
|
}
|
2024-04-01 23:14:53 +00:00
|
|
|
|
2024-06-01 03:00:49 +00:00
|
|
|
for _, sv := range t.SpecialVocabulary {
|
|
|
|
kv[fmt.Sprintf("tokenizer.ggml.%s_token_id", sv.Key())] = uint32(sv.ID)
|
|
|
|
kv[fmt.Sprintf("tokenizer.ggml.add_%s_token", sv.Key())] = sv.AddToken
|
|
|
|
}
|
2024-04-15 18:26:42 +00:00
|
|
|
|
2024-06-01 03:00:49 +00:00
|
|
|
return kv
|
2024-04-01 23:14:53 +00:00
|
|
|
}
|
|
|
|
|
2024-07-08 23:59:48 +00:00
|
|
|
func (Parameters) specialTokenTypes() []string {
|
2024-06-01 03:00:49 +00:00
|
|
|
return []string{
|
|
|
|
"bos", "eos", "unk", "sep", "pad", "cls", "mask",
|
2024-03-07 05:01:51 +00:00
|
|
|
}
|
2024-06-01 03:00:49 +00:00
|
|
|
}
|
2024-03-07 05:01:51 +00:00
|
|
|
|
2024-07-08 23:59:48 +00:00
|
|
|
func (Parameters) writeFile(ws io.WriteSeeker, kv llm.KV, ts []llm.Tensor) error {
|
2024-06-01 03:00:49 +00:00
|
|
|
return llm.WriteGGUF(ws, kv, ts)
|
2024-03-07 05:01:51 +00:00
|
|
|
}
|
|
|
|
|
2024-06-01 03:00:49 +00:00
|
|
|
type Converter interface {
|
|
|
|
// KV maps parameters to LLM key-values
|
|
|
|
KV(*Tokenizer) llm.KV
|
|
|
|
// Tensors maps input tensors to LLM tensors. Model specific modifications can be done here.
|
2024-07-08 23:59:48 +00:00
|
|
|
Tensors([]Tensor) []llm.Tensor
|
2024-06-28 20:27:05 +00:00
|
|
|
// Replacements returns a list of string pairs to replace in tensor names.
|
|
|
|
// See [strings.Replacer](https://pkg.go.dev/strings#Replacer) for details
|
|
|
|
Replacements() []string
|
2024-06-01 03:00:49 +00:00
|
|
|
|
2024-07-08 23:59:48 +00:00
|
|
|
// specialTokenTypes returns any special token types the model uses
|
|
|
|
specialTokenTypes() []string
|
2024-06-28 20:27:05 +00:00
|
|
|
// writeFile writes the model to the provided io.WriteSeeker
|
2024-07-08 23:59:48 +00:00
|
|
|
writeFile(io.WriteSeeker, llm.KV, []llm.Tensor) error
|
2024-03-07 05:01:51 +00:00
|
|
|
}
|
|
|
|
|
2024-06-06 15:59:04 +00:00
|
|
|
type moreParser interface {
|
|
|
|
parseMore(fs.FS) error
|
|
|
|
}
|
|
|
|
|
2024-07-08 23:59:48 +00:00
|
|
|
// Convert writes an Ollama compatible model to the provided io.WriteSeeker based on configurations
|
|
|
|
// and files it finds in the input path.
|
|
|
|
// Supported input model formats include safetensors.
|
|
|
|
// Supported input tokenizers files include tokenizer.json (preferred) and tokenizer.model.
|
2024-06-29 23:53:59 +00:00
|
|
|
func Convert(fsys fs.FS, ws io.WriteSeeker) error {
|
|
|
|
bts, err := fs.ReadFile(fsys, "config.json")
|
2024-03-07 05:01:51 +00:00
|
|
|
if err != nil {
|
2024-06-01 03:00:49 +00:00
|
|
|
return err
|
2024-03-07 05:01:51 +00:00
|
|
|
}
|
|
|
|
|
2024-06-01 03:00:49 +00:00
|
|
|
var p Parameters
|
2024-07-08 23:59:48 +00:00
|
|
|
if err := json.Unmarshal(bts, &p); err != nil {
|
2024-06-01 03:00:49 +00:00
|
|
|
return err
|
2024-03-07 05:01:51 +00:00
|
|
|
}
|
|
|
|
|
2024-06-01 03:00:49 +00:00
|
|
|
if len(p.Architectures) < 1 {
|
|
|
|
return errors.New("unknown architecture")
|
2024-03-07 05:01:51 +00:00
|
|
|
}
|
|
|
|
|
2024-07-08 23:59:48 +00:00
|
|
|
var conv Converter
|
2024-06-01 03:00:49 +00:00
|
|
|
switch p.Architectures[0] {
|
|
|
|
case "LlamaForCausalLM", "MistralForCausalLM":
|
2024-07-08 23:59:48 +00:00
|
|
|
conv = &llama{}
|
2024-06-01 03:00:49 +00:00
|
|
|
case "MixtralForCausalLM":
|
2024-07-08 23:59:48 +00:00
|
|
|
conv = &mixtral{}
|
2024-06-01 03:00:49 +00:00
|
|
|
case "GemmaForCausalLM":
|
2024-07-08 23:59:48 +00:00
|
|
|
conv = &gemma{}
|
2024-06-28 20:27:05 +00:00
|
|
|
case "Gemma2ForCausalLM":
|
|
|
|
conv = &gemma2{}
|
2024-06-03 22:53:58 +00:00
|
|
|
case "Phi3ForCausalLM":
|
|
|
|
conv = &phi3{}
|
2024-06-06 15:59:04 +00:00
|
|
|
case "BertModel":
|
|
|
|
conv = &bert{}
|
2024-06-01 03:00:49 +00:00
|
|
|
default:
|
|
|
|
return errors.New("unsupported architecture")
|
2024-03-07 05:01:51 +00:00
|
|
|
}
|
|
|
|
|
2024-07-08 23:59:48 +00:00
|
|
|
if err := json.Unmarshal(bts, conv); err != nil {
|
2024-06-01 03:00:49 +00:00
|
|
|
return err
|
2024-03-07 05:01:51 +00:00
|
|
|
}
|
|
|
|
|
2024-06-06 15:59:04 +00:00
|
|
|
if t, ok := conv.(moreParser); ok {
|
|
|
|
if err := t.parseMore(fsys); err != nil {
|
|
|
|
return err
|
|
|
|
}
|
|
|
|
}
|
|
|
|
|
2024-06-29 23:53:59 +00:00
|
|
|
t, err := parseTokenizer(fsys, conv.specialTokenTypes())
|
2024-06-01 03:00:49 +00:00
|
|
|
if err != nil {
|
|
|
|
return err
|
2024-03-07 05:01:51 +00:00
|
|
|
}
|
|
|
|
|
2024-06-01 03:00:49 +00:00
|
|
|
if vocabSize := int(p.VocabSize); vocabSize > len(t.Vocabulary.Tokens) {
|
|
|
|
slog.Warn("vocabulary is smaller than expected, padding with dummy tokens", "expect", p.VocabSize, "actual", len(t.Vocabulary.Tokens))
|
|
|
|
for i := range vocabSize - len(t.Vocabulary.Tokens) {
|
|
|
|
t.Vocabulary.Tokens = append(t.Vocabulary.Tokens, fmt.Sprintf("[PAD%d]", i))
|
|
|
|
t.Vocabulary.Scores = append(t.Vocabulary.Scores, -1)
|
|
|
|
t.Vocabulary.Types = append(t.Vocabulary.Types, tokenTypeUserDefined)
|
2024-03-07 05:01:51 +00:00
|
|
|
}
|
2024-07-08 23:59:48 +00:00
|
|
|
} else {
|
|
|
|
slog.Debug("vocabulary", "size", len(t.Vocabulary.Tokens))
|
2024-03-07 05:01:51 +00:00
|
|
|
}
|
2024-06-01 03:00:49 +00:00
|
|
|
|
2024-06-28 20:27:05 +00:00
|
|
|
ts, err := parseTensors(fsys, strings.NewReplacer(conv.Replacements()...))
|
2024-06-01 03:00:49 +00:00
|
|
|
if err != nil {
|
|
|
|
return err
|
2024-03-29 01:54:01 +00:00
|
|
|
}
|
|
|
|
|
2024-07-08 23:59:48 +00:00
|
|
|
return conv.writeFile(ws, conv.KV(t), conv.Tensors(ts))
|
2024-03-07 05:01:51 +00:00
|
|
|
}
|