ollama/llm/llm.go

152 lines
3.8 KiB
Go
Raw Normal View History

2023-07-21 13:33:56 -07:00
package llm
import (
"context"
"fmt"
"log"
2023-07-21 13:33:56 -07:00
"os"
"runtime"
2023-07-21 13:33:56 -07:00
"github.com/jmorganca/ollama/api"
"github.com/jmorganca/ollama/gpu"
2023-07-21 13:33:56 -07:00
)
type LLM interface {
2023-12-05 14:57:33 -05:00
Predict(context.Context, PredictOpts, func(PredictResult)) error
Embedding(context.Context, string) ([]float64, error)
Encode(context.Context, string) ([]int, error)
Decode(context.Context, []int) (string, error)
2023-07-21 13:33:56 -07:00
Close()
}
2023-11-30 10:30:23 -08:00
func New(workDir, model string, adapters, projectors []string, opts api.Options) (LLM, error) {
2023-07-21 13:33:56 -07:00
if _, err := os.Stat(model); err != nil {
return nil, err
}
f, err := os.Open(model)
if err != nil {
return nil, err
}
2023-08-14 16:08:02 -07:00
defer f.Close()
2023-07-21 13:33:56 -07:00
2023-09-07 13:55:37 -04:00
ggml, err := DecodeGGML(f)
2023-07-21 13:33:56 -07:00
if err != nil {
return nil, err
}
if opts.NumCtx < 4 {
opts.NumCtx = 4
}
vram, _ := gpu.CheckVRAM()
size := ggml.Size
// fp16 k,v matrices require = n_ctx * n_layer * n_embd / n_head * n_head_kv * 2 bytes each * 2 key and value
kv := 2 * 2 * int64(opts.NumCtx) * int64(ggml.NumLayers()) * int64(ggml.NumEmbed()) * int64(ggml.NumHeadKv()) / int64(ggml.NumHead())
2023-10-12 10:36:23 -07:00
// this amount is the overhead + tensors in memory
2024-01-09 09:45:42 -08:00
// TODO: get this from the llama.cpp's graph calculations instead of
2024-01-08 21:32:44 -05:00
// estimating it's 1/6 * kv_cache_size * num_gqa
graph := int64(ggml.NumGQA()) * kv / 6
info := gpu.GetGPUInfo()
switch runtime.GOOS {
case "darwin":
if opts.NumGPU == 0 {
break
}
if size+kv+graph > vram {
log.Println("not enough vram available, falling back to CPU only")
info.Library = "cpu"
info.Variant = gpu.GetCPUVariant()
opts.NumGPU = 0
break
}
opts.NumGPU = 1
default:
if info.Library == "cpu" {
log.Println("GPU not available, falling back to CPU")
opts.NumGPU = 0
break
}
// don't use GPU at all if no layers are loaded
if opts.NumGPU == 0 {
info.Library = "cpu"
info.Variant = gpu.GetCPUVariant()
break
}
// user-defined GPU count
if opts.NumGPU != -1 {
break
}
// the "main" GPU needs the most memory and determines the limit
// of how many layers can be loaded. It needs to fit:
// 1. the full compute graph allocation for all devices (graph)
// 2. the proportional kv cache for all devices (kv * % layers)
// 3. the proportional model (size * % layers / # devices)
// This estimates the number of layers
maxlayers := int64(ggml.NumLayers()) + 1
devices := int64(info.DeviceCount)
avg := vram / devices
layers := maxlayers * (avg - graph) / (kv + size/devices)
if layers > maxlayers {
layers = maxlayers
}
// 1 + 2 must fit on the main gpu
min := graph + kv*layers/maxlayers
if layers <= 0 || min > avg {
log.Printf("not enough vram available, falling back to CPU only")
info.Library = "cpu"
info.Variant = gpu.GetCPUVariant()
opts.NumGPU = 0
break
2023-10-13 14:41:51 -07:00
}
opts.NumGPU = int(layers)
}
opts.RopeFrequencyBase = 0.0
opts.RopeFrequencyScale = 0.0
return newLlmServer(info, model, adapters, projectors, opts)
}
// Give any native cgo implementations an opportunity to initialize
func Init(workdir string) error {
return nativeInit(workdir)
2023-07-21 13:33:56 -07:00
}
func newLlmServer(gpuInfo gpu.GpuInfo, model string, adapters, projectors []string, opts api.Options) (LLM, error) {
dynLibs := getDynLibs(gpuInfo)
// Check to see if the user has requested a specific library instead of auto-detecting
demandLib := os.Getenv("OLLAMA_LLM_LIBRARY")
if demandLib != "" {
libPath := availableDynLibs[demandLib]
if libPath == "" {
log.Printf("Invalid OLLAMA_LLM_LIBRARY %s - not found", demandLib)
} else {
log.Printf("Loading OLLAMA_LLM_LIBRARY=%s", demandLib)
dynLibs = []string{libPath}
}
}
err2 := fmt.Errorf("unable to locate suitable llm library")
for _, dynLib := range dynLibs {
srv, err := newDynExtServer(dynLib, model, adapters, projectors, opts)
if err == nil {
return srv, nil
}
log.Printf("Failed to load dynamic library %s %s", dynLib, err)
err2 = err
}
return nil, err2
}