2024-03-14 10:24:13 -07:00
package llm
import (
"bufio"
"bytes"
"context"
"encoding/json"
"errors"
"fmt"
"io"
"log"
"log/slog"
"math/rand"
"net"
"net/http"
"os"
"os/exec"
"path/filepath"
"runtime"
"strconv"
"strings"
"time"
2024-03-30 09:50:05 -07:00
"golang.org/x/sync/semaphore"
2024-03-14 10:24:13 -07:00
"github.com/ollama/ollama/api"
"github.com/ollama/ollama/format"
"github.com/ollama/ollama/gpu"
2024-05-04 11:46:01 -07:00
"github.com/ollama/ollama/server/envconfig"
2024-03-14 10:24:13 -07:00
)
2024-03-30 09:50:05 -07:00
type LlamaServer interface {
Ping ( ctx context . Context ) error
WaitUntilRunning ( ctx context . Context ) error
Completion ( ctx context . Context , req CompletionRequest , fn func ( CompletionResponse ) ) error
Embedding ( ctx context . Context , prompt string ) ( [ ] float64 , error )
Tokenize ( ctx context . Context , content string ) ( [ ] int , error )
Detokenize ( ctx context . Context , tokens [ ] int ) ( string , error )
Close ( ) error
EstimatedVRAM ( ) uint64
}
// llmServer is an instance of the llama.cpp server
type llmServer struct {
2024-03-14 10:24:13 -07:00
port int
cmd * exec . Cmd
done chan error // Channel to signal when the process exits
status * StatusWriter
2024-04-02 16:44:10 -07:00
options api . Options
2024-03-30 09:50:05 -07:00
// TODO - this should be broken down by GPU
2024-05-04 09:15:31 -07:00
estimatedVRAM uint64 // Estimated usage of VRAM by the loaded model
estimatedTotal uint64 // Total size of model
totalLayers uint64
gpuCount int
2024-05-09 11:10:28 -07:00
loadDuration time . Duration // Record how long it took the model to load
2024-03-30 09:50:05 -07:00
sem * semaphore . Weighted
2024-03-14 10:24:13 -07:00
}
2024-03-30 09:50:05 -07:00
func LoadModel ( model string ) ( * GGML , error ) {
if _ , err := os . Stat ( model ) ; err != nil {
return nil , err
}
2024-03-14 10:24:13 -07:00
f , err := os . Open ( model )
if err != nil {
return nil , err
}
defer f . Close ( )
ggml , _ , err := DecodeGGML ( f )
2024-03-30 09:50:05 -07:00
return ggml , err
}
2024-03-14 10:24:13 -07:00
2024-03-30 09:50:05 -07:00
// NewLlamaServer will run a server for the given GPUs
// The gpu list must be a single family.
func NewLlamaServer ( gpus gpu . GpuInfoList , model string , ggml * GGML , adapters , projectors [ ] string , opts api . Options ) ( LlamaServer , error ) {
var err error
2024-03-14 10:24:13 -07:00
if opts . NumCtx > int ( ggml . KV ( ) . ContextLength ( ) ) {
2024-04-29 10:07:30 -04:00
slog . Warn ( "requested context length is greater than the model's training context window size" , "requested" , opts . NumCtx , "training size" , ggml . KV ( ) . ContextLength ( ) )
2024-03-14 10:24:13 -07:00
}
if opts . NumCtx < 4 {
opts . NumCtx = 4
}
2024-03-30 09:50:05 -07:00
cpuRunner := ""
var estimatedVRAM uint64
2024-05-04 09:15:31 -07:00
var estimatedTotal uint64
2024-03-30 09:50:05 -07:00
var systemMemory uint64
2024-05-04 09:15:31 -07:00
gpuCount := len ( gpus )
2024-03-30 09:50:05 -07:00
if ( len ( gpus ) == 1 && gpus [ 0 ] . Library == "cpu" ) || opts . NumGPU == 0 {
2024-03-14 10:24:13 -07:00
2024-03-30 09:50:05 -07:00
// TODO evaluate system memory to see if we should block the load, or force an unload of another CPU runner
2024-03-14 10:24:13 -07:00
2024-03-30 09:50:05 -07:00
cpuRunner = serverForCpu ( )
2024-05-04 09:15:31 -07:00
gpuCount = 0
2024-03-30 09:50:05 -07:00
} else {
if gpus [ 0 ] . Library == "metal" {
memInfo , err := gpu . GetCPUMem ( )
if err != nil {
slog . Error ( "failed to lookup system memory" , "error" , err )
} else {
systemMemory = memInfo . TotalMemory
slog . Debug ( "system memory" , "total" , format . HumanBytes2 ( systemMemory ) )
}
2024-03-14 10:24:13 -07:00
}
2024-03-30 09:50:05 -07:00
var layers int
2024-05-04 09:15:31 -07:00
layers , estimatedVRAM , estimatedTotal = EstimateGPULayers ( gpus , ggml , projectors , opts )
2024-03-30 09:50:05 -07:00
if gpus [ 0 ] . Library == "metal" && estimatedVRAM > systemMemory {
// disable partial offloading when model is greater than total system memory as this
// can lead to locking up the system
opts . NumGPU = 0
} else if opts . NumGPU < 0 && layers > 0 && gpus [ 0 ] . Library != "cpu" {
opts . NumGPU = layers
2024-04-17 10:29:12 -07:00
}
}
2024-03-30 09:50:05 -07:00
// Loop through potential servers
finalErr := fmt . Errorf ( "no suitable llama servers found" )
2024-03-14 10:24:13 -07:00
if len ( adapters ) > 1 {
return nil , errors . New ( "ollama supports only one lora adapter, but multiple were provided" )
}
availableServers := availableServers ( )
2024-03-30 09:50:05 -07:00
var servers [ ] string
if cpuRunner != "" {
servers = [ ] string { cpuRunner }
} else {
servers = serversForGpu ( gpus [ 0 ] ) // All GPUs in the list are matching Library and Variant
}
2024-05-04 11:46:01 -07:00
demandLib := envconfig . LLMLibrary
2024-03-14 10:24:13 -07:00
if demandLib != "" {
serverPath := availableServers [ demandLib ]
if serverPath == "" {
slog . Info ( fmt . Sprintf ( "Invalid OLLAMA_LLM_LIBRARY %s - not found" , demandLib ) )
} else {
slog . Info ( "user override" , "OLLAMA_LLM_LIBRARY" , demandLib , "path" , serverPath )
servers = [ ] string { demandLib }
2024-05-04 09:15:31 -07:00
if strings . HasPrefix ( demandLib , "cpu" ) {
// Omit the GPU flag to silence the warning
opts . NumGPU = - 1
}
2024-03-14 10:24:13 -07:00
}
}
if len ( servers ) == 0 {
2024-03-30 09:50:05 -07:00
return nil , fmt . Errorf ( "no servers found for %v" , gpus )
2024-03-14 10:24:13 -07:00
}
params := [ ] string {
"--model" , model ,
"--ctx-size" , fmt . Sprintf ( "%d" , opts . NumCtx ) ,
"--batch-size" , fmt . Sprintf ( "%d" , opts . NumBatch ) ,
"--embedding" ,
}
2024-05-04 11:46:01 -07:00
if envconfig . Debug {
2024-03-14 10:24:13 -07:00
params = append ( params , "--log-format" , "json" )
} else {
params = append ( params , "--log-disable" )
}
2024-04-02 16:06:45 -07:00
if opts . NumGPU >= 0 {
2024-03-14 10:24:13 -07:00
params = append ( params , "--n-gpu-layers" , fmt . Sprintf ( "%d" , opts . NumGPU ) )
}
2024-05-04 11:46:01 -07:00
if envconfig . Debug {
2024-03-14 10:24:13 -07:00
params = append ( params , "--verbose" )
}
if opts . MainGPU > 0 {
params = append ( params , "--main-gpu" , fmt . Sprintf ( "%d" , opts . MainGPU ) )
}
if len ( adapters ) > 0 {
// TODO: applying multiple adapters is not supported by the llama.cpp server yet
params = append ( params , "--lora" , adapters [ 0 ] )
}
if len ( projectors ) > 0 {
// TODO: applying multiple projectors is not supported by the llama.cpp server yet
params = append ( params , "--mmproj" , projectors [ 0 ] )
}
if opts . NumThread > 0 {
params = append ( params , "--threads" , fmt . Sprintf ( "%d" , opts . NumThread ) )
}
if ! opts . F16KV {
params = append ( params , "--memory-f32" )
}
if opts . UseMLock {
params = append ( params , "--mlock" )
}
if ! opts . UseMMap {
params = append ( params , "--no-mmap" )
}
if opts . UseNUMA {
params = append ( params , "--numa" )
}
2024-05-04 11:46:01 -07:00
numParallel := envconfig . NumParallel
2024-05-05 20:50:31 -07:00
// TODO (jmorganca): multimodal models don't support parallel yet
// see https://github.com/ollama/ollama/issues/4165
if len ( projectors ) > 0 {
numParallel = 1
slog . Warn ( "multimodal models don't support parallel requests yet" )
}
2024-03-30 09:50:05 -07:00
params = append ( params , "--parallel" , fmt . Sprintf ( "%d" , numParallel ) )
2024-03-14 10:24:13 -07:00
for i := 0 ; i < len ( servers ) ; i ++ {
dir := availableServers [ servers [ i ] ]
2024-04-22 16:22:05 -07:00
if dir == "" {
// Shouldn't happen
finalErr = fmt . Errorf ( "[%d] server %s not listed in available servers %v" , i , servers [ i ] , availableServers )
slog . Error ( "sever list inconsistent" , "error" , finalErr )
continue
}
2024-03-14 10:24:13 -07:00
2024-05-04 09:15:31 -07:00
if strings . HasPrefix ( servers [ i ] , "cpu" ) {
// TODO if we tried a gpu runner first, and it failed, record the error and bubble that back up
gpuCount = 0
}
2024-03-14 10:24:13 -07:00
// Find an availableServers port, retry on each iterration in case the failure was a port conflict race
port := 0
if a , err := net . ResolveTCPAddr ( "tcp" , "localhost:0" ) ; err == nil {
var l * net . TCPListener
if l , err = net . ListenTCP ( "tcp" , a ) ; err == nil {
port = l . Addr ( ) . ( * net . TCPAddr ) . Port
l . Close ( )
}
}
if port == 0 {
slog . Debug ( "ResolveTCPAddr failed " , "error" , err )
port = rand . Intn ( 65535 - 49152 ) + 49152 // get a random port in the ephemeral range
}
finalParams := append ( params , "--port" , strconv . Itoa ( port ) )
pathEnv := "LD_LIBRARY_PATH"
if runtime . GOOS == "windows" {
pathEnv = "PATH"
}
2024-05-05 17:45:43 -07:00
// prepend the server directory to LD_LIBRARY_PATH/PATH
2024-03-14 10:24:13 -07:00
libraryPaths := [ ] string { dir }
2024-03-30 09:50:05 -07:00
2024-03-14 10:24:13 -07:00
if libraryPath , ok := os . LookupEnv ( pathEnv ) ; ok {
// Append our runner directory to the path
// This will favor system libraries over our bundled library dependencies
2024-05-05 17:45:43 -07:00
libraryPaths = append ( libraryPaths , filepath . SplitList ( libraryPath ) ... )
2024-03-14 10:24:13 -07:00
}
2024-03-30 09:50:05 -07:00
// Note: we always put the dependency path first
// since this was the exact version we verified for AMD GPUs
// and we favor what the user had in their path
if gpus [ 0 ] . DependencyPath != "" {
// TODO refine for multi-gpu support
libraryPaths = append ( [ ] string { gpus [ 0 ] . DependencyPath } , libraryPaths ... )
}
2024-03-14 10:24:13 -07:00
server := filepath . Join ( dir , "ollama_llama_server" )
if runtime . GOOS == "windows" {
server = server + ".exe"
}
2024-04-23 10:05:26 -07:00
// Detect tmp cleaners wiping out the file
_ , err := os . Stat ( server )
if errors . Is ( err , os . ErrNotExist ) {
slog . Warn ( "llama server disappeared, reinitializing payloads" , "path" , server , "error" , err )
err = Init ( )
if err != nil {
slog . Warn ( "failed to reinitialize payloads" , "error" , err )
return nil , err
}
}
2024-03-30 09:50:05 -07:00
s := & llmServer {
2024-05-04 09:15:31 -07:00
port : port ,
cmd : exec . Command ( server , finalParams ... ) ,
status : NewStatusWriter ( os . Stderr ) ,
options : opts ,
estimatedVRAM : estimatedVRAM ,
estimatedTotal : estimatedTotal ,
sem : semaphore . NewWeighted ( int64 ( numParallel ) ) ,
totalLayers : ggml . KV ( ) . BlockCount ( ) + 1 ,
gpuCount : gpuCount ,
2024-05-09 11:10:28 -07:00
done : make ( chan error , 1 ) ,
2024-03-14 10:24:13 -07:00
}
2024-03-30 09:50:05 -07:00
2024-05-05 17:45:43 -07:00
s . cmd . Env = os . Environ ( )
2024-03-14 10:24:13 -07:00
s . cmd . Stdout = os . Stdout
s . cmd . Stderr = s . status
2024-05-05 17:45:43 -07:00
visibleDevicesEnv , visibleDevicesEnvVal := gpu . GpuInfoList ( gpus ) . GetVisibleDevicesEnv ( )
pathEnvVal := strings . Join ( libraryPaths , string ( filepath . ListSeparator ) )
// Update or add the path and visible devices variable with our adjusted version
pathNeeded := true
devicesNeeded := visibleDevicesEnv != ""
for i := range s . cmd . Env {
cmp := strings . SplitN ( s . cmd . Env [ i ] , "=" , 2 )
if strings . EqualFold ( cmp [ 0 ] , pathEnv ) {
s . cmd . Env [ i ] = pathEnv + "=" + pathEnvVal
pathNeeded = false
} else if devicesNeeded && strings . EqualFold ( cmp [ 0 ] , visibleDevicesEnv ) {
s . cmd . Env [ i ] = visibleDevicesEnv + "=" + visibleDevicesEnvVal
devicesNeeded = false
}
}
if pathNeeded {
s . cmd . Env = append ( s . cmd . Env , pathEnv + "=" + pathEnvVal )
}
if devicesNeeded {
s . cmd . Env = append ( s . cmd . Env , visibleDevicesEnv + "=" + visibleDevicesEnvVal )
2024-03-30 09:50:05 -07:00
}
2024-03-14 10:24:13 -07:00
slog . Info ( "starting llama server" , "cmd" , s . cmd . String ( ) )
2024-03-30 09:50:05 -07:00
// Log at debug as the environment is inherited and might contain sensitive information
slog . Debug ( "subprocess" , "environment" , s . cmd . Env )
2024-03-14 10:24:13 -07:00
if err = s . cmd . Start ( ) ; err != nil {
2024-05-07 16:46:15 -07:00
// Detect permission denied and augment them essage about noexec
if errors . Is ( err , os . ErrPermission ) {
finalErr = fmt . Errorf ( "unable to start server %w. %s may have noexec set. Set OLLAMA_TMPDIR for server to a writable executable directory" , err , dir )
continue
}
2024-03-14 10:24:13 -07:00
msg := ""
if s . status != nil && s . status . LastErrMsg != "" {
msg = s . status . LastErrMsg
}
err = fmt . Errorf ( "error starting the external llama server: %v %s" , err , msg )
finalErr = err
continue
}
2024-05-09 11:10:28 -07:00
// reap subprocess when it exits
go func ( ) {
s . done <- s . cmd . Wait ( )
} ( )
2024-03-14 10:24:13 -07:00
return s , nil
}
slog . Error ( "unable to load any llama server" , "error" , finalErr )
return nil , finalErr
}
2024-04-05 14:50:38 -07:00
func projectorMemoryRequirements ( filename string ) uint64 {
2024-03-14 10:24:13 -07:00
file , err := os . Open ( filename )
if err != nil {
return 0
}
defer file . Close ( )
ggml , _ , err := DecodeGGML ( file )
if err != nil {
return 0
}
2024-04-03 15:00:31 -07:00
var mem uint64
for _ , layer := range ggml . Tensors ( ) . Layers ( ) {
mem += layer . size ( )
2024-03-14 10:24:13 -07:00
}
2024-04-05 14:50:38 -07:00
return mem
2024-03-14 10:24:13 -07:00
}
type ServerStatus int
const ( // iota is reset to 0
ServerStatusReady ServerStatus = iota
2024-05-06 14:22:53 -07:00
ServerStatusNoSlotsAvailable
2024-03-14 10:24:13 -07:00
ServerStatusLoadingModel
ServerStatusNotResponding
ServerStatusError
)
2024-03-30 09:50:05 -07:00
func ( s ServerStatus ) ToString ( ) string {
switch s {
case ServerStatusReady :
return "llm server ready"
2024-05-06 14:22:53 -07:00
case ServerStatusNoSlotsAvailable :
2024-03-30 09:50:05 -07:00
return "llm busy - no slots available"
case ServerStatusLoadingModel :
return "llm server loading model"
case ServerStatusNotResponding :
return "llm server not responding"
default :
return "llm server error"
}
}
2024-03-14 10:24:13 -07:00
type ServerStatusResp struct {
Status string ` json:"status" `
SlotsIdle int ` json:"slots_idle" `
SlotsProcessing int ` json:"slots_processing" `
Error string ` json:"error" `
}
2024-03-30 09:50:05 -07:00
func ( s * llmServer ) getServerStatus ( ctx context . Context ) ( ServerStatus , error ) {
2024-03-14 10:24:13 -07:00
// Fail fast if its exited
if s . cmd . ProcessState != nil {
msg := ""
if s . status != nil && s . status . LastErrMsg != "" {
msg = s . status . LastErrMsg
}
2024-05-07 16:46:15 -07:00
if s . cmd . ProcessState . ExitCode ( ) == - 1 {
// Most likely a signal killed it, log some more details to try to help troubleshoot
slog . Warn ( "llama runner process no longer running" , "sys" , s . cmd . ProcessState . Sys ( ) , "string" , s . cmd . ProcessState . String ( ) )
}
2024-03-14 10:24:13 -07:00
return ServerStatusError , fmt . Errorf ( "llama runner process no longer running: %d %s" , s . cmd . ProcessState . ExitCode ( ) , msg )
}
req , err := http . NewRequestWithContext ( ctx , http . MethodGet , fmt . Sprintf ( "http://127.0.0.1:%d/health" , s . port ) , nil )
if err != nil {
return ServerStatusError , fmt . Errorf ( "error creating GET request: %v" , err )
}
req . Header . Set ( "Content-Type" , "application/json" )
resp , err := http . DefaultClient . Do ( req )
if err != nil {
if errors . Is ( err , context . DeadlineExceeded ) {
return ServerStatusNotResponding , fmt . Errorf ( "server not responding" )
}
return ServerStatusError , fmt . Errorf ( "health resp: %w" , err )
}
defer resp . Body . Close ( )
body , err := io . ReadAll ( resp . Body )
if err != nil {
return ServerStatusError , fmt . Errorf ( "read health request: %w" , err )
}
var status ServerStatusResp
if err := json . Unmarshal ( body , & status ) ; err != nil {
return ServerStatusError , fmt . Errorf ( "health unmarshal encode response: %w" , err )
}
switch status . Status {
case "ok" :
return ServerStatusReady , nil
case "no slot available" :
2024-05-06 14:22:53 -07:00
return ServerStatusNoSlotsAvailable , nil
2024-03-14 10:24:13 -07:00
case "loading model" :
return ServerStatusLoadingModel , nil
default :
return ServerStatusError , fmt . Errorf ( "server error: %+v" , status )
}
}
2024-05-06 14:22:53 -07:00
// getServerStatusRetry will retry if ServerStatusNoSlotsAvailable is received
func ( s * llmServer ) getServerStatusRetry ( ctx context . Context ) ( ServerStatus , error ) {
var retries int
for {
status , err := s . getServerStatus ( ctx )
if err != nil {
return status , err
}
if status == ServerStatusNoSlotsAvailable {
if retries >= 10 {
return status , fmt . Errorf ( "no slots available after %d retries" , retries )
}
time . Sleep ( 5 * time . Millisecond )
retries ++
continue
}
return status , nil
}
}
2024-03-30 09:50:05 -07:00
func ( s * llmServer ) Ping ( ctx context . Context ) error {
2024-03-14 10:24:13 -07:00
_ , err := s . getServerStatus ( ctx )
if err != nil {
slog . Debug ( "server unhealthy" , "error" , err )
return err
}
return nil
}
2024-03-30 09:50:05 -07:00
func ( s * llmServer ) WaitUntilRunning ( ctx context . Context ) error {
2024-03-14 10:24:13 -07:00
start := time . Now ( )
2024-04-09 16:35:10 -07:00
expiresAt := time . Now ( ) . Add ( 10 * time . Minute ) // be generous with timeout, large models can take a while to load
2024-03-14 10:24:13 -07:00
slog . Info ( "waiting for llama runner to start responding" )
var lastStatus ServerStatus = - 1
2024-04-17 17:39:52 +02:00
2024-03-14 10:24:13 -07:00
for {
select {
2024-03-30 09:50:05 -07:00
case <- ctx . Done ( ) :
slog . Info ( "context expired before server started" )
2024-04-26 17:38:29 -04:00
return fmt . Errorf ( "timed out waiting for llama runner to start: %w" , ctx . Err ( ) )
2024-03-14 10:24:13 -07:00
case err := <- s . done :
msg := ""
if s . status != nil && s . status . LastErrMsg != "" {
msg = s . status . LastErrMsg
}
return fmt . Errorf ( "llama runner process has terminated: %v %s" , err , msg )
2024-05-08 16:44:35 -07:00
default :
}
2024-04-17 17:39:52 +02:00
if time . Now ( ) . After ( expiresAt ) {
// timeout
2024-03-14 10:24:13 -07:00
msg := ""
if s . status != nil && s . status . LastErrMsg != "" {
msg = s . status . LastErrMsg
}
2024-04-17 17:39:52 +02:00
return fmt . Errorf ( "timed out waiting for llama runner to start: %s" , msg )
}
if s . cmd . ProcessState != nil {
msg := ""
if s . status != nil && s . status . LastErrMsg != "" {
msg = s . status . LastErrMsg
2024-03-14 10:24:13 -07:00
}
2024-04-17 17:39:52 +02:00
return fmt . Errorf ( "llama runner process no longer running: %d %s" , s . cmd . ProcessState . ExitCode ( ) , msg )
}
2024-05-09 11:10:28 -07:00
ctx , cancel := context . WithTimeout ( ctx , 200 * time . Millisecond )
defer cancel ( )
status , _ := s . getServerStatus ( ctx )
if lastStatus != status && status != ServerStatusReady {
// Only log on status changes
slog . Info ( "waiting for server to become available" , "status" , status . ToString ( ) )
}
2024-04-17 17:39:52 +02:00
switch status {
case ServerStatusReady :
2024-05-09 11:10:28 -07:00
s . loadDuration = time . Since ( start )
slog . Info ( fmt . Sprintf ( "llama runner started in %0.2f seconds" , s . loadDuration . Seconds ( ) ) )
2024-04-17 17:39:52 +02:00
return nil
default :
2024-05-09 11:10:28 -07:00
lastStatus = status
2024-04-17 17:39:52 +02:00
time . Sleep ( time . Millisecond * 250 )
continue
2024-03-14 10:24:13 -07:00
}
}
}
const jsonGrammar = `
root : := object
value : := object | array | string | number | ( "true" | "false" | "null" ) ws
object : :=
"{" ws (
string ":" ws value
( "," ws string ":" ws value ) *
) ? "}" ws
array : :=
"[" ws (
value
( "," ws value ) *
) ? "]" ws
string : :=
"\"" (
[ ^ " \ \ ] |
"\\" ( [ "\\/bfnrt] | " u " [ 0 - 9 a - fA - F ] [ 0 - 9 a - fA - F ] [ 0 - 9 a - fA - F ] [ 0 - 9 a - fA - F ] ) # escapes
) * "\"" ws
number : := ( "-" ? ( [ 0 - 9 ] | [ 1 - 9 ] [ 0 - 9 ] * ) ) ( "." [ 0 - 9 ] + ) ? ( [ eE ] [ - + ] ? [ 0 - 9 ] + ) ? ws
# Optional space : by convention , applied in this grammar after literal chars when allowed
ws : := ( [ \ t \ n ] ws ) ?
`
const maxBufferSize = 512 * format . KiloByte
type ImageData struct {
Data [ ] byte ` json:"data" `
ID int ` json:"id" `
}
type completion struct {
2024-05-09 13:30:14 -07:00
Content string ` json:"content" `
Model string ` json:"model" `
Prompt string ` json:"prompt" `
Stop bool ` json:"stop" `
StoppedLimit bool ` json:"stopped_limit" `
2024-03-14 10:24:13 -07:00
Timings struct {
PredictedN int ` json:"predicted_n" `
PredictedMS float64 ` json:"predicted_ms" `
PromptN int ` json:"prompt_n" `
PromptMS float64 ` json:"prompt_ms" `
}
}
type CompletionRequest struct {
Prompt string
Format string
Images [ ] ImageData
Options api . Options
}
type CompletionResponse struct {
Content string
2024-05-09 13:30:14 -07:00
DoneReason string
2024-03-14 10:24:13 -07:00
Done bool
PromptEvalCount int
PromptEvalDuration time . Duration
EvalCount int
EvalDuration time . Duration
}
2024-03-30 09:50:05 -07:00
func ( s * llmServer ) Completion ( ctx context . Context , req CompletionRequest , fn func ( CompletionResponse ) ) error {
if err := s . sem . Acquire ( ctx , 1 ) ; err != nil {
slog . Error ( "Failed to acquire semaphore" , "error" , err )
return err
}
defer s . sem . Release ( 1 )
2024-04-25 19:02:30 -04:00
// only allow maximum 10 "context shifts" to avoid infinite generation
if req . Options . NumPredict < 0 || req . Options . NumPredict > 10 * s . options . NumCtx {
req . Options . NumPredict = 10 * s . options . NumCtx
slog . Debug ( "setting token limit to 10x num_ctx" , "num_ctx" , s . options . NumCtx , "num_predict" , req . Options . NumPredict )
}
2024-03-14 10:24:13 -07:00
request := map [ string ] any {
"prompt" : req . Prompt ,
"stream" : true ,
"n_predict" : req . Options . NumPredict ,
"n_keep" : req . Options . NumKeep ,
"main_gpu" : req . Options . MainGPU ,
"temperature" : req . Options . Temperature ,
"top_k" : req . Options . TopK ,
"top_p" : req . Options . TopP ,
"tfs_z" : req . Options . TFSZ ,
"typical_p" : req . Options . TypicalP ,
"repeat_last_n" : req . Options . RepeatLastN ,
"repeat_penalty" : req . Options . RepeatPenalty ,
"presence_penalty" : req . Options . PresencePenalty ,
"frequency_penalty" : req . Options . FrequencyPenalty ,
"mirostat" : req . Options . Mirostat ,
"mirostat_tau" : req . Options . MirostatTau ,
"mirostat_eta" : req . Options . MirostatEta ,
"penalize_nl" : req . Options . PenalizeNewline ,
"seed" : req . Options . Seed ,
"stop" : req . Options . Stop ,
"image_data" : req . Images ,
"cache_prompt" : true ,
}
// Make sure the server is ready
2024-05-06 14:22:53 -07:00
status , err := s . getServerStatusRetry ( ctx )
2024-03-14 10:24:13 -07:00
if err != nil {
return err
} else if status != ServerStatusReady {
2024-03-30 09:50:05 -07:00
return fmt . Errorf ( "unexpected server status: %s" , status . ToString ( ) )
2024-03-14 10:24:13 -07:00
}
if req . Format == "json" {
request [ "grammar" ] = jsonGrammar
if ! strings . Contains ( strings . ToLower ( req . Prompt ) , "json" ) {
slog . Warn ( "Prompt does not specify that the LLM should response in JSON, but JSON format is expected. For best results specify that JSON is expected in the system prompt." )
}
}
2024-05-06 14:22:53 -07:00
// Handling JSON marshaling with special characters unescaped.
buffer := & bytes . Buffer { }
enc := json . NewEncoder ( buffer )
enc . SetEscapeHTML ( false )
2024-03-14 10:24:13 -07:00
2024-05-06 14:22:53 -07:00
if err := enc . Encode ( request ) ; err != nil {
return fmt . Errorf ( "failed to marshal data: %v" , err )
}
2024-03-14 10:24:13 -07:00
2024-05-06 14:22:53 -07:00
endpoint := fmt . Sprintf ( "http://127.0.0.1:%d/completion" , s . port )
serverReq , err := http . NewRequestWithContext ( ctx , http . MethodPost , endpoint , buffer )
if err != nil {
return fmt . Errorf ( "error creating POST request: %v" , err )
}
serverReq . Header . Set ( "Content-Type" , "application/json" )
2024-03-14 10:24:13 -07:00
2024-05-06 14:22:53 -07:00
res , err := http . DefaultClient . Do ( serverReq )
if err != nil {
return fmt . Errorf ( "POST predict: %v" , err )
}
defer res . Body . Close ( )
2024-03-14 10:24:13 -07:00
2024-05-06 14:22:53 -07:00
if res . StatusCode >= 400 {
bodyBytes , err := io . ReadAll ( res . Body )
2024-03-14 10:24:13 -07:00
if err != nil {
2024-05-06 14:22:53 -07:00
return fmt . Errorf ( "failed reading llm error response: %w" , err )
2024-03-14 10:24:13 -07:00
}
2024-05-06 14:22:53 -07:00
log . Printf ( "llm predict error: %s" , bodyBytes )
return fmt . Errorf ( "%s" , bodyBytes )
}
2024-03-14 10:24:13 -07:00
2024-05-06 14:22:53 -07:00
scanner := bufio . NewScanner ( res . Body )
buf := make ( [ ] byte , 0 , maxBufferSize )
scanner . Buffer ( buf , maxBufferSize )
2024-03-14 10:24:13 -07:00
2024-05-06 14:22:53 -07:00
// keep track of the last token generated, this is used to abort if the model starts looping
var lastToken string
var tokenRepeat int
2024-03-14 10:24:13 -07:00
2024-05-06 14:22:53 -07:00
for scanner . Scan ( ) {
select {
case <- ctx . Done ( ) :
// This handles the request cancellation
return ctx . Err ( )
default :
line := scanner . Bytes ( )
if len ( line ) == 0 {
continue
}
2024-03-14 10:24:13 -07:00
2024-05-06 14:22:53 -07:00
evt , ok := bytes . CutPrefix ( line , [ ] byte ( "data: " ) )
if ! ok {
return fmt . Errorf ( "error parsing llm response stream: %s" , line )
}
2024-03-14 10:24:13 -07:00
2024-05-06 14:22:53 -07:00
var c completion
if err := json . Unmarshal ( evt , & c ) ; err != nil {
return fmt . Errorf ( "error unmarshaling llm prediction response: %v" , err )
}
2024-03-14 10:24:13 -07:00
2024-05-06 14:22:53 -07:00
switch {
case strings . TrimSpace ( c . Content ) == lastToken :
tokenRepeat ++
default :
lastToken = strings . TrimSpace ( c . Content )
tokenRepeat = 0
}
2024-03-14 10:24:13 -07:00
2024-05-06 14:22:53 -07:00
// 30 picked as an arbitrary max token repeat limit, modify as needed
if tokenRepeat > 30 {
slog . Debug ( "prediction aborted, token repeat limit reached" )
return ctx . Err ( )
}
2024-03-14 10:24:13 -07:00
2024-05-06 14:22:53 -07:00
if c . Content != "" {
fn ( CompletionResponse {
Content : c . Content ,
} )
}
2024-03-14 10:24:13 -07:00
2024-05-06 14:22:53 -07:00
if c . Stop {
2024-05-09 13:30:14 -07:00
doneReason := "stop"
if c . StoppedLimit {
doneReason = "length"
}
2024-05-06 14:22:53 -07:00
fn ( CompletionResponse {
Done : true ,
2024-05-09 13:30:14 -07:00
DoneReason : doneReason ,
2024-05-06 14:22:53 -07:00
PromptEvalCount : c . Timings . PromptN ,
PromptEvalDuration : parseDurationMs ( c . Timings . PromptMS ) ,
EvalCount : c . Timings . PredictedN ,
EvalDuration : parseDurationMs ( c . Timings . PredictedMS ) ,
} )
return nil
2024-03-14 10:24:13 -07:00
}
}
2024-05-06 14:22:53 -07:00
}
2024-03-14 10:24:13 -07:00
2024-05-06 14:22:53 -07:00
if err := scanner . Err ( ) ; err != nil {
if strings . Contains ( err . Error ( ) , "unexpected EOF" ) {
s . Close ( )
msg := ""
if s . status != nil && s . status . LastErrMsg != "" {
msg = s . status . LastErrMsg
2024-03-14 10:24:13 -07:00
}
2024-05-06 14:22:53 -07:00
return fmt . Errorf ( "an unknown error was encountered while running the model %s" , msg )
2024-03-14 10:24:13 -07:00
}
2024-05-06 14:22:53 -07:00
return fmt . Errorf ( "error reading llm response: %v" , err )
2024-03-14 10:24:13 -07:00
}
2024-05-06 14:22:53 -07:00
return nil
2024-03-14 10:24:13 -07:00
}
type EmbeddingRequest struct {
Content string ` json:"content" `
}
type EmbeddingResponse struct {
Embedding [ ] float64 ` json:"embedding" `
}
2024-03-30 09:50:05 -07:00
func ( s * llmServer ) Embedding ( ctx context . Context , prompt string ) ( [ ] float64 , error ) {
if err := s . sem . Acquire ( ctx , 1 ) ; err != nil {
slog . Error ( "Failed to acquire semaphore" , "error" , err )
return nil , err
}
defer s . sem . Release ( 1 )
2024-05-06 14:22:53 -07:00
2024-03-14 10:24:13 -07:00
// Make sure the server is ready
2024-05-06 14:22:53 -07:00
status , err := s . getServerStatusRetry ( ctx )
2024-03-14 10:24:13 -07:00
if err != nil {
return nil , err
} else if status != ServerStatusReady {
2024-03-30 09:50:05 -07:00
return nil , fmt . Errorf ( "unexpected server status: %s" , status . ToString ( ) )
2024-03-14 10:24:13 -07:00
}
data , err := json . Marshal ( TokenizeRequest { Content : prompt } )
if err != nil {
return nil , fmt . Errorf ( "error marshaling embed data: %w" , err )
}
req , err := http . NewRequestWithContext ( ctx , http . MethodPost , fmt . Sprintf ( "http://127.0.0.1:%d/embedding" , s . port ) , bytes . NewBuffer ( data ) )
if err != nil {
return nil , fmt . Errorf ( "error creating embed request: %w" , err )
}
req . Header . Set ( "Content-Type" , "application/json" )
resp , err := http . DefaultClient . Do ( req )
if err != nil {
return nil , fmt . Errorf ( "do embedding request: %w" , err )
}
defer resp . Body . Close ( )
body , err := io . ReadAll ( resp . Body )
if err != nil {
return nil , fmt . Errorf ( "error reading embed response: %w" , err )
}
if resp . StatusCode >= 400 {
log . Printf ( "llm encode error: %s" , body )
return nil , fmt . Errorf ( "%s" , body )
}
var embedding EmbeddingResponse
if err := json . Unmarshal ( body , & embedding ) ; err != nil {
return nil , fmt . Errorf ( "unmarshal tokenize response: %w" , err )
}
return embedding . Embedding , nil
}
type TokenizeRequest struct {
Content string ` json:"content" `
}
type TokenizeResponse struct {
Tokens [ ] int ` json:"tokens" `
}
2024-03-30 09:50:05 -07:00
func ( s * llmServer ) Tokenize ( ctx context . Context , content string ) ( [ ] int , error ) {
2024-03-14 10:24:13 -07:00
// Make sure the server is ready
status , err := s . getServerStatus ( ctx )
if err != nil {
return nil , err
2024-05-06 14:22:53 -07:00
} else if status != ServerStatusReady && status != ServerStatusNoSlotsAvailable {
2024-03-30 09:50:05 -07:00
return nil , fmt . Errorf ( "unexpected server status: %s" , status . ToString ( ) )
2024-03-14 10:24:13 -07:00
}
data , err := json . Marshal ( TokenizeRequest { Content : content } )
if err != nil {
return nil , fmt . Errorf ( "marshaling encode data: %w" , err )
}
req , err := http . NewRequestWithContext ( ctx , http . MethodPost , fmt . Sprintf ( "http://127.0.0.1:%d/tokenize" , s . port ) , bytes . NewBuffer ( data ) )
if err != nil {
return nil , fmt . Errorf ( "encode request: %w" , err )
}
req . Header . Set ( "Content-Type" , "application/json" )
resp , err := http . DefaultClient . Do ( req )
if err != nil {
return nil , fmt . Errorf ( "do encode request: %w" , err )
}
defer resp . Body . Close ( )
body , err := io . ReadAll ( resp . Body )
if err != nil {
return nil , fmt . Errorf ( "read encode request: %w" , err )
}
if resp . StatusCode >= 400 {
log . Printf ( "llm encode error: %s" , body )
return nil , fmt . Errorf ( "%s" , body )
}
var encoded TokenizeResponse
if err := json . Unmarshal ( body , & encoded ) ; err != nil {
return nil , fmt . Errorf ( "unmarshal encode response: %w" , err )
}
return encoded . Tokens , nil
}
type DetokenizeRequest struct {
Tokens [ ] int ` json:"tokens" `
}
type DetokenizeResponse struct {
Content string ` json:"content" `
}
2024-03-30 09:50:05 -07:00
func ( s * llmServer ) Detokenize ( ctx context . Context , tokens [ ] int ) ( string , error ) {
2024-03-14 10:24:13 -07:00
// Make sure the server is ready
status , err := s . getServerStatus ( ctx )
if err != nil {
return "" , err
2024-05-06 14:22:53 -07:00
} else if status != ServerStatusReady && status != ServerStatusNoSlotsAvailable {
2024-03-30 09:50:05 -07:00
return "" , fmt . Errorf ( "unexpected server status: %s" , status . ToString ( ) )
2024-03-14 10:24:13 -07:00
}
data , err := json . Marshal ( DetokenizeRequest { Tokens : tokens } )
if err != nil {
return "" , fmt . Errorf ( "marshaling decode data: %w" , err )
}
req , err := http . NewRequestWithContext ( ctx , http . MethodPost , fmt . Sprintf ( "http://127.0.0.1:%d/detokenize" , s . port ) , bytes . NewBuffer ( data ) )
if err != nil {
return "" , fmt . Errorf ( "decode request: %w" , err )
}
req . Header . Set ( "Content-Type" , "application/json" )
resp , err := http . DefaultClient . Do ( req )
if err != nil {
return "" , fmt . Errorf ( "do decode request: %w" , err )
}
defer resp . Body . Close ( )
body , err := io . ReadAll ( resp . Body )
if err != nil {
return "" , fmt . Errorf ( "read decode request: %w" , err )
}
if resp . StatusCode >= 400 {
log . Printf ( "llm decode error: %s" , body )
return "" , fmt . Errorf ( "%s" , body )
}
var decoded DetokenizeResponse
if err := json . Unmarshal ( body , & decoded ) ; err != nil {
return "" , fmt . Errorf ( "unmarshal encode response: %w" , err )
}
return decoded . Content , nil
}
2024-03-30 09:50:05 -07:00
func ( s * llmServer ) Close ( ) error {
2024-03-14 10:24:13 -07:00
if s . cmd != nil {
slog . Debug ( "stopping llama server" )
2024-04-28 16:41:38 +00:00
if err := s . cmd . Process . Kill ( ) ; err != nil {
return err
}
2024-05-09 11:10:28 -07:00
// if ProcessState is already populated, Wait already completed, no need to wait again
if s . cmd . ProcessState == nil {
slog . Debug ( "waiting for llama server to exit" )
<- s . done
}
2024-04-29 18:06:56 +00:00
slog . Debug ( "llama server stopped" )
2024-03-14 10:24:13 -07:00
}
return nil
}
2024-03-30 09:50:05 -07:00
func ( s * llmServer ) EstimatedVRAM ( ) uint64 {
return s . estimatedVRAM
}
2024-03-14 10:24:13 -07:00
func parseDurationMs ( ms float64 ) time . Duration {
dur , err := time . ParseDuration ( fmt . Sprintf ( "%fms" , ms ) )
if err != nil {
panic ( err )
}
return dur
}