add safetensors version

This commit is contained in:
Patrick Devine 2024-04-24 18:32:01 -07:00 committed by Michael Yang
parent d88582dffd
commit 4730762e5c
2 changed files with 20 additions and 4 deletions

View file

@ -20,7 +20,7 @@ type LlamaModel struct {
ModelData
}
func llamaLayerHandler(w io.Writer, r torchWriterTo) error {
func llamaTorchLayerHandler(w io.Writer, r torchWriterTo) error {
slog.Debug(fmt.Sprintf("repacking layer '%s'", r.t.Name))
data := r.storage.(*pytorch.HalfStorage).Data
@ -105,9 +105,16 @@ func (m *LlamaModel) GetTensors() error {
matches := re.FindAllStringSubmatch(l.Name, -1)
if len(matches) > 0 {
slog.Debug(fmt.Sprintf("setting handler for: %s", l.Name))
wt := l.WriterTo.(torchWriterTo)
wt.handler = llamaLayerHandler
l.WriterTo = wt
switch l.WriterTo.(type) {
case torchWriterTo:
wt := l.WriterTo.(torchWriterTo)
wt.handler = llamaTorchLayerHandler
l.WriterTo = wt
case safetensorWriterTo:
wt := l.WriterTo.(safetensorWriterTo)
wt.handler = mistralLayerHandler
l.WriterTo = wt
}
}
m.Tensors = append(m.Tensors, l)
}

View file

@ -281,6 +281,15 @@ func (m *SafetensorFormat) GetModelArch(name, dirPath string, params *Params) (M
return nil, fmt.Errorf("No architecture specified to convert")
case 1:
switch params.Architectures[0] {
case "LlamaForCausalLM":
return &LlamaModel{
ModelData{
Name: name,
Path: dirPath,
Params: params,
Format: m,
},
}, nil
case "MistralForCausalLM":
return &MistralModel{
ModelData{