Merge pull request #260 from jmorganca/embed-ggml-metal

override ggml-metal if the file is different
This commit is contained in:
Michael Yang 2023-08-02 13:01:46 -07:00 committed by GitHub
commit cc509a994e
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23

View file

@ -1,6 +1,8 @@
package llama package llama
import ( import (
"bytes"
"crypto/sha256"
"errors" "errors"
"io" "io"
"log" "log"
@ -27,26 +29,51 @@ func initBackend() error {
} }
metal := filepath.Join(filepath.Dir(exec), "ggml-metal.metal") metal := filepath.Join(filepath.Dir(exec), "ggml-metal.metal")
if _, err := os.Stat(metal); err != nil { fi, err := os.Stat(metal)
if !errors.Is(err, os.ErrNotExist) { if err != nil && !errors.Is(err, os.ErrNotExist) {
return err return err
} }
dst, err := os.Create(filepath.Join(filepath.Dir(exec), "ggml-metal.metal")) if fi != nil {
actual, err := os.Open(metal)
if err != nil { if err != nil {
return err return err
} }
defer dst.Close()
src, err := fs.Open("ggml-metal.metal") actualSum := sha256.New()
if _, err := io.Copy(actualSum, actual); err != nil {
return err
}
expect, err := fs.Open("ggml-metal.metal")
if err != nil { if err != nil {
return err return err
} }
defer src.Close()
if _, err := io.Copy(dst, src); err != nil { expectSum := sha256.New()
if _, err := io.Copy(expectSum, expect); err != nil {
return err return err
} }
if bytes.Equal(actualSum.Sum(nil), expectSum.Sum(nil)) {
return nil
}
}
dst, err := os.Create(filepath.Join(filepath.Dir(exec), "ggml-metal.metal"))
if err != nil {
return err
}
defer dst.Close()
src, err := fs.Open("ggml-metal.metal")
if err != nil {
return err
}
defer src.Close()
if _, err := io.Copy(dst, src); err != nil {
return err
} }
return nil return nil