2023-07-21 20:33:56 +00:00
package llm
import (
2023-08-30 20:35:03 +00:00
"context"
2024-01-11 22:43:16 +00:00
"fmt"
2024-01-18 18:52:01 +00:00
"log/slog"
2023-07-21 20:33:56 +00:00
"os"
2023-10-05 16:53:47 +00:00
"runtime"
2023-07-21 20:33:56 +00:00
"github.com/jmorganca/ollama/api"
2023-11-29 19:00:37 +00:00
"github.com/jmorganca/ollama/gpu"
2023-07-21 20:33:56 +00:00
)
type LLM interface {
2023-12-05 19:57:33 +00:00
Predict ( context . Context , PredictOpts , func ( PredictResult ) ) error
2023-08-30 20:35:03 +00:00
Embedding ( context . Context , string ) ( [ ] float64 , error )
Encode ( context . Context , string ) ( [ ] int , error )
Decode ( context . Context , [ ] int ) ( string , error )
2023-07-21 20:33:56 +00:00
Close ( )
}
2023-11-30 18:30:23 +00:00
func New ( workDir , model string , adapters , projectors [ ] string , opts api . Options ) ( LLM , error ) {
2023-07-21 20:33:56 +00:00
if _ , err := os . Stat ( model ) ; err != nil {
return nil , err
}
f , err := os . Open ( model )
if err != nil {
return nil , err
}
2023-08-14 23:08:02 +00:00
defer f . Close ( )
2023-07-21 20:33:56 +00:00
2023-09-07 17:55:37 +00:00
ggml , err := DecodeGGML ( f )
2023-07-21 20:33:56 +00:00
if err != nil {
return nil , err
}
2024-01-12 22:54:01 +00:00
if opts . NumCtx > int ( ggml . NumCtx ( ) ) {
2024-01-18 18:52:01 +00:00
slog . Warn ( fmt . Sprintf ( "requested context length is greater than model's max context length (%d > %d), using %d instead" , opts . NumCtx , ggml . NumCtx ( ) , ggml . NumCtx ( ) ) )
2024-01-12 22:54:01 +00:00
opts . NumCtx = int ( ggml . NumCtx ( ) )
}
2024-01-08 21:42:00 +00:00
if opts . NumCtx < 4 {
opts . NumCtx = 4
}
2024-01-11 06:45:31 +00:00
vram , _ := gpu . CheckVRAM ( )
size := ggml . Size
2024-01-08 21:42:00 +00:00
// fp16 k,v matrices require = n_ctx * n_layer * n_embd / n_head * n_head_kv * 2 bytes each * 2 key and value
2024-01-11 06:45:31 +00:00
kv := 2 * 2 * int64 ( opts . NumCtx ) * int64 ( ggml . NumLayers ( ) ) * int64 ( ggml . NumEmbed ( ) ) * int64 ( ggml . NumHeadKv ( ) ) / int64 ( ggml . NumHead ( ) )
2023-10-12 17:36:23 +00:00
2024-01-08 21:42:00 +00:00
// this amount is the overhead + tensors in memory
2024-01-09 17:45:42 +00:00
// TODO: get this from the llama.cpp's graph calculations instead of
2024-01-09 02:32:44 +00:00
// estimating it's 1/6 * kv_cache_size * num_gqa
2024-01-11 06:45:31 +00:00
graph := int64 ( ggml . NumGQA ( ) ) * kv / 6
2024-01-08 21:42:00 +00:00
info := gpu . GetGPUInfo ( )
2024-01-11 06:45:31 +00:00
switch runtime . GOOS {
case "darwin" :
if opts . NumGPU == 0 {
break
}
2024-01-08 21:42:00 +00:00
2024-01-11 06:45:31 +00:00
if size + kv + graph > vram {
2024-01-18 18:52:01 +00:00
slog . Info ( "not enough vram available, falling back to CPU only" )
2024-01-11 22:43:16 +00:00
info . Library = "cpu"
info . Variant = gpu . GetCPUVariant ( )
2024-01-11 06:45:31 +00:00
opts . NumGPU = 0
break
}
2024-01-23 01:40:06 +00:00
// TODO: implement layer splitting on macOS
opts . NumGPU = 999
2024-01-11 06:45:31 +00:00
default :
2024-01-11 22:43:16 +00:00
if info . Library == "cpu" {
2024-01-18 18:52:01 +00:00
slog . Info ( "GPU not available, falling back to CPU" )
2024-01-11 06:45:31 +00:00
opts . NumGPU = 0
break
}
// don't use GPU at all if no layers are loaded
if opts . NumGPU == 0 {
2024-01-11 22:43:16 +00:00
info . Library = "cpu"
info . Variant = gpu . GetCPUVariant ( )
2024-01-11 06:45:31 +00:00
break
}
// user-defined GPU count
if opts . NumGPU != - 1 {
break
}
// the "main" GPU needs the most memory and determines the limit
// of how many layers can be loaded. It needs to fit:
// 1. the full compute graph allocation for all devices (graph)
// 2. the proportional kv cache for all devices (kv * % layers)
// 3. the proportional model (size * % layers / # devices)
// This estimates the number of layers
maxlayers := int64 ( ggml . NumLayers ( ) ) + 1
devices := int64 ( info . DeviceCount )
avg := vram / devices
layers := maxlayers * ( avg - graph ) / ( kv + size / devices )
if layers > maxlayers {
layers = maxlayers
}
2024-01-08 21:42:00 +00:00
2024-01-11 06:45:31 +00:00
// 1 + 2 must fit on the main gpu
min := graph + kv * layers / maxlayers
if layers <= 0 || min > avg {
2024-01-18 18:52:01 +00:00
slog . Info ( "not enough vram available, falling back to CPU only" )
2024-01-11 22:43:16 +00:00
info . Library = "cpu"
info . Variant = gpu . GetCPUVariant ( )
2024-01-11 06:45:31 +00:00
opts . NumGPU = 0
break
2023-10-13 21:41:51 +00:00
}
2024-01-11 06:45:31 +00:00
opts . NumGPU = int ( layers )
2023-08-03 22:47:36 +00:00
}
2023-11-24 18:58:09 +00:00
opts . RopeFrequencyBase = 0.0
opts . RopeFrequencyScale = 0.0
2024-01-11 22:43:16 +00:00
return newLlmServer ( info , model , adapters , projectors , opts )
2023-11-29 19:00:37 +00:00
}
// Give any native cgo implementations an opportunity to initialize
func Init ( workdir string ) error {
return nativeInit ( workdir )
2023-07-21 20:33:56 +00:00
}
2023-12-20 18:36:01 +00:00
2024-01-10 04:29:58 +00:00
func newLlmServer ( gpuInfo gpu . GpuInfo , model string , adapters , projectors [ ] string , opts api . Options ) ( LLM , error ) {
dynLibs := getDynLibs ( gpuInfo )
2024-01-07 23:48:05 +00:00
// Check to see if the user has requested a specific library instead of auto-detecting
demandLib := os . Getenv ( "OLLAMA_LLM_LIBRARY" )
if demandLib != "" {
2024-01-10 04:29:58 +00:00
libPath := availableDynLibs [ demandLib ]
2024-01-07 23:48:05 +00:00
if libPath == "" {
2024-01-18 18:52:01 +00:00
slog . Info ( fmt . Sprintf ( "Invalid OLLAMA_LLM_LIBRARY %s - not found" , demandLib ) )
2024-01-07 23:48:05 +00:00
} else {
2024-01-18 18:52:01 +00:00
slog . Info ( fmt . Sprintf ( "Loading OLLAMA_LLM_LIBRARY=%s" , demandLib ) )
2024-01-10 04:29:58 +00:00
dynLibs = [ ] string { libPath }
2024-01-07 23:48:05 +00:00
}
}
2024-01-10 04:29:58 +00:00
err2 := fmt . Errorf ( "unable to locate suitable llm library" )
for _ , dynLib := range dynLibs {
srv , err := newDynExtServer ( dynLib , model , adapters , projectors , opts )
2023-12-20 18:36:01 +00:00
if err == nil {
return srv , nil
}
2024-01-18 18:52:01 +00:00
slog . Warn ( fmt . Sprintf ( "Failed to load dynamic library %s %s" , dynLib , err ) )
2024-01-10 04:29:58 +00:00
err2 = err
2023-12-20 18:36:01 +00:00
}
2024-01-10 04:29:58 +00:00
return nil , err2
2023-12-20 18:36:01 +00:00
}