diff --git a/llm/llama.go b/llm/llama.go index 4481e97d..4eab751d 100644 --- a/llm/llama.go +++ b/llm/llama.go @@ -339,6 +339,7 @@ func newLlama(model string, adapters []string, runners []ModelRunner, numLayers "--model", model, "--ctx-size", fmt.Sprintf("%d", opts.NumCtx), "--batch-size", fmt.Sprintf("%d", opts.NumBatch), + "--main-gpu", fmt.Sprintf("%d", opts.MainGPU), "--n-gpu-layers", fmt.Sprintf("%d", numGPU), "--embedding", } @@ -544,6 +545,7 @@ func (llm *llama) Predict(ctx context.Context, prevContext []int, prompt string, "stream": true, "n_predict": llm.NumPredict, "n_keep": llm.NumKeep, + "main_gpu": llm.MainGPU, "temperature": llm.Temperature, "top_k": llm.TopK, "top_p": llm.TopP,