add safetensors version
This commit is contained in:
parent
d88582dffd
commit
4730762e5c
2 changed files with 20 additions and 4 deletions
|
@ -20,7 +20,7 @@ type LlamaModel struct {
|
||||||
ModelData
|
ModelData
|
||||||
}
|
}
|
||||||
|
|
||||||
func llamaLayerHandler(w io.Writer, r torchWriterTo) error {
|
func llamaTorchLayerHandler(w io.Writer, r torchWriterTo) error {
|
||||||
slog.Debug(fmt.Sprintf("repacking layer '%s'", r.t.Name))
|
slog.Debug(fmt.Sprintf("repacking layer '%s'", r.t.Name))
|
||||||
|
|
||||||
data := r.storage.(*pytorch.HalfStorage).Data
|
data := r.storage.(*pytorch.HalfStorage).Data
|
||||||
|
@ -105,9 +105,16 @@ func (m *LlamaModel) GetTensors() error {
|
||||||
matches := re.FindAllStringSubmatch(l.Name, -1)
|
matches := re.FindAllStringSubmatch(l.Name, -1)
|
||||||
if len(matches) > 0 {
|
if len(matches) > 0 {
|
||||||
slog.Debug(fmt.Sprintf("setting handler for: %s", l.Name))
|
slog.Debug(fmt.Sprintf("setting handler for: %s", l.Name))
|
||||||
wt := l.WriterTo.(torchWriterTo)
|
switch l.WriterTo.(type) {
|
||||||
wt.handler = llamaLayerHandler
|
case torchWriterTo:
|
||||||
l.WriterTo = wt
|
wt := l.WriterTo.(torchWriterTo)
|
||||||
|
wt.handler = llamaTorchLayerHandler
|
||||||
|
l.WriterTo = wt
|
||||||
|
case safetensorWriterTo:
|
||||||
|
wt := l.WriterTo.(safetensorWriterTo)
|
||||||
|
wt.handler = mistralLayerHandler
|
||||||
|
l.WriterTo = wt
|
||||||
|
}
|
||||||
}
|
}
|
||||||
m.Tensors = append(m.Tensors, l)
|
m.Tensors = append(m.Tensors, l)
|
||||||
}
|
}
|
||||||
|
|
|
@ -281,6 +281,15 @@ func (m *SafetensorFormat) GetModelArch(name, dirPath string, params *Params) (M
|
||||||
return nil, fmt.Errorf("No architecture specified to convert")
|
return nil, fmt.Errorf("No architecture specified to convert")
|
||||||
case 1:
|
case 1:
|
||||||
switch params.Architectures[0] {
|
switch params.Architectures[0] {
|
||||||
|
case "LlamaForCausalLM":
|
||||||
|
return &LlamaModel{
|
||||||
|
ModelData{
|
||||||
|
Name: name,
|
||||||
|
Path: dirPath,
|
||||||
|
Params: params,
|
||||||
|
Format: m,
|
||||||
|
},
|
||||||
|
}, nil
|
||||||
case "MistralForCausalLM":
|
case "MistralForCausalLM":
|
||||||
return &MistralModel{
|
return &MistralModel{
|
||||||
ModelData{
|
ModelData{
|
||||||
|
|
Loading…
Reference in a new issue