2023-07-21 13:33:56 -07:00
|
|
|
package llm
|
2023-07-26 11:50:29 -07:00
|
|
|
|
|
|
|
import (
|
2023-08-02 12:50:30 -07:00
|
|
|
"bytes"
|
|
|
|
"crypto/sha256"
|
2023-07-26 11:50:29 -07:00
|
|
|
"errors"
|
|
|
|
"io"
|
|
|
|
"log"
|
|
|
|
"os"
|
|
|
|
"path/filepath"
|
|
|
|
)
|
|
|
|
|
|
|
|
func init() {
|
|
|
|
if err := initBackend(); err != nil {
|
|
|
|
log.Printf("WARNING: GPU could not be initialized correctly: %v", err)
|
|
|
|
log.Printf("WARNING: falling back to CPU")
|
|
|
|
}
|
|
|
|
}
|
|
|
|
|
|
|
|
func initBackend() error {
|
|
|
|
exec, err := os.Executable()
|
|
|
|
if err != nil {
|
|
|
|
return err
|
|
|
|
}
|
|
|
|
|
|
|
|
exec, err = filepath.EvalSymlinks(exec)
|
|
|
|
if err != nil {
|
|
|
|
return err
|
|
|
|
}
|
|
|
|
|
|
|
|
metal := filepath.Join(filepath.Dir(exec), "ggml-metal.metal")
|
2023-08-02 12:50:30 -07:00
|
|
|
fi, err := os.Stat(metal)
|
|
|
|
if err != nil && !errors.Is(err, os.ErrNotExist) {
|
|
|
|
return err
|
|
|
|
}
|
|
|
|
|
|
|
|
if fi != nil {
|
|
|
|
actual, err := os.Open(metal)
|
|
|
|
if err != nil {
|
2023-07-26 11:50:29 -07:00
|
|
|
return err
|
|
|
|
}
|
2023-08-14 16:08:02 -07:00
|
|
|
defer actual.Close()
|
2023-07-26 11:50:29 -07:00
|
|
|
|
2023-08-02 12:50:30 -07:00
|
|
|
actualSum := sha256.New()
|
|
|
|
if _, err := io.Copy(actualSum, actual); err != nil {
|
2023-07-26 11:50:29 -07:00
|
|
|
return err
|
|
|
|
}
|
|
|
|
|
2023-08-02 12:50:30 -07:00
|
|
|
expect, err := fs.Open("ggml-metal.metal")
|
2023-07-26 11:50:29 -07:00
|
|
|
if err != nil {
|
|
|
|
return err
|
|
|
|
}
|
|
|
|
|
2023-08-02 12:50:30 -07:00
|
|
|
expectSum := sha256.New()
|
|
|
|
if _, err := io.Copy(expectSum, expect); err != nil {
|
2023-07-26 11:50:29 -07:00
|
|
|
return err
|
|
|
|
}
|
2023-08-02 12:50:30 -07:00
|
|
|
|
|
|
|
if bytes.Equal(actualSum.Sum(nil), expectSum.Sum(nil)) {
|
|
|
|
return nil
|
|
|
|
}
|
|
|
|
}
|
|
|
|
|
|
|
|
dst, err := os.Create(filepath.Join(filepath.Dir(exec), "ggml-metal.metal"))
|
|
|
|
if err != nil {
|
|
|
|
return err
|
|
|
|
}
|
|
|
|
defer dst.Close()
|
|
|
|
|
|
|
|
src, err := fs.Open("ggml-metal.metal")
|
|
|
|
if err != nil {
|
|
|
|
return err
|
|
|
|
}
|
|
|
|
defer src.Close()
|
|
|
|
|
|
|
|
if _, err := io.Copy(dst, src); err != nil {
|
|
|
|
return err
|
2023-07-26 11:50:29 -07:00
|
|
|
}
|
|
|
|
|
|
|
|
return nil
|
|
|
|
}
|