some changes for llama3

This commit is contained in:
Patrick Devine 2024-04-18 16:00:20 -07:00 committed by Michael Yang
parent 2f81b3dce2
commit d88582dffd
2 changed files with 6 additions and 3 deletions

View file

@ -77,7 +77,8 @@ func GetModelFormat(dirname string) (ModelFormat, error) {
slog.Debug(fmt.Sprintf("file = %s", fn)) slog.Debug(fmt.Sprintf("file = %s", fn))
if strings.HasSuffix(fn, ".safetensors") { if strings.HasSuffix(fn, ".safetensors") {
return &SafetensorFormat{}, nil return &SafetensorFormat{}, nil
} else if strings.HasSuffix(fn, ".bin") { //} else if strings.HasSuffix(fn, ".bin") {
} else if strings.HasSuffix(fn, ".pth") {
slog.Debug("model is torch") slog.Debug("model is torch")
return &TorchFormat{}, nil return &TorchFormat{}, nil
} }

View file

@ -33,7 +33,8 @@ type TorchFormat struct{}
func (tf *TorchFormat) GetTensors(dirpath string, params *Params) ([]llm.Tensor, error) { func (tf *TorchFormat) GetTensors(dirpath string, params *Params) ([]llm.Tensor, error) {
slog.Debug("getting torch tensors") slog.Debug("getting torch tensors")
files, err := filepath.Glob(filepath.Join(dirpath, "pytorch_model-*.bin")) //files, err := filepath.Glob(filepath.Join(dirpath, "pytorch_model-*.bin"))
files, err := filepath.Glob(filepath.Join(dirpath, "consolidatedr.*.pth"))
if err != nil { if err != nil {
slog.Error("didn't find any torch files") slog.Error("didn't find any torch files")
return nil, err return nil, err
@ -120,7 +121,7 @@ func getAltParams(dirpath string) (*Params, error) {
AttentionHeads int `json:"n_heads"` AttentionHeads int `json:"n_heads"`
KeyValHeads int `json:"n_kv_heads"` KeyValHeads int `json:"n_kv_heads"`
HiddenLayers int `json:"n_layers"` HiddenLayers int `json:"n_layers"`
RopeTheta int `json:"rope_theta"` RopeTheta float64 `json:"rope_theta"`
NormEPS float64 `json:"norm_eps"` NormEPS float64 `json:"norm_eps"`
} }
@ -133,6 +134,7 @@ func getAltParams(dirpath string) (*Params, error) {
} }
params := &Params{ params := &Params{
Architectures: []string{"LlamaForCausalLM"},
HiddenSize: tparams.HiddenSize, HiddenSize: tparams.HiddenSize,
AttentionHeads: tparams.AttentionHeads, AttentionHeads: tparams.AttentionHeads,
KeyValHeads: tparams.KeyValHeads, KeyValHeads: tparams.KeyValHeads,