Always dynamically load the llm server library

This switches darwin to dynamic loading, and refactors the code now that no
static linking of the library is used on any platform
This commit is contained in:
Daniel Hiltgen 2024-01-09 20:29:58 -08:00
parent d88c527be3
commit 39928a42e8
23 changed files with 290 additions and 463 deletions

2
go.mod
View file

@ -45,7 +45,7 @@ require (
golang.org/x/crypto v0.14.0 golang.org/x/crypto v0.14.0
golang.org/x/exp v0.0.0-20230817173708-d852ddb80c63 golang.org/x/exp v0.0.0-20230817173708-d852ddb80c63
golang.org/x/net v0.17.0 // indirect golang.org/x/net v0.17.0 // indirect
golang.org/x/sys v0.13.0 // indirect golang.org/x/sys v0.13.0
golang.org/x/term v0.13.0 golang.org/x/term v0.13.0
golang.org/x/text v0.13.0 // indirect golang.org/x/text v0.13.0 // indirect
google.golang.org/protobuf v1.30.0 // indirect google.golang.org/protobuf v1.30.0 // indirect

21
gpu/cpu_common.go Normal file
View file

@ -0,0 +1,21 @@
package gpu
import (
"log"
"golang.org/x/sys/cpu"
)
func GetCPUVariant() string {
if cpu.X86.HasAVX2 {
log.Printf("CPU has AVX2")
return "avx2"
}
if cpu.X86.HasAVX {
log.Printf("CPU has AVX")
return "avx"
}
log.Printf("CPU does not have vector extensions")
// else LCD
return ""
}

View file

@ -32,8 +32,15 @@ func CheckVRAM() (int64, error) {
func GetGPUInfo() GpuInfo { func GetGPUInfo() GpuInfo {
mem, _ := getCPUMem() mem, _ := getCPUMem()
if runtime.GOARCH == "amd64" {
return GpuInfo{
Library: "default",
Variant: GetCPUVariant(),
memInfo: mem,
}
}
return GpuInfo{ return GpuInfo{
Library: "default", Library: "metal",
memInfo: mem, memInfo: mem,
} }
} }
@ -45,12 +52,3 @@ func getCPUMem() (memInfo, error) {
DeviceCount: 0, DeviceCount: 0,
}, nil }, nil
} }
func nativeInit() error {
return nil
}
func GetCPUVariant() string {
// We don't yet have CPU based builds for Darwin...
return ""
}

View file

@ -9,7 +9,7 @@ import (
func TestBasicGetGPUInfo(t *testing.T) { func TestBasicGetGPUInfo(t *testing.T) {
info := GetGPUInfo() info := GetGPUInfo()
assert.Contains(t, "cuda rocm cpu default", info.Library) assert.Contains(t, "cuda rocm cpu metal", info.Library)
switch runtime.GOOS { switch runtime.GOOS {
case "darwin": case "darwin":

View file

@ -1,4 +1,4 @@
#include "dynamic_shim.h" #include "dyn_ext_server.h"
#include <stdio.h> #include <stdio.h>
#include <string.h> #include <string.h>
@ -33,7 +33,7 @@ inline char *LOAD_ERR() {
#define UNLOAD_LIBRARY(handle) dlclose(handle) #define UNLOAD_LIBRARY(handle) dlclose(handle)
#endif #endif
void dynamic_shim_init(const char *libPath, struct dynamic_llama_server *s, void dyn_init(const char *libPath, struct dynamic_llama_server *s,
ext_server_resp_t *err) { ext_server_resp_t *err) {
int i = 0; int i = 0;
struct lookup { struct lookup {
@ -83,63 +83,63 @@ void dynamic_shim_init(const char *libPath, struct dynamic_llama_server *s,
} }
} }
inline void dynamic_shim_llama_server_init(struct dynamic_llama_server s, inline void dyn_llama_server_init(struct dynamic_llama_server s,
ext_server_params_t *sparams, ext_server_params_t *sparams,
ext_server_resp_t *err) { ext_server_resp_t *err) {
s.llama_server_init(sparams, err); s.llama_server_init(sparams, err);
} }
inline void dynamic_shim_llama_server_start(struct dynamic_llama_server s) { inline void dyn_llama_server_start(struct dynamic_llama_server s) {
s.llama_server_start(); s.llama_server_start();
} }
inline void dynamic_shim_llama_server_stop(struct dynamic_llama_server s) { inline void dyn_llama_server_stop(struct dynamic_llama_server s) {
s.llama_server_stop(); s.llama_server_stop();
} }
inline void dynamic_shim_llama_server_completion(struct dynamic_llama_server s, inline void dyn_llama_server_completion(struct dynamic_llama_server s,
const char *json_req, const char *json_req,
ext_server_resp_t *resp) { ext_server_resp_t *resp) {
s.llama_server_completion(json_req, resp); s.llama_server_completion(json_req, resp);
} }
inline void dynamic_shim_llama_server_completion_next_result( inline void dyn_llama_server_completion_next_result(
struct dynamic_llama_server s, const int task_id, struct dynamic_llama_server s, const int task_id,
ext_server_task_result_t *result) { ext_server_task_result_t *result) {
s.llama_server_completion_next_result(task_id, result); s.llama_server_completion_next_result(task_id, result);
} }
inline void dynamic_shim_llama_server_completion_cancel( inline void dyn_llama_server_completion_cancel(
struct dynamic_llama_server s, const int task_id, ext_server_resp_t *err) { struct dynamic_llama_server s, const int task_id, ext_server_resp_t *err) {
s.llama_server_completion_cancel(task_id, err); s.llama_server_completion_cancel(task_id, err);
} }
inline void dynamic_shim_llama_server_release_task_result( inline void dyn_llama_server_release_task_result(
struct dynamic_llama_server s, ext_server_task_result_t *result) { struct dynamic_llama_server s, ext_server_task_result_t *result) {
s.llama_server_release_task_result(result); s.llama_server_release_task_result(result);
} }
inline void dynamic_shim_llama_server_tokenize(struct dynamic_llama_server s, inline void dyn_llama_server_tokenize(struct dynamic_llama_server s,
const char *json_req, const char *json_req,
char **json_resp, char **json_resp,
ext_server_resp_t *err) { ext_server_resp_t *err) {
s.llama_server_tokenize(json_req, json_resp, err); s.llama_server_tokenize(json_req, json_resp, err);
} }
inline void dynamic_shim_llama_server_detokenize(struct dynamic_llama_server s, inline void dyn_llama_server_detokenize(struct dynamic_llama_server s,
const char *json_req, const char *json_req,
char **json_resp, char **json_resp,
ext_server_resp_t *err) { ext_server_resp_t *err) {
s.llama_server_detokenize(json_req, json_resp, err); s.llama_server_detokenize(json_req, json_resp, err);
} }
inline void dynamic_shim_llama_server_embedding(struct dynamic_llama_server s, inline void dyn_llama_server_embedding(struct dynamic_llama_server s,
const char *json_req, const char *json_req,
char **json_resp, char **json_resp,
ext_server_resp_t *err) { ext_server_resp_t *err) {
s.llama_server_embedding(json_req, json_resp, err); s.llama_server_embedding(json_req, json_resp, err);
} }
inline void dynamic_shim_llama_server_release_json_resp( inline void dyn_llama_server_release_json_resp(
struct dynamic_llama_server s, char **json_resp) { struct dynamic_llama_server s, char **json_resp) {
s.llama_server_release_json_resp(json_resp); s.llama_server_release_json_resp(json_resp);
} }

View file

@ -10,25 +10,25 @@ package llm
#cgo darwin CPPFLAGS: -DGGML_USE_METAL -DGGML_METAL_NDEBUG #cgo darwin CPPFLAGS: -DGGML_USE_METAL -DGGML_METAL_NDEBUG
#cgo darwin LDFLAGS: -lc++ -framework Accelerate #cgo darwin LDFLAGS: -lc++ -framework Accelerate
#cgo darwin LDFLAGS: -framework Foundation -framework Metal -framework MetalKit -framework MetalPerformanceShaders #cgo darwin LDFLAGS: -framework Foundation -framework Metal -framework MetalKit -framework MetalPerformanceShaders
#cgo darwin LDFLAGS: ${SRCDIR}/llama.cpp/build/darwin/metal/lib/libcommon.a
#cgo darwin LDFLAGS: ${SRCDIR}/llama.cpp/build/darwin/metal/lib/libext_server.a
#cgo darwin LDFLAGS: ${SRCDIR}/llama.cpp/build/darwin/metal/lib/libllama.a
#cgo darwin LDFLAGS: ${SRCDIR}/llama.cpp/build/darwin/metal/lib/libggml_static.a
#cgo linux CFLAGS: -D_GNU_SOURCE #cgo linux CFLAGS: -D_GNU_SOURCE
#cgo linux LDFLAGS: -lrt -ldl -lstdc++ -lm #cgo linux LDFLAGS: -lrt -ldl -lstdc++ -lm
#cgo linux windows LDFLAGS: -lpthread #cgo linux windows LDFLAGS: -lpthread
#include <stdlib.h> #include <stdlib.h>
#include "ext_server.h" #include "dyn_ext_server.h"
*/ */
import "C" import "C"
import ( import (
"bytes" "bytes"
"context" "context"
"encoding/json" "encoding/json"
"fmt" "fmt"
"log" "log"
"os"
"path/filepath"
"runtime"
"strings" "strings"
"sync" "sync"
"time" "time"
@ -37,21 +37,9 @@ import (
"github.com/jmorganca/ollama/api" "github.com/jmorganca/ollama/api"
) )
// TODO switch Linux to always be dynamic type dynExtServer struct {
// If that works out, then look at the impact of doing the same for Mac s C.struct_dynamic_llama_server
type extServer interface { options api.Options
LLM
llama_server_init(sparams *C.ext_server_params_t, err *C.ext_server_resp_t)
llama_server_start()
llama_server_stop()
llama_server_completion(json_req *C.char, resp *C.ext_server_resp_t)
llama_server_completion_next_result(task_id C.int, resp *C.ext_server_task_result_t)
llama_server_completion_cancel(task_id C.int, err *C.ext_server_resp_t)
llama_server_release_task_result(result *C.ext_server_task_result_t)
llama_server_tokenize(json_req *C.char, json_resp **C.char, err *C.ext_server_resp_t)
llama_server_detokenize(json_req *C.char, json_resp **C.char, err *C.ext_server_resp_t)
llama_server_embedding(json_req *C.char, json_resp **C.char, err *C.ext_server_resp_t)
llama_server_release_json_resp(json_resp **C.char)
} }
// Note: current implementation does not support concurrent instantiations // Note: current implementation does not support concurrent instantiations
@ -76,11 +64,30 @@ func extServerResponseToErr(resp C.ext_server_resp_t) error {
return fmt.Errorf(C.GoString(resp.msg)) return fmt.Errorf(C.GoString(resp.msg))
} }
func newExtServer(server extServer, model string, adapters, projectors []string, opts api.Options) (extServer, error) { // Note: current implementation does not support concurrent instantiations
var llm *dynExtServer
func newDynExtServer(library, model string, adapters, projectors []string, opts api.Options) (LLM, error) {
if !mutex.TryLock() { if !mutex.TryLock() {
log.Printf("concurrent llm servers not yet supported, waiting for prior server to complete") log.Printf("concurrent llm servers not yet supported, waiting for prior server to complete")
mutex.Lock() mutex.Lock()
} }
updatePath(filepath.Dir(library))
libPath := C.CString(library)
defer C.free(unsafe.Pointer(libPath))
resp := newExtServerResp(128)
defer freeExtServerResp(resp)
var srv C.struct_dynamic_llama_server
C.dyn_init(libPath, &srv, &resp)
if resp.id < 0 {
mutex.Unlock()
return nil, fmt.Errorf("Unable to load dynamic library: %s", C.GoString(resp.msg))
}
llm = &dynExtServer{
s: srv,
options: opts,
}
log.Printf("Loading Dynamic llm server: %s", library)
var sparams C.ext_server_params_t var sparams C.ext_server_params_t
sparams.model = C.CString(model) sparams.model = C.CString(model)
@ -129,20 +136,20 @@ func newExtServer(server extServer, model string, adapters, projectors []string,
sparams.n_threads = C.uint(opts.NumThread) sparams.n_threads = C.uint(opts.NumThread)
log.Printf("Initializing internal llama server") log.Printf("Initializing llama server")
resp := newExtServerResp(128) initResp := newExtServerResp(128)
defer freeExtServerResp(resp) defer freeExtServerResp(initResp)
server.llama_server_init(&sparams, &resp) C.dyn_llama_server_init(llm.s, &sparams, &initResp)
if resp.id < 0 { if initResp.id < 0 {
return nil, extServerResponseToErr(resp) return nil, extServerResponseToErr(initResp)
} }
log.Printf("Starting internal llama main loop") log.Printf("Starting llama main loop")
server.llama_server_start() C.dyn_llama_server_start(llm.s)
return server, nil return llm, nil
} }
func predict(ctx context.Context, llm extServer, predict PredictOpts, fn func(PredictResult)) error { func (llm *dynExtServer) Predict(ctx context.Context, predict PredictOpts, fn func(PredictResult)) error {
resp := newExtServerResp(128) resp := newExtServerResp(128)
defer freeExtServerResp(resp) defer freeExtServerResp(resp)
var imageData []ImageData var imageData []ImageData
@ -200,7 +207,7 @@ func predict(ctx context.Context, llm extServer, predict PredictOpts, fn func(Pr
req := C.CString(buffer.String()) req := C.CString(buffer.String())
defer C.free(unsafe.Pointer(req)) defer C.free(unsafe.Pointer(req))
llm.llama_server_completion(req, &resp) C.dyn_llama_server_completion(llm.s, req, &resp)
if resp.id < 0 { if resp.id < 0 {
return extServerResponseToErr(resp) return extServerResponseToErr(resp)
} }
@ -211,7 +218,7 @@ func predict(ctx context.Context, llm extServer, predict PredictOpts, fn func(Pr
select { select {
case <-ctx.Done(): case <-ctx.Done():
// This handles the request cancellation // This handles the request cancellation
llm.llama_server_completion_cancel(resp.id, &resp) C.dyn_llama_server_completion_cancel(llm.s, resp.id, &resp)
if resp.id < 0 { if resp.id < 0 {
return extServerResponseToErr(resp) return extServerResponseToErr(resp)
} else { } else {
@ -219,13 +226,13 @@ func predict(ctx context.Context, llm extServer, predict PredictOpts, fn func(Pr
} }
default: default:
var result C.ext_server_task_result_t var result C.ext_server_task_result_t
llm.llama_server_completion_next_result(resp.id, &result) C.dyn_llama_server_completion_next_result(llm.s, resp.id, &result)
json_resp := C.GoString(result.json_resp) json_resp := C.GoString(result.json_resp)
llm.llama_server_release_task_result(&result) C.dyn_llama_server_release_task_result(llm.s, &result)
var p prediction var p prediction
if err := json.Unmarshal([]byte(json_resp), &p); err != nil { if err := json.Unmarshal([]byte(json_resp), &p); err != nil {
llm.llama_server_completion_cancel(resp.id, &resp) C.dyn_llama_server_completion_cancel(llm.s, resp.id, &resp)
if resp.id < 0 { if resp.id < 0 {
return fmt.Errorf("error unmarshaling llm prediction response: %w and cancel %s", err, C.GoString(resp.msg)) return fmt.Errorf("error unmarshaling llm prediction response: %w and cancel %s", err, C.GoString(resp.msg))
} else { } else {
@ -266,7 +273,7 @@ func predict(ctx context.Context, llm extServer, predict PredictOpts, fn func(Pr
return fmt.Errorf("max retries exceeded") return fmt.Errorf("max retries exceeded")
} }
func encode(llm extServer, ctx context.Context, prompt string) ([]int, error) { func (llm *dynExtServer) Encode(ctx context.Context, prompt string) ([]int, error) {
data, err := json.Marshal(TokenizeRequest{Content: prompt}) data, err := json.Marshal(TokenizeRequest{Content: prompt})
if err != nil { if err != nil {
return nil, fmt.Errorf("marshaling encode data: %w", err) return nil, fmt.Errorf("marshaling encode data: %w", err)
@ -276,11 +283,11 @@ func encode(llm extServer, ctx context.Context, prompt string) ([]int, error) {
var json_resp *C.char var json_resp *C.char
resp := newExtServerResp(128) resp := newExtServerResp(128)
defer freeExtServerResp(resp) defer freeExtServerResp(resp)
llm.llama_server_tokenize(req, &json_resp, &resp) C.dyn_llama_server_tokenize(llm.s, req, &json_resp, &resp)
if resp.id < 0 { if resp.id < 0 {
return nil, extServerResponseToErr(resp) return nil, extServerResponseToErr(resp)
} }
defer llm.llama_server_release_json_resp(&json_resp) defer C.dyn_llama_server_release_json_resp(llm.s, &json_resp)
var encoded TokenizeResponse var encoded TokenizeResponse
if err2 := json.Unmarshal([]byte(C.GoString(json_resp)), &encoded); err2 != nil { if err2 := json.Unmarshal([]byte(C.GoString(json_resp)), &encoded); err2 != nil {
@ -290,7 +297,7 @@ func encode(llm extServer, ctx context.Context, prompt string) ([]int, error) {
return encoded.Tokens, err return encoded.Tokens, err
} }
func decode(llm extServer, ctx context.Context, tokens []int) (string, error) { func (llm *dynExtServer) Decode(ctx context.Context, tokens []int) (string, error) {
if len(tokens) == 0 { if len(tokens) == 0 {
return "", nil return "", nil
} }
@ -304,11 +311,11 @@ func decode(llm extServer, ctx context.Context, tokens []int) (string, error) {
var json_resp *C.char var json_resp *C.char
resp := newExtServerResp(128) resp := newExtServerResp(128)
defer freeExtServerResp(resp) defer freeExtServerResp(resp)
llm.llama_server_detokenize(req, &json_resp, &resp) C.dyn_llama_server_detokenize(llm.s, req, &json_resp, &resp)
if resp.id < 0 { if resp.id < 0 {
return "", extServerResponseToErr(resp) return "", extServerResponseToErr(resp)
} }
defer llm.llama_server_release_json_resp(&json_resp) defer C.dyn_llama_server_release_json_resp(llm.s, &json_resp)
var decoded DetokenizeResponse var decoded DetokenizeResponse
if err2 := json.Unmarshal([]byte(C.GoString(json_resp)), &decoded); err2 != nil { if err2 := json.Unmarshal([]byte(C.GoString(json_resp)), &decoded); err2 != nil {
@ -318,7 +325,7 @@ func decode(llm extServer, ctx context.Context, tokens []int) (string, error) {
return decoded.Content, err return decoded.Content, err
} }
func embedding(llm extServer, ctx context.Context, input string) ([]float64, error) { func (llm *dynExtServer) Embedding(ctx context.Context, input string) ([]float64, error) {
data, err := json.Marshal(TokenizeRequest{Content: input}) data, err := json.Marshal(TokenizeRequest{Content: input})
if err != nil { if err != nil {
return nil, fmt.Errorf("error marshaling embed data: %w", err) return nil, fmt.Errorf("error marshaling embed data: %w", err)
@ -329,11 +336,11 @@ func embedding(llm extServer, ctx context.Context, input string) ([]float64, err
var json_resp *C.char var json_resp *C.char
resp := newExtServerResp(128) resp := newExtServerResp(128)
defer freeExtServerResp(resp) defer freeExtServerResp(resp)
llm.llama_server_embedding(req, &json_resp, &resp) C.dyn_llama_server_embedding(llm.s, req, &json_resp, &resp)
if resp.id < 0 { if resp.id < 0 {
return nil, extServerResponseToErr(resp) return nil, extServerResponseToErr(resp)
} }
defer llm.llama_server_release_json_resp(&json_resp) defer C.dyn_llama_server_release_json_resp(llm.s, &json_resp)
var embedding EmbeddingResponse var embedding EmbeddingResponse
if err := json.Unmarshal([]byte(C.GoString(json_resp)), &embedding); err != nil { if err := json.Unmarshal([]byte(C.GoString(json_resp)), &embedding); err != nil {
@ -343,7 +350,38 @@ func embedding(llm extServer, ctx context.Context, input string) ([]float64, err
return embedding.Embedding, nil return embedding.Embedding, nil
} }
func close(llm extServer) { func (llm *dynExtServer) Close() {
llm.llama_server_stop() C.dyn_llama_server_stop(llm.s)
mutex.Unlock() mutex.Unlock()
} }
func updatePath(dir string) {
if runtime.GOOS == "windows" {
tmpDir := filepath.Dir(dir)
pathComponents := strings.Split(os.Getenv("PATH"), ";")
i := 0
for _, comp := range pathComponents {
if strings.EqualFold(comp, dir) {
return
}
// Remove any other prior paths to our temp dir
if !strings.HasPrefix(strings.ToLower(comp), strings.ToLower(tmpDir)) {
pathComponents[i] = comp
i++
}
}
newPath := strings.Join(append([]string{dir}, pathComponents...), ";")
log.Printf("Updating PATH to %s", newPath)
os.Setenv("PATH", newPath)
} else {
pathComponents := strings.Split(os.Getenv("LD_LIBRARY_PATH"), ":")
for _, comp := range pathComponents {
if comp == dir {
return
}
}
newPath := strings.Join(append([]string{dir}, pathComponents...), ":")
log.Printf("Updating LD_LIBRARY_PATH to %s", newPath)
os.Setenv("LD_LIBRARY_PATH", newPath)
}
}

View file

@ -27,46 +27,46 @@ struct dynamic_llama_server {
void (*llama_server_release_json_resp)(char **json_resp); void (*llama_server_release_json_resp)(char **json_resp);
}; };
void dynamic_shim_init(const char *libPath, struct dynamic_llama_server *s, void dyn_init(const char *libPath, struct dynamic_llama_server *s,
ext_server_resp_t *err); ext_server_resp_t *err);
// No good way to call C function pointers from Go so inline the indirection // No good way to call C function pointers from Go so inline the indirection
void dynamic_shim_llama_server_init(struct dynamic_llama_server s, void dyn_llama_server_init(struct dynamic_llama_server s,
ext_server_params_t *sparams, ext_server_params_t *sparams,
ext_server_resp_t *err); ext_server_resp_t *err);
void dynamic_shim_llama_server_start(struct dynamic_llama_server s); void dyn_llama_server_start(struct dynamic_llama_server s);
void dynamic_shim_llama_server_stop(struct dynamic_llama_server s); void dyn_llama_server_stop(struct dynamic_llama_server s);
void dynamic_shim_llama_server_completion(struct dynamic_llama_server s, void dyn_llama_server_completion(struct dynamic_llama_server s,
const char *json_req, const char *json_req,
ext_server_resp_t *resp); ext_server_resp_t *resp);
void dynamic_shim_llama_server_completion_next_result( void dyn_llama_server_completion_next_result(
struct dynamic_llama_server s, const int task_id, struct dynamic_llama_server s, const int task_id,
ext_server_task_result_t *result); ext_server_task_result_t *result);
void dynamic_shim_llama_server_completion_cancel(struct dynamic_llama_server s, void dyn_llama_server_completion_cancel(struct dynamic_llama_server s,
const int task_id, const int task_id,
ext_server_resp_t *err); ext_server_resp_t *err);
void dynamic_shim_llama_server_release_task_result( void dyn_llama_server_release_task_result(
struct dynamic_llama_server s, ext_server_task_result_t *result); struct dynamic_llama_server s, ext_server_task_result_t *result);
void dynamic_shim_llama_server_tokenize(struct dynamic_llama_server s, void dyn_llama_server_tokenize(struct dynamic_llama_server s,
const char *json_req, char **json_resp, const char *json_req, char **json_resp,
ext_server_resp_t *err); ext_server_resp_t *err);
void dynamic_shim_llama_server_detokenize(struct dynamic_llama_server s, void dyn_llama_server_detokenize(struct dynamic_llama_server s,
const char *json_req, const char *json_req,
char **json_resp, char **json_resp,
ext_server_resp_t *err); ext_server_resp_t *err);
void dynamic_shim_llama_server_embedding(struct dynamic_llama_server s, void dyn_llama_server_embedding(struct dynamic_llama_server s,
const char *json_req, char **json_resp, const char *json_req, char **json_resp,
ext_server_resp_t *err); ext_server_resp_t *err);
void dynamic_shim_llama_server_release_json_resp(struct dynamic_llama_server s, void dyn_llama_server_release_json_resp(struct dynamic_llama_server s,
char **json_resp); char **json_resp);
#ifdef __cplusplus #ifdef __cplusplus

View file

@ -1,17 +0,0 @@
//go:build !darwin
package llm
import (
"fmt"
"github.com/jmorganca/ollama/api"
)
func newDefaultExtServer(model string, adapters, projectors []string, opts api.Options) (extServer, error) {
// On windows and linux we always load the llama.cpp libraries dynamically to avoid startup DLL dependencies
// This ensures we can update the PATH at runtime to get everything loaded
// This should never happen as we'll always try to load one or more cpu dynamic libaries before hitting default
return nil, fmt.Errorf("no available default llm library")
}

View file

@ -1,82 +0,0 @@
//go:build darwin
package llm
/*
#include <stdlib.h>
#include "ext_server.h"
*/
import "C"
import (
"context"
"github.com/jmorganca/ollama/api"
)
// TODO - explore shifting Darwin to a dynamic loading pattern for consistency with Linux and Windows
type llamaExtServer struct {
api.Options
}
func (llm *llamaExtServer) llama_server_init(sparams *C.ext_server_params_t, err *C.ext_server_resp_t) {
C.llama_server_init(sparams, err)
}
func (llm *llamaExtServer) llama_server_start() {
C.llama_server_start()
}
func (llm *llamaExtServer) llama_server_stop() {
C.llama_server_stop()
}
func (llm *llamaExtServer) llama_server_completion(json_req *C.char, resp *C.ext_server_resp_t) {
C.llama_server_completion(json_req, resp)
}
func (llm *llamaExtServer) llama_server_completion_next_result(task_id C.int, resp *C.ext_server_task_result_t) {
C.llama_server_completion_next_result(task_id, resp)
}
func (llm *llamaExtServer) llama_server_completion_cancel(task_id C.int, err *C.ext_server_resp_t) {
C.llama_server_completion_cancel(task_id, err)
}
func (llm *llamaExtServer) llama_server_release_task_result(result *C.ext_server_task_result_t) {
C.llama_server_release_task_result(result)
}
func (llm *llamaExtServer) llama_server_tokenize(json_req *C.char, json_resp **C.char, err *C.ext_server_resp_t) {
C.llama_server_tokenize(json_req, json_resp, err)
}
func (llm *llamaExtServer) llama_server_detokenize(json_req *C.char, json_resp **C.char, err *C.ext_server_resp_t) {
C.llama_server_detokenize(json_req, json_resp, err)
}
func (llm *llamaExtServer) llama_server_embedding(json_req *C.char, json_resp **C.char, err *C.ext_server_resp_t) {
C.llama_server_embedding(json_req, json_resp, err)
}
func (llm *llamaExtServer) llama_server_release_json_resp(json_resp **C.char) {
C.llama_server_release_json_resp(json_resp)
}
func newDefaultExtServer(model string, adapters, projectors []string, opts api.Options) (extServer, error) {
server := &llamaExtServer{opts}
return newExtServer(server, model, adapters, projectors, opts)
}
func (llm *llamaExtServer) Predict(ctx context.Context, pred PredictOpts, fn func(PredictResult)) error {
return predict(ctx, llm, pred, fn)
}
func (llm *llamaExtServer) Encode(ctx context.Context, prompt string) ([]int, error) {
return encode(llm, ctx, prompt)
}
func (llm *llamaExtServer) Decode(ctx context.Context, tokens []int) (string, error) {
return decode(llm, ctx, tokens)
}
func (llm *llamaExtServer) Embedding(ctx context.Context, input string) ([]float64, error) {
return embedding(llm, ctx, input)
}
func (llm *llamaExtServer) Close() {
close(llm)
}

View file

@ -29,4 +29,16 @@ git_module_setup
apply_patches apply_patches
build build
install install
gcc -fPIC -g -shared -o ${BUILD_DIR}/lib/libext_server.so \
-Wl,-force_load ${BUILD_DIR}/lib/libext_server.a \
${BUILD_DIR}/lib/libcommon.a \
${BUILD_DIR}/lib/libllama.a \
${BUILD_DIR}/lib/libggml_static.a \
-lpthread -ldl -lm -lc++ \
-framework Accelerate \
-framework Foundation \
-framework Metal \
-framework MetalKit \
-framework MetalPerformanceShaders
cleanup cleanup

View file

@ -104,12 +104,6 @@ if [ -z "${OLLAMA_SKIP_CPU_GENERATE}" ]; then
build build
install install
link_server_lib link_server_lib
gcc -fPIC -g -shared -o ${BUILD_DIR}/lib/libext_server.so \
-Wl,--whole-archive \
${BUILD_DIR}/lib/libext_server.a \
-Wl,--no-whole-archive \
${BUILD_DIR}/lib/libcommon.a \
${BUILD_DIR}/lib/libllama.a
fi fi
else else
echo "Skipping CPU generation step as requested" echo "Skipping CPU generation step as requested"

View file

@ -4,7 +4,7 @@ $ErrorActionPreference = "Stop"
function init_vars { function init_vars {
$script:llamacppDir = "../llama.cpp" $script:llamacppDir = "../llama.cpp"
$script:cmakeDefs = @("-DBUILD_SHARED_LIBS=on", "-DLLAMA_NATIVE=off", "-DLLAMA_F16C=off", "-DLLAMA_FMA=off", "-DLLAMA_AVX512=off", "-DLLAMA_AVX2=off", "-DLLAMA_AVX=on", "-A","x64") $script:cmakeDefs = @("-DBUILD_SHARED_LIBS=on", "-DLLAMA_NATIVE=off", "-A","x64")
$script:cmakeTargets = @("ggml", "ggml_static", "llama", "build_info", "common", "ext_server_shared", "llava_static") $script:cmakeTargets = @("ggml", "ggml_static", "llama", "build_info", "common", "ext_server_shared", "llava_static")
if ($env:CGO_CFLAGS -contains "-g") { if ($env:CGO_CFLAGS -contains "-g") {
$script:cmakeDefs += @("-DCMAKE_VERBOSE_MAKEFILE=on", "-DLLAMA_SERVER_VERBOSE=on") $script:cmakeDefs += @("-DCMAKE_VERBOSE_MAKEFILE=on", "-DLLAMA_SERVER_VERBOSE=on")
@ -63,16 +63,36 @@ init_vars
git_module_setup git_module_setup
apply_patches apply_patches
# first build CPU based # -DLLAMA_AVX -- 2011 Intel Sandy Bridge & AMD Bulldozer
$script:buildDir="${script:llamacppDir}/build/windows/cpu" # -DLLAMA_F16C -- 2012 Intel Ivy Bridge & AMD 2011 Bulldozer (No significant improvement over just AVX)
# -DLLAMA_AVX2 -- 2013 Intel Haswell & 2015 AMD Excavator / 2017 AMD Zen
# -DLLAMA_FMA (FMA3) -- 2013 Intel Haswell & 2012 AMD Piledriver
$script:commonCpuDefs = @("-DCMAKE_POSITION_INDEPENDENT_CODE=on", "-DLLAMA_NATIVE=off")
$script:cmakeDefs = $script:commonCpuDefs + @("-DLLAMA_AVX=off", "-DLLAMA_AVX2=off", "-DLLAMA_AVX512=off", "-DLLAMA_FMA=off", "-DLLAMA_F16C=off") + $script:cmakeDefs
$script:buildDir="${script:llamacppDir}/build/windows/cpu"
write-host "Building LCD CPU"
build
install
$script:cmakeDefs = $script:commonCpuDefs + @("-DLLAMA_AVX=on", "-DLLAMA_AVX2=off", "-DLLAMA_AVX512=off", "-DLLAMA_FMA=off", "-DLLAMA_F16C=off") + $script:cmakeDefs
$script:buildDir="${script:llamacppDir}/build/windows/cpu_avx"
write-host "Building AVX CPU"
build
install
$script:cmakeDefs = $script:commonCpuDefs + @("-DLLAMA_AVX=on", "-DLLAMA_AVX2=on", "-DLLAMA_AVX512=off", "-DLLAMA_FMA=on", "-DLLAMA_F16C=on") + $script:cmakeDefs
$script:buildDir="${script:llamacppDir}/build/windows/cpu_avx2"
write-host "Building AVX2 CPU"
build build
install install
# Then build cuda as a dynamically loaded library # Then build cuda as a dynamically loaded library
# TODO figure out how to detect cuda version
init_vars init_vars
$script:buildDir="${script:llamacppDir}/build/windows/cuda" $script:buildDir="${script:llamacppDir}/build/windows/cuda"
$script:cmakeDefs += @("-DLLAMA_CUBLAS=ON") $script:cmakeDefs += @("-DLLAMA_CUBLAS=ON", "-DLLAMA_AVX=on")
build build
install install

View file

@ -138,33 +138,30 @@ func Init(workdir string) error {
return nativeInit(workdir) return nativeInit(workdir)
} }
func newLlmServer(gpuInfo gpu.GpuInfo, model string, adapters, projectors []string, opts api.Options) (extServer, error) { func newLlmServer(gpuInfo gpu.GpuInfo, model string, adapters, projectors []string, opts api.Options) (LLM, error) {
shims := getShims(gpuInfo) dynLibs := getDynLibs(gpuInfo)
// Check to see if the user has requested a specific library instead of auto-detecting // Check to see if the user has requested a specific library instead of auto-detecting
demandLib := os.Getenv("OLLAMA_LLM_LIBRARY") demandLib := os.Getenv("OLLAMA_LLM_LIBRARY")
if demandLib != "" { if demandLib != "" {
libPath := availableShims[demandLib] libPath := availableDynLibs[demandLib]
if libPath == "" { if libPath == "" {
log.Printf("Invalid OLLAMA_LLM_LIBRARY %s - not found", demandLib) log.Printf("Invalid OLLAMA_LLM_LIBRARY %s - not found", demandLib)
} else { } else {
log.Printf("Loading OLLAMA_LLM_LIBRARY=%s", demandLib) log.Printf("Loading OLLAMA_LLM_LIBRARY=%s", demandLib)
shims = []string{libPath} dynLibs = []string{libPath}
} }
} }
for _, shim := range shims { err2 := fmt.Errorf("unable to locate suitable llm library")
// TODO - only applies on Darwin (switch to fully dynamic there too...) for _, dynLib := range dynLibs {
if shim == "default" { srv, err := newDynExtServer(dynLib, model, adapters, projectors, opts)
break
}
srv, err := newDynamicShimExtServer(shim, model, adapters, projectors, opts)
if err == nil { if err == nil {
return srv, nil return srv, nil
} }
log.Printf("Failed to load dynamic library %s %s", shim, err) log.Printf("Failed to load dynamic library %s %s", dynLib, err)
err2 = err
} }
return newDefaultExtServer(model, adapters, projectors, opts) return nil, err2
} }

View file

@ -18,42 +18,42 @@ import (
// Libraries names may contain an optional variant separated by '_' // Libraries names may contain an optional variant separated by '_'
// For example, "rocm_v6" and "rocm_v5" or "cpu" and "cpu_avx2" // For example, "rocm_v6" and "rocm_v5" or "cpu" and "cpu_avx2"
// Any library without a variant is the lowest common denominator // Any library without a variant is the lowest common denominator
var availableShims = map[string]string{} var availableDynLibs = map[string]string{}
const pathComponentCount = 6 const pathComponentCount = 6
// getShims returns an ordered list of shims to try, starting with the best // getDynLibs returns an ordered list of LLM libraries to try, starting with the best
func getShims(gpuInfo gpu.GpuInfo) []string { func getDynLibs(gpuInfo gpu.GpuInfo) []string {
// Short circuit if we know we're using the default built-in (darwin only) // Short circuit if we know we're using the default built-in (darwin only)
if gpuInfo.Library == "default" { if gpuInfo.Library == "default" {
return []string{"default"} return []string{"default"}
} }
exactMatch := "" exactMatch := ""
shims := []string{} dynLibs := []string{}
altShims := []string{} altDynLibs := []string{}
requested := gpuInfo.Library requested := gpuInfo.Library
if gpuInfo.Variant != "" { if gpuInfo.Variant != "" {
requested += "_" + gpuInfo.Variant requested += "_" + gpuInfo.Variant
} }
// Try to find an exact match // Try to find an exact match
for cmp := range availableShims { for cmp := range availableDynLibs {
if requested == cmp { if requested == cmp {
exactMatch = cmp exactMatch = cmp
shims = []string{availableShims[cmp]} dynLibs = []string{availableDynLibs[cmp]}
break break
} }
} }
// Then for GPUs load alternates and sort the list for consistent load ordering // Then for GPUs load alternates and sort the list for consistent load ordering
if gpuInfo.Library != "cpu" { if gpuInfo.Library != "cpu" {
for cmp := range availableShims { for cmp := range availableDynLibs {
if gpuInfo.Library == strings.Split(cmp, "_")[0] && cmp != exactMatch { if gpuInfo.Library == strings.Split(cmp, "_")[0] && cmp != exactMatch {
altShims = append(altShims, cmp) altDynLibs = append(altDynLibs, cmp)
} }
} }
slices.Sort(altShims) slices.Sort(altDynLibs)
for _, altShim := range altShims { for _, altDynLib := range altDynLibs {
shims = append(shims, availableShims[altShim]) dynLibs = append(dynLibs, availableDynLibs[altDynLib])
} }
} }
@ -65,27 +65,27 @@ func getShims(gpuInfo gpu.GpuInfo) []string {
// Attempting to run the wrong CPU instructions will panic the // Attempting to run the wrong CPU instructions will panic the
// process // process
if variant != "" { if variant != "" {
for cmp := range availableShims { for cmp := range availableDynLibs {
if cmp == "cpu_"+variant { if cmp == "cpu_"+variant {
shims = append(shims, availableShims[cmp]) dynLibs = append(dynLibs, availableDynLibs[cmp])
break break
} }
} }
} else { } else {
shims = append(shims, availableShims["cpu"]) dynLibs = append(dynLibs, availableDynLibs["cpu"])
} }
} }
// Finaly, if we didn't find any matches, LCD CPU FTW // Finaly, if we didn't find any matches, LCD CPU FTW
if len(shims) == 0 { if len(dynLibs) == 0 {
shims = []string{availableShims["cpu"]} dynLibs = []string{availableDynLibs["cpu"]}
} }
return shims return dynLibs
} }
func rocmShimPresent() bool { func rocmDynLibPresent() bool {
for shimName := range availableShims { for dynLibName := range availableDynLibs {
if strings.HasPrefix(shimName, "rocm") { if strings.HasPrefix(dynLibName, "rocm") {
return true return true
} }
} }
@ -104,7 +104,6 @@ func nativeInit(workdir string) error {
return err return err
} }
os.Setenv("GGML_METAL_PATH_RESOURCES", workdir) os.Setenv("GGML_METAL_PATH_RESOURCES", workdir)
return nil
} }
libs, err := extractDynamicLibs(workdir, "llama.cpp/build/*/*/lib/*") libs, err := extractDynamicLibs(workdir, "llama.cpp/build/*/*/lib/*")
@ -118,7 +117,7 @@ func nativeInit(workdir string) error {
for _, lib := range libs { for _, lib := range libs {
// The last dir component is the variant name // The last dir component is the variant name
variant := filepath.Base(filepath.Dir(lib)) variant := filepath.Base(filepath.Dir(lib))
availableShims[variant] = lib availableDynLibs[variant] = lib
} }
if err := verifyDriverAccess(); err != nil { if err := verifyDriverAccess(); err != nil {
@ -126,9 +125,9 @@ func nativeInit(workdir string) error {
} }
// Report which dynamic libraries we have loaded to assist troubleshooting // Report which dynamic libraries we have loaded to assist troubleshooting
variants := make([]string, len(availableShims)) variants := make([]string, len(availableDynLibs))
i := 0 i := 0
for variant := range availableShims { for variant := range availableDynLibs {
variants[i] = variant variants[i] = variant
i++ i++
} }
@ -226,7 +225,7 @@ func verifyDriverAccess() error {
return nil return nil
} }
// Only check ROCm access if we have the dynamic lib loaded // Only check ROCm access if we have the dynamic lib loaded
if rocmShimPresent() { if rocmDynLibPresent() {
// Verify we have permissions - either running as root, or we have group access to the driver // Verify we have permissions - either running as root, or we have group access to the driver
fd, err := os.OpenFile("/dev/kfd", os.O_RDWR, 0666) fd, err := os.OpenFile("/dev/kfd", os.O_RDWR, 0666)
if err != nil { if err != nil {

8
llm/payload_darwin.go Normal file
View file

@ -0,0 +1,8 @@
package llm
import (
"embed"
)
//go:embed llama.cpp/ggml-metal.metal llama.cpp/build/darwin/*/lib/*.so
var libEmbed embed.FS

8
llm/payload_linux.go Normal file
View file

@ -0,0 +1,8 @@
package llm
import (
"embed"
)
//go:embed llama.cpp/build/linux/*/lib/*.so
var libEmbed embed.FS

54
llm/payload_test.go Normal file
View file

@ -0,0 +1,54 @@
package llm
import (
"testing"
"github.com/jmorganca/ollama/gpu"
"github.com/stretchr/testify/assert"
)
func TestGetDynLibs(t *testing.T) {
availableDynLibs = map[string]string{
"cpu": "X_cpu",
}
assert.Equal(t, false, rocmDynLibPresent())
res := getDynLibs(gpu.GpuInfo{Library: "cpu"})
assert.Len(t, res, 1)
assert.Equal(t, availableDynLibs["cpu"], res[0])
availableDynLibs = map[string]string{
"rocm_v5": "X_rocm_v5",
"rocm_v6": "X_rocm_v6",
"cpu": "X_cpu",
}
assert.Equal(t, true, rocmDynLibPresent())
res = getDynLibs(gpu.GpuInfo{Library: "rocm"})
assert.Len(t, res, 3)
assert.Equal(t, availableDynLibs["rocm_v5"], res[0])
assert.Equal(t, availableDynLibs["rocm_v6"], res[1])
assert.Equal(t, availableDynLibs["cpu"], res[2])
res = getDynLibs(gpu.GpuInfo{Library: "rocm", Variant: "v6"})
assert.Len(t, res, 3)
assert.Equal(t, availableDynLibs["rocm_v6"], res[0])
assert.Equal(t, availableDynLibs["rocm_v5"], res[1])
assert.Equal(t, availableDynLibs["cpu"], res[2])
res = getDynLibs(gpu.GpuInfo{Library: "cuda"})
assert.Len(t, res, 1)
assert.Equal(t, availableDynLibs["cpu"], res[0])
res = getDynLibs(gpu.GpuInfo{Library: "default"})
assert.Len(t, res, 1)
assert.Equal(t, "default", res[0])
availableDynLibs = map[string]string{
"rocm": "X_rocm_v5",
"cpu": "X_cpu",
}
assert.Equal(t, true, rocmDynLibPresent())
res = getDynLibs(gpu.GpuInfo{Library: "rocm", Variant: "v6"})
assert.Len(t, res, 2)
assert.Equal(t, availableDynLibs["rocm"], res[0])
assert.Equal(t, availableDynLibs["cpu"], res[1])
}

8
llm/payload_windows.go Normal file
View file

@ -0,0 +1,8 @@
package llm
import (
"embed"
)
//go:embed llama.cpp/build/windows/*/lib/*.dll
var libEmbed embed.FS

View file

@ -1,16 +0,0 @@
package llm
import (
"embed"
"fmt"
"github.com/jmorganca/ollama/api"
)
//go:embed llama.cpp/ggml-metal.metal
var libEmbed embed.FS
func newDynamicShimExtServer(library, model string, adapters, projectors []string, opts api.Options) (extServer, error) {
// should never happen...
return nil, fmt.Errorf("Dynamic library loading not supported on Mac")
}

View file

@ -1,107 +0,0 @@
//go:build !darwin
package llm
/*
#include <stdlib.h>
#include "dynamic_shim.h"
*/
import "C"
import (
"context"
"fmt"
"log"
"path/filepath"
"sync"
"unsafe"
"github.com/jmorganca/ollama/api"
)
type shimExtServer struct {
s C.struct_dynamic_llama_server
options api.Options
}
// Note: current implementation does not support concurrent instantiations
var shimMutex sync.Mutex
var llm *shimExtServer
func (llm *shimExtServer) llama_server_init(sparams *C.ext_server_params_t, err *C.ext_server_resp_t) {
C.dynamic_shim_llama_server_init(llm.s, sparams, err)
}
func (llm *shimExtServer) llama_server_start() {
C.dynamic_shim_llama_server_start(llm.s)
}
func (llm *shimExtServer) llama_server_stop() {
C.dynamic_shim_llama_server_stop(llm.s)
}
func (llm *shimExtServer) llama_server_completion(json_req *C.char, resp *C.ext_server_resp_t) {
C.dynamic_shim_llama_server_completion(llm.s, json_req, resp)
}
func (llm *shimExtServer) llama_server_completion_next_result(task_id C.int, resp *C.ext_server_task_result_t) {
C.dynamic_shim_llama_server_completion_next_result(llm.s, task_id, resp)
}
func (llm *shimExtServer) llama_server_completion_cancel(task_id C.int, err *C.ext_server_resp_t) {
C.dynamic_shim_llama_server_completion_cancel(llm.s, task_id, err)
}
func (llm *shimExtServer) llama_server_release_task_result(result *C.ext_server_task_result_t) {
C.dynamic_shim_llama_server_release_task_result(llm.s, result)
}
func (llm *shimExtServer) llama_server_tokenize(json_req *C.char, json_resp **C.char, err *C.ext_server_resp_t) {
C.dynamic_shim_llama_server_tokenize(llm.s, json_req, json_resp, err)
}
func (llm *shimExtServer) llama_server_detokenize(json_req *C.char, json_resp **C.char, err *C.ext_server_resp_t) {
C.dynamic_shim_llama_server_detokenize(llm.s, json_req, json_resp, err)
}
func (llm *shimExtServer) llama_server_embedding(json_req *C.char, json_resp **C.char, err *C.ext_server_resp_t) {
C.dynamic_shim_llama_server_embedding(llm.s, json_req, json_resp, err)
}
func (llm *shimExtServer) llama_server_release_json_resp(json_resp **C.char) {
C.dynamic_shim_llama_server_release_json_resp(llm.s, json_resp)
}
func newDynamicShimExtServer(library, model string, adapters, projectors []string, opts api.Options) (extServer, error) {
shimMutex.Lock()
defer shimMutex.Unlock()
updatePath(filepath.Dir(library))
libPath := C.CString(library)
defer C.free(unsafe.Pointer(libPath))
resp := newExtServerResp(128)
defer freeExtServerResp(resp)
var srv C.struct_dynamic_llama_server
C.dynamic_shim_init(libPath, &srv, &resp)
if resp.id < 0 {
return nil, fmt.Errorf("Unable to load dynamic library: %s", C.GoString(resp.msg))
}
llm = &shimExtServer{
s: srv,
options: opts,
}
log.Printf("Loading Dynamic Shim llm server: %s", library)
return newExtServer(llm, model, adapters, projectors, opts)
}
func (llm *shimExtServer) Predict(ctx context.Context, pred PredictOpts, fn func(PredictResult)) error {
return predict(ctx, llm, pred, fn)
}
func (llm *shimExtServer) Encode(ctx context.Context, prompt string) ([]int, error) {
return encode(llm, ctx, prompt)
}
func (llm *shimExtServer) Decode(ctx context.Context, tokens []int) (string, error) {
return decode(llm, ctx, tokens)
}
func (llm *shimExtServer) Embedding(ctx context.Context, input string) ([]float64, error) {
return embedding(llm, ctx, input)
}
func (llm *shimExtServer) Close() {
close(llm)
}

View file

@ -1,23 +0,0 @@
package llm
import (
"embed"
"log"
"os"
"strings"
)
//go:embed llama.cpp/build/*/*/lib/*.so
var libEmbed embed.FS
func updatePath(dir string) {
pathComponents := strings.Split(os.Getenv("LD_LIBRARY_PATH"), ":")
for _, comp := range pathComponents {
if comp == dir {
return
}
}
newPath := strings.Join(append([]string{dir}, pathComponents...), ":")
log.Printf("Updating LD_LIBRARY_PATH to %s", newPath)
os.Setenv("LD_LIBRARY_PATH", newPath)
}

View file

@ -1,31 +0,0 @@
package llm
import (
"embed"
"log"
"os"
"path/filepath"
"strings"
)
//go:embed llama.cpp/build/windows/*/lib/*.dll
var libEmbed embed.FS
func updatePath(dir string) {
tmpDir := filepath.Dir(dir)
pathComponents := strings.Split(os.Getenv("PATH"), ";")
i := 0
for _, comp := range pathComponents {
if strings.EqualFold(comp, dir) {
return
}
// Remove any other prior paths to our temp dir
if !strings.HasPrefix(strings.ToLower(comp), strings.ToLower(tmpDir)) {
pathComponents[i] = comp
i++
}
}
newPath := strings.Join(append([]string{dir}, pathComponents...), ";")
log.Printf("Updating PATH to %s", newPath)
os.Setenv("PATH", newPath)
}

View file

@ -1,54 +0,0 @@
package llm
import (
"testing"
"github.com/jmorganca/ollama/gpu"
"github.com/stretchr/testify/assert"
)
func TestGetShims(t *testing.T) {
availableShims = map[string]string{
"cpu": "X_cpu",
}
assert.Equal(t, false, rocmShimPresent())
res := getShims(gpu.GpuInfo{Library: "cpu"})
assert.Len(t, res, 1)
assert.Equal(t, availableShims["cpu"], res[0])
availableShims = map[string]string{
"rocm_v5": "X_rocm_v5",
"rocm_v6": "X_rocm_v6",
"cpu": "X_cpu",
}
assert.Equal(t, true, rocmShimPresent())
res = getShims(gpu.GpuInfo{Library: "rocm"})
assert.Len(t, res, 3)
assert.Equal(t, availableShims["rocm_v5"], res[0])
assert.Equal(t, availableShims["rocm_v6"], res[1])
assert.Equal(t, availableShims["cpu"], res[2])
res = getShims(gpu.GpuInfo{Library: "rocm", Variant: "v6"})
assert.Len(t, res, 3)
assert.Equal(t, availableShims["rocm_v6"], res[0])
assert.Equal(t, availableShims["rocm_v5"], res[1])
assert.Equal(t, availableShims["cpu"], res[2])
res = getShims(gpu.GpuInfo{Library: "cuda"})
assert.Len(t, res, 1)
assert.Equal(t, availableShims["cpu"], res[0])
res = getShims(gpu.GpuInfo{Library: "default"})
assert.Len(t, res, 1)
assert.Equal(t, "default", res[0])
availableShims = map[string]string{
"rocm": "X_rocm_v5",
"cpu": "X_cpu",
}
assert.Equal(t, true, rocmShimPresent())
res = getShims(gpu.GpuInfo{Library: "rocm", Variant: "v6"})
assert.Len(t, res, 2)
assert.Equal(t, availableShims["rocm"], res[0])
assert.Equal(t, availableShims["cpu"], res[1])
}