2023-07-21 13:33:56 -07:00
package llm
import (
2023-08-30 16:35:03 -04:00
"context"
2024-01-11 14:43:16 -08:00
"fmt"
2024-01-18 10:52:01 -08:00
"log/slog"
2023-07-21 13:33:56 -07:00
"os"
2023-10-05 12:53:47 -04:00
"runtime"
2023-07-21 13:33:56 -07:00
"github.com/jmorganca/ollama/api"
2023-11-29 11:00:37 -08:00
"github.com/jmorganca/ollama/gpu"
2023-07-21 13:33:56 -07:00
)
type LLM interface {
2023-12-05 14:57:33 -05:00
Predict ( context . Context , PredictOpts , func ( PredictResult ) ) error
2023-08-30 16:35:03 -04:00
Embedding ( context . Context , string ) ( [ ] float64 , error )
Encode ( context . Context , string ) ( [ ] int , error )
Decode ( context . Context , [ ] int ) ( string , error )
2023-07-21 13:33:56 -07:00
Close ( )
}
2023-11-30 10:30:23 -08:00
func New ( workDir , model string , adapters , projectors [ ] string , opts api . Options ) ( LLM , error ) {
2023-07-21 13:33:56 -07:00
if _ , err := os . Stat ( model ) ; err != nil {
return nil , err
}
f , err := os . Open ( model )
if err != nil {
return nil , err
}
2023-08-14 16:08:02 -07:00
defer f . Close ( )
2023-07-21 13:33:56 -07:00
2023-09-07 13:55:37 -04:00
ggml , err := DecodeGGML ( f )
2023-07-21 13:33:56 -07:00
if err != nil {
return nil , err
}
2024-01-12 14:54:01 -08:00
if opts . NumCtx > int ( ggml . NumCtx ( ) ) {
2024-01-18 10:52:01 -08:00
slog . Warn ( fmt . Sprintf ( "requested context length is greater than model's max context length (%d > %d), using %d instead" , opts . NumCtx , ggml . NumCtx ( ) , ggml . NumCtx ( ) ) )
2024-01-12 14:54:01 -08:00
opts . NumCtx = int ( ggml . NumCtx ( ) )
}
2024-01-08 16:42:00 -05:00
if opts . NumCtx < 4 {
opts . NumCtx = 4
}
2024-01-11 01:45:31 -05:00
vram , _ := gpu . CheckVRAM ( )
size := ggml . Size
2024-01-08 16:42:00 -05:00
// fp16 k,v matrices require = n_ctx * n_layer * n_embd / n_head * n_head_kv * 2 bytes each * 2 key and value
2024-01-11 01:45:31 -05:00
kv := 2 * 2 * int64 ( opts . NumCtx ) * int64 ( ggml . NumLayers ( ) ) * int64 ( ggml . NumEmbed ( ) ) * int64 ( ggml . NumHeadKv ( ) ) / int64 ( ggml . NumHead ( ) )
2023-10-12 10:36:23 -07:00
2024-01-08 16:42:00 -05:00
// this amount is the overhead + tensors in memory
2024-01-09 09:45:42 -08:00
// TODO: get this from the llama.cpp's graph calculations instead of
2024-01-08 21:32:44 -05:00
// estimating it's 1/6 * kv_cache_size * num_gqa
2024-01-11 01:45:31 -05:00
graph := int64 ( ggml . NumGQA ( ) ) * kv / 6
2024-01-08 16:42:00 -05:00
info := gpu . GetGPUInfo ( )
2024-01-11 01:45:31 -05:00
switch runtime . GOOS {
case "darwin" :
if opts . NumGPU == 0 {
break
}
2024-01-08 16:42:00 -05:00
2024-01-11 01:45:31 -05:00
if size + kv + graph > vram {
2024-01-18 10:52:01 -08:00
slog . Info ( "not enough vram available, falling back to CPU only" )
2024-01-11 14:43:16 -08:00
info . Library = "cpu"
info . Variant = gpu . GetCPUVariant ( )
2024-01-11 01:45:31 -05:00
opts . NumGPU = 0
break
}
2024-01-22 17:40:06 -08:00
// TODO: implement layer splitting on macOS
opts . NumGPU = 999
2024-01-11 01:45:31 -05:00
default :
2024-01-11 14:43:16 -08:00
if info . Library == "cpu" {
2024-01-18 10:52:01 -08:00
slog . Info ( "GPU not available, falling back to CPU" )
2024-01-11 01:45:31 -05:00
opts . NumGPU = 0
break
}
// don't use GPU at all if no layers are loaded
if opts . NumGPU == 0 {
2024-01-11 14:43:16 -08:00
info . Library = "cpu"
info . Variant = gpu . GetCPUVariant ( )
2024-01-11 01:45:31 -05:00
break
}
// user-defined GPU count
if opts . NumGPU != - 1 {
break
}
// the "main" GPU needs the most memory and determines the limit
// of how many layers can be loaded. It needs to fit:
// 1. the full compute graph allocation for all devices (graph)
// 2. the proportional kv cache for all devices (kv * % layers)
// 3. the proportional model (size * % layers / # devices)
// This estimates the number of layers
maxlayers := int64 ( ggml . NumLayers ( ) ) + 1
devices := int64 ( info . DeviceCount )
avg := vram / devices
layers := maxlayers * ( avg - graph ) / ( kv + size / devices )
if layers > maxlayers {
layers = maxlayers
}
2024-01-08 16:42:00 -05:00
2024-01-11 01:45:31 -05:00
// 1 + 2 must fit on the main gpu
min := graph + kv * layers / maxlayers
if layers <= 0 || min > avg {
2024-01-18 10:52:01 -08:00
slog . Info ( "not enough vram available, falling back to CPU only" )
2024-01-11 14:43:16 -08:00
info . Library = "cpu"
info . Variant = gpu . GetCPUVariant ( )
2024-01-11 01:45:31 -05:00
opts . NumGPU = 0
break
2023-10-13 14:41:51 -07:00
}
2024-01-11 01:45:31 -05:00
opts . NumGPU = int ( layers )
2023-08-03 15:47:36 -07:00
}
2023-11-24 13:58:09 -05:00
opts . RopeFrequencyBase = 0.0
opts . RopeFrequencyScale = 0.0
2024-02-07 17:27:49 -08:00
return newLlmServer ( info , workDir , model , adapters , projectors , opts )
2023-11-29 11:00:37 -08:00
}
// Give any native cgo implementations an opportunity to initialize
func Init ( workdir string ) error {
return nativeInit ( workdir )
2023-07-21 13:33:56 -07:00
}
2023-12-20 10:36:01 -08:00
2024-02-07 17:27:49 -08:00
func newLlmServer ( gpuInfo gpu . GpuInfo , workDir , model string , adapters , projectors [ ] string , opts api . Options ) ( LLM , error ) {
2024-01-09 20:29:58 -08:00
dynLibs := getDynLibs ( gpuInfo )
2024-01-07 15:48:05 -08:00
// Check to see if the user has requested a specific library instead of auto-detecting
demandLib := os . Getenv ( "OLLAMA_LLM_LIBRARY" )
if demandLib != "" {
2024-01-09 20:29:58 -08:00
libPath := availableDynLibs [ demandLib ]
2024-01-07 15:48:05 -08:00
if libPath == "" {
2024-01-18 10:52:01 -08:00
slog . Info ( fmt . Sprintf ( "Invalid OLLAMA_LLM_LIBRARY %s - not found" , demandLib ) )
2024-01-07 15:48:05 -08:00
} else {
2024-01-18 10:52:01 -08:00
slog . Info ( fmt . Sprintf ( "Loading OLLAMA_LLM_LIBRARY=%s" , demandLib ) )
2024-01-09 20:29:58 -08:00
dynLibs = [ ] string { libPath }
2024-01-07 15:48:05 -08:00
}
}
2024-02-07 17:27:49 -08:00
// We stage into a temp directory, and if we've been idle for a while, it may have been reaped
_ , err := os . Stat ( dynLibs [ 0 ] )
if err != nil {
slog . Info ( fmt . Sprintf ( "%s has disappeared, reloading libraries" , dynLibs [ 0 ] ) )
err = nativeInit ( workDir )
if err != nil {
return nil , err
}
}
2024-01-09 20:29:58 -08:00
err2 := fmt . Errorf ( "unable to locate suitable llm library" )
for _ , dynLib := range dynLibs {
srv , err := newDynExtServer ( dynLib , model , adapters , projectors , opts )
2023-12-20 10:36:01 -08:00
if err == nil {
return srv , nil
}
2024-01-18 10:52:01 -08:00
slog . Warn ( fmt . Sprintf ( "Failed to load dynamic library %s %s" , dynLib , err ) )
2024-01-09 20:29:58 -08:00
err2 = err
2023-12-20 10:36:01 -08:00
}
2024-01-09 20:29:58 -08:00
return nil , err2
2023-12-20 10:36:01 -08:00
}