ollama/convert/reader_torch.go

48 lines
818 B
Go
Raw Normal View History

2024-05-31 20:00:49 -07:00
package convert
import (
"io"
2024-06-29 16:53:59 -07:00
"io/fs"
2024-05-31 20:00:49 -07:00
"github.com/nlpodyssey/gopickle/pytorch"
"github.com/nlpodyssey/gopickle/types"
)
2024-06-29 16:53:59 -07:00
func parseTorch(fsys fs.FS, ps ...string) ([]Tensor, error) {
2024-05-31 20:00:49 -07:00
var ts []Tensor
for _, p := range ps {
pt, err := pytorch.Load(p)
if err != nil {
return nil, err
}
for _, k := range pt.(*types.Dict).Keys() {
t := pt.(*types.Dict).MustGet(k)
var shape []uint64
for dim := range t.(*pytorch.Tensor).Size {
shape = append(shape, uint64(dim))
}
ts = append(ts, torch{
storage: t.(*pytorch.Tensor).Source,
tensorBase: &tensorBase{
name: k.(string),
shape: shape,
},
})
}
}
return ts, nil
}
type torch struct {
storage pytorch.StorageInterface
*tensorBase
}
func (pt torch) WriteTo(w io.Writer) (int64, error) {
return 0, nil
}