From 4730762e5c9453f304aa456b549530e165ff1936 Mon Sep 17 00:00:00 2001 From: Patrick Devine Date: Wed, 24 Apr 2024 18:32:01 -0700 Subject: [PATCH] add safetensors version --- convert/llama.go | 15 +++++++++++---- convert/safetensors.go | 9 +++++++++ 2 files changed, 20 insertions(+), 4 deletions(-) diff --git a/convert/llama.go b/convert/llama.go index fb576e2e..5dfb8d7d 100644 --- a/convert/llama.go +++ b/convert/llama.go @@ -20,7 +20,7 @@ type LlamaModel struct { ModelData } -func llamaLayerHandler(w io.Writer, r torchWriterTo) error { +func llamaTorchLayerHandler(w io.Writer, r torchWriterTo) error { slog.Debug(fmt.Sprintf("repacking layer '%s'", r.t.Name)) data := r.storage.(*pytorch.HalfStorage).Data @@ -105,9 +105,16 @@ func (m *LlamaModel) GetTensors() error { matches := re.FindAllStringSubmatch(l.Name, -1) if len(matches) > 0 { slog.Debug(fmt.Sprintf("setting handler for: %s", l.Name)) - wt := l.WriterTo.(torchWriterTo) - wt.handler = llamaLayerHandler - l.WriterTo = wt + switch l.WriterTo.(type) { + case torchWriterTo: + wt := l.WriterTo.(torchWriterTo) + wt.handler = llamaTorchLayerHandler + l.WriterTo = wt + case safetensorWriterTo: + wt := l.WriterTo.(safetensorWriterTo) + wt.handler = mistralLayerHandler + l.WriterTo = wt + } } m.Tensors = append(m.Tensors, l) } diff --git a/convert/safetensors.go b/convert/safetensors.go index 69424c4d..64aaf866 100644 --- a/convert/safetensors.go +++ b/convert/safetensors.go @@ -281,6 +281,15 @@ func (m *SafetensorFormat) GetModelArch(name, dirPath string, params *Params) (M return nil, fmt.Errorf("No architecture specified to convert") case 1: switch params.Architectures[0] { + case "LlamaForCausalLM": + return &LlamaModel{ + ModelData{ + Name: name, + Path: dirPath, + Params: params, + Format: m, + }, + }, nil case "MistralForCausalLM": return &MistralModel{ ModelData{