Offload layers to GPU based on new model size estimates (#1850)
* select layers based on estimated model memory usage * always account for scratch vram * dont load +1 layers * better estmation for graph alloc * Update gpu/gpu_darwin.go Co-authored-by: Bruce MacDonald <brucewmacdonald@gmail.com> * Update llm/llm.go Co-authored-by: Bruce MacDonald <brucewmacdonald@gmail.com> * Update llm/llm.go * add overhead for cuda memory * Update llm/llm.go Co-authored-by: Bruce MacDonald <brucewmacdonald@gmail.com> * fix build error on linux * address comments --------- Co-authored-by: Bruce MacDonald <brucewmacdonald@gmail.com>
This commit is contained in:
parent
7e8f7c8358
commit
08f1e18965
10 changed files with 161 additions and 154 deletions
35
gpu/gpu.go
35
gpu/gpu.go
|
@ -16,8 +16,6 @@ import (
|
||||||
"runtime"
|
"runtime"
|
||||||
"sync"
|
"sync"
|
||||||
"unsafe"
|
"unsafe"
|
||||||
|
|
||||||
"github.com/jmorganca/ollama/api"
|
|
||||||
)
|
)
|
||||||
|
|
||||||
type handles struct {
|
type handles struct {
|
||||||
|
@ -133,31 +131,14 @@ func getCPUMem() (memInfo, error) {
|
||||||
func CheckVRAM() (int64, error) {
|
func CheckVRAM() (int64, error) {
|
||||||
gpuInfo := GetGPUInfo()
|
gpuInfo := GetGPUInfo()
|
||||||
if gpuInfo.FreeMemory > 0 && (gpuInfo.Library == "cuda" || gpuInfo.Library == "rocm") {
|
if gpuInfo.FreeMemory > 0 && (gpuInfo.Library == "cuda" || gpuInfo.Library == "rocm") {
|
||||||
return int64(gpuInfo.FreeMemory), nil
|
// allocate 384MiB for llama.cpp overhead (outside of model)
|
||||||
|
overhead := uint64(384 * 1024 * 1024)
|
||||||
|
if gpuInfo.FreeMemory <= overhead {
|
||||||
|
return 0, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
return int64(gpuInfo.FreeMemory - overhead), nil
|
||||||
|
}
|
||||||
|
|
||||||
return 0, fmt.Errorf("no GPU detected") // TODO - better handling of CPU based memory determiniation
|
return 0, fmt.Errorf("no GPU detected") // TODO - better handling of CPU based memory determiniation
|
||||||
}
|
}
|
||||||
|
|
||||||
func NumGPU(numLayer, fileSizeBytes int64, opts api.Options) int {
|
|
||||||
if opts.NumGPU != -1 {
|
|
||||||
return opts.NumGPU
|
|
||||||
}
|
|
||||||
info := GetGPUInfo()
|
|
||||||
if info.Library == "cpu" || info.Library == "default" {
|
|
||||||
return 0
|
|
||||||
}
|
|
||||||
|
|
||||||
/*
|
|
||||||
Calculate bytes per layer, this will roughly be the size of the model file divided by the number of layers.
|
|
||||||
We can store the model weights and the kv cache in vram,
|
|
||||||
to enable kv chache vram storage add two additional layers to the number of layers retrieved from the model file.
|
|
||||||
*/
|
|
||||||
bytesPerLayer := uint64(fileSizeBytes / numLayer)
|
|
||||||
|
|
||||||
// 75% of the absolute max number of layers we can fit in available VRAM, off-loading too many layers to the GPU can cause OOM errors
|
|
||||||
layers := int(info.FreeMemory/bytesPerLayer) * 3 / 4
|
|
||||||
|
|
||||||
log.Printf("%d MB VRAM available, loading up to %d %s GPU layers out of %d", info.FreeMemory/(1024*1024), layers, info.Library, numLayer)
|
|
||||||
|
|
||||||
return layers
|
|
||||||
}
|
|
||||||
|
|
|
@ -6,18 +6,31 @@ import "C"
|
||||||
import (
|
import (
|
||||||
"runtime"
|
"runtime"
|
||||||
|
|
||||||
"github.com/jmorganca/ollama/api"
|
"github.com/pbnjay/memory"
|
||||||
)
|
)
|
||||||
|
|
||||||
// CheckVRAM returns the free VRAM in bytes on Linux machines with NVIDIA GPUs
|
// CheckVRAM returns the free VRAM in bytes on Linux machines with NVIDIA GPUs
|
||||||
func CheckVRAM() (int64, error) {
|
func CheckVRAM() (int64, error) {
|
||||||
// TODO - assume metal, and return free memory?
|
if runtime.GOARCH == "amd64" {
|
||||||
|
// gpu not supported, this may not be metal
|
||||||
return 0, nil
|
return 0, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// on macOS, there's already buffer for available vram (see below) so just return the total
|
||||||
|
systemMemory := int64(memory.TotalMemory())
|
||||||
|
|
||||||
|
// macOS limits how much memory is available to the GPU based on the amount of system memory
|
||||||
|
// TODO: handle case where iogpu.wired_limit_mb is set to a higher value
|
||||||
|
if systemMemory <= 36*1024*1024*1024 {
|
||||||
|
systemMemory = systemMemory * 2 / 3
|
||||||
|
} else {
|
||||||
|
systemMemory = systemMemory * 3 / 4
|
||||||
|
}
|
||||||
|
|
||||||
|
return systemMemory, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func GetGPUInfo() GpuInfo {
|
func GetGPUInfo() GpuInfo {
|
||||||
// TODO - Metal vs. x86 macs...
|
|
||||||
mem, _ := getCPUMem()
|
mem, _ := getCPUMem()
|
||||||
return GpuInfo{
|
return GpuInfo{
|
||||||
Library: "default",
|
Library: "default",
|
||||||
|
@ -32,19 +45,6 @@ func getCPUMem() (memInfo, error) {
|
||||||
}, nil
|
}, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func NumGPU(numLayer, fileSizeBytes int64, opts api.Options) int {
|
|
||||||
if opts.NumGPU != -1 {
|
|
||||||
return opts.NumGPU
|
|
||||||
}
|
|
||||||
|
|
||||||
// metal only supported on arm64
|
|
||||||
if runtime.GOARCH == "arm64" {
|
|
||||||
return 1
|
|
||||||
}
|
|
||||||
|
|
||||||
return 0
|
|
||||||
}
|
|
||||||
|
|
||||||
func nativeInit() error {
|
func nativeInit() error {
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
|
@ -35,14 +35,12 @@ import (
|
||||||
"encoding/json"
|
"encoding/json"
|
||||||
"fmt"
|
"fmt"
|
||||||
"log"
|
"log"
|
||||||
"os"
|
|
||||||
"strings"
|
"strings"
|
||||||
"sync"
|
"sync"
|
||||||
"time"
|
"time"
|
||||||
"unsafe"
|
"unsafe"
|
||||||
|
|
||||||
"github.com/jmorganca/ollama/api"
|
"github.com/jmorganca/ollama/api"
|
||||||
"github.com/jmorganca/ollama/gpu"
|
|
||||||
)
|
)
|
||||||
|
|
||||||
type extServer interface {
|
type extServer interface {
|
||||||
|
@ -82,25 +80,20 @@ func extServerResponseToErr(resp C.ext_server_resp_t) error {
|
||||||
return fmt.Errorf(C.GoString(resp.msg))
|
return fmt.Errorf(C.GoString(resp.msg))
|
||||||
}
|
}
|
||||||
|
|
||||||
func newExtServer(server extServer, model string, adapters, projectors []string, numLayers int64, opts api.Options) (extServer, error) {
|
func newExtServer(server extServer, model string, adapters, projectors []string, opts api.Options) (extServer, error) {
|
||||||
if !mutex.TryLock() {
|
if !mutex.TryLock() {
|
||||||
log.Printf("concurrent llm servers not yet supported, waiting for prior server to complete")
|
log.Printf("concurrent llm servers not yet supported, waiting for prior server to complete")
|
||||||
mutex.Lock()
|
mutex.Lock()
|
||||||
}
|
}
|
||||||
fileInfo, err := os.Stat(model)
|
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
var sparams C.ext_server_params_t
|
var sparams C.ext_server_params_t
|
||||||
sparams.model = C.CString(model)
|
sparams.model = C.CString(model)
|
||||||
defer C.free(unsafe.Pointer(sparams.model))
|
defer C.free(unsafe.Pointer(sparams.model))
|
||||||
|
|
||||||
numGPU := gpu.NumGPU(numLayers, fileInfo.Size(), opts)
|
|
||||||
|
|
||||||
sparams.embedding = true
|
sparams.embedding = true
|
||||||
sparams.n_ctx = C.uint(opts.NumCtx)
|
sparams.n_ctx = C.uint(opts.NumCtx)
|
||||||
sparams.n_batch = C.uint(opts.NumBatch)
|
sparams.n_batch = C.uint(opts.NumBatch)
|
||||||
sparams.n_gpu_layers = C.int(numGPU)
|
sparams.n_gpu_layers = C.int(opts.NumGPU)
|
||||||
sparams.main_gpu = C.int(opts.MainGPU)
|
sparams.main_gpu = C.int(opts.MainGPU)
|
||||||
sparams.n_parallel = 1 // TODO - wire up concurrency
|
sparams.n_parallel = 1 // TODO - wire up concurrency
|
||||||
|
|
||||||
|
|
|
@ -54,9 +54,9 @@ func (llm *llamaExtServer) llama_server_release_json_resp(json_resp **C.char) {
|
||||||
C.llama_server_release_json_resp(json_resp)
|
C.llama_server_release_json_resp(json_resp)
|
||||||
}
|
}
|
||||||
|
|
||||||
func newDefaultExtServer(model string, adapters, projectors []string, numLayers int64, opts api.Options) (extServer, error) {
|
func newDefaultExtServer(model string, adapters, projectors []string, opts api.Options) (extServer, error) {
|
||||||
server := &llamaExtServer{opts}
|
server := &llamaExtServer{opts}
|
||||||
return newExtServer(server, model, adapters, projectors, numLayers, opts)
|
return newExtServer(server, model, adapters, projectors, opts)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (llm *llamaExtServer) Predict(ctx context.Context, pred PredictOpts, fn func(PredictResult)) error {
|
func (llm *llamaExtServer) Predict(ctx context.Context, pred PredictOpts, fn func(PredictResult)) error {
|
||||||
|
|
|
@ -78,7 +78,11 @@ type model interface {
|
||||||
ModelFamily() string
|
ModelFamily() string
|
||||||
ModelType() string
|
ModelType() string
|
||||||
FileType() string
|
FileType() string
|
||||||
NumLayers() int64
|
NumLayers() uint32
|
||||||
|
NumGQA() uint32
|
||||||
|
NumEmbed() uint32
|
||||||
|
NumHead() uint32
|
||||||
|
NumHeadKv() uint32
|
||||||
}
|
}
|
||||||
|
|
||||||
type container interface {
|
type container interface {
|
||||||
|
|
41
llm/gguf.go
41
llm/gguf.go
|
@ -272,14 +272,49 @@ func (llm *ggufModel) Decode(rso *readSeekOffset) error {
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (llm *ggufModel) NumLayers() int64 {
|
func (llm *ggufModel) NumLayers() uint32 {
|
||||||
value, exists := llm.kv[fmt.Sprintf("%s.block_count", llm.ModelFamily())]
|
value, exists := llm.kv[fmt.Sprintf("%s.block_count", llm.ModelFamily())]
|
||||||
if !exists {
|
if !exists {
|
||||||
return 0
|
return 0
|
||||||
}
|
}
|
||||||
|
|
||||||
v := value.(uint32)
|
return value.(uint32)
|
||||||
return int64(v)
|
}
|
||||||
|
|
||||||
|
func (llm *ggufModel) NumHead() uint32 {
|
||||||
|
value, exists := llm.kv[fmt.Sprintf("%s.attention.head_count", llm.ModelFamily())]
|
||||||
|
if !exists {
|
||||||
|
return 0
|
||||||
|
}
|
||||||
|
|
||||||
|
return value.(uint32)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (llm *ggufModel) NumEmbed() uint32 {
|
||||||
|
value, exists := llm.kv[fmt.Sprintf("%s.embedding_length", llm.ModelFamily())]
|
||||||
|
if !exists {
|
||||||
|
return 0
|
||||||
|
}
|
||||||
|
|
||||||
|
return value.(uint32)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (llm *ggufModel) NumHeadKv() uint32 {
|
||||||
|
value, exists := llm.kv[fmt.Sprintf("%s.attention.head_count_kv", llm.ModelFamily())]
|
||||||
|
if !exists {
|
||||||
|
return 0
|
||||||
|
}
|
||||||
|
|
||||||
|
return value.(uint32)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (llm *ggufModel) NumGQA() uint32 {
|
||||||
|
numHeadKv := llm.NumHeadKv()
|
||||||
|
if numHeadKv == 0 {
|
||||||
|
return 0
|
||||||
|
}
|
||||||
|
|
||||||
|
return llm.NumHead() / numHeadKv
|
||||||
}
|
}
|
||||||
|
|
||||||
func (llm ggufModel) readU8(r io.Reader) uint8 {
|
func (llm ggufModel) readU8(r io.Reader) uint8 {
|
||||||
|
|
59
llm/llama.go
59
llm/llama.go
|
@ -8,7 +8,6 @@ import (
|
||||||
"fmt"
|
"fmt"
|
||||||
"os"
|
"os"
|
||||||
"os/exec"
|
"os/exec"
|
||||||
"sync"
|
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
"github.com/jmorganca/ollama/api"
|
"github.com/jmorganca/ollama/api"
|
||||||
|
@ -43,68 +42,10 @@ number ::= ("-"? ([0-9] | [1-9] [0-9]*)) ("." [0-9]+)? ([eE] [-+]? [0-9]+)? ws
|
||||||
ws ::= ([ \t\n] ws)?
|
ws ::= ([ \t\n] ws)?
|
||||||
`
|
`
|
||||||
|
|
||||||
type llamaModel struct {
|
|
||||||
hyperparameters llamaHyperparameters
|
|
||||||
}
|
|
||||||
|
|
||||||
func (llm *llamaModel) ModelFamily() string {
|
|
||||||
return "llama"
|
|
||||||
}
|
|
||||||
|
|
||||||
func llamaModelType(numLayer uint32) string {
|
|
||||||
switch numLayer {
|
|
||||||
case 26:
|
|
||||||
return "3B"
|
|
||||||
case 32:
|
|
||||||
return "7B"
|
|
||||||
case 40:
|
|
||||||
return "13B"
|
|
||||||
case 48:
|
|
||||||
return "34B"
|
|
||||||
case 60:
|
|
||||||
return "30B"
|
|
||||||
case 80:
|
|
||||||
return "65B"
|
|
||||||
default:
|
|
||||||
return "unknown"
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func (llm *llamaModel) ModelType() string {
|
|
||||||
return llamaModelType(llm.hyperparameters.NumLayer)
|
|
||||||
}
|
|
||||||
|
|
||||||
func (llm *llamaModel) FileType() string {
|
|
||||||
return fileType(llm.hyperparameters.FileType)
|
|
||||||
}
|
|
||||||
|
|
||||||
func (llm *llamaModel) NumLayers() int64 {
|
|
||||||
return int64(llm.hyperparameters.NumLayer)
|
|
||||||
}
|
|
||||||
|
|
||||||
type llamaHyperparameters struct {
|
|
||||||
// NumVocab is the size of the model's vocabulary.
|
|
||||||
NumVocab uint32
|
|
||||||
|
|
||||||
// NumEmbd is the size of the model's embedding layer.
|
|
||||||
NumEmbd uint32
|
|
||||||
NumMult uint32
|
|
||||||
NumHead uint32
|
|
||||||
|
|
||||||
// NumLayer is the number of layers in the model.
|
|
||||||
NumLayer uint32
|
|
||||||
NumRot uint32
|
|
||||||
|
|
||||||
// FileType describes the quantization level of the model, e.g. Q4_0, Q5_K, etc.
|
|
||||||
FileType uint32
|
|
||||||
}
|
|
||||||
|
|
||||||
type Running struct {
|
type Running struct {
|
||||||
Port int
|
Port int
|
||||||
Cmd *exec.Cmd
|
Cmd *exec.Cmd
|
||||||
Cancel context.CancelFunc
|
Cancel context.CancelFunc
|
||||||
exitOnce sync.Once
|
|
||||||
exitCh chan error // channel to receive the exit status of the subprocess
|
|
||||||
*StatusWriter // captures error messages from the llama runner process
|
*StatusWriter // captures error messages from the llama runner process
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
115
llm/llm.go
115
llm/llm.go
|
@ -7,10 +7,7 @@ import (
|
||||||
"os"
|
"os"
|
||||||
"runtime"
|
"runtime"
|
||||||
|
|
||||||
"github.com/pbnjay/memory"
|
|
||||||
|
|
||||||
"github.com/jmorganca/ollama/api"
|
"github.com/jmorganca/ollama/api"
|
||||||
"github.com/jmorganca/ollama/format"
|
|
||||||
"github.com/jmorganca/ollama/gpu"
|
"github.com/jmorganca/ollama/gpu"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@ -40,32 +37,89 @@ func New(workDir, model string, adapters, projectors []string, opts api.Options)
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
if runtime.GOOS == "darwin" {
|
if opts.NumCtx < 4 {
|
||||||
var requiredMemory int64
|
opts.NumCtx = 4
|
||||||
var f16Multiplier int64 = 2
|
|
||||||
|
|
||||||
switch ggml.ModelType() {
|
|
||||||
case "3B", "7B":
|
|
||||||
requiredMemory = 8 * format.GigaByte
|
|
||||||
case "13B":
|
|
||||||
requiredMemory = 16 * format.GigaByte
|
|
||||||
case "30B", "34B", "40B":
|
|
||||||
requiredMemory = 32 * format.GigaByte
|
|
||||||
case "47B":
|
|
||||||
requiredMemory = 48 * format.GigaByte
|
|
||||||
case "65B", "70B":
|
|
||||||
requiredMemory = 64 * format.GigaByte
|
|
||||||
case "180B":
|
|
||||||
requiredMemory = 128 * format.GigaByte
|
|
||||||
f16Multiplier = 4
|
|
||||||
}
|
}
|
||||||
|
|
||||||
systemMemory := int64(memory.TotalMemory())
|
fmt.Println("size", ggml.Size)
|
||||||
|
fmt.Println("filetype", ggml.FileType())
|
||||||
|
fmt.Println("architecture", ggml.ModelFamily())
|
||||||
|
fmt.Println("type", ggml.ModelType())
|
||||||
|
fmt.Println("name", ggml.Name())
|
||||||
|
fmt.Println("embd", ggml.NumEmbed())
|
||||||
|
fmt.Println("head", ggml.NumHead())
|
||||||
|
fmt.Println("head_kv", ggml.NumHeadKv())
|
||||||
|
fmt.Println("gqa", ggml.NumGQA())
|
||||||
|
|
||||||
if ggml.FileType() == "F16" && requiredMemory*f16Multiplier > systemMemory {
|
available, _ := gpu.CheckVRAM()
|
||||||
return nil, fmt.Errorf("F16 model requires at least %s of memory", format.HumanBytes(requiredMemory))
|
|
||||||
} else if requiredMemory > systemMemory {
|
// For now assume filesize = model size
|
||||||
return nil, fmt.Errorf("model requires at least %s of memory", format.HumanBytes(requiredMemory))
|
// TODO: use actual model size
|
||||||
|
requiredModel := ggml.Size
|
||||||
|
|
||||||
|
// fp16 k,v matrices require = n_ctx * n_layer * n_embd / n_head * n_head_kv * 2 bytes each * 2 key and value
|
||||||
|
requiredKv := 2 * 2 * int64(opts.NumCtx) * int64(ggml.NumLayers()) * int64(ggml.NumEmbed()) * int64(ggml.NumHeadKv()) / int64(ggml.NumHead())
|
||||||
|
|
||||||
|
// this amount is the overhead + tensors in memory
|
||||||
|
// TODO: get this from the llama.cpp's graph calcluations instead of
|
||||||
|
// guessing it's ~1/7th of the kv cache times gqa
|
||||||
|
requiredAlloc := int64(ggml.NumGQA()) * requiredKv / 7
|
||||||
|
|
||||||
|
requiredTotal := requiredModel + requiredKv + requiredAlloc
|
||||||
|
|
||||||
|
log.Println("system memory bytes:", available)
|
||||||
|
log.Println("required model bytes:", requiredModel)
|
||||||
|
log.Println("required kv bytes:", requiredKv)
|
||||||
|
log.Println("required alloc bytes:", requiredAlloc)
|
||||||
|
log.Println("required total bytes:", requiredTotal)
|
||||||
|
|
||||||
|
info := gpu.GetGPUInfo()
|
||||||
|
library := info.Library
|
||||||
|
|
||||||
|
if opts.NumGPU == -1 {
|
||||||
|
// default to offloading all layers
|
||||||
|
opts.NumGPU = int(ggml.NumLayers()) + 1
|
||||||
|
}
|
||||||
|
|
||||||
|
// decide how many layers to put on the GPU
|
||||||
|
if opts.NumGPU > 0 {
|
||||||
|
switch runtime.GOOS {
|
||||||
|
case "darwin":
|
||||||
|
if requiredTotal > available {
|
||||||
|
log.Println("not enough vram available, falling back to CPU only")
|
||||||
|
opts.NumGPU = 0
|
||||||
|
}
|
||||||
|
default:
|
||||||
|
if library == "cpu" || library == "default" {
|
||||||
|
opts.NumGPU = 0
|
||||||
|
break
|
||||||
|
}
|
||||||
|
|
||||||
|
// no offloading required
|
||||||
|
if requiredTotal <= available {
|
||||||
|
break
|
||||||
|
}
|
||||||
|
|
||||||
|
// This handles two cases:
|
||||||
|
// 1. overhead + tensors are always loaded into scratch memory even with num_gpu 0
|
||||||
|
// 2. it seems llama.cpp always tries to allocate the entire kv cache (even if later split into layers) into vram or crashes
|
||||||
|
if requiredAlloc > available || requiredKv > available {
|
||||||
|
log.Printf("not enough vram available, falling back to CPU only")
|
||||||
|
library = "cpu"
|
||||||
|
opts.NumGPU = 0
|
||||||
|
break
|
||||||
|
}
|
||||||
|
|
||||||
|
available -= requiredAlloc
|
||||||
|
|
||||||
|
// fill remaining vram with layers
|
||||||
|
log.Println("splitting", available, "of available memory bytes into layers")
|
||||||
|
bytesPerLayer := int64((requiredModel + requiredKv) / int64(ggml.NumLayers()))
|
||||||
|
log.Println("bytes per layer:", bytesPerLayer)
|
||||||
|
layers := available / bytesPerLayer
|
||||||
|
if layers < int64(opts.NumGPU) {
|
||||||
|
opts.NumGPU = int(layers)
|
||||||
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -73,7 +127,7 @@ func New(workDir, model string, adapters, projectors []string, opts api.Options)
|
||||||
opts.RopeFrequencyBase = 0.0
|
opts.RopeFrequencyBase = 0.0
|
||||||
opts.RopeFrequencyScale = 0.0
|
opts.RopeFrequencyScale = 0.0
|
||||||
gpuInfo := gpu.GetGPUInfo()
|
gpuInfo := gpu.GetGPUInfo()
|
||||||
return newLlmServer(gpuInfo.Library, model, adapters, projectors, ggml.NumLayers(), opts)
|
return newLlmServer(gpuInfo.Library, model, adapters, projectors, opts)
|
||||||
}
|
}
|
||||||
|
|
||||||
// Give any native cgo implementations an opportunity to initialize
|
// Give any native cgo implementations an opportunity to initialize
|
||||||
|
@ -81,9 +135,9 @@ func Init(workdir string) error {
|
||||||
return nativeInit(workdir)
|
return nativeInit(workdir)
|
||||||
}
|
}
|
||||||
|
|
||||||
func newLlmServer(library, model string, adapters, projectors []string, numLayers int64, opts api.Options) (extServer, error) {
|
func newLlmServer(library, model string, adapters, projectors []string, opts api.Options) (extServer, error) {
|
||||||
if _, libPresent := AvailableShims[library]; libPresent && library != "default" {
|
if _, libPresent := AvailableShims[library]; libPresent && library != "default" {
|
||||||
srv, err := newDynamicShimExtServer(AvailableShims[library], model, adapters, projectors, numLayers, opts)
|
srv, err := newDynamicShimExtServer(AvailableShims[library], model, adapters, projectors, opts)
|
||||||
if err == nil {
|
if err == nil {
|
||||||
return srv, nil
|
return srv, nil
|
||||||
}
|
}
|
||||||
|
@ -91,6 +145,5 @@ func newLlmServer(library, model string, adapters, projectors []string, numLayer
|
||||||
// TODO - update some state to indicate we were unable to load the GPU library for future "info" ux
|
// TODO - update some state to indicate we were unable to load the GPU library for future "info" ux
|
||||||
}
|
}
|
||||||
|
|
||||||
return newDefaultExtServer(model, adapters, projectors, numLayers, opts)
|
return newDefaultExtServer(model, adapters, projectors, opts)
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
|
@ -16,7 +16,7 @@ import (
|
||||||
//go:embed llama.cpp/ggml-metal.metal
|
//go:embed llama.cpp/ggml-metal.metal
|
||||||
var libEmbed embed.FS
|
var libEmbed embed.FS
|
||||||
|
|
||||||
func newDynamicShimExtServer(library, model string, adapters, projectors []string, numLayers int64, opts api.Options) (extServer, error) {
|
func newDynamicShimExtServer(library, model string, adapters, projectors []string, opts api.Options) (extServer, error) {
|
||||||
// should never happen...
|
// should never happen...
|
||||||
return nil, fmt.Errorf("Dynamic library loading not supported on Mac")
|
return nil, fmt.Errorf("Dynamic library loading not supported on Mac")
|
||||||
}
|
}
|
||||||
|
|
|
@ -72,7 +72,7 @@ func (llm *shimExtServer) llama_server_release_json_resp(json_resp **C.char) {
|
||||||
C.dynamic_shim_llama_server_release_json_resp(llm.s, json_resp)
|
C.dynamic_shim_llama_server_release_json_resp(llm.s, json_resp)
|
||||||
}
|
}
|
||||||
|
|
||||||
func newDynamicShimExtServer(library, model string, adapters, projectors []string, numLayers int64, opts api.Options) (extServer, error) {
|
func newDynamicShimExtServer(library, model string, adapters, projectors []string, opts api.Options) (extServer, error) {
|
||||||
shimMutex.Lock()
|
shimMutex.Lock()
|
||||||
defer shimMutex.Unlock()
|
defer shimMutex.Unlock()
|
||||||
updatePath(filepath.Dir(library))
|
updatePath(filepath.Dir(library))
|
||||||
|
@ -90,7 +90,7 @@ func newDynamicShimExtServer(library, model string, adapters, projectors []strin
|
||||||
options: opts,
|
options: opts,
|
||||||
}
|
}
|
||||||
log.Printf("Loading Dynamic Shim llm server: %s", library)
|
log.Printf("Loading Dynamic Shim llm server: %s", library)
|
||||||
return newExtServer(llm, model, adapters, projectors, numLayers, opts)
|
return newExtServer(llm, model, adapters, projectors, opts)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (llm *shimExtServer) Predict(ctx context.Context, pred PredictOpts, fn func(PredictResult)) error {
|
func (llm *shimExtServer) Predict(ctx context.Context, pred PredictOpts, fn func(PredictResult)) error {
|
||||||
|
|
Loading…
Reference in a new issue