override ggml-metal if the file is different

This commit is contained in:
Michael Yang 2023-08-02 12:50:30 -07:00
parent 6fbb380076
commit 0e79e52ddd

View file

@ -1,6 +1,8 @@
package llama
import (
"bytes"
"crypto/sha256"
"errors"
"io"
"log"
@ -27,26 +29,51 @@ func initBackend() error {
}
metal := filepath.Join(filepath.Dir(exec), "ggml-metal.metal")
if _, err := os.Stat(metal); err != nil {
if !errors.Is(err, os.ErrNotExist) {
return err
}
fi, err := os.Stat(metal)
if err != nil && !errors.Is(err, os.ErrNotExist) {
return err
}
dst, err := os.Create(filepath.Join(filepath.Dir(exec), "ggml-metal.metal"))
if fi != nil {
actual, err := os.Open(metal)
if err != nil {
return err
}
defer dst.Close()
src, err := fs.Open("ggml-metal.metal")
actualSum := sha256.New()
if _, err := io.Copy(actualSum, actual); err != nil {
return err
}
expect, err := fs.Open("ggml-metal.metal")
if err != nil {
return err
}
defer src.Close()
if _, err := io.Copy(dst, src); err != nil {
expectSum := sha256.New()
if _, err := io.Copy(expectSum, expect); err != nil {
return err
}
if bytes.Equal(actualSum.Sum(nil), expectSum.Sum(nil)) {
return nil
}
}
dst, err := os.Create(filepath.Join(filepath.Dir(exec), "ggml-metal.metal"))
if err != nil {
return err
}
defer dst.Close()
src, err := fs.Open("ggml-metal.metal")
if err != nil {
return err
}
defer src.Close()
if _, err := io.Copy(dst, src); err != nil {
return err
}
return nil