some changes for llama3
This commit is contained in:
parent
2f81b3dce2
commit
d88582dffd
2 changed files with 6 additions and 3 deletions
|
@ -77,7 +77,8 @@ func GetModelFormat(dirname string) (ModelFormat, error) {
|
|||
slog.Debug(fmt.Sprintf("file = %s", fn))
|
||||
if strings.HasSuffix(fn, ".safetensors") {
|
||||
return &SafetensorFormat{}, nil
|
||||
} else if strings.HasSuffix(fn, ".bin") {
|
||||
//} else if strings.HasSuffix(fn, ".bin") {
|
||||
} else if strings.HasSuffix(fn, ".pth") {
|
||||
slog.Debug("model is torch")
|
||||
return &TorchFormat{}, nil
|
||||
}
|
||||
|
|
|
@ -33,7 +33,8 @@ type TorchFormat struct{}
|
|||
func (tf *TorchFormat) GetTensors(dirpath string, params *Params) ([]llm.Tensor, error) {
|
||||
slog.Debug("getting torch tensors")
|
||||
|
||||
files, err := filepath.Glob(filepath.Join(dirpath, "pytorch_model-*.bin"))
|
||||
//files, err := filepath.Glob(filepath.Join(dirpath, "pytorch_model-*.bin"))
|
||||
files, err := filepath.Glob(filepath.Join(dirpath, "consolidatedr.*.pth"))
|
||||
if err != nil {
|
||||
slog.Error("didn't find any torch files")
|
||||
return nil, err
|
||||
|
@ -120,7 +121,7 @@ func getAltParams(dirpath string) (*Params, error) {
|
|||
AttentionHeads int `json:"n_heads"`
|
||||
KeyValHeads int `json:"n_kv_heads"`
|
||||
HiddenLayers int `json:"n_layers"`
|
||||
RopeTheta int `json:"rope_theta"`
|
||||
RopeTheta float64 `json:"rope_theta"`
|
||||
NormEPS float64 `json:"norm_eps"`
|
||||
}
|
||||
|
||||
|
@ -133,6 +134,7 @@ func getAltParams(dirpath string) (*Params, error) {
|
|||
}
|
||||
|
||||
params := &Params{
|
||||
Architectures: []string{"LlamaForCausalLM"},
|
||||
HiddenSize: tparams.HiddenSize,
|
||||
AttentionHeads: tparams.AttentionHeads,
|
||||
KeyValHeads: tparams.KeyValHeads,
|
||||
|
|
Loading…
Reference in a new issue