2023-11-13 17:20:34 -08:00
package llm
/ *
2024-01-04 09:40:15 -08:00
# cgo CFLAGS : - I $ { SRCDIR } / ext_server - I $ { SRCDIR } / llama . cpp - I $ { SRCDIR } / llama . cpp / common - I $ { SRCDIR } / llama . cpp / examples / server
2023-11-13 17:20:34 -08:00
# cgo CFLAGS : - DNDEBUG - DLLAMA_SERVER_LIBRARY = 1 - D_XOPEN_SOURCE = 600 - DACCELERATE_NEW_LAPACK - DACCELERATE_LAPACK_ILP64
2024-01-07 10:39:19 -05:00
# cgo CFLAGS : - Wmissing - noreturn - Wextra - Wcast - qual - Wno - unused - function - Wno - array - bounds
2024-01-28 17:51:23 -08:00
# cgo CPPFLAGS : - Ofast - Wextra - Wno - unused - function - Wno - unused - variable - Wno - deprecated - declarations
2023-11-13 17:20:34 -08:00
# cgo darwin CFLAGS : - D_DARWIN_C_SOURCE
# cgo darwin CPPFLAGS : - DGGML_USE_ACCELERATE
2023-12-19 13:32:24 -08:00
# cgo darwin CPPFLAGS : - DGGML_USE_METAL - DGGML_METAL_NDEBUG
2023-11-13 17:20:34 -08:00
# cgo darwin LDFLAGS : - lc ++ - framework Accelerate
2023-12-19 13:32:24 -08:00
# cgo darwin LDFLAGS : - framework Foundation - framework Metal - framework MetalKit - framework MetalPerformanceShaders
2023-11-13 17:20:34 -08:00
# cgo linux CFLAGS : - D_GNU_SOURCE
2023-12-23 11:35:44 -08:00
# cgo linux LDFLAGS : - lrt - ldl - lstdc ++ - lm
# cgo linux windows LDFLAGS : - lpthread
2023-11-13 17:20:34 -08:00
# include < stdlib . h >
2024-01-09 20:29:58 -08:00
# include "dyn_ext_server.h"
2023-11-13 17:20:34 -08:00
* /
import "C"
2024-01-09 20:29:58 -08:00
2023-11-13 17:20:34 -08:00
import (
"bytes"
"context"
"encoding/json"
"fmt"
2024-01-18 10:52:01 -08:00
"log/slog"
2024-01-09 20:29:58 -08:00
"os"
"path/filepath"
2023-11-29 11:00:37 -08:00
"strings"
2023-11-13 17:20:34 -08:00
"sync"
"time"
"unsafe"
"github.com/jmorganca/ollama/api"
2024-02-15 17:15:09 -08:00
"github.com/jmorganca/ollama/gpu"
2023-11-13 17:20:34 -08:00
)
2024-01-09 20:29:58 -08:00
type dynExtServer struct {
s C . struct_dynamic_llama_server
options api . Options
2023-11-13 17:20:34 -08:00
}
// Note: current implementation does not support concurrent instantiations
var mutex sync . Mutex
2023-12-23 11:35:44 -08:00
func newExtServerResp ( len C . size_t ) C . ext_server_resp_t {
var resp C . ext_server_resp_t
resp . msg_len = len
bytes := make ( [ ] byte , len )
resp . msg = ( * C . char ) ( C . CBytes ( bytes ) )
return resp
2023-11-29 11:00:37 -08:00
}
2023-12-23 11:35:44 -08:00
func freeExtServerResp ( resp C . ext_server_resp_t ) {
if resp . msg_len == 0 {
return
}
C . free ( unsafe . Pointer ( resp . msg ) )
2023-11-29 11:00:37 -08:00
}
2023-12-23 11:35:44 -08:00
func extServerResponseToErr ( resp C . ext_server_resp_t ) error {
return fmt . Errorf ( C . GoString ( resp . msg ) )
2023-11-29 11:00:37 -08:00
}
2024-01-09 20:29:58 -08:00
// Note: current implementation does not support concurrent instantiations
var llm * dynExtServer
func newDynExtServer ( library , model string , adapters , projectors [ ] string , opts api . Options ) ( LLM , error ) {
2023-11-13 17:20:34 -08:00
if ! mutex . TryLock ( ) {
2024-01-18 10:52:01 -08:00
slog . Info ( "concurrent llm servers not yet supported, waiting for prior server to complete" )
2023-11-13 17:20:34 -08:00
mutex . Lock ( )
}
2024-02-15 17:15:09 -08:00
gpu . UpdatePath ( filepath . Dir ( library ) )
2024-01-09 20:29:58 -08:00
libPath := C . CString ( library )
defer C . free ( unsafe . Pointer ( libPath ) )
2024-01-13 14:46:34 -08:00
resp := newExtServerResp ( 512 )
2024-01-09 20:29:58 -08:00
defer freeExtServerResp ( resp )
var srv C . struct_dynamic_llama_server
C . dyn_init ( libPath , & srv , & resp )
if resp . id < 0 {
mutex . Unlock ( )
return nil , fmt . Errorf ( "Unable to load dynamic library: %s" , C . GoString ( resp . msg ) )
}
llm = & dynExtServer {
s : srv ,
options : opts ,
}
2024-01-18 10:52:01 -08:00
slog . Info ( fmt . Sprintf ( "Loading Dynamic llm server: %s" , library ) )
2024-01-08 16:42:00 -05:00
2023-11-29 11:00:37 -08:00
var sparams C . ext_server_params_t
2023-11-13 17:20:34 -08:00
sparams . model = C . CString ( model )
defer C . free ( unsafe . Pointer ( sparams . model ) )
sparams . embedding = true
sparams . n_ctx = C . uint ( opts . NumCtx )
sparams . n_batch = C . uint ( opts . NumBatch )
2024-01-08 16:42:00 -05:00
sparams . n_gpu_layers = C . int ( opts . NumGPU )
2023-11-13 17:20:34 -08:00
sparams . main_gpu = C . int ( opts . MainGPU )
2023-12-14 10:25:12 -08:00
sparams . n_parallel = 1 // TODO - wire up concurrency
2023-11-13 17:20:34 -08:00
// Always use the value encoded in the model
sparams . rope_freq_base = 0.0
sparams . rope_freq_scale = 0.0
2023-11-29 11:00:37 -08:00
sparams . memory_f16 = C . bool ( opts . F16KV )
sparams . use_mlock = C . bool ( opts . UseMLock )
sparams . use_mmap = C . bool ( opts . UseMMap )
2024-02-20 17:42:31 -05:00
if opts . UseNUMA {
sparams . numa = C . int ( 1 )
} else {
sparams . numa = C . int ( 0 )
}
2023-11-13 17:20:34 -08:00
sparams . lora_adapters = nil
for i := 0 ; i < len ( adapters ) ; i ++ {
2023-11-29 11:00:37 -08:00
la := ( * C . ext_server_lora_adapter_t ) ( C . malloc ( C . sizeof_ext_server_lora_adapter_t ) )
2023-11-13 17:20:34 -08:00
defer C . free ( unsafe . Pointer ( la ) )
la . adapter = C . CString ( adapters [ i ] )
defer C . free ( unsafe . Pointer ( la . adapter ) )
la . scale = C . float ( 1.0 ) // TODO expose scale/weights up through ollama UX
la . next = nil
if i == 0 {
sparams . lora_adapters = la
} else {
tmp := sparams . lora_adapters
for ; tmp . next != nil ; tmp = tmp . next {
}
tmp . next = la
}
}
2023-11-29 11:00:37 -08:00
if len ( projectors ) > 0 {
// TODO: applying multiple projectors is not supported by the llama.cpp server yet
sparams . mmproj = C . CString ( projectors [ 0 ] )
defer C . free ( unsafe . Pointer ( sparams . mmproj ) )
} else {
sparams . mmproj = nil
}
2023-11-13 17:20:34 -08:00
2023-12-21 16:23:36 -08:00
sparams . n_threads = C . uint ( opts . NumThread )
2023-11-13 17:20:34 -08:00
2024-01-22 12:26:49 -08:00
if debug := os . Getenv ( "OLLAMA_DEBUG" ) ; debug != "" {
sparams . verbose_logging = C . bool ( true )
} else {
sparams . verbose_logging = C . bool ( false )
}
2024-01-18 10:52:01 -08:00
slog . Info ( "Initializing llama server" )
2024-02-15 17:15:09 -08:00
slog . Debug ( fmt . Sprintf ( "server params: %+v" , sparams ) )
2024-03-11 16:48:27 -04:00
initResp := newExtServerResp ( 512 )
2024-01-09 20:29:58 -08:00
defer freeExtServerResp ( initResp )
C . dyn_llama_server_init ( llm . s , & sparams , & initResp )
if initResp . id < 0 {
2024-01-20 20:54:46 -05:00
mutex . Unlock ( )
2024-01-22 12:08:22 -08:00
err := extServerResponseToErr ( initResp )
slog . Debug ( fmt . Sprintf ( "failure during initialization: %s" , err ) )
return nil , err
2023-11-13 17:20:34 -08:00
}
2024-01-18 10:52:01 -08:00
slog . Info ( "Starting llama main loop" )
2024-01-09 20:29:58 -08:00
C . dyn_llama_server_start ( llm . s )
return llm , nil
2023-11-13 17:20:34 -08:00
}
2024-01-09 20:29:58 -08:00
func ( llm * dynExtServer ) Predict ( ctx context . Context , predict PredictOpts , fn func ( PredictResult ) ) error {
2023-11-29 11:00:37 -08:00
resp := newExtServerResp ( 128 )
defer freeExtServerResp ( resp )
2024-01-31 18:56:12 -08:00
2023-11-29 11:00:37 -08:00
if len ( predict . Images ) > 0 {
2024-01-31 18:56:12 -08:00
slog . Info ( fmt . Sprintf ( "loaded %d images" , len ( predict . Images ) ) )
2023-11-29 11:00:37 -08:00
}
2023-11-13 17:20:34 -08:00
request := map [ string ] any {
"prompt" : predict . Prompt ,
"stream" : true ,
2024-01-03 12:01:42 -05:00
"n_predict" : predict . Options . NumPredict ,
"n_keep" : predict . Options . NumKeep ,
"temperature" : predict . Options . Temperature ,
"top_k" : predict . Options . TopK ,
"top_p" : predict . Options . TopP ,
"tfs_z" : predict . Options . TFSZ ,
"typical_p" : predict . Options . TypicalP ,
"repeat_last_n" : predict . Options . RepeatLastN ,
"repeat_penalty" : predict . Options . RepeatPenalty ,
"presence_penalty" : predict . Options . PresencePenalty ,
"frequency_penalty" : predict . Options . FrequencyPenalty ,
"mirostat" : predict . Options . Mirostat ,
"mirostat_tau" : predict . Options . MirostatTau ,
"mirostat_eta" : predict . Options . MirostatEta ,
"penalize_nl" : predict . Options . PenalizeNewline ,
"seed" : predict . Options . Seed ,
"stop" : predict . Options . Stop ,
2024-01-31 18:56:12 -08:00
"image_data" : predict . Images ,
2024-01-25 13:46:20 -08:00
"cache_prompt" : true ,
2023-11-13 17:20:34 -08:00
}
if predict . Format == "json" {
request [ "grammar" ] = jsonGrammar
2024-03-12 19:07:11 -04:00
if ! strings . Contains ( strings . ToLower ( predict . Prompt ) , "json" ) {
slog . Warn ( "Prompt does not specify that the LLM should response in JSON, but JSON format is expected. For best results specify that JSON is expected in the system prompt." )
}
2023-11-13 17:20:34 -08:00
}
2023-11-29 11:00:37 -08:00
retryDelay := 100 * time . Microsecond
for retries := 0 ; retries < maxRetries ; retries ++ {
if retries > 0 {
time . Sleep ( retryDelay ) // wait before retrying
retryDelay *= 2 // exponential backoff
}
2023-11-13 17:20:34 -08:00
2023-11-29 11:00:37 -08:00
// Handling JSON marshaling with special characters unescaped.
buffer := & bytes . Buffer { }
enc := json . NewEncoder ( buffer )
enc . SetEscapeHTML ( false )
2023-11-13 17:20:34 -08:00
2023-11-29 11:00:37 -08:00
if err := enc . Encode ( request ) ; err != nil {
return fmt . Errorf ( "failed to marshal data: %w" , err )
}
2023-11-13 17:20:34 -08:00
2023-11-29 11:00:37 -08:00
req := C . CString ( buffer . String ( ) )
defer C . free ( unsafe . Pointer ( req ) )
2023-11-13 17:20:34 -08:00
2024-01-09 20:29:58 -08:00
C . dyn_llama_server_completion ( llm . s , req , & resp )
2023-11-29 11:00:37 -08:00
if resp . id < 0 {
return extServerResponseToErr ( resp )
}
2023-11-13 17:20:34 -08:00
2023-11-29 11:00:37 -08:00
retryNeeded := false
2024-03-12 22:08:25 -04:00
// keep track of the last token generated, this is used to abort if the model starts looping
var lastToken string
var tokenRepeat int
2023-11-29 11:00:37 -08:00
out :
for {
select {
case <- ctx . Done ( ) :
2024-03-12 22:08:25 -04:00
return cancelCompletion ( llm , resp )
2023-11-29 11:00:37 -08:00
default :
var result C . ext_server_task_result_t
2024-01-09 20:29:58 -08:00
C . dyn_llama_server_completion_next_result ( llm . s , resp . id , & result )
2023-11-29 11:00:37 -08:00
json_resp := C . GoString ( result . json_resp )
2024-01-09 20:29:58 -08:00
C . dyn_llama_server_release_task_result ( llm . s , & result )
2023-11-29 11:00:37 -08:00
var p prediction
if err := json . Unmarshal ( [ ] byte ( json_resp ) , & p ) ; err != nil {
2024-01-09 20:29:58 -08:00
C . dyn_llama_server_completion_cancel ( llm . s , resp . id , & resp )
2023-11-29 11:00:37 -08:00
if resp . id < 0 {
return fmt . Errorf ( "error unmarshaling llm prediction response: %w and cancel %s" , err , C . GoString ( resp . msg ) )
} else {
return fmt . Errorf ( "error unmarshaling llm prediction response: %w" , err )
}
}
if bool ( result . error ) && strings . Contains ( json_resp , "slot unavailable" ) {
retryNeeded = true
// task will already be canceled
break out
}
2024-03-12 22:08:25 -04:00
switch {
case strings . TrimSpace ( p . Content ) == lastToken :
tokenRepeat ++
default :
lastToken = strings . TrimSpace ( p . Content )
tokenRepeat = 0
}
// 30 picked as an arbitrary max token repeat limit, modify as needed
if tokenRepeat > 30 {
slog . Debug ( "prediction aborted, token repeat limit reached" )
return cancelCompletion ( llm , resp )
}
2023-11-29 11:00:37 -08:00
if p . Content != "" {
fn ( PredictResult {
Content : p . Content ,
} )
}
2024-02-08 22:22:50 -08:00
if p . Stop || bool ( result . stop ) {
2023-11-29 11:00:37 -08:00
fn ( PredictResult {
Done : true ,
PromptEvalCount : p . Timings . PromptN ,
PromptEvalDuration : parseDurationMs ( p . Timings . PromptMS ) ,
EvalCount : p . Timings . PredictedN ,
EvalDuration : parseDurationMs ( p . Timings . PredictedMS ) ,
} )
return nil
}
2023-11-13 17:20:34 -08:00
}
}
2023-11-29 11:00:37 -08:00
if ! retryNeeded {
return nil // success
}
2023-11-13 17:20:34 -08:00
}
2023-11-29 11:00:37 -08:00
// should never reach here ideally
return fmt . Errorf ( "max retries exceeded" )
}
2024-03-12 22:08:25 -04:00
func cancelCompletion ( llm * dynExtServer , resp C . ext_server_resp_t ) error {
C . dyn_llama_server_completion_cancel ( llm . s , resp . id , & resp )
if resp . id < 0 {
return extServerResponseToErr ( resp )
} else {
return nil
}
}
2024-01-09 20:29:58 -08:00
func ( llm * dynExtServer ) Encode ( ctx context . Context , prompt string ) ( [ ] int , error ) {
2023-11-13 17:20:34 -08:00
data , err := json . Marshal ( TokenizeRequest { Content : prompt } )
if err != nil {
return nil , fmt . Errorf ( "marshaling encode data: %w" , err )
}
req := C . CString ( string ( data ) )
defer C . free ( unsafe . Pointer ( req ) )
2023-11-29 11:00:37 -08:00
var json_resp * C . char
resp := newExtServerResp ( 128 )
defer freeExtServerResp ( resp )
2024-01-09 20:29:58 -08:00
C . dyn_llama_server_tokenize ( llm . s , req , & json_resp , & resp )
2023-11-29 11:00:37 -08:00
if resp . id < 0 {
return nil , extServerResponseToErr ( resp )
2023-11-13 17:20:34 -08:00
}
2024-01-09 20:29:58 -08:00
defer C . dyn_llama_server_release_json_resp ( llm . s , & json_resp )
2023-11-13 17:20:34 -08:00
var encoded TokenizeResponse
2023-11-29 11:00:37 -08:00
if err2 := json . Unmarshal ( [ ] byte ( C . GoString ( json_resp ) ) , & encoded ) ; err2 != nil {
2023-11-13 17:20:34 -08:00
return nil , fmt . Errorf ( "unmarshal encode response: %w" , err2 )
}
return encoded . Tokens , err
}
2024-01-09 20:29:58 -08:00
func ( llm * dynExtServer ) Decode ( ctx context . Context , tokens [ ] int ) ( string , error ) {
2023-11-13 17:20:34 -08:00
if len ( tokens ) == 0 {
return "" , nil
}
data , err := json . Marshal ( DetokenizeRequest { Tokens : tokens } )
if err != nil {
return "" , fmt . Errorf ( "marshaling decode data: %w" , err )
}
req := C . CString ( string ( data ) )
defer C . free ( unsafe . Pointer ( req ) )
2023-11-29 11:00:37 -08:00
var json_resp * C . char
resp := newExtServerResp ( 128 )
defer freeExtServerResp ( resp )
2024-01-09 20:29:58 -08:00
C . dyn_llama_server_detokenize ( llm . s , req , & json_resp , & resp )
2023-11-29 11:00:37 -08:00
if resp . id < 0 {
return "" , extServerResponseToErr ( resp )
2023-11-13 17:20:34 -08:00
}
2024-01-09 20:29:58 -08:00
defer C . dyn_llama_server_release_json_resp ( llm . s , & json_resp )
2023-11-13 17:20:34 -08:00
var decoded DetokenizeResponse
2023-11-29 11:00:37 -08:00
if err2 := json . Unmarshal ( [ ] byte ( C . GoString ( json_resp ) ) , & decoded ) ; err2 != nil {
2023-11-13 17:20:34 -08:00
return "" , fmt . Errorf ( "unmarshal encode response: %w" , err2 )
}
return decoded . Content , err
}
2024-01-09 20:29:58 -08:00
func ( llm * dynExtServer ) Embedding ( ctx context . Context , input string ) ( [ ] float64 , error ) {
2023-11-13 17:20:34 -08:00
data , err := json . Marshal ( TokenizeRequest { Content : input } )
if err != nil {
return nil , fmt . Errorf ( "error marshaling embed data: %w" , err )
}
req := C . CString ( string ( data ) )
defer C . free ( unsafe . Pointer ( req ) )
2023-11-29 11:00:37 -08:00
var json_resp * C . char
resp := newExtServerResp ( 128 )
defer freeExtServerResp ( resp )
2024-01-09 20:29:58 -08:00
C . dyn_llama_server_embedding ( llm . s , req , & json_resp , & resp )
2023-11-29 11:00:37 -08:00
if resp . id < 0 {
return nil , extServerResponseToErr ( resp )
2023-11-13 17:20:34 -08:00
}
2024-01-09 20:29:58 -08:00
defer C . dyn_llama_server_release_json_resp ( llm . s , & json_resp )
2023-11-13 17:20:34 -08:00
var embedding EmbeddingResponse
2023-11-29 11:00:37 -08:00
if err := json . Unmarshal ( [ ] byte ( C . GoString ( json_resp ) ) , & embedding ) ; err != nil {
2023-11-13 17:20:34 -08:00
return nil , fmt . Errorf ( "unmarshal tokenize response: %w" , err )
}
return embedding . Embedding , nil
}
2024-01-09 20:29:58 -08:00
func ( llm * dynExtServer ) Close ( ) {
C . dyn_llama_server_stop ( llm . s )
2023-11-13 17:20:34 -08:00
mutex . Unlock ( )
}