Add cgo implementation for llama.cpp

Run the server.cpp directly inside the Go runtime via cgo
while retaining the LLM Go abstractions.
This commit is contained in:
Daniel Hiltgen 2023-11-13 17:20:34 -08:00
parent 5e7fd6906f
commit d4cd695759
27 changed files with 1189 additions and 765 deletions

View file

@ -2,7 +2,7 @@
ollama
app
dist
scripts
llm/llama.cpp/gguf
.env
.cache
test_data

3
.gitignore vendored
View file

@ -8,4 +8,5 @@ ollama
ggml-metal.metal
.cache
*.exe
.idea
.idea
test_data

2
go.mod
View file

@ -7,7 +7,7 @@ require (
github.com/gin-gonic/gin v1.9.1
github.com/olekukonko/tablewriter v0.0.5
github.com/spf13/cobra v1.7.0
github.com/stretchr/testify v1.8.3
github.com/stretchr/testify v1.8.4
golang.org/x/sync v0.3.0
)

3
go.sum
View file

@ -98,8 +98,9 @@ github.com/stretchr/testify v1.7.1/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/
github.com/stretchr/testify v1.8.0/go.mod h1:yNjHg4UonilssWZ8iaSj1OCr/vHnekPRkoO+kdMU+MU=
github.com/stretchr/testify v1.8.1/go.mod h1:w2LPCIKwWwSfY2zedu0+kehJoqGctiVI29o6fzry7u4=
github.com/stretchr/testify v1.8.2/go.mod h1:w2LPCIKwWwSfY2zedu0+kehJoqGctiVI29o6fzry7u4=
github.com/stretchr/testify v1.8.3 h1:RP3t2pwF7cMEbC1dqtB6poj3niw/9gnV4Cjg5oW5gtY=
github.com/stretchr/testify v1.8.3/go.mod h1:sz/lmYIOXD/1dqDmKjjqLyZ2RngseejIcXlSw2iwfAo=
github.com/stretchr/testify v1.8.4 h1:CcVxjf3Q8PM0mHUKJCdn+eZZtm5yQwehR5yeSVQQcUk=
github.com/stretchr/testify v1.8.4/go.mod h1:sz/lmYIOXD/1dqDmKjjqLyZ2RngseejIcXlSw2iwfAo=
github.com/twitchyliquid64/golang-asm v0.15.1 h1:SU5vSMR7hnwNxj24w34ZyCi/FmDZTkS4MhqMhdFk5YI=
github.com/twitchyliquid64/golang-asm v0.15.1/go.mod h1:a1lVb/DtPvCB8fslRZhAngC2+aY1QWCk3Cedj/Gdt08=
github.com/ugorji/go v1.2.7/go.mod h1:nF9osbDWLy6bDVv/Rtoh6QgnvNDpmCalQV5urGCCS6M=

325
llm/ext_server.go Normal file
View file

@ -0,0 +1,325 @@
package llm
/*
#cgo CFLAGS: -I${SRCDIR}/llama.cpp/gguf -I${SRCDIR}/llama.cpp/gguf/common
#cgo CFLAGS: -DNDEBUG -DLLAMA_SERVER_LIBRARY=1 -D_XOPEN_SOURCE=600 -DACCELERATE_NEW_LAPACK -DACCELERATE_LAPACK_ILP64
#cgo CFLAGS: -Wmissing-noreturn -Wall -Wextra -Wcast-qual -Wno-unused-function -Wno-array-bounds
#cgo CPPFLAGS: -Ofast -Wall -Wextra -Wno-unused-function -Wno-unused-variable -Wno-deprecated-declarations -Wno-unused-but-set-variable
#cgo darwin CFLAGS: -D_DARWIN_C_SOURCE
#cgo darwin CPPFLAGS: -DGGML_USE_ACCELERATE
#cgo darwin,arm64 CPPFLAGS: -DGGML_USE_METAL -DGGML_METAL_NDEBUG
#cgo darwin LDFLAGS: -lc++ -framework Accelerate
#cgo darwin,arm64 LDFLAGS: -framework Foundation -framework Metal -framework MetalKit -framework MetalPerformanceShaders
#cgo darwin,arm64 LDFLAGS: ${SRCDIR}/llama.cpp/gguf/build/metal/common/libcommon.a
#cgo darwin,arm64 LDFLAGS: ${SRCDIR}/llama.cpp/gguf/build/metal/examples/server/libext_server.a
#cgo darwin,arm64 LDFLAGS: ${SRCDIR}/llama.cpp/gguf/build/metal/libllama.a
#cgo darwin,arm64 LDFLAGS: ${SRCDIR}/llama.cpp/gguf/build/metal/libggml_static.a
#cgo darwin,amd64 LDFLAGS: ${SRCDIR}/llama.cpp/gguf/build/cpu/common/libcommon.a
#cgo darwin,amd64 LDFLAGS: ${SRCDIR}/llama.cpp/gguf/build/cpu/examples/server/libext_server.a
#cgo darwin,amd64 LDFLAGS: ${SRCDIR}/llama.cpp/gguf/build/cpu/libllama.a
#cgo darwin,amd64 LDFLAGS: ${SRCDIR}/llama.cpp/gguf/build/cpu/libggml_static.a
#cgo linux CFLAGS: -D_GNU_SOURCE
#cgo linux windows CFLAGS: -DGGML_CUDA_DMMV_X=32 -DGGML_CUDA_MMV_Y=1 -DGGML_CUDA_PEER_MAX_BATCH_SIZE=128 -DGGML_USE_CUBLAS
#cgo linux LDFLAGS: -L/usr/local/cuda/targets/x86_64-linux/lib -L/usr/local/cuda/lib64 -L/usr/local/cuda/targets/x86_64-linux/lib/stubs
#cgo linux LDFLAGS: ${SRCDIR}/llama.cpp/gguf/build/cuda/examples/server/libext_server.a
#cgo linux LDFLAGS: ${SRCDIR}/llama.cpp/gguf/build/cuda/common/libcommon.a
#cgo linux LDFLAGS: ${SRCDIR}/llama.cpp/gguf/build/cuda/libllama.a
#cgo linux LDFLAGS: ${SRCDIR}/llama.cpp/gguf/build/cuda/libggml_static.a
#cgo linux LDFLAGS: /usr/local/cuda/lib64/libcudart_static.a
#cgo linux LDFLAGS: /usr/local/cuda/lib64/libcublas_static.a
#cgo linux LDFLAGS: /usr/local/cuda/lib64/libcublasLt_static.a
#cgo linux LDFLAGS: /usr/local/cuda/lib64/libcudadevrt.a
#cgo linux LDFLAGS: /usr/local/cuda/lib64/libculibos.a
#cgo linux LDFLAGS: -lrt -lpthread -ldl -lstdc++ -lm
#cgo windows LDFLAGS: -L${SRCDIR}/llama.cpp/gguf/build/wincuda/dist/bin
#cgo windows LDFLAGS: -lext_server_shared -lpthread
#include <stdlib.h>
#include "examples/server/server.h"
*/
import "C"
import (
"bytes"
"context"
"encoding/json"
"errors"
"fmt"
"log"
"os"
"runtime"
"sync"
"time"
"unsafe"
"github.com/jmorganca/ollama/api"
)
func errWrap(resp C.ext_server_err) error {
if resp.code == 0 {
return nil
}
err := fmt.Errorf(C.GoString(resp.err))
C.free(unsafe.Pointer(resp.err))
return err
}
type llamaExtServer struct {
api.Options
}
// Note: current implementation does not support concurrent instantiations
var mutex sync.Mutex
func newLlamaExtServer(model string, adapters, projectors []string, numLayers int64, opts api.Options) (*llamaExtServer, error) {
if !mutex.TryLock() {
log.Printf("concurrent llm servers not yet supported, waiting for prior server to complete")
mutex.Lock()
}
server := &llamaExtServer{opts}
fileInfo, err := os.Stat(model)
if err != nil {
return nil, err
}
var sparams C.ext_server_params
sparams.model = C.CString(model)
defer C.free(unsafe.Pointer(sparams.model))
numGPU := NumGPU(numLayers, fileInfo.Size(), opts)
sparams.embedding = true
sparams.n_ctx = C.uint(opts.NumCtx)
sparams.n_batch = C.uint(opts.NumBatch)
sparams.n_gpu_layers = C.int(numGPU)
sparams.main_gpu = C.int(opts.MainGPU)
sparams.n_parallel = 2 // TODO - wire up concurrency
// Always use the value encoded in the model
sparams.rope_freq_base = 0.0
sparams.rope_freq_scale = 0.0
sparams.lora_adapters = nil
for i := 0; i < len(adapters); i++ {
la := (*C.ext_server_lora_adapter)(C.malloc(C.sizeof_struct_ext_server_lora_adapter))
defer C.free(unsafe.Pointer(la))
la.adapter = C.CString(adapters[i])
defer C.free(unsafe.Pointer(la.adapter))
la.scale = C.float(1.0) // TODO expose scale/weights up through ollama UX
la.next = nil
if i == 0 {
sparams.lora_adapters = la
} else {
tmp := sparams.lora_adapters
for ; tmp.next != nil; tmp = tmp.next {
}
tmp.next = la
}
}
// TODO - implement ME
// if len(projectors) > 0 {
// // TODO: applying multiple projectors is not supported by the llama.cpp server yet
// params = append(params, "--mmproj", projectors[0])
// }
if opts.NumThread > 0 {
sparams.n_threads = C.uint(opts.NumThread)
} else {
sparams.n_threads = C.uint(runtime.NumCPU())
}
sparams.memory_f16 = false
if opts.F16KV {
sparams.memory_f16 = true
}
sparams.use_mlock = false
if opts.UseMLock {
sparams.use_mlock = true
}
sparams.use_mmap = true
if !opts.UseMMap {
sparams.use_mmap = false
}
sparams.numa = false
if opts.UseNUMA {
sparams.numa = true
}
log.Printf("Initializing internal llama server")
err = errWrap(C.llama_server_init(&sparams))
if err != nil {
return nil, err
}
log.Printf("Starting internal llama main loop")
C.llama_server_start()
return server, nil
}
func (llm *llamaExtServer) Predict(ctx context.Context, predict PredictOpts, fn func(PredictResult)) error {
request := map[string]any{
"prompt": predict.Prompt,
"stream": true,
"n_predict": llm.NumPredict,
"n_keep": llm.NumKeep,
"temperature": llm.Temperature,
"top_k": llm.TopK,
"top_p": llm.TopP,
"tfs_z": llm.TFSZ,
"typical_p": llm.TypicalP,
"repeat_last_n": llm.RepeatLastN,
"repeat_penalty": llm.RepeatPenalty,
"presence_penalty": llm.PresencePenalty,
"frequency_penalty": llm.FrequencyPenalty,
"mirostat": llm.Mirostat,
"mirostat_tau": llm.MirostatTau,
"mirostat_eta": llm.MirostatEta,
"penalize_nl": llm.PenalizeNewline,
"seed": llm.Seed,
"stop": llm.Stop,
}
if predict.Format == "json" {
request["grammar"] = jsonGrammar
}
// Handling JSON marshaling with special characters unescaped.
buffer := &bytes.Buffer{}
enc := json.NewEncoder(buffer)
enc.SetEscapeHTML(false)
if err := enc.Encode(request); err != nil {
return fmt.Errorf("failed to marshal data: %w", err)
}
req := C.CString(buffer.String())
defer C.free(unsafe.Pointer(req))
cmpCtx := C.llama_server_completion(req)
if cmpCtx.task_id < 0 {
defer C.free(unsafe.Pointer(cmpCtx.err))
return fmt.Errorf(C.GoString(cmpCtx.err))
}
for {
select {
case <-ctx.Done():
// This handles the request cancellation
return errWrap(C.llama_server_completion_cancel(cmpCtx.task_id))
default:
result := C.llama_server_completion_next_result(cmpCtx.task_id)
if result.result_json != nil {
defer C.free(unsafe.Pointer(result.result_json))
}
var p prediction
if err := json.Unmarshal([]byte(C.GoString(result.result_json)), &p); err != nil {
err2 := errWrap(C.llama_server_completion_cancel(cmpCtx.task_id))
return errors.Join(fmt.Errorf("error unmarshaling llm prediction response: %w", err), err2)
}
if p.Content != "" {
fn(PredictResult{
// Model: predict.Model, // XXX remove or replace?
CreatedAt: time.Now().UTC(),
Content: p.Content,
})
}
if p.Stop {
fn(PredictResult{
// Model: predict.Model, // XXX remove or replace?
CreatedAt: time.Now().UTC(),
TotalDuration: time.Since(predict.CheckpointStart),
Done: true,
PromptEvalCount: p.Timings.PromptN,
PromptEvalDuration: parseDurationMs(p.Timings.PromptMS),
EvalCount: p.Timings.PredictedN,
EvalDuration: parseDurationMs(p.Timings.PredictedMS),
})
return nil
}
}
}
}
func (llm *llamaExtServer) Encode(ctx context.Context, prompt string) ([]int, error) {
data, err := json.Marshal(TokenizeRequest{Content: prompt})
if err != nil {
return nil, fmt.Errorf("marshaling encode data: %w", err)
}
req := C.CString(string(data))
defer C.free(unsafe.Pointer(req))
var resp C.ext_server_resp
err = errWrap(C.llama_server_tokenize(req, &resp))
if resp.json_resp != nil {
defer C.free(unsafe.Pointer(resp.json_resp))
}
var encoded TokenizeResponse
if err2 := json.Unmarshal([]byte(C.GoString(resp.json_resp)), &encoded); err2 != nil {
return nil, fmt.Errorf("unmarshal encode response: %w", err2)
}
return encoded.Tokens, err
}
func (llm *llamaExtServer) Decode(ctx context.Context, tokens []int) (string, error) {
if len(tokens) == 0 {
return "", nil
}
data, err := json.Marshal(DetokenizeRequest{Tokens: tokens})
if err != nil {
return "", fmt.Errorf("marshaling decode data: %w", err)
}
req := C.CString(string(data))
defer C.free(unsafe.Pointer(req))
var resp C.ext_server_resp
err = errWrap(C.llama_server_detokenize(req, &resp))
if resp.json_resp != nil {
defer C.free(unsafe.Pointer(resp.json_resp))
}
var decoded DetokenizeResponse
if err2 := json.Unmarshal([]byte(C.GoString(resp.json_resp)), &decoded); err2 != nil {
return "", fmt.Errorf("unmarshal encode response: %w", err2)
}
return decoded.Content, err
}
func (llm *llamaExtServer) Embedding(ctx context.Context, input string) ([]float64, error) {
data, err := json.Marshal(TokenizeRequest{Content: input})
if err != nil {
return nil, fmt.Errorf("error marshaling embed data: %w", err)
}
req := C.CString(string(data))
defer C.free(unsafe.Pointer(req))
var resp C.ext_server_resp
err = errWrap(C.llama_server_embedding(req, &resp))
if resp.json_resp != nil {
defer C.free(unsafe.Pointer(resp.json_resp))
}
if err != nil {
return nil, err
}
var embedding EmbeddingResponse
if err := json.Unmarshal([]byte(C.GoString(resp.json_resp)), &embedding); err != nil {
return nil, fmt.Errorf("unmarshal tokenize response: %w", err)
}
return embedding.Embedding, nil
}
func (llm *llamaExtServer) Ping(ctx context.Context) error {
// TODO - consider some mechanism to check if the main loop and llama.cpp are in a good state
return nil
}
func (llm *llamaExtServer) Close() {
C.llama_server_stop()
mutex.Unlock()
}

57
llm/gpu_cuda.go Normal file
View file

@ -0,0 +1,57 @@
//go:build linux || windows
package llm
import (
"errors"
"log"
"github.com/jmorganca/ollama/api"
)
/*
#cgo windows LDFLAGS: -L"/Program Files/NVIDIA Corporation/NVSMI/"
#cgo linux LDFLAGS: -lnvidia-ml
#include <stdlib.h>
#include "examples/server/server.h"
*/
import "C"
// CheckVRAM returns the free VRAM in bytes on Linux machines with NVIDIA GPUs
func CheckVRAM() (int64, error) {
return int64(C.check_vram()), nil
}
func NumGPU(numLayer, fileSizeBytes int64, opts api.Options) int {
if opts.NumGPU != -1 {
return opts.NumGPU
}
freeBytes, err := CheckVRAM()
if err != nil {
if !errors.Is(err, errNvidiaSMI) {
log.Print(err.Error())
}
// nvidia driver not installed or no nvidia GPU found
return 0
}
/*
Calculate bytes per layer, this will roughly be the size of the model file divided by the number of layers.
We can store the model weights and the kv cache in vram,
to enable kv chache vram storage add two additional layers to the number of layers retrieved from the model file.
*/
bytesPerLayer := fileSizeBytes / numLayer
// 75% of the absolute max number of layers we can fit in available VRAM, off-loading too many layers to the GPU can cause OOM errors
layers := int(freeBytes/bytesPerLayer) * 3 / 4
// TODO - not sure on this part... if we can't fit all the layers, just fallback to CPU
// if int64(layers) < numLayer {
// log.Printf("%d MB VRAM available, insufficient to load current model (reuires %d MB) - falling back to CPU %d", freeBytes/(1024*1024), fileSizeBytes/(1024*1024))
// return 0
// }
log.Printf("%d MB VRAM available, loading up to %d GPU layers out of %d", freeBytes/(1024*1024), layers, numLayer)
return layers
}

19
llm/gpu_darwin.go Normal file
View file

@ -0,0 +1,19 @@
//go:build darwin
package llm
import (
"github.com/jmorganca/ollama/api"
)
// CheckVRAM returns the free VRAM in bytes on Linux machines with NVIDIA GPUs
func CheckVRAM() (int64, error) {
// TODO - assume metal, and return free memory?
return 0, errNvidiaSMI
}
func NumGPU(numLayer, fileSizeBytes int64, opts api.Options) int {
// default to enable metal on macOS
return 1
}

View file

@ -0,0 +1,34 @@
# common logic accross linux and darwin
init_vars() {
PATCHES="0001-Expose-callable-API-for-server.patch"
CMAKE_DEFS="-DLLAMA_ACCELERATE=on"
# TODO - LLAMA_K_QUANTS is stale and needs to be mapped to newer cmake settings
CMAKE_TARGETS="--target ggml --target ggml_static --target llama --target build_info --target common --target ext_server"
if echo "${CGO_CFLAGS}" | grep -- '-g' > /dev/null ; then
CMAKE_DEFS="-DCMAKE_BUILD_TYPE=RelWithDebInfo -DCMAKE_VERBOSE_MAKEFILE=on -DLLAMA_GPROF=on ${CMAKE_DEFS}"
else
# TODO - add additional optimization flags...
CMAKE_DEFS="-DCMAKE_BUILD_TYPE=Release ${CMAKE_DEFS}"
fi
}
git_module_setup() {
# TODO add flags to skip the init/patch logic to make it easier to mod llama.cpp code in-repo
git submodule init
git submodule update --force gguf
}
apply_patches() {
# Workaround git apply not handling creation well for iteration
rm -f gguf/examples/server/server.h
for patch in ${PATCHES} ; do
git -C gguf apply ../patches/${patch}
done
}
build() {
cmake -S gguf -B ${BUILD_DIR} ${CMAKE_DEFS}
cmake --build ${BUILD_DIR} ${CMAKE_TARGETS} -j8
}

36
llm/llama.cpp/gen_darwin.sh Executable file
View file

@ -0,0 +1,36 @@
#!/bin/sh
# This script is intended to run inside the go generate
# working directory must be ../llm/llama.cpp
# TODO - add hardening to detect missing tools (cmake, etc.)
set -ex
set -o pipefail
echo "Starting darwin generate script"
source $(dirname $0)/gen_common.sh
init_vars
CMAKE_DEFS="-DCMAKE_OSX_DEPLOYMENT_TARGET=11.0 ${CMAKE_DEFS}"
case "${GOARCH}" in
"amd64")
CMAKE_DEFS="-DLLAMA_METAL=off -DCMAKE_SYSTEM_PROCESSOR=x86_64 -DCMAKE_OSX_ARCHITECTURES=x86_64 ${CMAKE_DEFS}"
BUILD_DIR="gguf/build/cpu"
;;
"arm64")
CMAKE_DEFS="-DLLAMA_METAL=on -DCMAKE_SYSTEM_PROCESSOR=arm64 -DCMAKE_OSX_ARCHITECTURES=arm64 ${CMAKE_DEFS}"
BUILD_DIR="gguf/build/metal"
;;
*)
echo "GOARCH must be set"
echo "this script is meant to be run from within go generate"
exit 1
;;
esac
git_module_setup
apply_patches
build
# Enable local debug/run usecase
if [ -e "gguf/ggml-metal.metal" ]; then
cp gguf/ggml-metal.metal ../../
fi

17
llm/llama.cpp/gen_linux.sh Executable file
View file

@ -0,0 +1,17 @@
#!/bin/sh
# This script is intended to run inside the go generate
# working directory must be ../llm/llama.cpp
set -ex
set -o pipefail
# TODO - stopped here - map the variables from above over and refine the case statement below
echo "Starting linux generate script"
source $(dirname $0)/gen_common.sh
init_vars
CMAKE_DEFS="-DLLAMA_CUBLAS=on -DLLAMA_NATIVE=off -DLLAMA_AVX=on -DLLAMA_AVX2=off -DLLAMA_AVX512=off -DLLAMA_FMA=off -DLLAMA_F16C=off ${CMAKE_DEFS}"
BUILD_DIR="gguf/build/cuda"
git_module_setup
apply_patches
build

View file

@ -0,0 +1,51 @@
#!powershell
$ErrorActionPreference = "Stop"
function init_vars {
$script:buildDir="gguf/build/wincuda"
$script:installDir="gguf/build/wincuda/dist"
$script:patches = @("0001-Expose-callable-API-for-server.patch")
$script:cmakeDefs = @("-DLLAMA_NATIVE=off", "-DLLAMA_F16C=off", "-DLLAMA_FMA=off", "-DLLAMA_AVX512=off", "-DLLAMA_AVX2=off", "-DLLAMA_AVX=on", "-DLLAMA_K_QUANTS=on", "-DLLAMA_ACCELERATE=on", "-DLLAMA_CUBLAS=ON","-DCMAKE_VERBOSE_MAKEFILE=ON","-DBUILD_SHARED_LIBS=on","-A","x64")
if ($env:CGO_CFLAGS -contains "-g") {
$script:cmakeDefs += @("-DCMAKE_VERBOSE_MAKEFILE=on")
$script:config += "RelWithDebInfo"
} else {
$script:config += "Release"
}
}
function git_module_setup {
# TODO add flags to skip the init/patch logic to make it easier to mod llama.cpp code in-repo
& git submodule init
& git submodule update --force gguf
}
function apply_patches {
rm -erroraction ignore -path "gguf/examples/server/server.h"
foreach ($patch in $patches) {
write-host "Applying patch $patch"
& git -C gguf apply ../patches/$patch
}
}
function build {
write-host "generating config with: cmake -S gguf -B $buildDir $cmakeDefs"
& cmake --version
& cmake -S gguf -B $buildDir $cmakeDefs
write-host "building with: cmake --build $buildDir --config $config"
& cmake --build $buildDir --config $config
}
function install {
rm -erroraction ignore -recurse -force -path $installDir
& cmake --install $buildDir --prefix $installDir --config $config
}
init_vars
git_module_setup
apply_patches
build
install

View file

@ -0,0 +1,3 @@
package llm
//go:generate sh ./gen_darwin.sh

View file

@ -1,9 +0,0 @@
package llm
//go:generate git submodule init
//go:generate git submodule update --force gguf
//go:generate git -C gguf apply ../patches/0001-update-default-log-target.patch
//go:generate cmake -S gguf -B gguf/build/cpu -DLLAMA_METAL=off -DLLAMA_ACCELERATE=on -DLLAMA_K_QUANTS=on -DCMAKE_SYSTEM_NAME=Darwin -DCMAKE_SYSTEM_PROCESSOR=x86_64 -DCMAKE_OSX_ARCHITECTURES=x86_64 -DCMAKE_OSX_DEPLOYMENT_TARGET=11.0 -DLLAMA_NATIVE=off -DLLAMA_AVX=on -DLLAMA_AVX2=off -DLLAMA_AVX512=off -DLLAMA_FMA=off -DLLAMA_F16C=on
//go:generate cmake --build gguf/build/cpu --target server --config Release
//go:generate mv gguf/build/cpu/bin/server gguf/build/cpu/bin/ollama-runner

View file

@ -1,9 +0,0 @@
package llm
//go:generate git submodule init
//go:generate git submodule update --force gguf
//go:generate git -C gguf apply ../patches/0001-update-default-log-target.patch
//go:generate cmake -S gguf -B gguf/build/metal -DLLAMA_METAL=on -DLLAMA_ACCELERATE=on -DLLAMA_K_QUANTS=on -DCMAKE_SYSTEM_PROCESSOR=arm64 -DCMAKE_OSX_ARCHITECTURES=arm64 -DCMAKE_OSX_DEPLOYMENT_TARGET=11.0
//go:generate cmake --build gguf/build/metal --target server --config Release
//go:generate mv gguf/build/metal/bin/server gguf/build/metal/bin/ollama-runner

View file

@ -1,14 +1,3 @@
package llm
//go:generate git submodule init
//go:generate git submodule update --force gguf
//go:generate git -C gguf apply ../patches/0001-copy-cuda-runtime-libraries.patch
//go:generate git -C gguf apply ../patches/0001-update-default-log-target.patch
//go:generate cmake -S gguf -B gguf/build/cpu -DLLAMA_K_QUANTS=on -DLLAMA_NATIVE=off -DLLAMA_AVX=on -DLLAMA_AVX2=off -DLLAMA_AVX512=off -DLLAMA_FMA=off -DLLAMA_F16C=off
//go:generate cmake --build gguf/build/cpu --target server --config Release
//go:generate mv gguf/build/cpu/bin/server gguf/build/cpu/bin/ollama-runner
//go:generate cmake -S gguf -B gguf/build/cuda -DLLAMA_CUBLAS=on -DLLAMA_ACCELERATE=on -DLLAMA_K_QUANTS=on -DLLAMA_NATIVE=off -DLLAMA_AVX=on -DLLAMA_AVX2=off -DLLAMA_AVX512=off -DLLAMA_FMA=off -DLLAMA_F16C=off -DLLAMA_CUDA_PEER_MAX_BATCH_SIZE=0
//go:generate cmake --build gguf/build/cuda --target server --config Release
//go:generate mv gguf/build/cuda/bin/server gguf/build/cuda/bin/ollama-runner
//go:generate sh ./gen_linux.sh

View file

@ -1,17 +1,3 @@
package llm
//go:generate git submodule init
//go:generate git submodule update --force gguf
//go:generate git -C gguf apply ../patches/0001-update-default-log-target.patch
//go:generate cmake -S gguf -B gguf/build/cpu -DLLAMA_K_QUANTS=on -DLLAMA_NATIVE=off -DLLAMA_AVX=on -DLLAMA_AVX2=off -DLLAMA_AVX512=off -DLLAMA_FMA=off -DLLAMA_F16C=off
//go:generate cmake --build gguf/build/cpu --target server --config Release
//go:generate cmd /c move gguf\build\cpu\bin\Release\server.exe gguf\build\cpu\bin\Release\ollama-runner.exe
//go:generate cmake -S ggml -B ggml/build/cuda -DLLAMA_CUBLAS=on -DLLAMA_ACCELERATE=on -DLLAMA_K_QUANTS=on
//go:generate cmake --build ggml/build/cuda --target server --config Release
//go:generate cmd /c move ggml\build\cuda\bin\Release\server.exe ggml\build\cuda\bin\Release\ollama-runner.exe
//go:generate cmake -S gguf -B gguf/build/cuda -DLLAMA_CUBLAS=on -DLLAMA_ACCELERATE=on -DLLAMA_K_QUANTS=on -DLLAMA_NATIVE=off -DLLAMA_AVX=on -DLLAMA_AVX2=off -DLLAMA_AVX512=off -DLLAMA_FMA=off -DLLAMA_F16C=off
//go:generate cmake --build gguf/build/cuda --target server --config Release
//go:generate cmd /c move gguf\build\cuda\bin\Release\server.exe gguf\build\cuda\bin\Release\ollama-runner.exe
//go:generate powershell -ExecutionPolicy Bypass -File ./gen_windows.ps1

View file

@ -0,0 +1,422 @@
From 64b3fbb150d12b3ca63ac2fb4e57bc46f41d2ccd Mon Sep 17 00:00:00 2001
From: Daniel Hiltgen <daniel@ollama.com>
Date: Mon, 13 Nov 2023 12:25:58 -0800
Subject: [PATCH] Expose callable API for server
This adds an extern "C" interface within the example server
---
examples/server/CMakeLists.txt | 24 ++++
examples/server/server.cpp | 247 +++++++++++++++++++++++++++++++++
examples/server/server.h | 83 +++++++++++
ggml-cuda.cu | 1 +
4 files changed, 355 insertions(+)
create mode 100644 examples/server/server.h
diff --git a/examples/server/CMakeLists.txt b/examples/server/CMakeLists.txt
index 859cd12..4ea47a7 100644
--- a/examples/server/CMakeLists.txt
+++ b/examples/server/CMakeLists.txt
@@ -11,3 +11,27 @@ if (WIN32)
TARGET_LINK_LIBRARIES(${TARGET} PRIVATE ws2_32)
endif()
target_compile_features(${TARGET} PRIVATE cxx_std_11)
+
+set(TARGET ext_server)
+option(LLAMA_SERVER_VERBOSE "Build verbose logging option for Server" ON)
+add_library(${TARGET} STATIC server.cpp)
+target_include_directories(${TARGET} PRIVATE ../../common)
+target_include_directories(${TARGET} PRIVATE ../..)
+target_compile_features(${TARGET} PRIVATE cxx_std_11)
+target_compile_definitions(${TARGET} PUBLIC LLAMA_SERVER_LIBRARY=1)
+target_link_libraries(${TARGET} PRIVATE common llama llava ${CMAKE_THREAD_LIBS_INIT})
+
+if (BUILD_SHARED_LIBS)
+ set_target_properties(ext_server PROPERTIES POSITION_INDEPENDENT_CODE ON)
+ target_compile_definitions(ext_server PRIVATE LLAMA_SHARED LLAMA_BUILD)
+ add_library(ext_server_shared SHARED $<TARGET_OBJECTS:ext_server>)
+ target_link_libraries(ext_server_shared PRIVATE ggml llama llava common ${CMAKE_THREAD_LIBS_INIT})
+ install(TARGETS ext_server_shared LIBRARY)
+endif()
+
+if (CUDAToolkit_FOUND)
+ target_include_directories(${TARGET} PRIVATE ${CMAKE_CUDA_TOOLKIT_INCLUDE_DIRECTORIES})
+ if (WIN32)
+ target_link_libraries(ext_server_shared PRIVATE nvml)
+ endif()
+endif()
\ No newline at end of file
diff --git a/examples/server/server.cpp b/examples/server/server.cpp
index 895f751..f939590 100644
--- a/examples/server/server.cpp
+++ b/examples/server/server.cpp
@@ -5,6 +5,9 @@
#include "../llava/clip.h"
#include "stb_image.h"
+#if defined(LLAMA_SERVER_LIBRARY)
+#include "server.h"
+#endif
#ifndef NDEBUG
// crash the server in debug mode, otherwise send an http 500 error
@@ -2631,6 +2634,7 @@ static void append_to_generated_text_from_generated_token_probs(llama_server_con
}
}
+#ifndef LLAMA_SERVER_LIBRARY
int main(int argc, char **argv)
{
// own arguments required by this example
@@ -3065,3 +3069,246 @@ int main(int argc, char **argv)
llama_backend_free();
return 0;
}
+
+#else // LLAMA_SERVER_LIBRARY
+// Expose the llama server as a callable extern "C" API
+llama_server_context llama;
+std::atomic<bool> ext_server_running(false);
+std::thread ext_server_thread;
+inline ext_server_err makeErr(uint32_t code, std::string msg) {
+ if (code == 0) {
+ return ext_server_err{0, NULL};
+ }
+ const std::string::size_type size = msg.size();
+ ext_server_err ret = {
+ code,
+ new char[size + 1],
+ };
+ memcpy(ret.err, msg.c_str(), size + 1);
+ return ret;
+}
+
+ext_server_err llama_server_init(ext_server_params *sparams)
+{
+ log_set_target(stdout);
+ gpt_params params;
+ params.n_ctx = sparams->n_ctx;
+ params.n_batch = sparams->n_batch;
+ params.n_threads = sparams->n_threads;
+ params.n_parallel = sparams->n_parallel;
+ params.rope_freq_base = sparams->rope_freq_base;
+ params.rope_freq_scale = sparams->rope_freq_scale;
+
+ if (sparams->memory_f16) {
+ params.cache_type_k = "f16";
+ params.cache_type_v = "f16";
+ } else {
+ params.cache_type_k = "f32";
+ params.cache_type_v = "f32";
+ }
+
+ params.n_gpu_layers = sparams->n_gpu_layers;
+ params.main_gpu = sparams->main_gpu;
+ params.use_mlock = sparams->use_mlock;
+ params.use_mmap = sparams->use_mmap;
+ params.numa = sparams->numa;
+ params.embedding = sparams->embedding;
+ if (sparams->model != NULL) {
+ params.model = sparams->model;
+ }
+
+ for (ext_server_lora_adapter *la = sparams->lora_adapters; la != NULL; la = la->next) {
+ params.lora_adapter.push_back(std::make_tuple(la->adapter, la->scale));
+ }
+
+ try {
+ llama_backend_init(params.numa);
+
+ // load the model
+ if (!llama.load_model(params))
+ {
+ // TODO - consider modifying the logging logic or patching load_model so we can capture more detailed error messages
+ // and pass them back to the caller for better UX
+ return makeErr(1, "error loading model " + params.model);
+ }
+
+ llama.initialize();
+ } catch (std::exception &e) {
+ return makeErr(1, e.what());
+ } catch (...) {
+ return makeErr(1, "Unknown Exception initializing llama server");
+ }
+ return makeErr(0, "");
+}
+
+void llama_server_start()
+{
+ // TODO mutex to protect thread creation
+ ext_server_thread = std::thread([&]()
+ {
+ ext_server_running = true;
+ try {
+ LOG_TEE("llama server main loop starting\n");
+ ggml_time_init();
+ while (ext_server_running.load())
+ {
+ if (!llama.update_slots()) {
+ LOG_TEE("unexpected error in llama server update_slots - exiting main loop\n");
+ break;
+ }
+ }
+ } catch (std::exception &e) {
+ LOG_TEE("caught exception in llama server main loop: %s\n", e.what());
+ } catch (...) {
+ LOG_TEE("caught unknown exception in llama server main loop\n");
+ }
+ LOG_TEE("\nllama server shutting down\n");
+ llama_backend_free();
+ });
+}
+
+void llama_server_stop() {
+ // TODO - too verbose, remove once things are solid
+ LOG_TEE("requesting llama server shutdown\n");
+ ext_server_running = false;
+ ext_server_thread.join();
+ LOG_TEE("llama server shutdown complete\n");
+}
+
+ext_server_completion_resp llama_server_completion(const char *json_req) {
+ std::string msg;
+ ext_server_completion_resp resp = {
+ 0,
+ NULL,
+ };
+ try {
+ json data = json::parse(json_req);
+ resp.task_id = llama.request_completion(data, false, false, -1);
+ return resp;
+ } catch (std::exception &e) {
+ msg = e.what();
+ } catch (...) {
+ msg = "Unknown Exception during completion";
+ }
+ const std::string::size_type size = msg.size();
+ resp.task_id = 0;
+ resp.err = new char[size + 1];
+ memcpy(resp.err, msg.c_str(), size + 1);
+ return resp;
+}
+
+ext_task_result llama_server_completion_next_result(const int task_id) {
+ std::string msg;
+ ext_task_result resp = {-1,false,false,NULL};
+ try {
+ task_result result = llama.next_result(task_id);
+ std::string result_json = result.result_json.dump(-1, ' ', false, json::error_handler_t::replace);
+ const std::string::size_type size = result_json.size();
+ resp.id = result.id;
+ resp.stop = result.stop;
+ resp.error = result.error;
+ resp.result_json = new char[size + 1];
+ memcpy(resp.result_json, result_json.c_str(), size + 1);
+ if (result.error) {
+ llama.request_cancel(task_id);
+ } else if (result.stop) {
+ llama.request_cancel(task_id);
+ }
+ return resp;
+ } catch (std::exception &e) {
+ msg = e.what(); // TODO - json?
+ } catch (...) {
+ msg = "Unknown Exception during completion";
+ }
+ resp.error = true;
+ const std::string::size_type size = msg.size();
+ resp.result_json = new char[size + 1];
+ memcpy(resp.result_json, msg.c_str(), size + 1);
+ return resp;
+}
+
+ext_server_err llama_server_completion_cancel(const int task_id) {
+ try {
+ llama.request_cancel(task_id);
+ } catch (std::exception &e) {
+ return makeErr(1, e.what());
+ } catch (...) {
+ return makeErr(1, "Unknown Exception running llama server");
+ }
+ return makeErr(0, "");
+}
+
+
+ext_server_err llama_server_tokenize(const char *json_req, ext_server_resp *resp) {
+ resp->json_resp = NULL;
+ try {
+ const json body = json::parse(json_req);
+ std::vector<llama_token> tokens;
+ if (body.count("content") != 0)
+ {
+ tokens = llama.tokenize(body["content"], false);
+ }
+ const json data = format_tokenizer_response(tokens);
+ std::string result_json = data.dump();
+ const std::string::size_type size = result_json.size();
+ resp->json_resp = new char[size + 1];
+ memcpy(resp->json_resp, result_json.c_str(), size + 1);
+ } catch (std::exception &e) {
+ return makeErr(1, e.what());
+ } catch (...) {
+ return makeErr(1, "Unknown Exception during tokenize");
+ }
+ return makeErr(0, "");
+}
+
+ext_server_err llama_server_detokenize(const char *json_req, ext_server_resp *resp) {
+ resp->json_resp = NULL;
+ try {
+ const json body = json::parse(json_req);
+ std::string content;
+ if (body.count("tokens") != 0)
+ {
+ const std::vector<llama_token> tokens = body["tokens"];
+ content = tokens_to_str(llama.ctx, tokens.cbegin(), tokens.cend());
+ }
+ const json data = format_detokenized_response(content);
+ std::string result_json = data.dump();
+ const std::string::size_type size = result_json.size();
+ resp->json_resp = new char[size + 1];
+ memcpy(resp->json_resp, result_json.c_str(), size + 1);
+ } catch (std::exception &e) {
+ return makeErr(1, e.what());
+ } catch (...) {
+ return makeErr(1, "Unknown Exception during detokenize");
+ }
+ return makeErr(0, "");
+}
+
+ext_server_err llama_server_embedding(const char *json_req, ext_server_resp *resp) {
+ resp->json_resp = NULL;
+ try {
+ const json body = json::parse(json_req);
+ json prompt;
+ if (body.count("content") != 0)
+ {
+ prompt = body["content"];
+ }
+ else
+ {
+ prompt = "";
+ }
+ const int task_id = llama.request_completion({ {"prompt", prompt}, { "n_predict", 0} }, false, true, -1);
+ task_result result = llama.next_result(task_id);
+ std::string result_json = result.result_json.dump();
+ const std::string::size_type size = result_json.size();
+ resp->json_resp = new char[size + 1];
+ memcpy(resp->json_resp, result_json.c_str(), size + 1);
+ } catch (std::exception &e) {
+ return makeErr(1, e.what());
+ } catch (...) {
+ return makeErr(1, "Unknown Exception during detokenize");
+ }
+ return makeErr(0, "");
+}
+
+#endif // LLAMA_SERVER_LIBRARY
\ No newline at end of file
diff --git a/examples/server/server.h b/examples/server/server.h
new file mode 100644
index 0000000..4d03b1e
--- /dev/null
+++ b/examples/server/server.h
@@ -0,0 +1,83 @@
+#if defined(LLAMA_SERVER_LIBRARY)
+#ifndef LLAMA_SERVER_H
+#define LLAMA_SERVER_H
+#include <stddef.h>
+#include <stdint.h>
+#include <stdio.h>
+#include <stdbool.h>
+
+// This exposes extern C entrypoints into the llama_server
+// To enable the server compile with LLAMA_SERVER_LIBRARY
+
+#ifdef __cplusplus
+extern "C"
+{
+#endif
+ // TODO - clean the type def's up a bit for better consistency
+ typedef struct ext_server_err {
+ uint32_t code; // 0 on success, > 0 on error
+ char *err; // null if code == 0; else contains error message. Caller responsible for freeing memory
+ } ext_server_err;
+
+ typedef struct ext_server_lora_adapter {
+ char *adapter;
+ float scale;
+ struct ext_server_lora_adapter *next;
+ } ext_server_lora_adapter;
+ typedef struct ext_server_params
+ {
+ char *model;
+ uint32_t n_ctx; // text context, 0 = from model
+ uint32_t n_batch; // prompt processing maximum batch size
+ uint32_t n_threads; // number of threads to use for generation
+ int32_t n_parallel; // number of parallel sequences to decodewra
+ float rope_freq_base; // RoPE base frequency, 0 = from model
+ float rope_freq_scale; // RoPE frequency scaling factor, 0 = from model
+ bool memory_f16; // use f16 instead of f32 for memory kv
+ int32_t n_gpu_layers; // number of layers to store in VRAM (-1 - use default)
+ int32_t main_gpu; // the GPU that is used for scratch and small tensors
+ bool use_mlock; // force system to keep model in RAM
+ bool use_mmap; // use mmap if possible
+ bool numa; // attempt optimizations that help on some NUMA systems
+ bool embedding; // get only sentence embedding
+ ext_server_lora_adapter* lora_adapters;
+ } ext_server_params;
+
+ // Initialize the server once per process
+ ext_server_err llama_server_init(ext_server_params *sparams);
+
+ // Run the main loop
+ void llama_server_start();
+ // Stop the main loop
+ void llama_server_stop();
+
+ typedef struct ext_task_result
+ {
+ int id;
+ bool stop;
+ bool error;
+ char* result_json; // caller responsible to free this memory
+ } ext_task_result;
+
+ typedef struct ext_server_completion_resp {
+ int task_id; // < 0 on error, >= 0 on success
+ char *err; // null if task_id >= 0; else contains error message. Caller responsible for freeing memory
+ } ext_server_completion_resp;
+ ext_server_completion_resp llama_server_completion(const char *json_req);
+ ext_task_result llama_server_completion_next_result(const int task_id);
+ ext_server_err llama_server_completion_cancel(const int task_id);
+
+ // Caller responsible for freeing json_resp
+ typedef struct ext_server_resp {
+ char *json_resp; // Caller responsible for freeing string
+ } ext_server_resp;
+ ext_server_err llama_server_tokenize(const char *json_req, ext_server_resp *resp);
+ ext_server_err llama_server_detokenize(const char *json_req, ext_server_resp *resp);
+ ext_server_err llama_server_embedding(const char *json_req, ext_server_resp *resp);
+
+#ifdef __cplusplus
+}
+#endif
+
+#endif
+#endif // LLAMA_SERVER_LIBRARY
\ No newline at end of file
diff --git a/ggml-cuda.cu b/ggml-cuda.cu
index 85f7a29..ce51364 100644
--- a/ggml-cuda.cu
+++ b/ggml-cuda.cu
@@ -6410,6 +6410,7 @@ static cudaError_t ggml_cuda_cpy_tensor_2d(
CUDA_CHECK(cudaGetDevice(&id));
src_ptr = (char *) extra->data_device[id];
} else {
+ fprintf(stderr, "ggml_cuda_cpy_tensor_2d assert: backend: %d\n", src->backend);
GGML_ASSERT(false);
}
char * dst_ptr = (char *) dst;
--
2.39.3 (Apple Git-145)

View file

@ -1,27 +0,0 @@
From 5dd02993e8cc2ce309157736b95bb572f274a3fd Mon Sep 17 00:00:00 2001
From: Michael Yang <mxyng@pm.me>
Date: Wed, 20 Sep 2023 14:19:52 -0700
Subject: [PATCH] copy cuda runtime libraries
---
CMakeLists.txt | 4 ++++
1 file changed, 4 insertions(+)
diff --git a/CMakeLists.txt b/CMakeLists.txt
index 824d9f2..dd24137 100644
--- a/CMakeLists.txt
+++ b/CMakeLists.txt
@@ -274,6 +274,10 @@ if (LLAMA_CUBLAS)
set(LLAMA_EXTRA_LIBS ${LLAMA_EXTRA_LIBS} CUDA::cudart CUDA::cublas CUDA::cublasLt)
endif()
+ configure_file(${CUDAToolkit_LIBRARY_DIR}/libcudart.so ${CMAKE_RUNTIME_OUTPUT_DIRECTORY}/libcudart.so.${CUDAToolkit_VERSION_MAJOR}.0 COPYONLY)
+ configure_file(${CUDAToolkit_LIBRARY_DIR}/libcublas.so ${CMAKE_RUNTIME_OUTPUT_DIRECTORY}/libcublas.so.${CUDAToolkit_VERSION_MAJOR} COPYONLY)
+ configure_file(${CUDAToolkit_LIBRARY_DIR}/libcublasLt.so ${CMAKE_RUNTIME_OUTPUT_DIRECTORY}/libcublasLt.so.${CUDAToolkit_VERSION_MAJOR} COPYONLY)
+
if (NOT DEFINED CMAKE_CUDA_ARCHITECTURES)
# 52 == lowest CUDA 12 standard
# 60 == f16 CUDA intrinsics
--
2.42.0

View file

@ -1,25 +0,0 @@
From 6465fec6290f0a7f5d4d0fbe6bcf634e4810dde6 Mon Sep 17 00:00:00 2001
From: Michael Yang <mxyng@pm.me>
Date: Mon, 23 Oct 2023 10:39:34 -0700
Subject: [PATCH] default log stderr
---
common/log.h | 2 +-
1 file changed, 1 insertion(+), 1 deletion(-)
diff --git a/common/log.h b/common/log.h
index b8953fd..25522cd 100644
--- a/common/log.h
+++ b/common/log.h
@@ -90,7 +90,7 @@
// }
//
#ifndef LOG_TARGET
- #define LOG_TARGET log_handler()
+ #define LOG_TARGET nullptr
#endif
#ifndef LOG_TEE_TARGET
--
2.42.0

View file

@ -1,25 +1,12 @@
package llm
import (
"bufio"
"bytes"
"context"
"embed"
"encoding/json"
"errors"
"fmt"
"io"
"io/fs"
"log"
"math/rand"
"net/http"
"os"
"os/exec"
"path"
"path/filepath"
"runtime"
"strconv"
"strings"
"sync"
"time"
@ -55,107 +42,6 @@ number ::= ("-"? ([0-9] | [1-9] [0-9]*)) ("." [0-9]+)? ([eE] [-+]? [0-9]+)? ws
ws ::= ([ \t\n] ws)?
`
//go:embed llama.cpp/*/build/*/bin/*
var llamaCppEmbed embed.FS
type ModelRunner struct {
Path string // path to the model runner executable
Accelerated bool
}
func chooseRunners(workDir string) []ModelRunner {
buildPath := path.Join("llama.cpp", "gguf", "build")
var runners []ModelRunner
// set the runners based on the OS
// IMPORTANT: the order of the runners in the array is the priority order
switch runtime.GOOS {
case "darwin":
if runtime.GOARCH == "arm64" {
runners = []ModelRunner{{Path: path.Join(buildPath, "metal", "bin", "ollama-runner")}}
} else {
runners = []ModelRunner{{Path: path.Join(buildPath, "cpu", "bin", "ollama-runner")}}
}
case "linux":
runners = []ModelRunner{
{Path: path.Join(buildPath, "cuda", "bin", "ollama-runner"), Accelerated: true},
{Path: path.Join(buildPath, "cpu", "bin", "ollama-runner")},
}
case "windows":
// TODO: select windows GPU runner here when available
runners = []ModelRunner{
{Path: path.Join(buildPath, "cuda", "bin", "Release", "ollama-runner.exe"), Accelerated: true},
{Path: path.Join(buildPath, "cpu", "bin", "Release", "ollama-runner.exe")},
}
default:
log.Printf("unknown OS, running on CPU: %s", runtime.GOOS)
runners = []ModelRunner{
{Path: path.Join(buildPath, "cpu", "bin", "ollama-runner")},
}
}
runnerAvailable := false // if no runner files are found in the embed, this flag will cause a fast fail
for _, r := range runners {
// find all the files in the runner's bin directory
files, err := fs.Glob(llamaCppEmbed, path.Join(path.Dir(r.Path), "*"))
if err != nil {
// this is expected, ollama may be compiled without all runners packed in
log.Printf("%s runner not found: %v", r.Path, err)
continue
}
for _, f := range files {
runnerAvailable = true
srcFile, err := llamaCppEmbed.Open(f)
if err != nil {
log.Fatalf("read llama runner %s: %v", f, err)
}
defer srcFile.Close()
// create the directory in case it does not exist, filepath.Dir() converts the file path to the OS's format
destPath := filepath.Join(workDir, filepath.Dir(f))
if err := os.MkdirAll(destPath, 0o755); err != nil {
log.Fatalf("create runner temp dir %s: %v", filepath.Dir(f), err)
}
// create the path to the destination file, filepath.Base() converts the file path to the OS's format
destFile := filepath.Join(destPath, filepath.Base(f))
_, err = os.Stat(destFile)
switch {
case errors.Is(err, os.ErrNotExist):
destFile, err := os.OpenFile(destFile, os.O_WRONLY|os.O_CREATE|os.O_TRUNC, 0o755)
if err != nil {
log.Fatalf("write llama runner %s: %v", f, err)
}
defer destFile.Close()
if _, err := io.Copy(destFile, srcFile); err != nil {
log.Fatalf("copy llama runner %s: %v", f, err)
}
case err != nil:
log.Fatalf("stat llama runner %s: %v", f, err)
}
}
}
if !runnerAvailable {
log.Fatalf("gguf runner not found")
}
// return the runners to try in priority order
localRunnersByPriority := []ModelRunner{}
for _, r := range runners {
// clean the ModelRunner paths so that they match the OS we are running on
localRunnersByPriority = append(localRunnersByPriority, ModelRunner{
Path: filepath.Clean(path.Join(workDir, r.Path)),
Accelerated: r.Accelerated,
})
}
return localRunnersByPriority
}
type llamaModel struct {
hyperparameters llamaHyperparameters
}
@ -237,72 +123,6 @@ var (
errAvailableVRAM = errors.New("not enough VRAM available, falling back to CPU only")
)
// CheckVRAM returns the free VRAM in bytes on Linux machines with NVIDIA GPUs
func CheckVRAM() (int64, error) {
cmd := exec.Command("nvidia-smi", "--query-gpu=memory.free", "--format=csv,noheader,nounits")
var stdout bytes.Buffer
cmd.Stdout = &stdout
err := cmd.Run()
if err != nil {
return 0, errNvidiaSMI
}
var freeMiB int64
scanner := bufio.NewScanner(&stdout)
for scanner.Scan() {
line := scanner.Text()
if strings.Contains(line, "[Insufficient Permissions]") {
return 0, fmt.Errorf("GPU support may not enabled, check you have installed GPU drivers and have the necessary permissions to run nvidia-smi")
}
vram, err := strconv.ParseInt(strings.TrimSpace(line), 10, 64)
if err != nil {
return 0, fmt.Errorf("failed to parse available VRAM: %v", err)
}
freeMiB += vram
}
freeBytes := freeMiB * 1024 * 1024
if freeBytes < 2*format.GigaByte {
log.Printf("less than 2 GB VRAM available")
return 0, errAvailableVRAM
}
return freeBytes, nil
}
func NumGPU(numLayer, fileSizeBytes int64, opts api.Options) int {
if opts.NumGPU != -1 {
return opts.NumGPU
}
if runtime.GOOS == "linux" || runtime.GOOS == "windows" {
freeBytes, err := CheckVRAM()
if err != nil {
if !errors.Is(err, errNvidiaSMI) {
log.Print(err.Error())
}
// nvidia driver not installed or no nvidia GPU found
return 0
}
/*
Calculate bytes per layer, this will roughly be the size of the model file divided by the number of layers.
We can store the model weights and the kv cache in vram,
to enable kv chache vram storage add two additional layers to the number of layers retrieved from the model file.
*/
bytesPerLayer := fileSizeBytes / numLayer
// 75% of the absolute max number of layers we can fit in available VRAM, off-loading too many layers to the GPU can cause OOM errors
layers := int(freeBytes/bytesPerLayer) * 3 / 4
log.Printf("%d MB VRAM available, loading up to %d GPU layers", freeBytes/(1024*1024), layers)
return layers
}
// default to enable metal on macOS
return 1
}
// StatusWriter is a writer that captures error messages from the llama runner process
type StatusWriter struct {
ErrCh chan error
@ -331,204 +151,6 @@ func (w *StatusWriter) Write(b []byte) (int, error) {
return os.Stderr.Write(b)
}
func newLlama(model string, adapters, projectors []string, runners []ModelRunner, numLayers int64, opts api.Options) (*llama, error) {
fileInfo, err := os.Stat(model)
if err != nil {
return nil, err
}
if len(adapters) > 1 {
return nil, errors.New("ollama supports only one lora adapter, but multiple were provided")
}
numGPU := NumGPU(numLayers, fileInfo.Size(), opts)
params := []string{
"--model", model,
"--ctx-size", fmt.Sprintf("%d", opts.NumCtx),
"--batch-size", fmt.Sprintf("%d", opts.NumBatch),
"--n-gpu-layers", fmt.Sprintf("%d", numGPU),
"--embedding",
"--parallel", "2",
}
if opts.MainGPU > 0 {
params = append(params, "--main-gpu", fmt.Sprintf("%d", opts.MainGPU))
}
if opts.RopeFrequencyBase > 0 {
params = append(params, "--rope-freq-base", fmt.Sprintf("%f", opts.RopeFrequencyBase))
}
if opts.RopeFrequencyScale > 0 {
params = append(params, "--rope-freq-scale", fmt.Sprintf("%f", opts.RopeFrequencyScale))
}
if opts.NumGQA > 0 {
params = append(params, "--gqa", fmt.Sprintf("%d", opts.NumGQA))
}
if len(adapters) > 0 {
// TODO: applying multiple adapters is not supported by the llama.cpp server yet
params = append(params, "--lora", adapters[0])
}
if len(projectors) > 0 {
// TODO: applying multiple projectors is not supported by the llama.cpp server yet
params = append(params, "--mmproj", projectors[0])
}
if opts.NumThread > 0 {
params = append(params, "--threads", fmt.Sprintf("%d", opts.NumThread))
}
if !opts.F16KV {
params = append(params, "--memory-f32")
}
if opts.UseMLock {
params = append(params, "--mlock")
}
if !opts.UseMMap {
params = append(params, "--no-mmap")
}
if opts.UseNUMA {
params = append(params, "--numa")
}
var runnerErr error
// start the llama.cpp server with a retry in case the port is already in use
for _, runner := range runners {
if runner.Accelerated && numGPU == 0 {
log.Printf("skipping accelerated runner because num_gpu=0")
continue
}
if _, err := os.Stat(runner.Path); err != nil {
log.Printf("llama runner not found: %v", err)
continue
}
port := rand.Intn(65535-49152) + 49152 // get a random port in the ephemeral range
params := append(params, "--port", strconv.Itoa(port))
ctx, cancel := context.WithCancel(context.Background())
cmd := exec.CommandContext(
ctx,
runner.Path,
params...,
)
var libraryPaths []string
if libraryPath, ok := os.LookupEnv("LD_LIBRARY_PATH"); ok {
libraryPaths = append(libraryPaths, libraryPath)
}
libraryPaths = append(libraryPaths, filepath.Dir(runner.Path))
cmd.Env = append(os.Environ(), fmt.Sprintf("LD_LIBRARY_PATH=%s", strings.Join(libraryPaths, ":")))
cmd.Stdout = os.Stderr
statusWriter := NewStatusWriter()
cmd.Stderr = statusWriter
llm := &llama{Options: opts, Running: Running{Port: port, Cmd: cmd, Cancel: cancel, exitCh: make(chan error)}}
log.Print("starting llama runner")
if err := llm.Cmd.Start(); err != nil {
log.Printf("error starting the external llama runner: %v", err)
continue
}
// monitor the llama runner process and signal when it exits
go func() {
err := llm.Cmd.Wait()
// default to printing the exit message of the command process, it will probably just say 'exit staus 1'
errMsg := err.Error()
// try to set a better error message if llama runner logs captured an error
if statusWriter.LastErrMsg != "" {
errMsg = statusWriter.LastErrMsg
}
log.Println(errMsg)
// llm.Cmd.Wait() can only be called once, use this exit channel to signal that the process has exited
llm.exitOnce.Do(func() {
close(llm.exitCh)
})
}()
if err := waitForServer(llm); err != nil {
log.Printf("error starting llama runner: %v", err)
llm.Close()
// default the runnerErr to the error returned by the most recent llama runner process
runnerErr = err
// capture the error directly from the runner process, if any
select {
case runnerErr = <-statusWriter.ErrCh:
default:
// the runner process probably timed out
}
// try again
continue
}
// server started successfully
return llm, nil
}
if runnerErr != nil {
// this is the error returned from the llama runner process that failed most recently
return nil, runnerErr
}
return nil, fmt.Errorf("failed to start a llama runner")
}
func waitForServer(llm *llama) error {
start := time.Now()
expiresAt := time.Now().Add(3 * time.Minute) // be generous with timeout, large models can take a while to load
ticker := time.NewTicker(200 * time.Millisecond)
defer ticker.Stop()
log.Print("waiting for llama runner to start responding")
for {
select {
case <-llm.exitCh:
// failed to start subprocess
return fmt.Errorf("llama runner process has terminated")
case <-ticker.C:
if time.Now().After(expiresAt) {
// timeout
return fmt.Errorf("timed out waiting for llama runner to start")
}
if err := llm.Ping(context.Background()); err == nil {
// success
log.Printf("llama runner started in %f seconds", time.Since(start).Seconds())
return nil
}
}
}
}
func (llm *llama) Close() {
// signal the sub-process to terminate
llm.Cancel()
// wait for the command to exit to prevent race conditions with the next run
<-llm.exitCh
if llm.StatusWriter != nil && llm.StatusWriter.LastErrMsg != "" {
log.Printf("llama runner stopped with error: %v", llm.StatusWriter.LastErrMsg)
} else {
log.Print("llama runner stopped successfully")
}
}
func (llm *llama) SetOptions(opts api.Options) {
llm.Options = opts
}
type prediction struct {
Content string `json:"content"`
Model string `json:"model"`
@ -561,158 +183,6 @@ type PredictResult struct {
EvalDuration time.Duration
}
// IsRetryable checks if the line matches a condition that can be retried
func isRetryable(line []byte) bool {
return bytes.Contains(line, []byte("slot unavailable"))
}
func (llm *llama) Predict(ctx context.Context, predict PredictOpts, fn func(PredictResult)) error {
imageData := llm.ImageData
if len(predict.Images) > 0 {
for cnt, i := range predict.Images {
imageData = append(imageData, ImageData{Data: i, ID: cnt})
}
}
log.Printf("loaded %d images", len(imageData))
request := map[string]any{
"prompt": predict.Prompt,
"stream": true,
"n_predict": llm.NumPredict,
"n_keep": llm.NumKeep,
"main_gpu": llm.MainGPU,
"temperature": llm.Temperature,
"top_k": llm.TopK,
"top_p": llm.TopP,
"tfs_z": llm.TFSZ,
"typical_p": llm.TypicalP,
"repeat_last_n": llm.RepeatLastN,
"repeat_penalty": llm.RepeatPenalty,
"presence_penalty": llm.PresencePenalty,
"frequency_penalty": llm.FrequencyPenalty,
"mirostat": llm.Mirostat,
"mirostat_tau": llm.MirostatTau,
"mirostat_eta": llm.MirostatEta,
"penalize_nl": llm.PenalizeNewline,
"seed": llm.Seed,
"stop": llm.Stop,
"image_data": imageData,
}
if predict.Format == "json" {
request["grammar"] = jsonGrammar
}
retryDelay := 100 * time.Microsecond
for retries := 0; retries < maxRetries; retries++ {
if retries > 0 {
time.Sleep(retryDelay) // wait before retrying
retryDelay *= 2 // exponential backoff
}
// Handling JSON marshaling with special characters unescaped.
buffer := &bytes.Buffer{}
enc := json.NewEncoder(buffer)
enc.SetEscapeHTML(false)
if err := enc.Encode(request); err != nil {
return fmt.Errorf("failed to marshal data: %v", err)
}
endpoint := fmt.Sprintf("http://127.0.0.1:%d/completion", llm.Port)
req, err := http.NewRequestWithContext(ctx, http.MethodPost, endpoint, buffer)
if err != nil {
return fmt.Errorf("error creating POST request: %v", err)
}
req.Header.Set("Content-Type", "application/json")
resp, err := http.DefaultClient.Do(req)
if err != nil {
return fmt.Errorf("POST predict: %v", err)
}
defer resp.Body.Close()
if resp.StatusCode >= 400 {
bodyBytes, err := io.ReadAll(resp.Body)
if err != nil {
return fmt.Errorf("failed reading llm error response: %w", err)
}
log.Printf("llm predict error: %s", bodyBytes)
return fmt.Errorf("%s", bodyBytes)
}
scanner := bufio.NewScanner(resp.Body)
// increase the buffer size to avoid running out of space
buf := make([]byte, 0, maxBufferSize)
scanner.Buffer(buf, maxBufferSize)
retryNeeded := false
for scanner.Scan() {
select {
case <-ctx.Done():
// This handles the request cancellation
return ctx.Err()
default:
line := scanner.Bytes()
if len(line) == 0 {
continue
}
if isRetryable(line) {
retryNeeded = true
break
}
evt, ok := bytes.CutPrefix(line, []byte("data: "))
if !ok {
return fmt.Errorf("error parsing llm response stream: %s", line)
}
var p prediction
if err := json.Unmarshal(evt, &p); err != nil {
return fmt.Errorf("error unmarshaling llm prediction response: %v", err)
}
if p.Content != "" {
fn(PredictResult{
Content: p.Content,
})
}
if p.Stop {
fn(PredictResult{
Done: true,
PromptEvalCount: p.Timings.PromptN,
PromptEvalDuration: parseDurationMs(p.Timings.PromptMS),
EvalCount: p.Timings.PredictedN,
EvalDuration: parseDurationMs(p.Timings.PredictedMS),
})
return nil
}
}
}
if err := scanner.Err(); err != nil {
if strings.Contains(err.Error(), "unexpected EOF") {
// this means the llama runner subprocess crashed
llm.Close()
if llm.StatusWriter != nil && llm.StatusWriter.LastErrMsg != "" {
return fmt.Errorf("llama runner exited: %v", llm.StatusWriter.LastErrMsg)
}
return fmt.Errorf("llama runner exited, you may not have enough available memory to run this model")
}
return fmt.Errorf("error reading llm response: %v", err)
}
if !retryNeeded {
return nil // success
}
}
// should never reach here ideally
return fmt.Errorf("max retries exceeded")
}
type TokenizeRequest struct {
Content string `json:"content"`
}
@ -721,43 +191,6 @@ type TokenizeResponse struct {
Tokens []int `json:"tokens"`
}
func (llm *llama) Encode(ctx context.Context, prompt string) ([]int, error) {
endpoint := fmt.Sprintf("http://127.0.0.1:%d/tokenize", llm.Port)
data, err := json.Marshal(TokenizeRequest{Content: prompt})
if err != nil {
return nil, fmt.Errorf("marshaling encode data: %w", err)
}
req, err := http.NewRequestWithContext(ctx, http.MethodPost, endpoint, bytes.NewBuffer(data))
if err != nil {
return nil, fmt.Errorf("encode request: %w", err)
}
req.Header.Set("Content-Type", "application/json")
resp, err := http.DefaultClient.Do(req)
if err != nil {
return nil, fmt.Errorf("do encode request: %w", err)
}
defer resp.Body.Close()
body, err := io.ReadAll(resp.Body)
if err != nil {
return nil, fmt.Errorf("read encode request: %w", err)
}
if resp.StatusCode >= 400 {
log.Printf("llm encode error: %s", body)
return nil, fmt.Errorf("%s", body)
}
var encoded TokenizeResponse
if err := json.Unmarshal(body, &encoded); err != nil {
return nil, fmt.Errorf("unmarshal encode response: %w", err)
}
return encoded.Tokens, nil
}
type DetokenizeRequest struct {
Tokens []int `json:"tokens"`
}
@ -766,46 +199,6 @@ type DetokenizeResponse struct {
Content string `json:"content"`
}
func (llm *llama) Decode(ctx context.Context, tokens []int) (string, error) {
if len(tokens) == 0 {
return "", nil
}
endpoint := fmt.Sprintf("http://127.0.0.1:%d/detokenize", llm.Port)
data, err := json.Marshal(DetokenizeRequest{Tokens: tokens})
if err != nil {
return "", fmt.Errorf("marshaling decode data: %w", err)
}
req, err := http.NewRequestWithContext(ctx, http.MethodPost, endpoint, bytes.NewBuffer(data))
if err != nil {
return "", fmt.Errorf("decode request: %w", err)
}
req.Header.Set("Content-Type", "application/json")
resp, err := http.DefaultClient.Do(req)
if err != nil {
return "", fmt.Errorf("do decode request: %w", err)
}
defer resp.Body.Close()
body, err := io.ReadAll(resp.Body)
if err != nil {
return "", fmt.Errorf("read decode request: %w", err)
}
if resp.StatusCode >= 400 {
log.Printf("llm decode error: %s", body)
return "", fmt.Errorf("%s", body)
}
var decoded DetokenizeResponse
if err := json.Unmarshal(body, &decoded); err != nil {
return "", fmt.Errorf("unmarshal encode response: %w", err)
}
return decoded.Content, nil
}
type EmbeddingRequest struct {
Content string `json:"content"`
}
@ -813,52 +206,3 @@ type EmbeddingRequest struct {
type EmbeddingResponse struct {
Embedding []float64 `json:"embedding"`
}
func (llm *llama) Embedding(ctx context.Context, input string) ([]float64, error) {
endpoint := fmt.Sprintf("http://127.0.0.1:%d/embedding", llm.Port)
data, err := json.Marshal(TokenizeRequest{Content: input})
if err != nil {
return nil, fmt.Errorf("error marshaling embed data: %w", err)
}
req, err := http.NewRequestWithContext(ctx, http.MethodPost, endpoint, bytes.NewBuffer(data))
if err != nil {
return nil, fmt.Errorf("error creating embed request: %w", err)
}
req.Header.Set("Content-Type", "application/json")
resp, err := http.DefaultClient.Do(req)
if err != nil {
return nil, fmt.Errorf("POST embedding: %w", err)
}
defer resp.Body.Close()
body, err := io.ReadAll(resp.Body)
if err != nil {
return nil, fmt.Errorf("error reading embed response: %w", err)
}
if resp.StatusCode >= 400 {
log.Printf("llm encode error: %s", body)
return nil, fmt.Errorf("%s", body)
}
var embedding EmbeddingResponse
if err := json.Unmarshal(body, &embedding); err != nil {
return nil, fmt.Errorf("unmarshal tokenize response: %w", err)
}
return embedding.Embedding, nil
}
// Ping checks that the server subprocess is still running and responding to requests
func (llm *llama) Ping(ctx context.Context) error {
resp, err := http.Head(fmt.Sprintf("http://127.0.0.1:%d", llm.Port))
if err != nil {
return fmt.Errorf("ping resp: %w", err)
}
if resp.StatusCode != http.StatusOK {
return fmt.Errorf("unexpected ping status: %s", resp.Status)
}
return nil
}

View file

@ -18,7 +18,6 @@ type LLM interface {
Embedding(context.Context, string) ([]float64, error)
Encode(context.Context, string) ([]int, error)
Decode(context.Context, []int) (string, error)
SetOptions(api.Options)
Close()
Ping(context.Context) error
}
@ -79,5 +78,5 @@ func New(workDir, model string, adapters, projectors []string, opts api.Options)
opts.NumGQA = 0
opts.RopeFrequencyBase = 0.0
opts.RopeFrequencyScale = 0.0
return newLlama(model, adapters, projectors, chooseRunners(workDir), ggml.NumLayers(), opts)
return newLlamaExtServer(model, adapters, projectors, ggml.NumLayers(), opts)
}

View file

@ -9,7 +9,7 @@ mkdir -p dist
for TARGETARCH in arm64 amd64; do
GOOS=darwin GOARCH=$TARGETARCH go generate ./...
GOOS=darwin GOARCH=$TARGETARCH go build -o dist/ollama-darwin-$TARGETARCH
CGO_ENABLED=1 GOOS=darwin GOARCH=$TARGETARCH go build -o dist/ollama-darwin-$TARGETARCH
rm -rf llm/llama.cpp/*/build
done

View file

@ -7,7 +7,7 @@ export GOFLAGS="'-ldflags=-w -s \"-X=github.com/jmorganca/ollama/version.Version
mkdir -p dist
for TARGETARCH in arm64 amd64; do
for TARGETARCH in amd64 arm64; do
docker buildx build --load --platform=linux/$TARGETARCH --build-arg=VERSION --build-arg=GOFLAGS -f Dockerfile.build -t builder:$TARGETARCH .
docker create --platform linux/$TARGETARCH --name builder-$TARGETARCH builder:$TARGETARCH
docker cp builder-$TARGETARCH:/go/src/github.com/jmorganca/ollama/ollama ./dist/ollama-linux-$TARGETARCH

View file

@ -0,0 +1,35 @@
#!/bin/bash
# This script sets up integration tests which run the full stack to verify
# inference locally
set -e
set -o pipefail
REPO=$(dirname $0)/../
export OLLAMA_MODELS=${REPO}/test_data/models
REGISTRY_SCHEME=https
REGISTRY=registry.ollama.ai
TEST_MODEL=library/orca-mini
TEST_MODEL_TAG=latest
ACCEPT_HEADER="Accept: application/vnd.docker.distribution.manifest.v2+json"
mkdir -p ${OLLAMA_MODELS}/manifests/${REGISTRY}/${TEST_MODEL}/
mkdir -p ${OLLAMA_MODELS}/blobs/
echo "Pulling manifest for ${TEST_MODEL}:${TEST_MODEL_TAG}"
curl -s --header "${ACCEPT_HEADER}" \
-o ${OLLAMA_MODELS}/manifests/${REGISTRY}/${TEST_MODEL}/${TEST_MODEL_TAG} \
${REGISTRY_SCHEME}://${REGISTRY}/v2/${TEST_MODEL}/manifests/${TEST_MODEL_TAG}
CFG_HASH=$(cat ${OLLAMA_MODELS}/manifests/${REGISTRY}/${TEST_MODEL}/${TEST_MODEL_TAG} | jq -r ".config.digest")
echo "Pulling config blob ${CFG_HASH}"
curl -L -C - --header "${ACCEPT_HEADER}" \
-o ${OLLAMA_MODELS}/blobs/${CFG_HASH} \
${REGISTRY_SCHEME}://${REGISTRY}/v2/${TEST_MODEL}/blobs/${CFG_HASH}
for LAYER in $(cat ${OLLAMA_MODELS}/manifests/${REGISTRY}/${TEST_MODEL}/${TEST_MODEL_TAG} | jq -r ".layers[].digest" ) ; do
echo "Pulling blob ${LAYER}"
curl -L -C - --header "${ACCEPT_HEADER}" \
-o ${OLLAMA_MODELS}/blobs/${LAYER} \
${REGISTRY_SCHEME}://${REGISTRY}/v2/${TEST_MODEL}/blobs/${LAYER}
done

103
server/llm_test.go Normal file
View file

@ -0,0 +1,103 @@
package server
import (
"context"
"strings"
"sync"
"testing"
"time"
"github.com/stretchr/testify/assert"
"github.com/jmorganca/ollama/api"
)
// TODO - this would ideally be in the llm package, but that would require some refactoring of interfaces in the server
// package to avoid circular dependencies
// WARNING - these tests will fail on mac if you don't manually copy ggml-metal.metal to this dir (./server)
//
// TODO - Fix this ^^
var (
req = [2]api.GenerateRequest{
{
Model: "orca-mini",
Prompt: "tell me a short story about agi?",
Options: map[string]interface{}{},
}, {
Model: "orca-mini",
Prompt: "what is the origin of the us thanksgiving holiday?",
Options: map[string]interface{}{},
},
}
resp = [2]string{
"once upon a time",
"fourth thursday",
}
)
func TestIntegrationSimpleOrcaMini(t *testing.T) {
SkipIFNoTestData(t)
ctx, cancel := context.WithTimeout(context.Background(), time.Second*60)
defer cancel()
opts := api.DefaultOptions()
opts.Seed = 42
opts.Temperature = 0.0
model, llmRunner := PrepareModelForPrompts(t, req[0].Model, opts)
defer llmRunner.Close()
response := OneShotPromptResponse(t, ctx, req[0], model, llmRunner)
assert.Contains(t, strings.ToLower(response), resp[0])
}
// TODO
// The server always loads a new runner and closes the old one, which forces serial execution
// At present this test case fails with concurrency problems. Eventually we should try to
// get true concurrency working with n_parallel support in the backend
func TestIntegrationConcurrentPredictOrcaMini(t *testing.T) {
SkipIFNoTestData(t)
t.Skip("concurrent prediction on single runner not currently supported")
ctx, cancel := context.WithTimeout(context.Background(), time.Second*60)
defer cancel()
opts := api.DefaultOptions()
opts.Seed = 42
opts.Temperature = 0.0
var wg sync.WaitGroup
wg.Add(len(req))
model, llmRunner := PrepareModelForPrompts(t, req[0].Model, opts)
defer llmRunner.Close()
for i := 0; i < len(req); i++ {
go func(i int) {
defer wg.Done()
response := OneShotPromptResponse(t, ctx, req[i], model, llmRunner)
t.Logf("Prompt: %s\nResponse: %s", req[0].Prompt, response)
assert.Contains(t, strings.ToLower(response), resp[i], "error in thread %d (%s)", i, req[i].Prompt)
}(i)
}
wg.Wait()
}
func TestIntegrationConcurrentRunnersOrcaMini(t *testing.T) {
SkipIFNoTestData(t)
ctx, cancel := context.WithTimeout(context.Background(), time.Second*60)
defer cancel()
opts := api.DefaultOptions()
opts.Seed = 42
opts.Temperature = 0.0
var wg sync.WaitGroup
wg.Add(len(req))
for i := 0; i < len(req); i++ {
go func(i int) {
defer wg.Done()
model, llmRunner := PrepareModelForPrompts(t, req[0].Model, opts)
defer llmRunner.Close()
response := OneShotPromptResponse(t, ctx, req[i], model, llmRunner)
t.Logf("Prompt: %s\nResponse: %s", req[0].Prompt, response)
assert.Contains(t, strings.ToLower(response), resp[i], "error in thread %d (%s)", i, req[i].Prompt)
}(i)
}
wg.Wait()
}
// TODO - create a parallel test with 2 different models once we support concurrency

76
server/llm_utils_test.go Normal file
View file

@ -0,0 +1,76 @@
package server
import (
"context"
"errors"
"os"
"path"
"runtime"
"testing"
"time"
"github.com/jmorganca/ollama/api"
"github.com/jmorganca/ollama/llm"
"github.com/stretchr/testify/require"
)
func SkipIFNoTestData(t *testing.T) {
modelDir := getModelDir()
if _, err := os.Stat(modelDir); errors.Is(err, os.ErrNotExist) {
t.Skipf("%s does not exist - skipping integration tests", modelDir)
}
}
func getModelDir() string {
_, filename, _, _ := runtime.Caller(0)
return path.Dir(path.Dir(filename) + "/../test_data/models/.")
}
func PrepareModelForPrompts(t *testing.T, modelName string, opts api.Options) (*Model, llm.LLM) {
modelDir := getModelDir()
os.Setenv("OLLAMA_MODELS", modelDir)
model, err := GetModel(modelName)
require.NoError(t, err, "GetModel ")
err = opts.FromMap(model.Options)
require.NoError(t, err, "opts from model ")
runner, err := llm.New("unused", model.ModelPath, model.AdapterPaths, model.ProjectorPaths, opts)
require.NoError(t, err, "llm.New failed")
return model, runner
}
func OneShotPromptResponse(t *testing.T, ctx context.Context, req api.GenerateRequest, model *Model, runner llm.LLM) string {
checkpointStart := time.Now()
prompt, err := model.Prompt(PromptVars{
System: req.System,
Prompt: req.Prompt,
First: len(req.Context) == 0,
})
require.NoError(t, err, "prompt generation failed")
success := make(chan bool, 1)
response := ""
cb := func(r llm.PredictResult) {
if !r.Done {
response += r.Content
} else {
success <- true
}
}
checkpointLoaded := time.Now()
predictReq := llm.PredictOpts{
Prompt: prompt,
Format: req.Format,
CheckpointStart: checkpointStart,
CheckpointLoaded: checkpointLoaded,
}
err = runner.Predict(ctx, predictReq, cb)
require.NoError(t, err, "predict call failed")
select {
case <-ctx.Done():
t.Errorf("failed to complete before timeout: \n%s", response)
return ""
case <-success:
return response
}
}

View file

@ -126,10 +126,6 @@ func load(c *gin.Context, modelName string, reqOpts map[string]interface{}, sess
loaded.Options = &opts
}
// update options for the loaded llm
// TODO(mxyng): this isn't thread safe, but it should be fine for now
loaded.runner.SetOptions(opts)
loaded.expireAt = time.Now().Add(sessionDuration)
if loaded.expireTimer == nil {