2023-08-30 16:35:03 -04:00
package llm
import (
"bufio"
"bytes"
"context"
"embed"
"encoding/json"
"errors"
"fmt"
"io"
"io/fs"
"log"
"math/rand"
"net/http"
"os"
"os/exec"
"path"
"path/filepath"
"runtime"
2023-10-19 12:18:31 -04:00
"slices"
2023-08-30 16:35:03 -04:00
"strconv"
"strings"
2023-10-11 12:32:13 -04:00
"sync"
2023-08-30 16:35:03 -04:00
"time"
"github.com/jmorganca/ollama/api"
2023-10-13 14:45:50 -07:00
"github.com/jmorganca/ollama/format"
2023-08-30 16:35:03 -04:00
)
2023-09-07 13:55:37 -04:00
//go:embed llama.cpp/*/build/*/bin/*
2023-08-30 16:35:03 -04:00
var llamaCppEmbed embed . FS
2023-09-18 15:16:32 -04:00
type ModelRunner struct {
2023-10-13 13:00:44 -07:00
Path string // path to the model runner executable
Accelerated bool
2023-09-18 15:16:32 -04:00
}
2023-09-14 15:08:13 -04:00
2023-09-21 20:38:49 +01:00
func chooseRunners ( workDir , runnerType string ) [ ] ModelRunner {
2023-09-18 15:16:32 -04:00
buildPath := path . Join ( "llama.cpp" , runnerType , "build" )
2023-10-13 13:00:44 -07:00
var runners [ ] ModelRunner
2023-09-14 15:08:13 -04:00
2023-09-18 15:16:32 -04:00
// set the runners based on the OS
// IMPORTANT: the order of the runners in the array is the priority order
2023-09-07 13:55:37 -04:00
switch runtime . GOOS {
case "darwin" :
2023-10-13 13:00:44 -07:00
runners = [ ] ModelRunner {
{ Path : path . Join ( buildPath , "metal" , "bin" , "ollama-runner" ) } ,
{ Path : path . Join ( buildPath , "cpu" , "bin" , "ollama-runner" ) } ,
2023-09-14 15:08:13 -04:00
}
2023-09-18 15:16:32 -04:00
case "linux" :
2023-10-13 13:00:44 -07:00
runners = [ ] ModelRunner {
{ Path : path . Join ( buildPath , "cuda" , "bin" , "ollama-runner" ) , Accelerated : true } ,
{ Path : path . Join ( buildPath , "cpu" , "bin" , "ollama-runner" ) } ,
2023-09-14 15:08:13 -04:00
}
case "windows" :
// TODO: select windows GPU runner here when available
2023-10-13 13:00:44 -07:00
runners = [ ] ModelRunner {
{ Path : path . Join ( buildPath , "cpu" , "bin" , "Release" , "ollama-runner.exe" ) } ,
2023-09-18 15:16:32 -04:00
}
2023-09-14 15:08:13 -04:00
default :
log . Printf ( "unknown OS, running on CPU: %s" , runtime . GOOS )
2023-10-13 13:00:44 -07:00
runners = [ ] ModelRunner {
{ Path : path . Join ( buildPath , "cpu" , "bin" , "ollama-runner" ) } ,
2023-09-12 11:04:35 -04:00
}
2023-09-07 13:55:37 -04:00
}
2023-08-30 16:35:03 -04:00
2023-09-18 15:16:32 -04:00
runnerAvailable := false // if no runner files are found in the embed, this flag will cause a fast fail
for _ , r := range runners {
// find all the files in the runner's bin directory
2023-10-13 13:00:44 -07:00
files , err := fs . Glob ( llamaCppEmbed , path . Join ( path . Dir ( r . Path ) , "*" ) )
2023-09-07 13:55:37 -04:00
if err != nil {
2023-09-18 15:16:32 -04:00
// this is expected, ollama may be compiled without all runners packed in
2023-10-16 16:14:12 -07:00
log . Printf ( "%s runner not found: %v" , r . Path , err )
2023-09-18 15:16:32 -04:00
continue
2023-09-07 13:55:37 -04:00
}
2023-08-30 16:35:03 -04:00
2023-09-18 15:16:32 -04:00
for _ , f := range files {
2023-09-29 11:47:55 -04:00
runnerAvailable = true
2023-09-18 15:16:32 -04:00
srcFile , err := llamaCppEmbed . Open ( f )
if err != nil {
log . Fatalf ( "read llama runner %s: %v" , f , err )
}
defer srcFile . Close ( )
2023-09-29 11:47:55 -04:00
// create the directory in case it does not exist, filepath.Dir() converts the file path to the OS's format
2023-09-21 20:38:49 +01:00
destPath := filepath . Join ( workDir , filepath . Dir ( f ) )
2023-09-18 15:16:32 -04:00
if err := os . MkdirAll ( destPath , 0 o755 ) ; err != nil {
log . Fatalf ( "create runner temp dir %s: %v" , filepath . Dir ( f ) , err )
}
2023-08-30 16:35:03 -04:00
2023-09-29 11:47:55 -04:00
// create the path to the destination file, filepath.Base() converts the file path to the OS's format
2023-09-21 20:38:49 +01:00
destFile := filepath . Join ( destPath , filepath . Base ( f ) )
_ , err = os . Stat ( destFile )
switch {
case errors . Is ( err , os . ErrNotExist ) :
destFile , err := os . OpenFile ( destFile , os . O_WRONLY | os . O_CREATE | os . O_TRUNC , 0 o755 )
if err != nil {
log . Fatalf ( "write llama runner %s: %v" , f , err )
}
defer destFile . Close ( )
if _ , err := io . Copy ( destFile , srcFile ) ; err != nil {
log . Fatalf ( "copy llama runner %s: %v" , f , err )
}
case err != nil :
log . Fatalf ( "stat llama runner %s: %v" , f , err )
2023-09-18 15:16:32 -04:00
}
2023-08-30 16:35:03 -04:00
}
2023-09-07 13:55:37 -04:00
}
2023-09-18 15:16:32 -04:00
if ! runnerAvailable {
log . Fatalf ( "%s runner not found" , runnerType )
}
2023-08-30 16:35:03 -04:00
2023-09-18 15:16:32 -04:00
// return the runners to try in priority order
localRunnersByPriority := [ ] ModelRunner { }
for _ , r := range runners {
2023-09-29 11:47:55 -04:00
// clean the ModelRunner paths so that they match the OS we are running on
2023-10-13 13:00:44 -07:00
localRunnersByPriority = append ( localRunnersByPriority , ModelRunner {
Path : filepath . Clean ( path . Join ( workDir , r . Path ) ) ,
Accelerated : r . Accelerated ,
} )
2023-09-07 13:55:37 -04:00
}
2023-08-30 16:35:03 -04:00
2023-09-18 15:16:32 -04:00
return localRunnersByPriority
2023-08-30 16:35:03 -04:00
}
type llamaModel struct {
hyperparameters llamaHyperparameters
}
2023-09-12 10:01:20 -07:00
func ( llm * llamaModel ) ModelFamily ( ) string {
return "llama"
2023-08-30 16:35:03 -04:00
}
2023-09-12 10:01:20 -07:00
func llamaModelType ( numLayer uint32 ) string {
switch numLayer {
2023-08-30 16:35:03 -04:00
case 26 :
2023-09-12 10:01:20 -07:00
return "3B"
2023-08-30 16:35:03 -04:00
case 32 :
2023-09-12 10:01:20 -07:00
return "7B"
2023-08-30 16:35:03 -04:00
case 40 :
2023-09-12 10:01:20 -07:00
return "13B"
2023-08-30 16:35:03 -04:00
case 48 :
2023-09-12 10:01:20 -07:00
return "34B"
2023-08-30 16:35:03 -04:00
case 60 :
2023-09-12 10:01:20 -07:00
return "30B"
2023-08-30 16:35:03 -04:00
case 80 :
2023-09-12 10:01:20 -07:00
return "65B"
default :
2023-10-02 19:52:25 -07:00
return "unknown"
2023-08-30 16:35:03 -04:00
}
2023-09-12 10:01:20 -07:00
}
2023-08-30 16:35:03 -04:00
2023-09-12 10:01:20 -07:00
func ( llm * llamaModel ) ModelType ( ) string {
return llamaModelType ( llm . hyperparameters . NumLayer )
2023-08-30 16:35:03 -04:00
}
2023-09-12 10:01:20 -07:00
func ( llm * llamaModel ) FileType ( ) string {
return fileType ( llm . hyperparameters . FileType )
2023-08-30 16:35:03 -04:00
}
2023-09-25 23:36:46 +01:00
func ( llm * llamaModel ) NumLayers ( ) int64 {
return int64 ( llm . hyperparameters . NumLayer )
}
2023-08-30 16:35:03 -04:00
type llamaHyperparameters struct {
// NumVocab is the size of the model's vocabulary.
NumVocab uint32
// NumEmbd is the size of the model's embedding layer.
NumEmbd uint32
NumMult uint32
NumHead uint32
// NumLayer is the number of layers in the model.
NumLayer uint32
NumRot uint32
// FileType describes the quantization level of the model, e.g. Q4_0, Q5_K, etc.
2023-09-12 10:01:20 -07:00
FileType uint32
2023-08-30 16:35:03 -04:00
}
type Running struct {
2023-10-18 15:36:56 -04:00
Port int
Cmd * exec . Cmd
Cancel context . CancelFunc
exitOnce sync . Once
exitCh chan error // channel to receive the exit status of the subprocess
* StatusWriter // captures error messages from the llama runner process
2023-08-30 16:35:03 -04:00
}
type llama struct {
api . Options
Running
}
2023-09-12 11:04:35 -04:00
var errNoGPU = errors . New ( "nvidia-smi command failed" )
2023-10-13 14:45:50 -07:00
// CheckVRAM returns the free VRAM in bytes on Linux machines with NVIDIA GPUs
2023-09-28 10:00:34 -07:00
func CheckVRAM ( ) ( int64 , error ) {
2023-10-10 16:16:09 -04:00
cmd := exec . Command ( "nvidia-smi" , "--query-gpu=memory.free" , "--format=csv,noheader,nounits" )
2023-09-12 11:04:35 -04:00
var stdout bytes . Buffer
cmd . Stdout = & stdout
err := cmd . Run ( )
if err != nil {
return 0 , errNoGPU
}
2023-10-13 14:45:50 -07:00
var freeMiB int64
2023-09-12 11:04:35 -04:00
scanner := bufio . NewScanner ( & stdout )
for scanner . Scan ( ) {
line := scanner . Text ( )
2023-09-28 10:00:34 -07:00
vram , err := strconv . ParseInt ( strings . TrimSpace ( line ) , 10 , 64 )
2023-09-12 11:04:35 -04:00
if err != nil {
return 0 , fmt . Errorf ( "failed to parse available VRAM: %v" , err )
}
2023-10-13 14:45:50 -07:00
freeMiB += vram
2023-09-12 11:04:35 -04:00
}
2023-10-13 14:45:50 -07:00
freeBytes := freeMiB * 1024 * 1024
if freeBytes < 2 * format . GigaByte {
2023-10-13 12:58:54 -07:00
log . Printf ( "less than 2 GB VRAM available, falling back to CPU only" )
2023-10-13 14:45:50 -07:00
freeMiB = 0
2023-10-13 12:58:54 -07:00
}
2023-10-13 14:45:50 -07:00
return freeBytes , nil
2023-09-12 11:04:35 -04:00
}
2023-09-25 23:36:46 +01:00
func NumGPU ( numLayer , fileSizeBytes int64 , opts api . Options ) int {
2023-09-12 11:04:35 -04:00
if opts . NumGPU != - 1 {
return opts . NumGPU
}
if runtime . GOOS == "linux" {
2023-10-13 14:45:50 -07:00
freeBytes , err := CheckVRAM ( )
2023-09-12 11:04:35 -04:00
if err != nil {
if err . Error ( ) != "nvidia-smi command failed" {
log . Print ( err . Error ( ) )
}
// nvidia driver not installed or no nvidia GPU found
return 0
}
2023-09-25 23:36:46 +01:00
// Calculate bytes per layer
// TODO: this is a rough heuristic, better would be to calculate this based on number of layers and context size
bytesPerLayer := fileSizeBytes / numLayer
2023-10-13 16:57:10 -04:00
// max number of layers we can fit in VRAM, subtract 8% to prevent consuming all available VRAM and running out of memory
2023-10-13 14:45:50 -07:00
layers := int ( freeBytes / bytesPerLayer ) * 92 / 100
2023-10-17 15:35:16 -04:00
log . Printf ( "%d MB VRAM available, loading up to %d GPU layers" , freeBytes / ( 1024 * 1024 ) , layers )
2023-09-25 23:36:46 +01:00
2023-10-02 14:53:42 -04:00
return layers
2023-09-12 11:04:35 -04:00
}
2023-09-25 23:36:46 +01:00
// default to enable metal on macOS
return 1
2023-09-12 11:04:35 -04:00
}
2023-10-12 11:16:37 -04:00
// StatusWriter is a writer that captures error messages from the llama runner process
type StatusWriter struct {
2023-10-18 15:36:56 -04:00
ErrCh chan error
LastErrMsg string
2023-10-12 11:16:37 -04:00
}
func NewStatusWriter ( ) * StatusWriter {
return & StatusWriter {
ErrCh : make ( chan error , 1 ) ,
}
}
func ( w * StatusWriter ) Write ( b [ ] byte ) ( int , error ) {
2023-10-18 15:36:56 -04:00
var errMsg string
2023-10-12 11:16:37 -04:00
if _ , after , ok := bytes . Cut ( b , [ ] byte ( "error:" ) ) ; ok {
2023-10-18 15:36:56 -04:00
errMsg = string ( bytes . TrimSpace ( after ) )
} else if _ , after , ok := bytes . Cut ( b , [ ] byte ( "CUDA error" ) ) ; ok {
errMsg = string ( bytes . TrimSpace ( after ) )
2023-10-12 11:16:37 -04:00
}
2023-10-18 15:36:56 -04:00
if errMsg != "" {
w . LastErrMsg = errMsg
w . ErrCh <- fmt . Errorf ( "llama runner: %s" , errMsg )
}
2023-10-12 11:16:37 -04:00
return os . Stderr . Write ( b )
}
2023-10-19 12:18:31 -04:00
func newLlama ( model string , adapters [ ] string , runners [ ] ModelRunner , ggml * GGML , opts api . Options ) ( * llama , error ) {
2023-09-25 23:36:46 +01:00
fileInfo , err := os . Stat ( model )
if err != nil {
2023-08-30 16:35:03 -04:00
return nil , err
}
if len ( adapters ) > 1 {
return nil , errors . New ( "ollama supports only one lora adapter, but multiple were provided" )
}
2023-10-19 12:18:31 -04:00
numGPU := NumGPU ( ggml . NumLayers ( ) , fileInfo . Size ( ) , opts )
2023-08-30 16:35:03 -04:00
params := [ ] string {
"--model" , model ,
"--ctx-size" , fmt . Sprintf ( "%d" , opts . NumCtx ) ,
"--rope-freq-base" , fmt . Sprintf ( "%f" , opts . RopeFrequencyBase ) ,
"--rope-freq-scale" , fmt . Sprintf ( "%f" , opts . RopeFrequencyScale ) ,
"--batch-size" , fmt . Sprintf ( "%d" , opts . NumBatch ) ,
2023-10-16 14:37:17 -07:00
"--n-gpu-layers" , fmt . Sprintf ( "%d" , numGPU ) ,
2023-08-30 16:35:03 -04:00
"--embedding" ,
}
2023-09-07 13:55:37 -04:00
if opts . NumGQA > 0 {
params = append ( params , "--gqa" , fmt . Sprintf ( "%d" , opts . NumGQA ) )
}
2023-08-30 16:35:03 -04:00
if len ( adapters ) > 0 {
// TODO: applying multiple adapters is not supported by the llama.cpp server yet
params = append ( params , "--lora" , adapters [ 0 ] )
}
if opts . NumThread > 0 {
params = append ( params , "--threads" , fmt . Sprintf ( "%d" , opts . NumThread ) )
}
if ! opts . F16KV {
params = append ( params , "--memory-f32" )
}
if opts . UseMLock {
params = append ( params , "--mlock" )
}
if ! opts . UseMMap {
params = append ( params , "--no-mmap" )
}
if opts . UseNUMA {
params = append ( params , "--numa" )
}
2023-10-12 11:16:37 -04:00
var runnerErr error
2023-08-30 16:35:03 -04:00
// start the llama.cpp server with a retry in case the port is already in use
2023-09-18 15:16:32 -04:00
for _ , runner := range runners {
2023-10-13 13:00:44 -07:00
if runner . Accelerated && numGPU == 0 {
log . Printf ( "skipping accelerated runner because num_gpu=0" )
continue
}
2023-09-18 15:16:32 -04:00
if _ , err := os . Stat ( runner . Path ) ; err != nil {
log . Printf ( "llama runner not found: %v" , err )
continue
}
2023-08-30 16:35:03 -04:00
port := rand . Intn ( 65535 - 49152 ) + 49152 // get a random port in the ephemeral range
ctx , cancel := context . WithCancel ( context . Background ( ) )
cmd := exec . CommandContext (
ctx ,
runner . Path ,
append ( params , "--port" , strconv . Itoa ( port ) ) ... ,
)
2023-09-20 17:40:42 +01:00
cmd . Env = append ( os . Environ ( ) , fmt . Sprintf ( "LD_LIBRARY_PATH=%s" , filepath . Dir ( runner . Path ) ) )
2023-09-03 14:10:03 -04:00
cmd . Stdout = os . Stderr
2023-10-12 11:16:37 -04:00
statusWriter := NewStatusWriter ( )
cmd . Stderr = statusWriter
2023-08-30 16:35:03 -04:00
2023-10-11 12:32:13 -04:00
llm := & llama { Options : opts , Running : Running { Port : port , Cmd : cmd , Cancel : cancel , exitCh : make ( chan error ) } }
2023-08-30 16:35:03 -04:00
2023-09-18 15:16:32 -04:00
log . Print ( "starting llama runner" )
2023-09-07 13:55:37 -04:00
if err := llm . Cmd . Start ( ) ; err != nil {
2023-09-18 15:16:32 -04:00
log . Printf ( "error starting the external llama runner: %v" , err )
2023-09-07 13:55:37 -04:00
continue
}
2023-10-11 12:32:13 -04:00
// monitor the llama runner process and signal when it exits
2023-09-18 15:16:32 -04:00
go func ( ) {
2023-10-11 12:32:13 -04:00
err := llm . Cmd . Wait ( )
2023-10-18 15:36:56 -04:00
// default to printing the exit message of the command process, it will probably just say 'exit staus 1'
errMsg := err . Error ( )
// try to set a better error message if llama runner logs captured an error
if statusWriter . LastErrMsg != "" {
errMsg = statusWriter . LastErrMsg
}
log . Println ( errMsg )
2023-10-11 12:32:13 -04:00
// llm.Cmd.Wait() can only be called once, use this exit channel to signal that the process has exited
llm . exitOnce . Do ( func ( ) {
close ( llm . exitCh )
} )
2023-09-18 15:16:32 -04:00
} ( )
2023-08-30 16:35:03 -04:00
if err := waitForServer ( llm ) ; err != nil {
2023-09-18 15:16:32 -04:00
log . Printf ( "error starting llama runner: %v" , err )
2023-08-30 16:35:03 -04:00
llm . Close ( )
2023-10-12 11:16:37 -04:00
// default the runnerErr to the error returned by the most recent llama runner process
runnerErr = err
// capture the error directly from the runner process, if any
select {
case runnerErr = <- statusWriter . ErrCh :
default :
// the runner process probably timed out
}
2023-08-30 16:35:03 -04:00
// try again
continue
}
2023-09-07 13:55:37 -04:00
2023-08-30 16:35:03 -04:00
// server started successfully
return llm , nil
}
2023-10-12 11:16:37 -04:00
if runnerErr != nil {
// this is the error returned from the llama runner process that failed most recently
2023-10-19 12:18:31 -04:00
// falcon and starcoder model families are not compatible with older versions of llama.cpp
families := [ ] string { "falcon" , "starcoder" }
if strings . Contains ( runnerErr . Error ( ) , "failed to load model" ) && slices . Contains ( families , ggml . ModelFamily ( ) ) {
return nil , fmt . Errorf ( "%v: %s" , runnerErr , "this model may be incompatible with your version of Ollama. Please run `ollama pull` to get the latest version of this model." )
}
2023-10-12 11:16:37 -04:00
return nil , runnerErr
}
2023-09-18 15:16:32 -04:00
return nil , fmt . Errorf ( "failed to start a llama runner" )
2023-08-30 16:35:03 -04:00
}
func waitForServer ( llm * llama ) error {
start := time . Now ( )
2023-10-12 11:16:37 -04:00
expiresAt := time . Now ( ) . Add ( 3 * time . Minute ) // be generous with timeout, large models can take a while to load
2023-09-07 13:55:37 -04:00
ticker := time . NewTicker ( 200 * time . Millisecond )
2023-10-11 12:32:13 -04:00
defer ticker . Stop ( )
2023-08-30 16:35:03 -04:00
2023-09-18 15:16:32 -04:00
log . Print ( "waiting for llama runner to start responding" )
2023-10-11 12:32:13 -04:00
for {
select {
case <- llm . exitCh :
// failed to start subprocess
2023-09-18 15:16:32 -04:00
return fmt . Errorf ( "llama runner process has terminated" )
2023-10-11 12:32:13 -04:00
case <- ticker . C :
if time . Now ( ) . After ( expiresAt ) {
// timeout
2023-10-12 11:16:37 -04:00
return fmt . Errorf ( "timed out waiting for llama runner to start" )
2023-10-11 12:32:13 -04:00
}
2023-08-30 16:35:03 -04:00
2023-10-11 12:32:13 -04:00
if err := llm . Ping ( context . Background ( ) ) ; err == nil {
// success
log . Printf ( "llama runner started in %f seconds" , time . Since ( start ) . Seconds ( ) )
return nil
}
2023-08-30 16:35:03 -04:00
}
}
}
func ( llm * llama ) Close ( ) {
2023-10-10 16:16:09 -04:00
// signal the sub-process to terminate
2023-09-07 13:55:37 -04:00
llm . Cancel ( )
2023-10-10 16:16:09 -04:00
// wait for the command to exit to prevent race conditions with the next run
2023-10-11 12:32:13 -04:00
<- llm . exitCh
2023-10-18 15:36:56 -04:00
if llm . StatusWriter != nil && llm . StatusWriter . LastErrMsg != "" {
log . Printf ( "llama runner stopped with error: %v" , llm . StatusWriter . LastErrMsg )
2023-10-11 12:32:13 -04:00
} else {
log . Print ( "llama runner stopped successfully" )
2023-10-10 16:16:09 -04:00
}
2023-08-30 16:35:03 -04:00
}
func ( llm * llama ) SetOptions ( opts api . Options ) {
llm . Options = opts
}
2023-10-16 16:31:29 -07:00
type prediction struct {
2023-09-03 17:46:35 -04:00
Content string ` json:"content" `
Model string ` json:"model" `
Prompt string ` json:"prompt" `
Stop bool ` json:"stop" `
2023-10-16 16:31:29 -07:00
Timings struct {
PredictedN int ` json:"predicted_n" `
PredictedMS float64 ` json:"predicted_ms" `
PromptN int ` json:"prompt_n" `
PromptMS float64 ` json:"prompt_ms" `
}
2023-08-30 16:35:03 -04:00
}
2023-10-12 09:34:16 -07:00
const maxBufferSize = 512 * format . KiloByte
2023-10-04 14:09:00 -04:00
2023-09-03 14:10:03 -04:00
func ( llm * llama ) Predict ( ctx context . Context , prevContext [ ] int , prompt string , fn func ( api . GenerateResponse ) ) error {
prevConvo , err := llm . Decode ( ctx , prevContext )
2023-08-30 16:35:03 -04:00
if err != nil {
2023-09-03 14:10:03 -04:00
return err
2023-08-30 16:35:03 -04:00
}
2023-09-03 14:10:03 -04:00
2023-10-18 22:41:19 +02:00
// Remove leading spaces from prevConvo if present
2023-10-18 22:51:30 +02:00
prevConvo = strings . TrimPrefix ( prevConvo , " " )
2023-10-18 20:08:26 +02:00
2023-09-03 14:10:03 -04:00
var nextContext strings . Builder
nextContext . WriteString ( prevConvo )
nextContext . WriteString ( prompt )
2023-10-16 16:31:29 -07:00
request := map [ string ] any {
"prompt" : nextContext . String ( ) ,
"stream" : true ,
"n_predict" : llm . NumPredict ,
"n_keep" : llm . NumKeep ,
"temperature" : llm . Temperature ,
"top_k" : llm . TopK ,
"top_p" : llm . TopP ,
"tfs_z" : llm . TFSZ ,
"typical_p" : llm . TypicalP ,
"repeat_last_n" : llm . RepeatLastN ,
"repeat_penalty" : llm . RepeatPenalty ,
"presence_penalty" : llm . PresencePenalty ,
"frequency_penalty" : llm . FrequencyPenalty ,
"mirostat" : llm . Mirostat ,
"mirostat_tau" : llm . MirostatTau ,
"mirostat_eta" : llm . MirostatEta ,
"penalize_nl" : llm . PenalizeNewline ,
"seed" : llm . Seed ,
"stop" : llm . Stop ,
2023-08-30 16:35:03 -04:00
}
2023-10-02 14:53:16 -04:00
2023-10-16 11:15:55 +02:00
// Handling JSON marshaling with special characters unescaped.
2023-10-17 08:17:35 +02:00
buffer := & bytes . Buffer { }
enc := json . NewEncoder ( buffer )
2023-10-16 11:15:55 +02:00
enc . SetEscapeHTML ( false )
2023-10-16 16:31:29 -07:00
if err := enc . Encode ( request ) ; err != nil {
2023-10-16 11:15:55 +02:00
return fmt . Errorf ( "failed to marshal data: %v" , err )
2023-08-30 16:35:03 -04:00
}
2023-10-16 16:31:29 -07:00
endpoint := fmt . Sprintf ( "http://127.0.0.1:%d/completion" , llm . Port )
2023-10-17 08:17:35 +02:00
req , err := http . NewRequestWithContext ( ctx , http . MethodPost , endpoint , buffer )
2023-08-30 16:35:03 -04:00
if err != nil {
return fmt . Errorf ( "error creating POST request: %v" , err )
}
req . Header . Set ( "Content-Type" , "application/json" )
resp , err := http . DefaultClient . Do ( req )
if err != nil {
return fmt . Errorf ( "POST predict: %v" , err )
}
defer resp . Body . Close ( )
if resp . StatusCode >= 400 {
bodyBytes , err := io . ReadAll ( resp . Body )
if err != nil {
return fmt . Errorf ( "failed reading llm error response: %w" , err )
}
log . Printf ( "llm predict error: %s" , bodyBytes )
return fmt . Errorf ( "%s" , bodyBytes )
}
scanner := bufio . NewScanner ( resp . Body )
2023-10-04 14:09:00 -04:00
// increase the buffer size to avoid running out of space
buf := make ( [ ] byte , 0 , maxBufferSize )
scanner . Buffer ( buf , maxBufferSize )
2023-08-30 16:35:03 -04:00
for scanner . Scan ( ) {
select {
case <- ctx . Done ( ) :
// This handles the request cancellation
return ctx . Err ( )
default :
2023-10-16 16:31:56 -07:00
line := scanner . Bytes ( )
if len ( line ) == 0 {
2023-08-30 16:35:03 -04:00
continue
}
2023-10-16 16:31:56 -07:00
if evt , ok := bytes . CutPrefix ( line , [ ] byte ( "data: " ) ) ; ok {
2023-10-16 16:31:29 -07:00
var p prediction
2023-10-16 16:31:56 -07:00
if err := json . Unmarshal ( evt , & p ) ; err != nil {
2023-09-03 17:46:35 -04:00
return fmt . Errorf ( "error unmarshaling llm prediction response: %v" , err )
2023-08-30 16:35:03 -04:00
}
2023-09-05 15:03:24 -07:00
if p . Content != "" {
fn ( api . GenerateResponse { Response : p . Content } )
nextContext . WriteString ( p . Content )
}
2023-09-03 17:46:35 -04:00
if p . Stop {
2023-09-03 14:10:03 -04:00
embd , err := llm . Encode ( ctx , nextContext . String ( ) )
2023-08-30 16:35:03 -04:00
if err != nil {
return fmt . Errorf ( "encoding context: %v" , err )
}
2023-09-03 14:10:03 -04:00
2023-08-30 16:35:03 -04:00
fn ( api . GenerateResponse {
Done : true ,
Context : embd ,
2023-10-16 16:31:29 -07:00
PromptEvalCount : p . Timings . PromptN ,
PromptEvalDuration : parseDurationMs ( p . Timings . PromptMS ) ,
EvalCount : p . Timings . PredictedN ,
EvalDuration : parseDurationMs ( p . Timings . PredictedMS ) ,
2023-08-30 16:35:03 -04:00
} )
2023-09-03 17:46:35 -04:00
return nil
2023-08-30 16:35:03 -04:00
}
}
}
}
if err := scanner . Err ( ) ; err != nil {
2023-10-18 15:36:56 -04:00
if strings . Contains ( err . Error ( ) , "unexpected EOF" ) {
// this means the llama runner subprocess crashed
llm . Close ( )
if llm . StatusWriter != nil && llm . StatusWriter . LastErrMsg != "" {
return fmt . Errorf ( "llama runner exited: %v" , llm . StatusWriter . LastErrMsg )
}
return fmt . Errorf ( "llama runner exited, you may not have enough available memory to run this model" )
}
2023-08-30 16:35:03 -04:00
return fmt . Errorf ( "error reading llm response: %v" , err )
}
return nil
}
type TokenizeRequest struct {
Content string ` json:"content" `
}
type TokenizeResponse struct {
Tokens [ ] int ` json:"tokens" `
}
func ( llm * llama ) Encode ( ctx context . Context , prompt string ) ( [ ] int , error ) {
endpoint := fmt . Sprintf ( "http://127.0.0.1:%d/tokenize" , llm . Port )
data , err := json . Marshal ( TokenizeRequest { Content : prompt } )
if err != nil {
return nil , fmt . Errorf ( "marshaling encode data: %w" , err )
}
req , err := http . NewRequestWithContext ( ctx , http . MethodPost , endpoint , bytes . NewBuffer ( data ) )
if err != nil {
return nil , fmt . Errorf ( "encode request: %w" , err )
}
req . Header . Set ( "Content-Type" , "application/json" )
resp , err := http . DefaultClient . Do ( req )
if err != nil {
return nil , fmt . Errorf ( "do encode request: %w" , err )
}
defer resp . Body . Close ( )
body , err := io . ReadAll ( resp . Body )
if err != nil {
return nil , fmt . Errorf ( "read encode request: %w" , err )
}
if resp . StatusCode >= 400 {
log . Printf ( "llm encode error: %s" , body )
return nil , fmt . Errorf ( "%s" , body )
}
var encoded TokenizeResponse
if err := json . Unmarshal ( body , & encoded ) ; err != nil {
return nil , fmt . Errorf ( "unmarshal encode response: %w" , err )
}
return encoded . Tokens , nil
}
type DetokenizeRequest struct {
Tokens [ ] int ` json:"tokens" `
}
type DetokenizeResponse struct {
Content string ` json:"content" `
}
func ( llm * llama ) Decode ( ctx context . Context , tokens [ ] int ) ( string , error ) {
if len ( tokens ) == 0 {
return "" , nil
}
endpoint := fmt . Sprintf ( "http://127.0.0.1:%d/detokenize" , llm . Port )
data , err := json . Marshal ( DetokenizeRequest { Tokens : tokens } )
if err != nil {
return "" , fmt . Errorf ( "marshaling decode data: %w" , err )
}
req , err := http . NewRequestWithContext ( ctx , http . MethodPost , endpoint , bytes . NewBuffer ( data ) )
if err != nil {
return "" , fmt . Errorf ( "decode request: %w" , err )
}
req . Header . Set ( "Content-Type" , "application/json" )
resp , err := http . DefaultClient . Do ( req )
if err != nil {
return "" , fmt . Errorf ( "do decode request: %w" , err )
}
defer resp . Body . Close ( )
body , err := io . ReadAll ( resp . Body )
if err != nil {
return "" , fmt . Errorf ( "read decode request: %w" , err )
}
if resp . StatusCode >= 400 {
log . Printf ( "llm decode error: %s" , body )
return "" , fmt . Errorf ( "%s" , body )
}
var decoded DetokenizeResponse
if err := json . Unmarshal ( body , & decoded ) ; err != nil {
return "" , fmt . Errorf ( "unmarshal encode response: %w" , err )
}
return decoded . Content , nil
}
type EmbeddingRequest struct {
Content string ` json:"content" `
}
type EmbeddingResponse struct {
Embedding [ ] float64 ` json:"embedding" `
}
func ( llm * llama ) Embedding ( ctx context . Context , input string ) ( [ ] float64 , error ) {
endpoint := fmt . Sprintf ( "http://127.0.0.1:%d/embedding" , llm . Port )
data , err := json . Marshal ( TokenizeRequest { Content : input } )
if err != nil {
return nil , fmt . Errorf ( "error marshaling embed data: %w" , err )
}
req , err := http . NewRequestWithContext ( ctx , http . MethodPost , endpoint , bytes . NewBuffer ( data ) )
if err != nil {
return nil , fmt . Errorf ( "error creating embed request: %w" , err )
}
req . Header . Set ( "Content-Type" , "application/json" )
resp , err := http . DefaultClient . Do ( req )
if err != nil {
return nil , fmt . Errorf ( "POST embedding: %w" , err )
}
defer resp . Body . Close ( )
body , err := io . ReadAll ( resp . Body )
if err != nil {
return nil , fmt . Errorf ( "error reading embed response: %w" , err )
}
if resp . StatusCode >= 400 {
log . Printf ( "llm encode error: %s" , body )
return nil , fmt . Errorf ( "%s" , body )
}
var embedding EmbeddingResponse
if err := json . Unmarshal ( body , & embedding ) ; err != nil {
return nil , fmt . Errorf ( "unmarshal tokenize response: %w" , err )
}
return embedding . Embedding , nil
}
// Ping checks that the server subprocess is still running and responding to requests
func ( llm * llama ) Ping ( ctx context . Context ) error {
2023-09-07 13:55:37 -04:00
resp , err := http . Head ( fmt . Sprintf ( "http://127.0.0.1:%d" , llm . Port ) )
2023-08-30 16:35:03 -04:00
if err != nil {
return fmt . Errorf ( "ping resp: %w" , err )
}
if resp . StatusCode != http . StatusOK {
return fmt . Errorf ( "unexpected ping status: %s" , resp . Status )
}
return nil
}