add safetensors version
This commit is contained in:
parent
d88582dffd
commit
4730762e5c
2 changed files with 20 additions and 4 deletions
|
@ -20,7 +20,7 @@ type LlamaModel struct {
|
|||
ModelData
|
||||
}
|
||||
|
||||
func llamaLayerHandler(w io.Writer, r torchWriterTo) error {
|
||||
func llamaTorchLayerHandler(w io.Writer, r torchWriterTo) error {
|
||||
slog.Debug(fmt.Sprintf("repacking layer '%s'", r.t.Name))
|
||||
|
||||
data := r.storage.(*pytorch.HalfStorage).Data
|
||||
|
@ -105,9 +105,16 @@ func (m *LlamaModel) GetTensors() error {
|
|||
matches := re.FindAllStringSubmatch(l.Name, -1)
|
||||
if len(matches) > 0 {
|
||||
slog.Debug(fmt.Sprintf("setting handler for: %s", l.Name))
|
||||
wt := l.WriterTo.(torchWriterTo)
|
||||
wt.handler = llamaLayerHandler
|
||||
l.WriterTo = wt
|
||||
switch l.WriterTo.(type) {
|
||||
case torchWriterTo:
|
||||
wt := l.WriterTo.(torchWriterTo)
|
||||
wt.handler = llamaTorchLayerHandler
|
||||
l.WriterTo = wt
|
||||
case safetensorWriterTo:
|
||||
wt := l.WriterTo.(safetensorWriterTo)
|
||||
wt.handler = mistralLayerHandler
|
||||
l.WriterTo = wt
|
||||
}
|
||||
}
|
||||
m.Tensors = append(m.Tensors, l)
|
||||
}
|
||||
|
|
|
@ -281,6 +281,15 @@ func (m *SafetensorFormat) GetModelArch(name, dirPath string, params *Params) (M
|
|||
return nil, fmt.Errorf("No architecture specified to convert")
|
||||
case 1:
|
||||
switch params.Architectures[0] {
|
||||
case "LlamaForCausalLM":
|
||||
return &LlamaModel{
|
||||
ModelData{
|
||||
Name: name,
|
||||
Path: dirPath,
|
||||
Params: params,
|
||||
Format: m,
|
||||
},
|
||||
}, nil
|
||||
case "MistralForCausalLM":
|
||||
return &MistralModel{
|
||||
ModelData{
|
||||
|
|
Loading…
Add table
Reference in a new issue