diff --git a/llama/llama.go b/llama/llama.go index 86bf0c54..f7032624 100644 --- a/llama/llama.go +++ b/llama/llama.go @@ -128,6 +128,11 @@ func New(model string, opts api.Options) (*LLM, error) { C.llama_backend_init(C.bool(llm.UseNUMA)) + // TODO: GQA == 8 suggests 70B model which doesn't support metal + if llm.NumGQA == 8 { + llm.NumGPU = 0 + } + params := C.llama_context_default_params() params.seed = C.uint(llm.Seed) params.n_ctx = C.int(llm.NumCtx)