ollama/convert/convert.go

229 lines
5.8 KiB
Go
Raw Normal View History

package convert
import (
"encoding/json"
2024-05-31 20:00:49 -07:00
"errors"
"fmt"
"io"
2024-06-29 16:53:59 -07:00
"io/fs"
"log/slog"
2024-06-28 13:27:05 -07:00
"strings"
"github.com/ollama/ollama/llm"
)
type ModelParameters struct {
2024-05-31 20:00:49 -07:00
Architectures []string `json:"architectures"`
VocabSize uint32 `json:"vocab_size"`
}
type AdapterParameters struct {
Alpha uint32 `json:"lora_alpha"`
LoraLayers uint32 `json:"lora_layers"`
LoraParameters struct {
Rank uint32 `json:"rank"`
Alpha float32 `json:"alpha"`
Scale float32 `json:"scale"`
} `json:"lora_parameters"`
}
func (ModelParameters) KV(t *Tokenizer) llm.KV {
2024-05-31 20:00:49 -07:00
kv := llm.KV{
"general.file_type": uint32(1),
"general.quantization_version": uint32(2),
"tokenizer.ggml.pre": t.Pre,
"tokenizer.ggml.model": t.Vocabulary.Model,
"tokenizer.ggml.tokens": t.Vocabulary.Tokens,
"tokenizer.ggml.scores": t.Vocabulary.Scores,
"tokenizer.ggml.token_type": t.Vocabulary.Types,
}
if len(t.Merges) > 0 {
kv["tokenizer.ggml.merges"] = t.Merges
}
2024-05-31 20:00:49 -07:00
if t.Template != "" {
kv["tokenizer.chat_template"] = t.Template
}
2024-04-01 16:14:53 -07:00
2024-05-31 20:00:49 -07:00
for _, sv := range t.SpecialVocabulary {
kv[fmt.Sprintf("tokenizer.ggml.%s_token_id", sv.Key())] = uint32(sv.ID)
kv[fmt.Sprintf("tokenizer.ggml.add_%s_token", sv.Key())] = sv.AddToken
}
2024-05-31 20:00:49 -07:00
return kv
2024-04-01 16:14:53 -07:00
}
func (p AdapterParameters) KV() llm.KV {
var alpha float32
if p.LoraParameters.Alpha == 0 {
alpha = float32(p.Alpha)
} else {
alpha = p.LoraParameters.Alpha
}
kv := llm.KV{
"adapter.lora.alpha": alpha,
"adapter.type": "lora",
"general.file_type": uint32(1),
"general.type": "adapter",
"general.version": "v0.2",
}
return kv
}
func (ModelParameters) specialTokenTypes() []string {
2024-05-31 20:00:49 -07:00
return []string{
"bos", "eos", "unk", "sep", "pad", "cls", "mask",
}
2024-05-31 20:00:49 -07:00
}
func (ModelParameters) writeFile(ws io.WriteSeeker, kv llm.KV, ts []llm.Tensor) error {
2024-05-31 20:00:49 -07:00
return llm.WriteGGUF(ws, kv, ts)
}
func (AdapterParameters) writeFile(ws io.WriteSeeker, kv llm.KV, ts []llm.Tensor) error {
return llm.WriteGGUF(ws, kv, ts)
}
type ModelConverter interface {
2024-05-31 20:00:49 -07:00
// KV maps parameters to LLM key-values
KV(*Tokenizer) llm.KV
// Tensors maps input tensors to LLM tensors. Model specific modifications can be done here.
2024-07-08 16:59:48 -07:00
Tensors([]Tensor) []llm.Tensor
2024-06-28 13:27:05 -07:00
// Replacements returns a list of string pairs to replace in tensor names.
// See [strings.Replacer](https://pkg.go.dev/strings#Replacer) for details
Replacements() []string
2024-05-31 20:00:49 -07:00
2024-07-08 16:59:48 -07:00
// specialTokenTypes returns any special token types the model uses
specialTokenTypes() []string
2024-06-28 13:27:05 -07:00
// writeFile writes the model to the provided io.WriteSeeker
2024-07-08 16:59:48 -07:00
writeFile(io.WriteSeeker, llm.KV, []llm.Tensor) error
}
2024-06-06 08:59:04 -07:00
type moreParser interface {
parseMore(fs.FS) error
}
type AdapterConverter interface {
// KV maps parameters to LLM key-values
KV(llm.KV) llm.KV
// Tensors maps input tensors to LLM tensors. Adapter specific modifications can be done here.
Tensors([]Tensor) []llm.Tensor
// Replacements returns a list of string pairs to replace in tensor names.
// See [strings.Replacer](https://pkg.go.dev/strings#Replacer) for details
Replacements() []string
writeFile(io.WriteSeeker, llm.KV, []llm.Tensor) error
}
func ConvertAdapter(fsys fs.FS, ws io.WriteSeeker, baseKV llm.KV) error {
bts, err := fs.ReadFile(fsys, "adapter_config.json")
if err != nil {
return err
}
var p AdapterParameters
if err := json.Unmarshal(bts, &p); err != nil {
return err
}
arch, ok := baseKV["general.architecture"]
if !ok {
return errors.New("architecture not set for the base model")
}
var conv AdapterConverter
switch arch {
case "llama":
conv = &llamaAdapter{}
case "gemma2":
conv = &gemma2Adapter{}
default:
return errors.New("unsupported architecture")
}
ts, err := parseTensors(fsys, strings.NewReplacer(conv.Replacements()...))
if err != nil {
return err
}
if err := json.Unmarshal(bts, conv); err != nil {
return err
}
return conv.writeFile(ws, conv.KV(baseKV), conv.Tensors(ts))
}
2024-07-08 16:59:48 -07:00
// Convert writes an Ollama compatible model to the provided io.WriteSeeker based on configurations
// and files it finds in the input path.
// Supported input model formats include safetensors.
// Supported input tokenizers files include tokenizer.json (preferred) and tokenizer.model.
func ConvertModel(fsys fs.FS, ws io.WriteSeeker) error {
2024-06-29 16:53:59 -07:00
bts, err := fs.ReadFile(fsys, "config.json")
if err != nil {
2024-05-31 20:00:49 -07:00
return err
}
var p ModelParameters
2024-07-08 16:59:48 -07:00
if err := json.Unmarshal(bts, &p); err != nil {
2024-05-31 20:00:49 -07:00
return err
}
2024-05-31 20:00:49 -07:00
if len(p.Architectures) < 1 {
return errors.New("unknown architecture")
}
var conv ModelConverter
2024-05-31 20:00:49 -07:00
switch p.Architectures[0] {
case "LlamaForCausalLM", "MistralForCausalLM":
conv = &llamaModel{}
2024-05-31 20:00:49 -07:00
case "MixtralForCausalLM":
conv = &mixtralModel{}
2024-05-31 20:00:49 -07:00
case "GemmaForCausalLM":
conv = &gemmaModel{}
2024-06-28 13:27:05 -07:00
case "Gemma2ForCausalLM":
conv = &gemma2Model{}
case "Phi3ForCausalLM":
conv = &phi3Model{}
2024-06-06 08:59:04 -07:00
case "BertModel":
conv = &bertModel{}
2024-05-31 20:00:49 -07:00
default:
return errors.New("unsupported architecture")
}
2024-07-08 16:59:48 -07:00
if err := json.Unmarshal(bts, conv); err != nil {
2024-05-31 20:00:49 -07:00
return err
}
2024-06-06 08:59:04 -07:00
if t, ok := conv.(moreParser); ok {
if err := t.parseMore(fsys); err != nil {
return err
}
}
2024-06-29 16:53:59 -07:00
t, err := parseTokenizer(fsys, conv.specialTokenTypes())
2024-05-31 20:00:49 -07:00
if err != nil {
return err
}
2024-05-31 20:00:49 -07:00
if vocabSize := int(p.VocabSize); vocabSize > len(t.Vocabulary.Tokens) {
slog.Warn("vocabulary is smaller than expected, padding with dummy tokens", "expect", p.VocabSize, "actual", len(t.Vocabulary.Tokens))
for i := range vocabSize - len(t.Vocabulary.Tokens) {
t.Vocabulary.Tokens = append(t.Vocabulary.Tokens, fmt.Sprintf("[PAD%d]", i))
t.Vocabulary.Scores = append(t.Vocabulary.Scores, -1)
t.Vocabulary.Types = append(t.Vocabulary.Types, tokenTypeUserDefined)
}
2024-07-08 16:59:48 -07:00
} else {
slog.Debug("vocabulary", "size", len(t.Vocabulary.Tokens))
}
2024-05-31 20:00:49 -07:00
2024-06-28 13:27:05 -07:00
ts, err := parseTensors(fsys, strings.NewReplacer(conv.Replacements()...))
2024-05-31 20:00:49 -07:00
if err != nil {
return err
}
2024-07-08 16:59:48 -07:00
return conv.writeFile(ws, conv.KV(t), conv.Tensors(ts))
}