Request and model concurrency

This change adds support for multiple concurrent requests, as well as
loading multiple models by spawning multiple runners. The default
settings are currently set at 1 concurrent request per model and only 1
loaded model at a time, but these can be adjusted by setting
OLLAMA_NUM_PARALLEL and OLLAMA_MAX_LOADED_MODELS.
This commit is contained in:
Daniel Hiltgen 2024-03-30 09:50:05 -07:00
parent ee448deaba
commit 34b9db5afc
30 changed files with 2572 additions and 1387 deletions

View file

@ -91,6 +91,13 @@ func ClientFromEnvironment() (*Client, error) {
}, nil
}
func NewClient(base *url.URL, http *http.Client) *Client {
return &Client{
base: base,
http: http,
}
}
func (c *Client) do(ctx context.Context, method, path string, reqData, respData any) error {
var reqBody io.Reader
var data []byte

View file

@ -15,6 +15,7 @@ const (
KibiByte = Byte * 1024
MebiByte = KibiByte * 1024
GibiByte = MebiByte * 1024
)
func HumanBytes(b int64) string {

View file

@ -7,7 +7,7 @@ import (
"log/slog"
"os"
"path/filepath"
"strconv"
"runtime"
"strings"
)
@ -35,22 +35,64 @@ func GetSupportedGFX(libDir string) ([]string, error) {
return ret, nil
}
func amdSetVisibleDevices(ids []int, skip map[int]interface{}) {
// Set the visible devices if not already set
// TODO - does sort order matter?
devices := []string{}
for i := range ids {
if _, skipped := skip[i]; skipped {
func rocmGetVisibleDevicesEnv(gpuInfo []GpuInfo) (string, string) {
ids := []string{}
for _, info := range gpuInfo {
if info.Library != "rocm" {
// TODO shouldn't happen if things are wired correctly...
slog.Debug("rocmGetVisibleDevicesEnv skipping over non-rocm device", "library", info.Library)
continue
}
devices = append(devices, strconv.Itoa(i))
ids = append(ids, info.ID)
}
return "HIP_VISIBLE_DEVICES", strings.Join(ids, ",")
}
func commonAMDValidateLibDir() (string, error) {
// We try to favor system paths first, so that we can wire up the subprocess to use
// the system version. Only use our bundled version if the system version doesn't work
// This gives users a more recovery options if versions have subtle problems at runtime
// Prefer explicit HIP env var
hipPath := os.Getenv("HIP_PATH")
if hipPath != "" {
hipLibDir := filepath.Join(hipPath, "bin")
if rocmLibUsable(hipLibDir) {
slog.Debug("detected ROCM via HIP_PATH=" + hipPath)
return hipLibDir, nil
}
}
val := strings.Join(devices, ",")
err := os.Setenv("HIP_VISIBLE_DEVICES", val)
if err != nil {
slog.Warn(fmt.Sprintf("failed to set env: %s", err))
} else {
slog.Info("Setting HIP_VISIBLE_DEVICES=" + val)
// Scan the LD_LIBRARY_PATH or PATH
pathEnv := "LD_LIBRARY_PATH"
if runtime.GOOS == "windows" {
pathEnv = "PATH"
}
paths := os.Getenv(pathEnv)
for _, path := range filepath.SplitList(paths) {
d, err := filepath.Abs(path)
if err != nil {
continue
}
if rocmLibUsable(d) {
return d, nil
}
}
// Well known location(s)
if rocmLibUsable(RocmStandardLocation) {
return RocmStandardLocation, nil
}
// Installer payload location if we're running the installed binary
exe, err := os.Executable()
if err == nil {
rocmTargetDir := filepath.Join(filepath.Dir(exe), "rocm")
if rocmLibUsable(rocmTargetDir) {
slog.Debug("detected ROCM next to ollama executable " + rocmTargetDir)
return rocmTargetDir, nil
}
}
return "", fmt.Errorf("no suitable rocm found, falling back to CPU")
}

View file

@ -69,7 +69,7 @@ func NewHipLib() (*HipLib, error) {
func (hl *HipLib) Release() {
err := windows.FreeLibrary(hl.dll)
if err != nil {
slog.Warn(fmt.Sprintf("failed to unload amdhip64.dll: %s", err))
slog.Warn("failed to unload amdhip64.dll", "error", err)
}
hl.dll = 0
}
@ -98,7 +98,7 @@ func (hl *HipLib) HipGetDeviceCount() int {
return 0
}
if status != hipSuccess {
slog.Warn(fmt.Sprintf("failed call to hipGetDeviceCount: %d %s", status, err))
slog.Warn("failed call to hipGetDeviceCount", "status", status, "error", err)
}
return count
}

View file

@ -11,6 +11,8 @@ import (
"slices"
"strconv"
"strings"
"github.com/ollama/ollama/format"
)
// Discovery logic for AMD/ROCm GPUs
@ -24,9 +26,6 @@ const (
GPUTotalMemoryFileGlob = "mem_banks/*/properties" // size_in_bytes line
GPUUsedMemoryFileGlob = "mem_banks/*/used_memory"
RocmStandardLocation = "/opt/rocm/lib"
// TODO find a better way to detect iGPU instead of minimum memory
IGPUMemLimit = 1024 * 1024 * 1024 // 512G is what they typically report, so anything less than 1G must be iGPU
)
var (
@ -35,14 +34,11 @@ var (
)
// Gather GPU information from the amdgpu driver if any supported GPUs are detected
// HIP_VISIBLE_DEVICES will be set if we detect a mix of unsupported and supported devices
// and the user hasn't already set this variable
func AMDGetGPUInfo(resp *GpuInfo) {
// TODO - DRY this out with windows
func AMDGetGPUInfo() []GpuInfo {
resp := []GpuInfo{}
if !AMDDetected() {
return
return resp
}
skip := map[int]interface{}{}
// Opportunistic logging of driver version to aid in troubleshooting
ver, err := AMDDriverVersion()
@ -50,160 +46,117 @@ func AMDGetGPUInfo(resp *GpuInfo) {
slog.Info("AMD Driver: " + ver)
} else {
// TODO - if we see users crash and burn with the upstreamed kernel this can be adjusted to hard-fail rocm support and fallback to CPU
slog.Warn(fmt.Sprintf("ollama recommends running the https://www.amd.com/en/support/linux-drivers: %s", err))
slog.Warn("ollama recommends running the https://www.amd.com/en/support/linux-drivers", "error", err)
}
// If the user has specified exactly which GPUs to use, look up their memory
visibleDevices := os.Getenv("HIP_VISIBLE_DEVICES")
if visibleDevices != "" {
ids := []int{}
for _, idStr := range strings.Split(visibleDevices, ",") {
id, err := strconv.Atoi(idStr)
if err != nil {
slog.Warn(fmt.Sprintf("malformed HIP_VISIBLE_DEVICES=%s %s", visibleDevices, err))
} else {
ids = append(ids, id)
}
}
amdProcMemLookup(resp, nil, ids)
return
// Determine if the user has already pre-selected which GPUs to look at, then ignore the others
var visibleDevices []string
hipVD := os.Getenv("HIP_VISIBLE_DEVICES") // zero based index only
rocrVD := os.Getenv("ROCR_VISIBLE_DEVICES") // zero based index or UUID, but consumer cards seem to not support UUID
gpuDO := os.Getenv("GPU_DEVICE_ORDINAL") // zero based index
switch {
// TODO is this priorty order right?
case hipVD != "":
visibleDevices = strings.Split(hipVD, ",")
case rocrVD != "":
visibleDevices = strings.Split(rocrVD, ",")
// TODO - since we don't yet support UUIDs, consider detecting and reporting here
// all our test systems show GPU-XX indicating UUID is not supported
case gpuDO != "":
visibleDevices = strings.Split(gpuDO, ",")
}
// Gather GFX version information from all detected cards
gfx := AMDGFXVersions()
verStrings := []string{}
for i, v := range gfx {
verStrings = append(verStrings, v.ToGFXString())
if v.Major == 0 {
// Silently skip CPUs
skip[i] = struct{}{}
continue
}
if v.Major < 9 {
// TODO consider this a build-time setting if we can support 8xx family GPUs
slog.Warn(fmt.Sprintf("amdgpu [%d] too old %s", i, v.ToGFXString()))
skip[i] = struct{}{}
}
}
slog.Info(fmt.Sprintf("detected amdgpu versions %v", verStrings))
// Abort if all GPUs are skipped
if len(skip) >= len(gfx) {
slog.Info("all detected amdgpus are skipped, falling back to CPU")
return
}
// If we got this far, then we have at least 1 GPU that's a ROCm candidate, so make sure we have a lib
libDir, err := AMDValidateLibDir()
if err != nil {
slog.Warn(fmt.Sprintf("unable to verify rocm library, will use cpu: %s", err))
return
}
updateLibPath(libDir)
gfxOverride := os.Getenv("HSA_OVERRIDE_GFX_VERSION")
if gfxOverride == "" {
supported, err := GetSupportedGFX(libDir)
var supported []string
libDir := ""
// The amdgpu driver always exposes the host CPU(s) first, but we have to skip them and subtract
// from the other IDs to get alignment with the HIP libraries expectations (zero is the first GPU, not the CPU)
matches, _ := filepath.Glob(GPUPropertiesFileGlob)
cpuCount := 0
for _, match := range matches {
slog.Debug("evaluating amdgpu node " + match)
fp, err := os.Open(match)
if err != nil {
slog.Warn(fmt.Sprintf("failed to lookup supported GFX types, falling back to CPU mode: %s", err))
return
}
slog.Debug(fmt.Sprintf("rocm supported GPU types %v", supported))
for i, v := range gfx {
if !slices.Contains[[]string, string](supported, v.ToGFXString()) {
slog.Warn(fmt.Sprintf("amdgpu [%d] %s is not supported by %s %v", i, v.ToGFXString(), libDir, supported))
// TODO - consider discrete markdown just for ROCM troubleshooting?
slog.Warn("See https://github.com/ollama/ollama/blob/main/docs/gpu.md#overrides for HSA_OVERRIDE_GFX_VERSION usage")
skip[i] = struct{}{}
} else {
slog.Info(fmt.Sprintf("amdgpu [%d] %s is supported", i, v.ToGFXString()))
}
}
} else {
slog.Debug("skipping rocm gfx compatibility check with HSA_OVERRIDE_GFX_VERSION=" + gfxOverride)
}
if len(skip) >= len(gfx) {
slog.Info("all detected amdgpus are skipped, falling back to CPU")
return
}
ids := make([]int, len(gfx))
i := 0
for k := range gfx {
ids[i] = k
i++
}
amdProcMemLookup(resp, skip, ids)
if resp.memInfo.DeviceCount == 0 {
return
}
if len(skip) > 0 {
amdSetVisibleDevices(ids, skip)
}
}
func updateLibPath(libDir string) {
ldPaths := []string{}
if val, ok := os.LookupEnv("LD_LIBRARY_PATH"); ok {
ldPaths = strings.Split(val, ":")
}
for _, d := range ldPaths {
if d == libDir {
return
}
}
val := strings.Join(append(ldPaths, libDir), ":")
slog.Debug("updated lib path", "LD_LIBRARY_PATH", val)
os.Setenv("LD_LIBRARY_PATH", val)
}
// Walk the sysfs nodes for the available GPUs and gather information from them
// skipping over any devices in the skip map
func amdProcMemLookup(resp *GpuInfo, skip map[int]interface{}, ids []int) {
resp.memInfo.DeviceCount = 0
resp.memInfo.TotalMemory = 0
resp.memInfo.FreeMemory = 0
slog.Debug("discovering VRAM for amdgpu devices")
if len(ids) == 0 {
entries, err := os.ReadDir(AMDNodesSysfsDir)
if err != nil {
slog.Warn(fmt.Sprintf("failed to read amdgpu sysfs %s - %s", AMDNodesSysfsDir, err))
return
}
for _, node := range entries {
if !node.IsDir() {
continue
}
id, err := strconv.Atoi(node.Name())
if err != nil {
slog.Warn("malformed amdgpu sysfs node id " + node.Name())
continue
}
ids = append(ids, id)
}
}
slog.Debug(fmt.Sprintf("amdgpu devices %v", ids))
for _, id := range ids {
if _, skipped := skip[id]; skipped {
slog.Debug("failed to open sysfs node", "file", match, "error", err)
continue
}
defer fp.Close()
nodeID, err := strconv.Atoi(filepath.Base(filepath.Dir(match)))
if err != nil {
slog.Debug("failed to parse node ID", "error", err)
continue
}
scanner := bufio.NewScanner(fp)
isCPU := false
var major, minor, patch uint64
for scanner.Scan() {
line := strings.TrimSpace(scanner.Text())
// Note: we could also use "cpu_cores_count X" where X is greater than zero to detect CPUs
if strings.HasPrefix(line, "gfx_target_version") {
ver := strings.Fields(line)
// Detect CPUs
if len(ver) == 2 && ver[1] == "0" {
slog.Debug("detected CPU " + match)
isCPU = true
break
}
if len(ver) != 2 || len(ver[1]) < 5 {
slog.Warn("malformed "+match, "gfx_target_version", line)
// If this winds up being a CPU, our offsets may be wrong
continue
}
l := len(ver[1])
var err1, err2, err3 error
patch, err1 = strconv.ParseUint(ver[1][l-2:l], 10, 32)
minor, err2 = strconv.ParseUint(ver[1][l-4:l-2], 10, 32)
major, err3 = strconv.ParseUint(ver[1][:l-4], 10, 32)
if err1 != nil || err2 != nil || err3 != nil {
slog.Debug("malformed int " + line)
continue
}
}
// TODO - any other properties we want to extract and record?
// vendor_id + device_id -> pci lookup for "Name"
// Other metrics that may help us understand relative performance between multiple GPUs
}
if isCPU {
cpuCount++
continue
}
// CPUs are always first in the list
gpuID := nodeID - cpuCount
// Shouldn't happen, but just in case...
if gpuID < 0 {
slog.Error("unexpected amdgpu sysfs data resulted in negative GPU ID, please set OLLAMA_DEBUG=1 and report an issue")
return []GpuInfo{}
}
if int(major) < RocmComputeMin {
slog.Warn(fmt.Sprintf("amdgpu too old gfx%d%d%d", major, minor, patch), "gpu", gpuID)
continue
}
// Look up the memory for the current node
totalMemory := uint64(0)
usedMemory := uint64(0)
// Adjust for sysfs vs HIP ids
propGlob := filepath.Join(AMDNodesSysfsDir, strconv.Itoa(id+1), GPUTotalMemoryFileGlob)
propGlob := filepath.Join(AMDNodesSysfsDir, strconv.Itoa(nodeID), GPUTotalMemoryFileGlob)
propFiles, err := filepath.Glob(propGlob)
if err != nil {
slog.Warn(fmt.Sprintf("error looking up total GPU memory: %s %s", propGlob, err))
slog.Warn("error looking up total GPU memory", "glob", propGlob, "error", err)
}
// 1 or more memory banks - sum the values of all of them
for _, propFile := range propFiles {
fp, err := os.Open(propFile)
if err != nil {
slog.Warn(fmt.Sprintf("failed to open sysfs node file %s: %s", propFile, err))
slog.Warn("failed to open sysfs node", "file", propFile, "erroir", err)
continue
}
defer fp.Close()
@ -226,49 +179,113 @@ func amdProcMemLookup(resp *GpuInfo, skip map[int]interface{}, ids []int) {
}
}
if totalMemory == 0 {
slog.Warn(fmt.Sprintf("amdgpu [%d] reports zero total memory, skipping", id))
skip[id] = struct{}{}
slog.Warn("amdgpu reports zero total memory", "gpu", gpuID)
continue
}
if totalMemory < IGPUMemLimit {
slog.Info(fmt.Sprintf("amdgpu [%d] appears to be an iGPU with %dM reported total memory, skipping", id, totalMemory/1024/1024))
skip[id] = struct{}{}
continue
}
usedGlob := filepath.Join(AMDNodesSysfsDir, strconv.Itoa(id), GPUUsedMemoryFileGlob)
usedGlob := filepath.Join(AMDNodesSysfsDir, strconv.Itoa(nodeID), GPUUsedMemoryFileGlob)
usedFiles, err := filepath.Glob(usedGlob)
if err != nil {
slog.Warn(fmt.Sprintf("error looking up used GPU memory: %s %s", usedGlob, err))
slog.Warn("error looking up used GPU memory", "glob", usedGlob, "error", err)
continue
}
for _, usedFile := range usedFiles {
fp, err := os.Open(usedFile)
if err != nil {
slog.Warn(fmt.Sprintf("failed to open sysfs node file %s: %s", usedFile, err))
slog.Warn("failed to open sysfs node", "file", usedFile, "error", err)
continue
}
defer fp.Close()
data, err := io.ReadAll(fp)
if err != nil {
slog.Warn(fmt.Sprintf("failed to read sysfs node file %s: %s", usedFile, err))
slog.Warn("failed to read sysfs node", "file", usedFile, "error", err)
continue
}
used, err := strconv.ParseUint(strings.TrimSpace(string(data)), 10, 64)
if err != nil {
slog.Warn(fmt.Sprintf("malformed used memory %s: %s", string(data), err))
slog.Warn("malformed used memory", "data", string(data), "error", err)
continue
}
usedMemory += used
}
slog.Info(fmt.Sprintf("[%d] amdgpu totalMemory %dM", id, totalMemory/1024/1024))
slog.Info(fmt.Sprintf("[%d] amdgpu freeMemory %dM", id, (totalMemory-usedMemory)/1024/1024))
resp.memInfo.DeviceCount++
resp.memInfo.TotalMemory += totalMemory
resp.memInfo.FreeMemory += (totalMemory - usedMemory)
// iGPU detection, remove this check once we can support an iGPU variant of the rocm library
if totalMemory < IGPUMemLimit {
slog.Info("amdgpu appears to be an iGPU, skipping", "gpu", gpuID, "total", format.HumanBytes2(totalMemory))
continue
}
slog.Info("amdgpu memory", "gpu", gpuID, "total", format.HumanBytes2(totalMemory))
slog.Info("amdgpu memory", "gpu", gpuID, "available", format.HumanBytes2(totalMemory-usedMemory))
gpuInfo := GpuInfo{
Library: "rocm",
memInfo: memInfo{
TotalMemory: totalMemory,
FreeMemory: (totalMemory - usedMemory),
},
ID: fmt.Sprintf("%d", gpuID),
// Name: not exposed in sysfs directly, would require pci device id lookup
Major: int(major),
Minor: int(minor),
Patch: int(patch),
MinimumMemory: rocmMinimumMemory,
}
// If the user wants to filter to a subset of devices, filter out if we aren't a match
if len(visibleDevices) > 0 {
include := false
for _, visible := range visibleDevices {
if visible == gpuInfo.ID {
include = true
break
}
}
if !include {
slog.Info("filtering out device per user request", "id", gpuInfo.ID, "visible_devices", visibleDevices)
continue
}
}
// Final validation is gfx compatibility - load the library if we haven't already loaded it
// even if the user overrides, we still need to validate the library
if libDir == "" {
libDir, err = AMDValidateLibDir()
if err != nil {
slog.Warn("unable to verify rocm library, will use cpu", "error", err)
return []GpuInfo{}
}
}
gpuInfo.DependencyPath = libDir
if gfxOverride == "" {
// Only load supported list once
if len(supported) == 0 {
supported, err = GetSupportedGFX(libDir)
if err != nil {
slog.Warn("failed to lookup supported GFX types, falling back to CPU mode", "error", err)
return []GpuInfo{}
}
slog.Debug("rocm supported GPUs", "types", supported)
}
gfx := fmt.Sprintf("gfx%d%d%d", gpuInfo.Major, gpuInfo.Minor, gpuInfo.Patch)
if !slices.Contains[[]string, string](supported, gfx) {
slog.Warn("amdgpu is not supported", "gpu", gpuInfo.ID, "gpu_type", gfx, "library", libDir, "supported_types", supported)
// TODO - consider discrete markdown just for ROCM troubleshooting?
slog.Warn("See https://github.com/ollama/ollama/blob/main/docs/gpu.md#overrides for HSA_OVERRIDE_GFX_VERSION usage")
continue
} else {
slog.Info("amdgpu is supported", "gpu", gpuInfo.ID, "gpu_type", gfx)
}
} else {
slog.Debug("skipping rocm gfx compatibility check with HSA_OVERRIDE_GFX_VERSION=" + gfxOverride)
}
// The GPU has passed all the verification steps and is supported
resp = append(resp, gpuInfo)
}
if resp.memInfo.DeviceCount > 0 {
resp.Library = "rocm"
if len(resp) == 0 {
slog.Info("no compatible amdgpu devices detected")
}
return resp
}
// Quick check for AMD driver so we can skip amdgpu discovery if not present
@ -280,87 +297,24 @@ func AMDDetected() bool {
slog.Debug("amdgpu driver not detected " + sysfsDir)
return false
} else if err != nil {
slog.Debug(fmt.Sprintf("error looking up amd driver %s %s", sysfsDir, err))
slog.Debug("error looking up amd driver", "path", sysfsDir, "error", err)
return false
}
return true
}
func setupLink(source, target string) error {
if err := os.RemoveAll(target); err != nil {
return fmt.Errorf("failed to remove old rocm directory %s %w", target, err)
}
if err := os.Symlink(source, target); err != nil {
return fmt.Errorf("failed to create link %s => %s %w", source, target, err)
}
slog.Debug(fmt.Sprintf("host rocm linked %s => %s", source, target))
return nil
}
// Ensure the AMD rocm lib dir is wired up
// Prefer to use host installed ROCm, as long as it meets our minimum requirements
// failing that, tell the user how to download it on their own
func AMDValidateLibDir() (string, error) {
// We rely on the rpath compiled into our library to find rocm
// so we establish a symlink to wherever we find it on the system
// to <payloads>/rocm
payloadsDir, err := PayloadsDir()
if err != nil {
return "", err
}
// If we already have a rocm dependency wired, nothing more to do
rocmTargetDir := filepath.Clean(filepath.Join(payloadsDir, "..", "rocm"))
if rocmLibUsable(rocmTargetDir) {
return rocmTargetDir, nil
}
// next to the running binary
exe, err := os.Executable()
libDir, err := commonAMDValidateLibDir()
if err == nil {
peerDir := filepath.Dir(exe)
if rocmLibUsable(peerDir) {
slog.Debug("detected ROCM next to ollama executable " + peerDir)
return rocmTargetDir, setupLink(peerDir, rocmTargetDir)
}
peerDir = filepath.Join(filepath.Dir(exe), "rocm")
if rocmLibUsable(peerDir) {
slog.Debug("detected ROCM next to ollama executable " + peerDir)
return rocmTargetDir, setupLink(peerDir, rocmTargetDir)
}
return libDir, nil
}
// Well known ollama installer path
installedRocmDir := "/usr/share/ollama/lib/rocm"
if rocmLibUsable(installedRocmDir) {
return rocmTargetDir, setupLink(installedRocmDir, rocmTargetDir)
}
// Prefer explicit HIP env var
hipPath := os.Getenv("HIP_PATH")
if hipPath != "" {
hipLibDir := filepath.Join(hipPath, "lib")
if rocmLibUsable(hipLibDir) {
slog.Debug("detected ROCM via HIP_PATH=" + hipPath)
return rocmTargetDir, setupLink(hipLibDir, rocmTargetDir)
}
}
// Scan the library path for potential matches
ldPaths := strings.Split(os.Getenv("LD_LIBRARY_PATH"), ":")
for _, ldPath := range ldPaths {
d, err := filepath.Abs(ldPath)
if err != nil {
continue
}
if rocmLibUsable(d) {
return rocmTargetDir, setupLink(d, rocmTargetDir)
}
}
// Well known location(s)
if rocmLibUsable("/opt/rocm/lib") {
return rocmTargetDir, setupLink("/opt/rocm/lib", rocmTargetDir)
return installedRocmDir, nil
}
// If we still haven't found a usable rocm, the user will have to install it on their own
@ -384,68 +338,3 @@ func AMDDriverVersion() (string, error) {
}
return strings.TrimSpace(string(verString)), nil
}
func AMDGFXVersions() map[int]Version {
// The amdgpu driver always exposes the host CPU as node 0, but we have to skip that and subtract one
// from the other IDs to get alignment with the HIP libraries expectations (zero is the first GPU, not the CPU)
res := map[int]Version{}
matches, _ := filepath.Glob(GPUPropertiesFileGlob)
for _, match := range matches {
fp, err := os.Open(match)
if err != nil {
slog.Debug(fmt.Sprintf("failed to open sysfs node file %s: %s", match, err))
continue
}
defer fp.Close()
i, err := strconv.Atoi(filepath.Base(filepath.Dir(match)))
if err != nil {
slog.Debug(fmt.Sprintf("failed to parse node ID %s", err))
continue
}
if i == 0 {
// Skipping the CPU
continue
}
// Align with HIP IDs (zero is first GPU, not CPU)
i -= 1
scanner := bufio.NewScanner(fp)
for scanner.Scan() {
line := strings.TrimSpace(scanner.Text())
if strings.HasPrefix(line, "gfx_target_version") {
ver := strings.Fields(line)
if len(ver) != 2 || len(ver[1]) < 5 {
if ver[1] != "0" {
slog.Debug("malformed " + line)
}
res[i] = Version{
Major: 0,
Minor: 0,
Patch: 0,
}
continue
}
l := len(ver[1])
patch, err1 := strconv.ParseUint(ver[1][l-2:l], 10, 32)
minor, err2 := strconv.ParseUint(ver[1][l-4:l-2], 10, 32)
major, err3 := strconv.ParseUint(ver[1][:l-4], 10, 32)
if err1 != nil || err2 != nil || err3 != nil {
slog.Debug("malformed int " + line)
continue
}
res[i] = Version{
Major: uint(major),
Minor: uint(minor),
Patch: uint(patch),
}
}
}
}
return res
}
func (v Version) ToGFXString() string {
return fmt.Sprintf("gfx%d%d%d", v.Major, v.Minor, v.Patch)
}

View file

@ -7,7 +7,10 @@ import (
"os"
"path/filepath"
"slices"
"strconv"
"strings"
"github.com/ollama/ollama/format"
)
const (
@ -22,36 +25,32 @@ var (
ROCmLibGlobs = []string{"hipblas.dll", "rocblas"} // TODO - probably include more coverage of files here...
)
func AMDGetGPUInfo(resp *GpuInfo) {
func AMDGetGPUInfo() []GpuInfo {
resp := []GpuInfo{}
hl, err := NewHipLib()
if err != nil {
slog.Debug(err.Error())
return
return nil
}
defer hl.Release()
skip := map[int]interface{}{}
ids := []int{}
resp.memInfo.DeviceCount = 0
resp.memInfo.TotalMemory = 0
resp.memInfo.FreeMemory = 0
ver, err := hl.AMDDriverVersion()
if err == nil {
slog.Info("AMD Driver: " + ver)
} else {
// For now this is benign, but we may eventually need to fail compatibility checks
slog.Debug(fmt.Sprintf("error looking up amd driver version: %s", err))
slog.Debug("error looking up amd driver version", "error", err)
}
// Note: the HIP library automatically handles HIP_VISIBLE_DEVICES
// Note: the HIP library automatically handles subsetting to any HIP_VISIBLE_DEVICES the user specified
count := hl.HipGetDeviceCount()
if count == 0 {
return
return nil
}
libDir, err := AMDValidateLibDir()
if err != nil {
slog.Warn(fmt.Sprintf("unable to verify rocm library, will use cpu: %s", err))
return
slog.Warn("unable to verify rocm library, will use cpu", "error", err)
return nil
}
var supported []string
@ -59,95 +58,117 @@ func AMDGetGPUInfo(resp *GpuInfo) {
if gfxOverride == "" {
supported, err = GetSupportedGFX(libDir)
if err != nil {
slog.Warn(fmt.Sprintf("failed to lookup supported GFX types, falling back to CPU mode: %s", err))
return
slog.Warn("failed to lookup supported GFX types, falling back to CPU mode", "error", err)
return nil
}
} else {
slog.Debug("skipping rocm gfx compatibility check with HSA_OVERRIDE_GFX_VERSION=" + gfxOverride)
}
slog.Info(fmt.Sprintf("detected %d hip devices", count))
slog.Info("detected hip devices", "count", count)
// TODO how to determine the underlying device ID when visible devices is causing this to subset?
for i := 0; i < count; i++ {
ids = append(ids, i)
err = hl.HipSetDevice(i)
if err != nil {
slog.Warn(fmt.Sprintf("[%d] %s", i, err))
skip[i] = struct{}{}
slog.Warn("set device", "id", i, "error", err)
continue
}
props, err := hl.HipGetDeviceProperties(i)
if err != nil {
slog.Warn(fmt.Sprintf("[%d] %s", i, err))
skip[i] = struct{}{}
slog.Warn("get properties", "id", i, "error", err)
continue
}
n := bytes.IndexByte(props.Name[:], 0)
name := string(props.Name[:n])
slog.Info(fmt.Sprintf("[%d] Name: %s", i, name))
// TODO is UUID actually populated on windows?
// Can luid be used on windows for setting visible devices (and is it actually set?)
n = bytes.IndexByte(props.GcnArchName[:], 0)
gfx := string(props.GcnArchName[:n])
slog.Info(fmt.Sprintf("[%d] GcnArchName: %s", i, gfx))
slog.Info("hip device", "id", i, "name", name, "gfx", gfx)
var major, minor, patch string
switch len(gfx) {
case 6:
major, minor, patch = gfx[3:4], gfx[4:5], gfx[5:]
case 7:
major, minor, patch = gfx[3:5], gfx[5:6], gfx[6:]
}
//slog.Info(fmt.Sprintf("[%d] Integrated: %d", i, props.iGPU)) // DOESN'T REPORT CORRECTLY! Always 0
// TODO Why isn't props.iGPU accurate!?
if strings.EqualFold(name, iGPUName) {
slog.Info(fmt.Sprintf("iGPU detected [%d] skipping", i))
skip[i] = struct{}{}
slog.Info("iGPU detected skipping", "id", i)
continue
}
if gfxOverride == "" {
if !slices.Contains[[]string, string](supported, gfx) {
slog.Warn(fmt.Sprintf("amdgpu [%d] %s is not supported by %s %v", i, gfx, libDir, supported))
slog.Warn("amdgpu is not supported", "gpu", i, "gpu_type", gfx, "library", libDir, "supported_types", supported)
// TODO - consider discrete markdown just for ROCM troubleshooting?
slog.Warn("See https://github.com/ollama/ollama/blob/main/docs/troubleshooting.md for HSA_OVERRIDE_GFX_VERSION usage")
skip[i] = struct{}{}
continue
} else {
slog.Info(fmt.Sprintf("amdgpu [%d] %s is supported", i, gfx))
slog.Info("amdgpu is supported", "gpu", i, "gpu_type", gfx)
}
}
totalMemory, freeMemory, err := hl.HipMemGetInfo()
freeMemory, totalMemory, err := hl.HipMemGetInfo()
if err != nil {
slog.Warn(fmt.Sprintf("[%d] %s", i, err))
slog.Warn("get mem info", "id", i, "error", err)
continue
}
// TODO according to docs, freeMem may lie on windows!
slog.Info(fmt.Sprintf("[%d] Total Mem: %d", i, totalMemory))
slog.Info(fmt.Sprintf("[%d] Free Mem: %d", i, freeMemory))
resp.memInfo.DeviceCount++
resp.memInfo.TotalMemory += totalMemory
resp.memInfo.FreeMemory += freeMemory
// iGPU detection, remove this check once we can support an iGPU variant of the rocm library
if totalMemory < IGPUMemLimit {
slog.Info("amdgpu appears to be an iGPU, skipping", "gpu", i, "total", format.HumanBytes2(totalMemory))
continue
}
// TODO revisit this once ROCm v6 is available on windows.
// v5.7 only reports VRAM used by this process, so it's completely wrong and unusable
slog.Info("amdgpu memory", "gpu", i, "total", format.HumanBytes2(totalMemory))
slog.Info("amdgpu memory", "gpu", i, "available", format.HumanBytes2(freeMemory))
gpuInfo := GpuInfo{
Library: "rocm",
memInfo: memInfo{
TotalMemory: totalMemory,
FreeMemory: freeMemory,
},
ID: fmt.Sprintf("%d", i), // TODO this is probably wrong if we specify visible devices
DependencyPath: libDir,
MinimumMemory: rocmMinimumMemory,
}
if major != "" {
gpuInfo.Major, err = strconv.Atoi(major)
if err != nil {
slog.Info("failed to parse version", "version", gfx, "error", err)
}
}
if minor != "" {
gpuInfo.Minor, err = strconv.Atoi(minor)
if err != nil {
slog.Info("failed to parse version", "version", gfx, "error", err)
}
}
if patch != "" {
gpuInfo.Patch, err = strconv.Atoi(patch)
if err != nil {
slog.Info("failed to parse version", "version", gfx, "error", err)
}
}
if gpuInfo.Major < RocmComputeMin {
slog.Warn(fmt.Sprintf("amdgpu [%s] too old gfx%d%d%d", gpuInfo.ID, gpuInfo.Major, gpuInfo.Minor, gpuInfo.Patch))
continue
}
resp = append(resp, gpuInfo)
}
if resp.memInfo.DeviceCount > 0 {
resp.Library = "rocm"
}
// Abort if all GPUs are skipped
if len(skip) >= count {
slog.Info("all detected amdgpus are skipped, falling back to CPU")
return
}
if len(skip) > 0 {
amdSetVisibleDevices(ids, skip)
}
UpdatePath(libDir)
return resp
}
func AMDValidateLibDir() (string, error) {
// On windows non-admins typically can't create links
// so instead of trying to rely on rpath and a link in
// $LibDir/rocm, we instead rely on setting PATH to point
// to the location of the ROCm library
// Installer payload location if we're running the installed binary
exe, err := os.Executable()
libDir, err := commonAMDValidateLibDir()
if err == nil {
rocmTargetDir := filepath.Join(filepath.Dir(exe), "rocm")
if rocmLibUsable(rocmTargetDir) {
slog.Debug("detected ROCM next to ollama executable " + rocmTargetDir)
return rocmTargetDir, nil
}
return libDir, nil
}
// Installer payload (if we're running from some other location)
@ -159,21 +180,6 @@ func AMDValidateLibDir() (string, error) {
return rocmTargetDir, nil
}
// Prefer explicit HIP env var
hipPath := os.Getenv("HIP_PATH")
if hipPath != "" {
hipLibDir := filepath.Join(hipPath, "bin")
if rocmLibUsable(hipLibDir) {
slog.Debug("detected ROCM via HIP_PATH=" + hipPath)
return hipLibDir, nil
}
}
// Well known location(s)
if rocmLibUsable(RocmStandardLocation) {
return RocmStandardLocation, nil
}
// Should not happen on windows since we include it in the installer, but stand-alone binary might hit this
slog.Warn("amdgpu detected, but no compatible rocm library found. Please install ROCm")
return "", fmt.Errorf("no suitable rocm found, falling back to CPU")

View file

@ -80,7 +80,7 @@ func cleanupTmpDirs() {
}
err = os.RemoveAll(d)
if err != nil {
slog.Debug(fmt.Sprintf("unable to cleanup stale tmpdir %s: %s", d, err))
slog.Debug("unable to cleanup stale tmpdir", "path", d, "error", err)
}
}
}
@ -120,7 +120,7 @@ func UpdatePath(dir string) {
}
}
newPath := strings.Join(append([]string{dir}, pathComponents...), ";")
slog.Info(fmt.Sprintf("Updating PATH to %s", newPath))
slog.Info("updating", "PATH", newPath)
os.Setenv("PATH", newPath)
}
// linux and darwin rely on rpath

22
gpu/cuda_common.go Normal file
View file

@ -0,0 +1,22 @@
//go:build linux || windows
package gpu
import (
"log/slog"
"strings"
)
func cudaGetVisibleDevicesEnv(gpuInfo []GpuInfo) (string, string) {
ids := []string{}
for _, info := range gpuInfo {
if info.Library != "cuda" {
// TODO shouldn't happen if things are wired correctly...
slog.Debug("cudaGetVisibleDevicesEnv skipping over non-cuda device", "library", info.Library)
continue
}
ids = append(ids, info.ID)
}
return "CUDA_VISIBLE_DEVICES", strings.Join(ids, ",")
}

View file

@ -16,7 +16,6 @@ import (
"os"
"path/filepath"
"runtime"
"strconv"
"strings"
"sync"
"unsafe"
@ -25,8 +24,8 @@ import (
)
type handles struct {
nvml *C.nvml_handle_t
cudart *C.cudart_handle_t
deviceCount int
cudart *C.cudart_handle_t
}
const (
@ -39,26 +38,10 @@ var gpuMutex sync.Mutex
// With our current CUDA compile flags, older than 5.0 will not work properly
var CudaComputeMin = [2]C.int{5, 0}
// Possible locations for the nvidia-ml library
var NvmlLinuxGlobs = []string{
"/usr/local/cuda/lib64/libnvidia-ml.so*",
"/usr/lib/x86_64-linux-gnu/nvidia/current/libnvidia-ml.so*",
"/usr/lib/x86_64-linux-gnu/libnvidia-ml.so*",
"/usr/lib/wsl/lib/libnvidia-ml.so*",
"/usr/lib/wsl/drivers/*/libnvidia-ml.so*",
"/opt/cuda/lib64/libnvidia-ml.so*",
"/usr/lib*/libnvidia-ml.so*",
"/usr/lib/aarch64-linux-gnu/nvidia/current/libnvidia-ml.so*",
"/usr/lib/aarch64-linux-gnu/libnvidia-ml.so*",
"/usr/local/lib*/libnvidia-ml.so*",
var RocmComputeMin = 9
// TODO: are these stubs ever valid?
"/opt/cuda/targets/x86_64-linux/lib/stubs/libnvidia-ml.so*",
}
var NvmlWindowsGlobs = []string{
"c:\\Windows\\System32\\nvml.dll",
}
// TODO find a better way to detect iGPU instead of minimum memory
const IGPUMemLimit = 1 * format.GibiByte // 512G is what they typically report, so anything less than 1G must be iGPU
var CudartLinuxGlobs = []string{
"/usr/local/cuda/lib64/libcudart.so*",
@ -88,26 +71,18 @@ func initGPUHandles() *handles {
// TODO - if the ollama build is CPU only, don't do these checks as they're irrelevant and confusing
gpuHandles := &handles{nil, nil}
var nvmlMgmtName string
var nvmlMgmtPatterns []string
gpuHandles := &handles{}
var cudartMgmtName string
var cudartMgmtPatterns []string
tmpDir, _ := PayloadsDir()
switch runtime.GOOS {
case "windows":
nvmlMgmtName = "nvml.dll"
nvmlMgmtPatterns = make([]string, len(NvmlWindowsGlobs))
copy(nvmlMgmtPatterns, NvmlWindowsGlobs)
cudartMgmtName = "cudart64_*.dll"
localAppData := os.Getenv("LOCALAPPDATA")
cudartMgmtPatterns = []string{filepath.Join(localAppData, "Programs", "Ollama", cudartMgmtName)}
cudartMgmtPatterns = append(cudartMgmtPatterns, CudartWindowsGlobs...)
case "linux":
nvmlMgmtName = "libnvidia-ml.so"
nvmlMgmtPatterns = make([]string, len(NvmlLinuxGlobs))
copy(nvmlMgmtPatterns, NvmlLinuxGlobs)
cudartMgmtName = "libcudart.so*"
if tmpDir != "" {
// TODO - add "payloads" for subprocess
@ -118,31 +93,21 @@ func initGPUHandles() *handles {
return gpuHandles
}
slog.Info("Detecting GPU type")
slog.Info("Detecting GPUs")
cudartLibPaths := FindGPULibs(cudartMgmtName, cudartMgmtPatterns)
if len(cudartLibPaths) > 0 {
cudart := LoadCUDARTMgmt(cudartLibPaths)
deviceCount, cudart, libPath := LoadCUDARTMgmt(cudartLibPaths)
if cudart != nil {
slog.Info("Nvidia GPU detected via cudart")
slog.Info("detected GPUs", "library", libPath, "count", deviceCount)
gpuHandles.cudart = cudart
return gpuHandles
}
}
// TODO once we build confidence, remove this and the gpu_info_nvml.[ch] files
nvmlLibPaths := FindGPULibs(nvmlMgmtName, nvmlMgmtPatterns)
if len(nvmlLibPaths) > 0 {
nvml := LoadNVMLMgmt(nvmlLibPaths)
if nvml != nil {
slog.Info("Nvidia GPU detected via nvidia-ml")
gpuHandles.nvml = nvml
gpuHandles.deviceCount = deviceCount
return gpuHandles
}
}
return gpuHandles
}
func GetGPUInfo() GpuInfo {
func GetGPUInfo() GpuInfoList {
// TODO - consider exploring lspci (and equivalent on windows) to check for
// GPUs so we can report warnings if we see Nvidia/AMD but fail to load the libraries
gpuMutex.Lock()
@ -150,9 +115,6 @@ func GetGPUInfo() GpuInfo {
gpuHandles := initGPUHandles()
defer func() {
if gpuHandles.nvml != nil {
C.nvml_release(*gpuHandles.nvml)
}
if gpuHandles.cudart != nil {
C.cudart_release(*gpuHandles.cudart)
}
@ -165,72 +127,63 @@ func GetGPUInfo() GpuInfo {
}
var memInfo C.mem_info_t
resp := GpuInfo{}
if gpuHandles.nvml != nil && (cpuVariant != "" || runtime.GOARCH != "amd64") {
C.nvml_check_vram(*gpuHandles.nvml, &memInfo)
resp := []GpuInfo{}
// NVIDIA first
for i := 0; i < gpuHandles.deviceCount; i++ {
// TODO once we support CPU compilation variants of GPU libraries refine this...
if cpuVariant == "" && runtime.GOARCH == "amd64" {
continue
}
gpuInfo := GpuInfo{
Library: "cuda",
}
C.cudart_check_vram(*gpuHandles.cudart, C.int(i), &memInfo)
if memInfo.err != nil {
slog.Info(fmt.Sprintf("[nvidia-ml] error looking up NVML GPU memory: %s", C.GoString(memInfo.err)))
slog.Info("error looking up nvidia GPU memory", "error", C.GoString(memInfo.err))
C.free(unsafe.Pointer(memInfo.err))
} else if memInfo.count > 0 {
// Verify minimum compute capability
var cc C.nvml_compute_capability_t
C.nvml_compute_capability(*gpuHandles.nvml, &cc)
if cc.err != nil {
slog.Info(fmt.Sprintf("[nvidia-ml] error looking up NVML GPU compute capability: %s", C.GoString(cc.err)))
C.free(unsafe.Pointer(cc.err))
} else if cc.major > CudaComputeMin[0] || (cc.major == CudaComputeMin[0] && cc.minor >= CudaComputeMin[1]) {
slog.Info(fmt.Sprintf("[nvidia-ml] NVML CUDA Compute Capability detected: %d.%d", cc.major, cc.minor))
resp.Library = "cuda"
resp.MinimumMemory = cudaMinimumMemory
} else {
slog.Info(fmt.Sprintf("[nvidia-ml] CUDA GPU is too old. Falling back to CPU mode. Compute Capability detected: %d.%d", cc.major, cc.minor))
}
continue
}
} else if gpuHandles.cudart != nil && (cpuVariant != "" || runtime.GOARCH != "amd64") {
C.cudart_check_vram(*gpuHandles.cudart, &memInfo)
if memInfo.err != nil {
slog.Info(fmt.Sprintf("[cudart] error looking up CUDART GPU memory: %s", C.GoString(memInfo.err)))
C.free(unsafe.Pointer(memInfo.err))
} else if memInfo.count > 0 {
// Verify minimum compute capability
var cc C.cudart_compute_capability_t
C.cudart_compute_capability(*gpuHandles.cudart, &cc)
if cc.err != nil {
slog.Info(fmt.Sprintf("[cudart] error looking up CUDA compute capability: %s", C.GoString(cc.err)))
C.free(unsafe.Pointer(cc.err))
} else if cc.major > CudaComputeMin[0] || (cc.major == CudaComputeMin[0] && cc.minor >= CudaComputeMin[1]) {
slog.Info(fmt.Sprintf("[cudart] CUDART CUDA Compute Capability detected: %d.%d", cc.major, cc.minor))
resp.Library = "cuda"
resp.MinimumMemory = cudaMinimumMemory
} else {
slog.Info(fmt.Sprintf("[cudart] CUDA GPU is too old. Falling back to CPU mode. Compute Capability detected: %d.%d", cc.major, cc.minor))
}
if memInfo.major < CudaComputeMin[0] || (memInfo.major == CudaComputeMin[0] && memInfo.minor < CudaComputeMin[1]) {
slog.Info(fmt.Sprintf("[%d] CUDA GPU is too old. Compute Capability detected: %d.%d", i, memInfo.major, memInfo.minor))
continue
}
} else {
AMDGetGPUInfo(&resp)
if resp.Library != "" {
resp.MinimumMemory = rocmMinimumMemory
return resp
}
}
if resp.Library == "" {
C.cpu_check_ram(&memInfo)
resp.Library = "cpu"
resp.Variant = cpuVariant
}
if memInfo.err != nil {
slog.Info(fmt.Sprintf("error looking up CPU memory: %s", C.GoString(memInfo.err)))
C.free(unsafe.Pointer(memInfo.err))
return resp
gpuInfo.TotalMemory = uint64(memInfo.total)
gpuInfo.FreeMemory = uint64(memInfo.free)
gpuInfo.ID = C.GoString(&memInfo.gpu_id[0])
gpuInfo.Major = int(memInfo.major)
gpuInfo.Minor = int(memInfo.minor)
gpuInfo.MinimumMemory = cudaMinimumMemory
// TODO potentially sort on our own algorithm instead of what the underlying GPU library does...
resp = append(resp, gpuInfo)
}
// Then AMD
resp = append(resp, AMDGetGPUInfo()...)
if len(resp) == 0 {
C.cpu_check_ram(&memInfo)
if memInfo.err != nil {
slog.Info("error looking up CPU memory", "error", C.GoString(memInfo.err))
C.free(unsafe.Pointer(memInfo.err))
return resp
}
gpuInfo := GpuInfo{
Library: "cpu",
Variant: cpuVariant,
}
gpuInfo.TotalMemory = uint64(memInfo.total)
gpuInfo.FreeMemory = uint64(memInfo.free)
gpuInfo.ID = C.GoString(&memInfo.gpu_id[0])
resp = append(resp, gpuInfo)
}
resp.DeviceCount = uint32(memInfo.count)
resp.FreeMemory = uint64(memInfo.free)
resp.TotalMemory = uint64(memInfo.total)
return resp
}
func getCPUMem() (memInfo, error) {
func GetCPUMem() (memInfo, error) {
var ret memInfo
var info C.mem_info_t
C.cpu_check_ram(&info)
@ -243,29 +196,11 @@ func getCPUMem() (memInfo, error) {
return ret, nil
}
func CheckVRAM() (uint64, error) {
userLimit := os.Getenv("OLLAMA_MAX_VRAM")
if userLimit != "" {
avail, err := strconv.ParseInt(userLimit, 10, 64)
if err != nil {
return 0, fmt.Errorf("Invalid OLLAMA_MAX_VRAM setting %s: %s", userLimit, err)
}
slog.Info(fmt.Sprintf("user override OLLAMA_MAX_VRAM=%d", avail))
return uint64(avail), nil
}
gpuInfo := GetGPUInfo()
if gpuInfo.FreeMemory > 0 && (gpuInfo.Library == "cuda" || gpuInfo.Library == "rocm") {
return gpuInfo.FreeMemory, nil
}
return 0, fmt.Errorf("no GPU detected") // TODO - better handling of CPU based memory determiniation
}
func FindGPULibs(baseLibName string, patterns []string) []string {
// Multiple GPU libraries may exist, and some may not work, so keep trying until we exhaust them
var ldPaths []string
gpuLibPaths := []string{}
slog.Info(fmt.Sprintf("Searching for GPU management library %s", baseLibName))
slog.Debug("Searching for GPU library", "name", baseLibName)
switch runtime.GOOS {
case "windows":
@ -283,7 +218,7 @@ func FindGPULibs(baseLibName string, patterns []string) []string {
}
patterns = append(patterns, filepath.Join(d, baseLibName+"*"))
}
slog.Debug(fmt.Sprintf("gpu management search paths: %v", patterns))
slog.Debug("gpu library search", "globs", patterns)
for _, pattern := range patterns {
// Ignore glob discovery errors
matches, _ := filepath.Glob(pattern)
@ -311,28 +246,11 @@ func FindGPULibs(baseLibName string, patterns []string) []string {
}
}
}
slog.Info(fmt.Sprintf("Discovered GPU libraries: %v", gpuLibPaths))
slog.Debug("discovered GPU libraries", "paths", gpuLibPaths)
return gpuLibPaths
}
func LoadNVMLMgmt(nvmlLibPaths []string) *C.nvml_handle_t {
var resp C.nvml_init_resp_t
resp.ch.verbose = getVerboseState()
for _, libPath := range nvmlLibPaths {
lib := C.CString(libPath)
defer C.free(unsafe.Pointer(lib))
C.nvml_init(lib, &resp)
if resp.err != nil {
slog.Info(fmt.Sprintf("Unable to load NVML management library %s: %s", libPath, C.GoString(resp.err)))
C.free(unsafe.Pointer(resp.err))
} else {
return &resp.ch
}
}
return nil
}
func LoadCUDARTMgmt(cudartLibPaths []string) *C.cudart_handle_t {
func LoadCUDARTMgmt(cudartLibPaths []string) (int, *C.cudart_handle_t, string) {
var resp C.cudart_init_resp_t
resp.ch.verbose = getVerboseState()
for _, libPath := range cudartLibPaths {
@ -340,13 +258,13 @@ func LoadCUDARTMgmt(cudartLibPaths []string) *C.cudart_handle_t {
defer C.free(unsafe.Pointer(lib))
C.cudart_init(lib, &resp)
if resp.err != nil {
slog.Info(fmt.Sprintf("Unable to load cudart CUDA management library %s: %s", libPath, C.GoString(resp.err)))
slog.Debug("Unable to load cudart", "library", libPath, "error", C.GoString(resp.err))
C.free(unsafe.Pointer(resp.err))
} else {
return &resp.ch
return int(resp.num_devices), &resp.ch, libPath
}
}
return nil
return 0, nil, ""
}
func getVerboseState() C.uint16_t {
@ -355,3 +273,22 @@ func getVerboseState() C.uint16_t {
}
return C.uint16_t(0)
}
// Given the list of GPUs this instantiation is targeted for,
// figure out the visible devices environment variable
//
// If different libraries are detected, the first one is what we use
func (l GpuInfoList) GetVisibleDevicesEnv() (string, string) {
if len(l) == 0 {
return "", ""
}
switch l[0].Library {
case "cuda":
return cudaGetVisibleDevicesEnv(l)
case "rocm":
return rocmGetVisibleDevicesEnv(l)
default:
slog.Debug("no filter required for library " + l[0].Library)
return "", ""
}
}

View file

@ -9,52 +9,41 @@ package gpu
*/
import "C"
import (
"fmt"
"log/slog"
"os"
"runtime"
"strconv"
)
// CheckVRAM returns the free VRAM in bytes on Linux machines with NVIDIA GPUs
func CheckVRAM() (uint64, error) {
userLimit := os.Getenv("OLLAMA_MAX_VRAM")
if userLimit != "" {
avail, err := strconv.ParseInt(userLimit, 10, 64)
if err != nil {
return 0, fmt.Errorf("Invalid OLLAMA_MAX_VRAM setting %s: %s", userLimit, err)
}
slog.Info(fmt.Sprintf("user override OLLAMA_MAX_VRAM=%d", avail))
return uint64(avail), nil
}
func GetGPUInfo() GpuInfoList {
mem, _ := GetCPUMem()
if runtime.GOARCH == "amd64" {
// gpu not supported, this may not be metal
return 0, nil
}
return uint64(C.getRecommendedMaxVRAM()), nil
}
func GetGPUInfo() GpuInfo {
mem, _ := getCPUMem()
if runtime.GOARCH == "amd64" {
return GpuInfo{
Library: "cpu",
Variant: GetCPUVariant(),
memInfo: mem,
return []GpuInfo{
{
Library: "cpu",
Variant: GetCPUVariant(),
memInfo: mem,
},
}
}
return GpuInfo{
info := GpuInfo{
Library: "metal",
memInfo: mem,
ID: "0",
}
info.TotalMemory = uint64(C.getRecommendedMaxVRAM())
// TODO is there a way to gather actual allocated video memory? (currentAllocatedSize doesn't work)
info.FreeMemory = info.TotalMemory
info.MinimumMemory = 0
return []GpuInfo{info}
}
func getCPUMem() (memInfo, error) {
func GetCPUMem() (memInfo, error) {
return memInfo{
TotalMemory: uint64(C.getPhysicalMemory()),
FreeMemory: 0,
DeviceCount: 1,
}, nil
}
func (l GpuInfoList) GetVisibleDevicesEnv() (string, string) {
// No-op on darwin
return "", ""
}

View file

@ -38,12 +38,17 @@
extern "C" {
#endif
#define GPU_ID_LEN 64
typedef struct mem_info {
char *err; // If non-nill, caller responsible for freeing
char gpu_id[GPU_ID_LEN];
uint64_t total;
uint64_t free;
unsigned int count;
int igpu_index; // If >= 0, we detected an integrated GPU to ignore
char *err; // If non-nill, caller responsible for freeing
// Compute Capability
int major;
int minor;
} mem_info_t;
void cpu_check_ram(mem_info_t *resp);
@ -52,7 +57,6 @@ void cpu_check_ram(mem_info_t *resp);
}
#endif
#include "gpu_info_nvml.h"
#include "gpu_info_cudart.h"
#endif // __GPU_INFO_H__

View file

@ -8,9 +8,11 @@ void cpu_check_ram(mem_info_t *resp) {
MEMORYSTATUSEX info;
info.dwLength = sizeof(info);
if (GlobalMemoryStatusEx(&info) != 0) {
resp->count = 1;
resp->total = info.ullTotalPhys;
resp->free = info.ullAvailPhys;
resp->major = 0;
resp->minor = 0;
snprintf(&resp->gpu_id[0], GPU_ID_LEN, "0");
} else {
resp->err = LOAD_ERR();
}
@ -27,9 +29,11 @@ void cpu_check_ram(mem_info_t *resp) {
if (sysinfo(&info) != 0) {
resp->err = strdup(strerror(errno));
} else {
resp->count = 1;
resp->total = info.totalram * info.mem_unit;
resp->free = info.freeram * info.mem_unit;
resp->major = 0;
resp->minor = 0;
snprintf(&resp->gpu_id[0], GPU_ID_LEN, "0");
}
return;
}

View file

@ -6,6 +6,7 @@
void cudart_init(char *cudart_lib_path, cudart_init_resp_t *resp) {
cudartReturn_t ret;
resp->err = NULL;
resp->num_devices = 0;
const int buflen = 256;
char buf[buflen + 1];
int i;
@ -21,6 +22,7 @@ void cudart_init(char *cudart_lib_path, cudart_init_resp_t *resp) {
{"cudaGetDeviceCount", (void *)&resp->ch.cudaGetDeviceCount},
{"cudaDeviceGetAttribute", (void *)&resp->ch.cudaDeviceGetAttribute},
{"cudaDriverGetVersion", (void *)&resp->ch.cudaDriverGetVersion},
{"cudaGetDeviceProperties", (void *)&resp->ch.cudaGetDeviceProperties},
{NULL, NULL},
};
@ -36,13 +38,7 @@ void cudart_init(char *cudart_lib_path, cudart_init_resp_t *resp) {
return;
}
// TODO once we've squashed the remaining corner cases remove this log
LOG(resp->ch.verbose, "wiring cudart library functions in %s\n", cudart_lib_path);
for (i = 0; l[i].s != NULL; i++) {
// TODO once we've squashed the remaining corner cases remove this log
LOG(resp->ch.verbose, "dlsym: %s\n", l[i].s);
*l[i].p = LOAD_SYMBOL(resp->ch.handle, l[i].s);
if (!l[i].p) {
char *msg = LOAD_ERR();
@ -63,7 +59,7 @@ void cudart_init(char *cudart_lib_path, cudart_init_resp_t *resp) {
UNLOAD_LIBRARY(resp->ch.handle);
resp->ch.handle = NULL;
if (ret == CUDA_ERROR_INSUFFICIENT_DRIVER) {
resp->err = strdup("your nvidia driver is too old or missing, please upgrade to run ollama");
resp->err = strdup("your nvidia driver is too old or missing. If you have a CUDA GPU please upgrade to run ollama");
return;
}
snprintf(buf, buflen, "cudart init failure: %d", ret);
@ -85,110 +81,95 @@ void cudart_init(char *cudart_lib_path, cudart_init_resp_t *resp) {
driverVersion.minor = (version - (driverVersion.major * 1000)) / 10;
LOG(resp->ch.verbose, "CUDA driver version: %d-%d\n", driverVersion.major, driverVersion.minor);
}
ret = (*resp->ch.cudaGetDeviceCount)(&resp->num_devices);
if (ret != CUDART_SUCCESS) {
LOG(resp->ch.verbose, "cudaGetDeviceCount err: %d\n", ret);
UNLOAD_LIBRARY(resp->ch.handle);
resp->ch.handle = NULL;
snprintf(buf, buflen, "unable to get device count: %d", ret);
resp->err = strdup(buf);
return;
}
}
void cudart_check_vram(cudart_handle_t h, mem_info_t *resp) {
void cudart_check_vram(cudart_handle_t h, int i, mem_info_t *resp) {
resp->err = NULL;
cudartMemory_t memInfo = {0,0,0};
cudartReturn_t ret;
const int buflen = 256;
char buf[buflen + 1];
int i;
if (h.handle == NULL) {
resp->err = strdup("cudart handle isn't initialized");
return;
}
// cudaGetDeviceCount takes int type, resp-> count is uint
int deviceCount;
ret = (*h.cudaGetDeviceCount)(&deviceCount);
ret = (*h.cudaSetDevice)(i);
if (ret != CUDART_SUCCESS) {
snprintf(buf, buflen, "unable to get device count: %d", ret);
snprintf(buf, buflen, "cudart device failed to initialize");
resp->err = strdup(buf);
return;
}
cudaDeviceProp_t props;
ret = (*h.cudaGetDeviceProperties)(&props, i);
if (ret != CUDART_SUCCESS) {
LOG(h.verbose, "[%d] device properties lookup failure: %d\n", i, ret);
snprintf(&resp->gpu_id[0], GPU_ID_LEN, "%d", i);
resp->major = 0;
resp->minor = 0;
} else {
resp->count = (unsigned int)deviceCount;
}
resp->total = 0;
resp->free = 0;
for (i = 0; i < resp-> count; i++) {
ret = (*h.cudaSetDevice)(i);
if (ret != CUDART_SUCCESS) {
snprintf(buf, buflen, "cudart device failed to initialize");
resp->err = strdup(buf);
return;
int allNull = 1;
for (int j = 0; j < 16; j++) {
if (props.uuid.bytes[j] != 0) {
allNull = 0;
break;
}
}
ret = (*h.cudaMemGetInfo)(&memInfo.free, &memInfo.total);
if (ret != CUDART_SUCCESS) {
snprintf(buf, buflen, "cudart device memory info lookup failure %d", ret);
resp->err = strdup(buf);
return;
if (allNull != 0) {
snprintf(&resp->gpu_id[0], GPU_ID_LEN, "%d", i);
} else {
// GPU-d110a105-ac29-1d54-7b49-9c90440f215b
snprintf(&resp->gpu_id[0], GPU_ID_LEN,
"GPU-%02x%02x%02x%02x-%02x%02x-%02x%02x-%02x%02x-%02x%02x%02x%02x%02x%02x",
props.uuid.bytes[0],
props.uuid.bytes[1],
props.uuid.bytes[2],
props.uuid.bytes[3],
props.uuid.bytes[4],
props.uuid.bytes[5],
props.uuid.bytes[6],
props.uuid.bytes[7],
props.uuid.bytes[8],
props.uuid.bytes[9],
props.uuid.bytes[10],
props.uuid.bytes[11],
props.uuid.bytes[12],
props.uuid.bytes[13],
props.uuid.bytes[14],
props.uuid.bytes[15]
);
}
resp->major = props.major;
resp->minor = props.minor;
LOG(h.verbose, "[%d] CUDA totalMem %lu\n", i, memInfo.total);
LOG(h.verbose, "[%d] CUDA freeMem %lu\n", i, memInfo.free);
resp->total += memInfo.total;
resp->free += memInfo.free;
// TODO add other useful properties from props
}
}
void cudart_compute_capability(cudart_handle_t h, cudart_compute_capability_t *resp) {
resp->err = NULL;
resp->major = 0;
resp->minor = 0;
int major = 0;
int minor = 0;
cudartReturn_t ret;
const int buflen = 256;
char buf[buflen + 1];
int i;
if (h.handle == NULL) {
resp->err = strdup("cudart handle not initialized");
return;
}
int devices;
ret = (*h.cudaGetDeviceCount)(&devices);
ret = (*h.cudaMemGetInfo)(&memInfo.free, &memInfo.total);
if (ret != CUDART_SUCCESS) {
snprintf(buf, buflen, "unable to get cudart device count: %d", ret);
snprintf(buf, buflen, "cudart device memory info lookup failure %d", ret);
resp->err = strdup(buf);
return;
}
for (i = 0; i < devices; i++) {
ret = (*h.cudaSetDevice)(i);
if (ret != CUDART_SUCCESS) {
snprintf(buf, buflen, "cudart device failed to initialize");
resp->err = strdup(buf);
return;
}
resp->total = memInfo.total;
resp->free = memInfo.free;
ret = (*h.cudaDeviceGetAttribute)(&major, cudartDevAttrComputeCapabilityMajor, i);
if (ret != CUDART_SUCCESS) {
snprintf(buf, buflen, "device compute capability lookup failure %d: %d", i, ret);
resp->err = strdup(buf);
return;
}
ret = (*h.cudaDeviceGetAttribute)(&minor, cudartDevAttrComputeCapabilityMinor, i);
if (ret != CUDART_SUCCESS) {
snprintf(buf, buflen, "device compute capability lookup failure %d: %d", i, ret);
resp->err = strdup(buf);
return;
}
// Report the lowest major.minor we detect as that limits our compatibility
if (resp->major == 0 || resp->major > major ) {
resp->major = major;
resp->minor = minor;
} else if ( resp->major == major && resp->minor > minor ) {
resp->minor = minor;
}
}
LOG(h.verbose, "[%s] CUDA totalMem %lu\n", resp->gpu_id, resp->total);
LOG(h.verbose, "[%s] CUDA freeMem %lu\n", resp->gpu_id, resp->free);
LOG(h.verbose, "[%s] Compute Capability %d.%d\n", resp->gpu_id, resp->major, resp->minor);
}
void cudart_release(cudart_handle_t h) {

View file

@ -6,7 +6,8 @@
// Just enough typedef's to dlopen/dlsym for memory information
typedef enum cudartReturn_enum {
CUDART_SUCCESS = 0,
CUDART_UNSUPPORTED = 1,
CUDA_ERROR_INVALID_VALUE = 1,
CUDA_ERROR_MEMORY_ALLOCATION = 2,
CUDA_ERROR_INSUFFICIENT_DRIVER = 35,
// Other values omitted for now...
} cudartReturn_t;
@ -14,6 +15,11 @@ typedef enum cudartReturn_enum {
typedef enum cudartDeviceAttr_enum {
cudartDevAttrComputeCapabilityMajor = 75,
cudartDevAttrComputeCapabilityMinor = 76,
// TODO - not yet wired up but may be useful for Jetson or other
// integrated GPU scenarios with shared memory
cudaDevAttrIntegrated = 18
} cudartDeviceAttr_t;
typedef void *cudartDevice_t; // Opaque is sufficient
@ -28,6 +34,92 @@ typedef struct cudartDriverVersion {
int minor;
} cudartDriverVersion_t;
typedef struct cudaUUID {
unsigned char bytes[16];
} cudaUUID_t;
typedef struct cudaDeviceProp {
char name[256]; /**< ASCII string identifying device */
cudaUUID_t uuid; /**< 16-byte unique identifier */
char luid[8]; /**< 8-byte locally unique identifier. Value is undefined on TCC and non-Windows platforms */
unsigned int luidDeviceNodeMask; /**< LUID device node mask. Value is undefined on TCC and non-Windows platforms */
size_t totalGlobalMem; /**< Global memory available on device in bytes */
size_t sharedMemPerBlock; /**< Shared memory available per block in bytes */
int regsPerBlock; /**< 32-bit registers available per block */
int warpSize; /**< Warp size in threads */
size_t memPitch; /**< Maximum pitch in bytes allowed by memory copies */
int maxThreadsPerBlock; /**< Maximum number of threads per block */
int maxThreadsDim[3]; /**< Maximum size of each dimension of a block */
int maxGridSize[3]; /**< Maximum size of each dimension of a grid */
int clockRate; /**< Clock frequency in kilohertz */
size_t totalConstMem; /**< Constant memory available on device in bytes */
int major; /**< Major compute capability */
int minor; /**< Minor compute capability */
size_t textureAlignment; /**< Alignment requirement for textures */
size_t texturePitchAlignment; /**< Pitch alignment requirement for texture references bound to pitched memory */
int deviceOverlap; /**< Device can concurrently copy memory and execute a kernel. Deprecated. Use instead asyncEngineCount. */
int multiProcessorCount; /**< Number of multiprocessors on device */
int kernelExecTimeoutEnabled; /**< Specified whether there is a run time limit on kernels */
int integrated; /**< Device is integrated as opposed to discrete */
int canMapHostMemory; /**< Device can map host memory with cudaHostAlloc/cudaHostGetDevicePointer */
int computeMode; /**< Compute mode (See ::cudaComputeMode) */
int maxTexture1D; /**< Maximum 1D texture size */
int maxTexture1DMipmap; /**< Maximum 1D mipmapped texture size */
int maxTexture1DLinear; /**< Deprecated, do not use. Use cudaDeviceGetTexture1DLinearMaxWidth() or cuDeviceGetTexture1DLinearMaxWidth() instead. */
int maxTexture2D[2]; /**< Maximum 2D texture dimensions */
int maxTexture2DMipmap[2]; /**< Maximum 2D mipmapped texture dimensions */
int maxTexture2DLinear[3]; /**< Maximum dimensions (width, height, pitch) for 2D textures bound to pitched memory */
int maxTexture2DGather[2]; /**< Maximum 2D texture dimensions if texture gather operations have to be performed */
int maxTexture3D[3]; /**< Maximum 3D texture dimensions */
int maxTexture3DAlt[3]; /**< Maximum alternate 3D texture dimensions */
int maxTextureCubemap; /**< Maximum Cubemap texture dimensions */
int maxTexture1DLayered[2]; /**< Maximum 1D layered texture dimensions */
int maxTexture2DLayered[3]; /**< Maximum 2D layered texture dimensions */
int maxTextureCubemapLayered[2];/**< Maximum Cubemap layered texture dimensions */
int maxSurface1D; /**< Maximum 1D surface size */
int maxSurface2D[2]; /**< Maximum 2D surface dimensions */
int maxSurface3D[3]; /**< Maximum 3D surface dimensions */
int maxSurface1DLayered[2]; /**< Maximum 1D layered surface dimensions */
int maxSurface2DLayered[3]; /**< Maximum 2D layered surface dimensions */
int maxSurfaceCubemap; /**< Maximum Cubemap surface dimensions */
int maxSurfaceCubemapLayered[2];/**< Maximum Cubemap layered surface dimensions */
size_t surfaceAlignment; /**< Alignment requirements for surfaces */
int concurrentKernels; /**< Device can possibly execute multiple kernels concurrently */
int ECCEnabled; /**< Device has ECC support enabled */
int pciBusID; /**< PCI bus ID of the device */
int pciDeviceID; /**< PCI device ID of the device */
int pciDomainID; /**< PCI domain ID of the device */
int tccDriver; /**< 1 if device is a Tesla device using TCC driver, 0 otherwise */
int asyncEngineCount; /**< Number of asynchronous engines */
int unifiedAddressing; /**< Device shares a unified address space with the host */
int memoryClockRate; /**< Peak memory clock frequency in kilohertz */
int memoryBusWidth; /**< Global memory bus width in bits */
int l2CacheSize; /**< Size of L2 cache in bytes */
int persistingL2CacheMaxSize; /**< Device's maximum l2 persisting lines capacity setting in bytes */
int maxThreadsPerMultiProcessor;/**< Maximum resident threads per multiprocessor */
int streamPrioritiesSupported; /**< Device supports stream priorities */
int globalL1CacheSupported; /**< Device supports caching globals in L1 */
int localL1CacheSupported; /**< Device supports caching locals in L1 */
size_t sharedMemPerMultiprocessor; /**< Shared memory available per multiprocessor in bytes */
int regsPerMultiprocessor; /**< 32-bit registers available per multiprocessor */
int managedMemory; /**< Device supports allocating managed memory on this system */
int isMultiGpuBoard; /**< Device is on a multi-GPU board */
int multiGpuBoardGroupID; /**< Unique identifier for a group of devices on the same multi-GPU board */
int hostNativeAtomicSupported; /**< Link between the device and the host supports native atomic operations */
int singleToDoublePrecisionPerfRatio; /**< Ratio of single precision performance (in floating-point operations per second) to double precision performance */
int pageableMemoryAccess; /**< Device supports coherently accessing pageable memory without calling cudaHostRegister on it */
int concurrentManagedAccess; /**< Device can coherently access managed memory concurrently with the CPU */
int computePreemptionSupported; /**< Device supports Compute Preemption */
int canUseHostPointerForRegisteredMem; /**< Device can access host registered memory at the same virtual address as the CPU */
int cooperativeLaunch; /**< Device supports launching cooperative kernels via ::cudaLaunchCooperativeKernel */
int cooperativeMultiDeviceLaunch; /**< Deprecated, cudaLaunchCooperativeKernelMultiDevice is deprecated. */
size_t sharedMemPerBlockOptin; /**< Per device maximum shared memory per block usable by special opt in */
int pageableMemoryAccessUsesHostPageTables; /**< Device accesses pageable memory via the host's page tables */
int directManagedMemAccessFromHost; /**< Host can directly access managed memory on the device without migration. */
int maxBlocksPerMultiProcessor; /**< Maximum number of resident blocks per multiprocessor */
int accessPolicyMaxWindowSize; /**< The maximum value of ::cudaAccessPolicyWindow::num_bytes. */
size_t reservedSharedMemPerBlock; /**< Shared memory reserved by CUDA driver per block in bytes */
} cudaDeviceProp_t;
typedef struct cudart_handle {
void *handle;
uint16_t verbose;
@ -38,23 +130,17 @@ typedef struct cudart_handle {
cudartReturn_t (*cudaGetDeviceCount)(int *);
cudartReturn_t (*cudaDeviceGetAttribute)(int* value, cudartDeviceAttr_t attr, int device);
cudartReturn_t (*cudaDriverGetVersion) (int *driverVersion);
cudartReturn_t (*cudaGetDeviceProperties) (cudaDeviceProp_t* prop, int device);
} cudart_handle_t;
typedef struct cudart_init_resp {
char *err; // If err is non-null handle is invalid
cudart_handle_t ch;
int num_devices;
} cudart_init_resp_t;
typedef struct cudart_compute_capability {
char *err;
int major;
int minor;
} cudart_compute_capability_t;
void cudart_init(char *cudart_lib_path, cudart_init_resp_t *resp);
void cudart_check_vram(cudart_handle_t ch, mem_info_t *resp);
void cudart_compute_capability(cudart_handle_t th, cudart_compute_capability_t *cc);
void cudart_check_vram(cudart_handle_t ch, int device_id, mem_info_t *resp);
void cudart_release(cudart_handle_t ch);
#endif // __GPU_INFO_CUDART_H__

View file

@ -1,221 +0,0 @@
#ifndef __APPLE__ // TODO - maybe consider nvidia support on intel macs?
#include <string.h>
#include "gpu_info_nvml.h"
void nvml_init(char *nvml_lib_path, nvml_init_resp_t *resp) {
nvmlReturn_t ret;
resp->err = NULL;
const int buflen = 256;
char buf[buflen + 1];
int i;
struct lookup {
char *s;
void **p;
} l[] = {
{"nvmlInit_v2", (void *)&resp->ch.nvmlInit_v2},
{"nvmlShutdown", (void *)&resp->ch.nvmlShutdown},
{"nvmlDeviceGetHandleByIndex", (void *)&resp->ch.nvmlDeviceGetHandleByIndex},
{"nvmlDeviceGetMemoryInfo", (void *)&resp->ch.nvmlDeviceGetMemoryInfo},
{"nvmlDeviceGetCount_v2", (void *)&resp->ch.nvmlDeviceGetCount_v2},
{"nvmlDeviceGetCudaComputeCapability", (void *)&resp->ch.nvmlDeviceGetCudaComputeCapability},
{"nvmlSystemGetDriverVersion", (void *)&resp->ch.nvmlSystemGetDriverVersion},
{"nvmlDeviceGetName", (void *)&resp->ch.nvmlDeviceGetName},
{"nvmlDeviceGetSerial", (void *)&resp->ch.nvmlDeviceGetSerial},
{"nvmlDeviceGetVbiosVersion", (void *)&resp->ch.nvmlDeviceGetVbiosVersion},
{"nvmlDeviceGetBoardPartNumber", (void *)&resp->ch.nvmlDeviceGetBoardPartNumber},
{"nvmlDeviceGetBrand", (void *)&resp->ch.nvmlDeviceGetBrand},
{NULL, NULL},
};
resp->ch.handle = LOAD_LIBRARY(nvml_lib_path, RTLD_LAZY);
if (!resp->ch.handle) {
char *msg = LOAD_ERR();
LOG(resp->ch.verbose, "library %s load err: %s\n", nvml_lib_path, msg);
snprintf(buf, buflen,
"Unable to load %s library to query for Nvidia GPUs: %s",
nvml_lib_path, msg);
free(msg);
resp->err = strdup(buf);
return;
}
// TODO once we've squashed the remaining corner cases remove this log
LOG(resp->ch.verbose, "wiring nvidia management library functions in %s\n", nvml_lib_path);
for (i = 0; l[i].s != NULL; i++) {
// TODO once we've squashed the remaining corner cases remove this log
LOG(resp->ch.verbose, "dlsym: %s\n", l[i].s);
*l[i].p = LOAD_SYMBOL(resp->ch.handle, l[i].s);
if (!l[i].p) {
resp->ch.handle = NULL;
char *msg = LOAD_ERR();
LOG(resp->ch.verbose, "dlerr: %s\n", msg);
UNLOAD_LIBRARY(resp->ch.handle);
snprintf(buf, buflen, "symbol lookup for %s failed: %s", l[i].s,
msg);
free(msg);
resp->err = strdup(buf);
return;
}
}
ret = (*resp->ch.nvmlInit_v2)();
if (ret != NVML_SUCCESS) {
LOG(resp->ch.verbose, "nvmlInit_v2 err: %d\n", ret);
UNLOAD_LIBRARY(resp->ch.handle);
resp->ch.handle = NULL;
snprintf(buf, buflen, "nvml vram init failure: %d", ret);
resp->err = strdup(buf);
return;
}
// Report driver version if we're in verbose mode, ignore errors
ret = (*resp->ch.nvmlSystemGetDriverVersion)(buf, buflen);
if (ret != NVML_SUCCESS) {
LOG(resp->ch.verbose, "nvmlSystemGetDriverVersion failed: %d\n", ret);
} else {
LOG(resp->ch.verbose, "CUDA driver version: %s\n", buf);
}
}
void nvml_check_vram(nvml_handle_t h, mem_info_t *resp) {
resp->err = NULL;
nvmlDevice_t device;
nvmlMemory_t memInfo = {0};
nvmlReturn_t ret;
const int buflen = 256;
char buf[buflen + 1];
int i;
if (h.handle == NULL) {
resp->err = strdup("nvml handle isn't initialized");
return;
}
ret = (*h.nvmlDeviceGetCount_v2)(&resp->count);
if (ret != NVML_SUCCESS) {
snprintf(buf, buflen, "unable to get device count: %d", ret);
resp->err = strdup(buf);
return;
}
resp->total = 0;
resp->free = 0;
for (i = 0; i < resp->count; i++) {
ret = (*h.nvmlDeviceGetHandleByIndex)(i, &device);
if (ret != NVML_SUCCESS) {
snprintf(buf, buflen, "unable to get device handle %d: %d", i, ret);
resp->err = strdup(buf);
return;
}
ret = (*h.nvmlDeviceGetMemoryInfo)(device, &memInfo);
if (ret != NVML_SUCCESS) {
snprintf(buf, buflen, "device memory info lookup failure %d: %d", i, ret);
resp->err = strdup(buf);
return;
}
if (h.verbose) {
nvmlBrandType_t brand = 0;
// When in verbose mode, report more information about
// the card we discover, but don't fail on error
ret = (*h.nvmlDeviceGetName)(device, buf, buflen);
if (ret != NVML_SUCCESS) {
LOG(h.verbose, "nvmlDeviceGetName failed: %d\n", ret);
} else {
LOG(h.verbose, "[%d] CUDA device name: %s\n", i, buf);
}
ret = (*h.nvmlDeviceGetBoardPartNumber)(device, buf, buflen);
if (ret != NVML_SUCCESS) {
LOG(h.verbose, "nvmlDeviceGetBoardPartNumber failed: %d\n", ret);
} else {
LOG(h.verbose, "[%d] CUDA part number: %s\n", i, buf);
}
ret = (*h.nvmlDeviceGetSerial)(device, buf, buflen);
if (ret != NVML_SUCCESS) {
LOG(h.verbose, "nvmlDeviceGetSerial failed: %d\n", ret);
} else {
LOG(h.verbose, "[%d] CUDA S/N: %s\n", i, buf);
}
ret = (*h.nvmlDeviceGetVbiosVersion)(device, buf, buflen);
if (ret != NVML_SUCCESS) {
LOG(h.verbose, "nvmlDeviceGetVbiosVersion failed: %d\n", ret);
} else {
LOG(h.verbose, "[%d] CUDA vbios version: %s\n", i, buf);
}
ret = (*h.nvmlDeviceGetBrand)(device, &brand);
if (ret != NVML_SUCCESS) {
LOG(h.verbose, "nvmlDeviceGetBrand failed: %d\n", ret);
} else {
LOG(h.verbose, "[%d] CUDA brand: %d\n", i, brand);
}
}
LOG(h.verbose, "[%d] CUDA totalMem %ld\n", i, memInfo.total);
LOG(h.verbose, "[%d] CUDA freeMem %ld\n", i, memInfo.free);
resp->total += memInfo.total;
resp->free += memInfo.free;
}
}
void nvml_compute_capability(nvml_handle_t h, nvml_compute_capability_t *resp) {
resp->err = NULL;
resp->major = 0;
resp->minor = 0;
nvmlDevice_t device;
int major = 0;
int minor = 0;
nvmlReturn_t ret;
const int buflen = 256;
char buf[buflen + 1];
int i;
if (h.handle == NULL) {
resp->err = strdup("nvml handle not initialized");
return;
}
unsigned int devices;
ret = (*h.nvmlDeviceGetCount_v2)(&devices);
if (ret != NVML_SUCCESS) {
snprintf(buf, buflen, "unable to get device count: %d", ret);
resp->err = strdup(buf);
return;
}
for (i = 0; i < devices; i++) {
ret = (*h.nvmlDeviceGetHandleByIndex)(i, &device);
if (ret != NVML_SUCCESS) {
snprintf(buf, buflen, "unable to get device handle %d: %d", i, ret);
resp->err = strdup(buf);
return;
}
ret = (*h.nvmlDeviceGetCudaComputeCapability)(device, &major, &minor);
if (ret != NVML_SUCCESS) {
snprintf(buf, buflen, "device compute capability lookup failure %d: %d", i, ret);
resp->err = strdup(buf);
return;
}
// Report the lowest major.minor we detect as that limits our compatibility
if (resp->major == 0 || resp->major > major ) {
resp->major = major;
resp->minor = minor;
} else if ( resp->major == major && resp->minor > minor ) {
resp->minor = minor;
}
}
}
void nvml_release(nvml_handle_t h) {
LOG(h.verbose, "releasing nvml library\n");
UNLOAD_LIBRARY(h.handle);
h.handle = NULL;
}
#endif // __APPLE__

View file

@ -1,57 +0,0 @@
#ifndef __APPLE__
#ifndef __GPU_INFO_NVML_H__
#define __GPU_INFO_NVML_H__
#include "gpu_info.h"
// Just enough typedef's to dlopen/dlsym for memory information
typedef enum nvmlReturn_enum {
NVML_SUCCESS = 0,
// Other values omitted for now...
} nvmlReturn_t;
typedef void *nvmlDevice_t; // Opaque is sufficient
typedef struct nvmlMemory_st {
unsigned long long total;
unsigned long long free;
unsigned long long used;
} nvmlMemory_t;
typedef enum nvmlBrandType_enum
{
NVML_BRAND_UNKNOWN = 0,
} nvmlBrandType_t;
typedef struct nvml_handle {
void *handle;
uint16_t verbose;
nvmlReturn_t (*nvmlInit_v2)(void);
nvmlReturn_t (*nvmlShutdown)(void);
nvmlReturn_t (*nvmlDeviceGetHandleByIndex)(unsigned int, nvmlDevice_t *);
nvmlReturn_t (*nvmlDeviceGetMemoryInfo)(nvmlDevice_t, nvmlMemory_t *);
nvmlReturn_t (*nvmlDeviceGetCount_v2)(unsigned int *);
nvmlReturn_t (*nvmlDeviceGetCudaComputeCapability)(nvmlDevice_t, int* major, int* minor);
nvmlReturn_t (*nvmlSystemGetDriverVersion) (char* version, unsigned int length);
nvmlReturn_t (*nvmlDeviceGetName) (nvmlDevice_t device, char* name, unsigned int length);
nvmlReturn_t (*nvmlDeviceGetSerial) (nvmlDevice_t device, char* serial, unsigned int length);
nvmlReturn_t (*nvmlDeviceGetVbiosVersion) (nvmlDevice_t device, char* version, unsigned int length);
nvmlReturn_t (*nvmlDeviceGetBoardPartNumber) (nvmlDevice_t device, char* partNumber, unsigned int length);
nvmlReturn_t (*nvmlDeviceGetBrand) (nvmlDevice_t device, nvmlBrandType_t* type);
} nvml_handle_t;
typedef struct nvml_init_resp {
char *err; // If err is non-null handle is invalid
nvml_handle_t ch;
} nvml_init_resp_t;
typedef struct nvml_compute_capability {
char *err;
int major;
int minor;
} nvml_compute_capability_t;
void nvml_init(char *nvml_lib_path, nvml_init_resp_t *resp);
void nvml_check_vram(nvml_handle_t ch, mem_info_t *resp);
void nvml_compute_capability(nvml_handle_t ch, nvml_compute_capability_t *cc);
void nvml_release(nvml_handle_t ch);
#endif // __GPU_INFO_NVML_H__
#endif // __APPLE__

View file

@ -9,23 +9,16 @@ import (
func TestBasicGetGPUInfo(t *testing.T) {
info := GetGPUInfo()
assert.Contains(t, "cuda rocm cpu metal", info.Library)
switch runtime.GOOS {
case "darwin":
// TODO - remove this once MacOS returns some size for CPU
return
case "linux", "windows":
assert.Greater(t, info.TotalMemory, uint64(0))
assert.Greater(t, info.FreeMemory, uint64(0))
assert.Greater(t, info.DeviceCount, uint32(0))
default:
return
assert.Greater(t, len(info), 0)
assert.Contains(t, "cuda rocm cpu metal", info[0].Library)
if info[0].Library != "cpu" {
assert.Greater(t, info[0].TotalMemory, uint64(0))
assert.Greater(t, info[0].FreeMemory, uint64(0))
}
}
func TestCPUMemInfo(t *testing.T) {
info, err := getCPUMem()
info, err := GetCPUMem()
assert.NoError(t, err)
switch runtime.GOOS {
case "darwin":

View file

@ -3,7 +3,6 @@ package gpu
type memInfo struct {
TotalMemory uint64 `json:"total_memory,omitempty"`
FreeMemory uint64 `json:"free_memory,omitempty"`
DeviceCount uint32 `json:"device_count,omitempty"`
}
// Beginning of an `ollama info` command
@ -17,11 +16,49 @@ type GpuInfo struct {
// MinimumMemory represents the minimum memory required to use the GPU
MinimumMemory uint64 `json:"-"`
// TODO add other useful attributes about the card here for discovery information
// Any extra PATH/LD_LIBRARY_PATH dependencies required for the Library to operate properly
DependencyPath string `json:"lib_path,omitempty"`
// GPU information
ID string `json:"gpu_id"` // string to use for selection of this specific GPU
Name string `json:"name"` // user friendly name if available
Major int `json:"major,omitempty"` // Major compatibility version (CC or gfx)
Minor int `json:"minor,omitempty"` // Minor compatibility version (CC or gfx)
Patch int `json:"patch,omitempty"` // Patch compatibility only matters on AMD
// TODO other performance capability info to help in scheduling decisions
}
type Version struct {
Major uint
Minor uint
Patch uint
type GpuInfoList []GpuInfo
// Split up the set of gpu info's by Library and variant
func (l GpuInfoList) ByLibrary() []GpuInfoList {
resp := []GpuInfoList{}
libs := []string{}
for _, info := range l {
found := false
requested := info.Library
if info.Variant != "" {
requested += "_" + info.Variant
}
for i, lib := range libs {
if lib == requested {
resp[i] = append(resp[i], info)
found = true
break
}
}
if !found {
libs = append(libs, info.Library)
resp = append(resp, []GpuInfo{info})
}
}
return resp
}
// Sort by Free Space
type ByFreeMemory []GpuInfo
func (a ByFreeMemory) Len() int { return len(a) }
func (a ByFreeMemory) Swap(i, j int) { a[i], a[j] = a[j], a[i] }
func (a ByFreeMemory) Less(i, j int) bool { return a[i].FreeMemory < a[j].FreeMemory }

View file

@ -4,7 +4,6 @@ package integration
import (
"context"
"net/http"
"testing"
"time"
@ -24,5 +23,5 @@ func TestOrcaMiniBlueSky(t *testing.T) {
"seed": 123,
},
}
GenerateTestHelper(ctx, t, &http.Client{}, req, []string{"rayleigh", "scattering"})
GenerateTestHelper(ctx, t, req, []string{"rayleigh", "scattering"})
}

View file

@ -0,0 +1,225 @@
//go:build integration
package integration
import (
"context"
"log/slog"
"os"
"strconv"
"sync"
"testing"
"time"
"github.com/ollama/ollama/api"
"github.com/stretchr/testify/require"
)
func TestMultiModelConcurrency(t *testing.T) {
var (
req = [2]api.GenerateRequest{
{
Model: "orca-mini",
Prompt: "why is the ocean blue?",
Stream: &stream,
Options: map[string]interface{}{
"seed": 42,
"temperature": 0.0,
},
}, {
Model: "tinydolphin",
Prompt: "what is the origin of the us thanksgiving holiday?",
Stream: &stream,
Options: map[string]interface{}{
"seed": 42,
"temperature": 0.0,
},
},
}
resp = [2][]string{
[]string{"sunlight"},
[]string{"england", "english", "massachusetts", "pilgrims"},
}
)
var wg sync.WaitGroup
wg.Add(len(req))
ctx, cancel := context.WithTimeout(context.Background(), time.Second*120)
defer cancel()
for i := 0; i < len(req); i++ {
go func(i int) {
defer wg.Done()
GenerateTestHelper(ctx, t, req[i], resp[i])
}(i)
}
wg.Wait()
}
func TestIntegrationConcurrentPredictOrcaMini(t *testing.T) {
ctx, cancel := context.WithTimeout(context.Background(), 10*time.Minute) // GTX 750 2G card takes ~9 minutes
defer cancel()
client, _, cleanup := InitServerConnection(ctx, t)
defer cleanup()
req, resp := GenerateRequests()
// Get the server running (if applicable) warm the model up with a single initial request
DoGenerate(ctx, t, client, req[0], resp[0], 60*time.Second, 5*time.Second)
var wg sync.WaitGroup
wg.Add(len(req))
for i := 0; i < len(req); i++ {
go func(i int) {
defer wg.Done()
for j := 0; j < 5; j++ {
slog.Info("Starting", "req", i, "iter", j)
// On slower GPUs it can take a while to process the 4 concurrent requests
// so we allow a much longer initial timeout
DoGenerate(ctx, t, client, req[i], resp[i], 90*time.Second, 5*time.Second)
}
}(i)
}
wg.Wait()
}
// Stress the system if we know how much VRAM it has, and attempt to load more models than will fit
func TestMultiModelStress(t *testing.T) {
vram := os.Getenv("OLLAMA_MAX_VRAM")
if vram == "" {
t.Skip("OLLAMA_MAX_VRAM not specified, can't pick the right models for the stress test")
}
max, err := strconv.ParseUint(vram, 10, 64)
require.NoError(t, err)
const MB = uint64(1024 * 1024)
type model struct {
name string
size uint64 // Approximate amount of VRAM they typically use when fully loaded in VRAM
}
smallModels := []model{
{
name: "orca-mini",
size: 2992 * MB,
},
{
name: "phi",
size: 2616 * MB,
},
{
name: "gemma:2b",
size: 2364 * MB,
},
{
name: "stable-code:3b",
size: 2608 * MB,
},
{
name: "starcoder2:3b",
size: 2166 * MB,
},
}
mediumModels := []model{
{
name: "llama2",
size: 5118 * MB,
},
{
name: "mistral",
size: 4620 * MB,
},
{
name: "orca-mini:7b",
size: 5118 * MB,
},
{
name: "dolphin-mistral",
size: 4620 * MB,
},
{
name: "gemma:7b",
size: 5000 * MB,
},
// TODO - uncomment this once #3565 is merged and this is rebased on it
// {
// name: "codellama:7b",
// size: 5118 * MB,
// },
}
// These seem to be too slow to be useful...
// largeModels := []model{
// {
// name: "llama2:13b",
// size: 7400 * MB,
// },
// {
// name: "codellama:13b",
// size: 7400 * MB,
// },
// {
// name: "orca-mini:13b",
// size: 7400 * MB,
// },
// {
// name: "gemma:7b",
// size: 5000 * MB,
// },
// {
// name: "starcoder2:15b",
// size: 9100 * MB,
// },
// }
var chosenModels []model
switch {
case max < 10000*MB:
slog.Info("selecting small models")
chosenModels = smallModels
// case max < 30000*MB:
default:
slog.Info("selecting medium models")
chosenModels = mediumModels
// default:
// slog.Info("selecting large models")
// chosenModels = largModels
}
req, resp := GenerateRequests()
for i := range req {
if i > len(chosenModels) {
break
}
req[i].Model = chosenModels[i].name
}
ctx, cancel := context.WithTimeout(context.Background(), 15*time.Minute) // TODO baseline -- 10m too short
defer cancel()
client, _, cleanup := InitServerConnection(ctx, t)
defer cleanup()
// Make sure all the models are pulled before we get started
for _, r := range req {
require.NoError(t, PullIfMissing(ctx, client, r.Model))
}
var wg sync.WaitGroup
consumed := uint64(256 * MB) // Assume some baseline usage
for i := 0; i < len(req); i++ {
// Always get at least 2 models, but dont' overshoot VRAM too much or we'll take too long
if i > 1 && consumed > max {
slog.Info("achieved target vram exhaustion", "count", i, "vramMB", max/1024/1024, "modelsMB", consumed/1024/1024)
break
}
consumed += chosenModels[i].size
slog.Info("target vram", "count", i, "vramMB", max/1024/1024, "modelsMB", consumed/1024/1024)
wg.Add(1)
go func(i int) {
defer wg.Done()
for j := 0; j < 3; j++ {
slog.Info("Starting", "req", i, "iter", j, "model", req[i].Model)
DoGenerate(ctx, t, client, req[i], resp[i], 90*time.Second, 5*time.Second)
}
}(i)
}
wg.Wait()
}

View file

@ -4,7 +4,6 @@ package integration
import (
"context"
"net/http"
"testing"
"time"
@ -25,5 +24,5 @@ func TestContextExhaustion(t *testing.T) {
"num_ctx": 128,
},
}
GenerateTestHelper(ctx, t, &http.Client{}, req, []string{"once", "upon", "lived"})
GenerateTestHelper(ctx, t, req, []string{"once", "upon", "lived"})
}

View file

@ -5,7 +5,6 @@ package integration
import (
"context"
"encoding/base64"
"net/http"
"testing"
"time"
@ -29,10 +28,11 @@ func TestIntegrationMultimodal(t *testing.T) {
},
}
resp := "the ollamas"
// Note: sometimes it returns "the ollamas" sometimes "the ollams"
resp := "the ollam"
ctx, cancel := context.WithTimeout(context.Background(), 3*time.Minute)
defer cancel()
GenerateTestHelper(ctx, t, &http.Client{}, req, []string{resp})
GenerateTestHelper(ctx, t, req, []string{resp})
}
const imageEncoding = `iVBORw0KGgoAAAANSUhEUgAAANIAAAB4CAYAAACHHqzKAAAAAXNSR0IArs4c6QAAAIRlWElmTU0AKgAAAAgABQESAAMAAAABAAEAAAEaAAUAAAABAAAASgEb

View file

@ -4,8 +4,6 @@ package integration
import (
"context"
"net/http"
"sync"
"testing"
"time"
@ -45,25 +43,5 @@ var (
func TestIntegrationSimpleOrcaMini(t *testing.T) {
ctx, cancel := context.WithTimeout(context.Background(), time.Second*120)
defer cancel()
GenerateTestHelper(ctx, t, &http.Client{}, req[0], resp[0])
GenerateTestHelper(ctx, t, req[0], resp[0])
}
// TODO
// The server always loads a new runner and closes the old one, which forces serial execution
// At present this test case fails with concurrency problems. Eventually we should try to
// get true concurrency working with n_parallel support in the backend
func TestIntegrationConcurrentPredictOrcaMini(t *testing.T) {
var wg sync.WaitGroup
wg.Add(len(req))
ctx, cancel := context.WithTimeout(context.Background(), time.Second*120)
defer cancel()
for i := 0; i < len(req); i++ {
go func(i int) {
defer wg.Done()
GenerateTestHelper(ctx, t, &http.Client{}, req[i], resp[i])
}(i)
}
wg.Wait()
}
// TODO - create a parallel test with 2 different models once we support concurrency

View file

@ -5,13 +5,14 @@ package integration
import (
"bytes"
"context"
"encoding/json"
"errors"
"fmt"
"io"
"log/slog"
"math/rand"
"net"
"net/http"
"net/url"
"os"
"path/filepath"
"runtime"
@ -23,9 +24,13 @@ import (
"github.com/ollama/ollama/api"
"github.com/ollama/ollama/app/lifecycle"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)
func Init() {
lifecycle.InitLogging()
}
func FindPort() string {
port := 0
if a, err := net.ResolveTCPAddr("tcp", "localhost:0"); err == nil {
@ -41,7 +46,7 @@ func FindPort() string {
return strconv.Itoa(port)
}
func GetTestEndpoint() (string, string) {
func GetTestEndpoint() (*api.Client, string) {
defaultPort := "11434"
ollamaHost := os.Getenv("OLLAMA_HOST")
@ -67,16 +72,20 @@ func GetTestEndpoint() (string, string) {
port = FindPort()
}
url := fmt.Sprintf("%s:%s", host, port)
slog.Info("server connection", "url", url)
return scheme, url
slog.Info("server connection", "host", host, "port", port)
return api.NewClient(
&url.URL{
Scheme: scheme,
Host: net.JoinHostPort(host, port),
},
http.DefaultClient), fmt.Sprintf("%s:%s", host, port)
}
// TODO make fanicier, grab logs, etc.
var serverMutex sync.Mutex
var serverReady bool
func StartServer(ctx context.Context, ollamaHost string) error {
func startServer(ctx context.Context, ollamaHost string) error {
// Make sure the server has been built
CLIName, err := filepath.Abs("../ollama")
if err != nil {
@ -125,67 +134,76 @@ func StartServer(ctx context.Context, ollamaHost string) error {
return nil
}
func PullIfMissing(ctx context.Context, client *http.Client, scheme, testEndpoint, modelName string) error {
func PullIfMissing(ctx context.Context, client *api.Client, modelName string) error {
slog.Info("checking status of model", "model", modelName)
showReq := &api.ShowRequest{Name: modelName}
requestJSON, err := json.Marshal(showReq)
if err != nil {
return err
}
req, err := http.NewRequest("POST", scheme+"://"+testEndpoint+"/api/show", bytes.NewReader(requestJSON))
if err != nil {
showCtx, cancel := context.WithDeadlineCause(
ctx,
time.Now().Add(5*time.Second),
fmt.Errorf("show for existing model %s took too long", modelName),
)
defer cancel()
_, err := client.Show(showCtx, showReq)
var statusError api.StatusError
switch {
case errors.As(err, &statusError) && statusError.StatusCode == http.StatusNotFound:
break
case err != nil:
return err
}
// Make the request with the HTTP client
response, err := client.Do(req.WithContext(ctx))
if err != nil {
return err
}
defer response.Body.Close()
if response.StatusCode == 200 {
default:
slog.Info("model already present", "model", modelName)
return nil
}
slog.Info("model missing", "status", response.StatusCode)
slog.Info("model missing", "model", modelName)
stallDuration := 30 * time.Second // This includes checksum verification, which can take a while on larger models
stallTimer := time.NewTimer(stallDuration)
fn := func(resp api.ProgressResponse) error {
// fmt.Print(".")
if !stallTimer.Reset(stallDuration) {
return fmt.Errorf("stall was detected, aborting status reporting")
}
return nil
}
stream := true
pullReq := &api.PullRequest{Name: modelName, Stream: &stream}
requestJSON, err = json.Marshal(pullReq)
if err != nil {
return err
}
req, err = http.NewRequest("POST", scheme+"://"+testEndpoint+"/api/pull", bytes.NewReader(requestJSON))
if err != nil {
return err
}
slog.Info("pulling", "model", modelName)
var pullError error
response, err = client.Do(req.WithContext(ctx))
if err != nil {
return err
done := make(chan int)
go func() {
pullError = client.Pull(ctx, pullReq, fn)
done <- 0
}()
select {
case <-stallTimer.C:
return fmt.Errorf("download stalled")
case <-done:
return pullError
}
defer response.Body.Close()
if response.StatusCode != 200 {
return fmt.Errorf("failed to pull model") // TODO more details perhaps
}
slog.Info("model pulled", "model", modelName)
return nil
}
var serverProcMutex sync.Mutex
func GenerateTestHelper(ctx context.Context, t *testing.T, client *http.Client, genReq api.GenerateRequest, anyResp []string) {
// TODO maybe stuff in an init routine?
lifecycle.InitLogging()
requestJSON, err := json.Marshal(genReq)
if err != nil {
t.Fatalf("Error serializing request: %v", err)
// Returns an Client, the testEndpoint, and a cleanup function, fails the test on errors
// Starts the server if needed
func InitServerConnection(ctx context.Context, t *testing.T) (*api.Client, string, func()) {
client, testEndpoint := GetTestEndpoint()
if os.Getenv("OLLAMA_TEST_EXISTING") == "" {
serverProcMutex.Lock()
fp, err := os.CreateTemp("", "ollama-server-*.log")
if err != nil {
t.Fatalf("failed to generate log file: %s", err)
}
lifecycle.ServerLogFile = fp.Name()
fp.Close()
require.NoError(t, startServer(ctx, testEndpoint))
}
defer func() {
return client, testEndpoint, func() {
if os.Getenv("OLLAMA_TEST_EXISTING") == "" {
defer serverProcMutex.Unlock()
if t.Failed() {
@ -203,63 +221,118 @@ func GenerateTestHelper(ctx context.Context, t *testing.T, client *http.Client,
os.Stderr.Write(data)
slog.Warn("END OF SERVER")
}
err = os.Remove(lifecycle.ServerLogFile)
err := os.Remove(lifecycle.ServerLogFile)
if err != nil && !os.IsNotExist(err) {
slog.Warn("failed to cleanup", "logfile", lifecycle.ServerLogFile, "error", err)
}
}
}()
scheme, testEndpoint := GetTestEndpoint()
if os.Getenv("OLLAMA_TEST_EXISTING") == "" {
serverProcMutex.Lock()
fp, err := os.CreateTemp("", "ollama-server-*.log")
if err != nil {
t.Fatalf("failed to generate log file: %s", err)
}
lifecycle.ServerLogFile = fp.Name()
fp.Close()
assert.NoError(t, StartServer(ctx, testEndpoint))
}
err = PullIfMissing(ctx, client, scheme, testEndpoint, genReq.Model)
if err != nil {
t.Fatalf("Error pulling model: %v", err)
}
// Make the request and get the response
req, err := http.NewRequest("POST", scheme+"://"+testEndpoint+"/api/generate", bytes.NewReader(requestJSON))
if err != nil {
t.Fatalf("Error creating request: %v", err)
}
// Set the content type for the request
req.Header.Set("Content-Type", "application/json")
// Make the request with the HTTP client
response, err := client.Do(req.WithContext(ctx))
if err != nil {
t.Fatalf("Error making request: %v", err)
}
defer response.Body.Close()
body, err := io.ReadAll(response.Body)
assert.NoError(t, err)
assert.Equal(t, response.StatusCode, 200, string(body))
// Verify the response is valid JSON
var payload api.GenerateResponse
err = json.Unmarshal(body, &payload)
if err != nil {
assert.NoError(t, err, body)
}
// Verify the response contains the expected data
atLeastOne := false
for _, resp := range anyResp {
if strings.Contains(strings.ToLower(payload.Response), resp) {
atLeastOne = true
break
}
}
assert.True(t, atLeastOne, "none of %v found in %s", anyResp, payload.Response)
}
func GenerateTestHelper(ctx context.Context, t *testing.T, genReq api.GenerateRequest, anyResp []string) {
client, _, cleanup := InitServerConnection(ctx, t)
defer cleanup()
require.NoError(t, PullIfMissing(ctx, client, genReq.Model))
DoGenerate(ctx, t, client, genReq, anyResp, 30*time.Second, 10*time.Second)
}
func DoGenerate(ctx context.Context, t *testing.T, client *api.Client, genReq api.GenerateRequest, anyResp []string, initialTimeout, streamTimeout time.Duration) {
stallTimer := time.NewTimer(initialTimeout)
var buf bytes.Buffer
fn := func(response api.GenerateResponse) error {
// fmt.Print(".")
buf.Write([]byte(response.Response))
if !stallTimer.Reset(streamTimeout) {
return fmt.Errorf("stall was detected while streaming response, aborting")
}
return nil
}
stream := true
genReq.Stream = &stream
done := make(chan int)
var genErr error
go func() {
genErr = client.Generate(ctx, &genReq, fn)
done <- 0
}()
select {
case <-stallTimer.C:
if buf.Len() == 0 {
t.Errorf("generate never started. Timed out after :%s", initialTimeout.String())
} else {
t.Errorf("generate stalled. Response so far:%s", buf.String())
}
case <-done:
require.NoError(t, genErr, "failed with %s request prompt %s ", genReq.Model, genReq.Prompt)
// Verify the response contains the expected data
response := buf.String()
atLeastOne := false
for _, resp := range anyResp {
if strings.Contains(strings.ToLower(response), resp) {
atLeastOne = true
break
}
}
require.True(t, atLeastOne, "none of %v found in %s", anyResp, response)
slog.Info("test pass", "model", genReq.Model, "prompt", genReq.Prompt, "contains", anyResp, "response", response)
case <-ctx.Done():
t.Error("outer test context done while waiting for generate")
}
}
// Generate a set of requests
// By default each request uses orca-mini as the model
func GenerateRequests() ([]api.GenerateRequest, [][]string) {
return []api.GenerateRequest{
{
Model: "orca-mini",
Prompt: "why is the ocean blue?",
Stream: &stream,
Options: map[string]interface{}{
"seed": 42,
"temperature": 0.0,
},
}, {
Model: "orca-mini",
Prompt: "why is the color of dirt brown?",
Stream: &stream,
Options: map[string]interface{}{
"seed": 42,
"temperature": 0.0,
},
}, {
Model: "orca-mini",
Prompt: "what is the origin of the us thanksgiving holiday?",
Stream: &stream,
Options: map[string]interface{}{
"seed": 42,
"temperature": 0.0,
},
}, {
Model: "orca-mini",
Prompt: "what is the origin of independence day?",
Stream: &stream,
Options: map[string]interface{}{
"seed": 42,
"temperature": 0.0,
},
}, {
Model: "orca-mini",
Prompt: "what is the composition of air?",
Stream: &stream,
Options: map[string]interface{}{
"seed": 42,
"temperature": 0.0,
},
},
},
[][]string{
[]string{"sunlight"},
[]string{"soil", "organic", "earth", "black", "tan"},
[]string{"england", "english", "massachusetts", "pilgrims"},
[]string{"fourth", "july", "declaration", "independence"},
[]string{"nitrogen", "oxygen", "carbon", "dioxide"},
}
}

162
llm/memory.go Normal file
View file

@ -0,0 +1,162 @@
package llm
import (
"fmt"
"log/slog"
"strings"
"github.com/ollama/ollama/api"
"github.com/ollama/ollama/format"
"github.com/ollama/ollama/gpu"
)
// This algorithm looks for a complete fit to determine if we need to unload other models
func PredictServerFit(allGpus gpu.GpuInfoList, ggml *GGML, adapters, projectors []string, opts api.Options) (bool, uint64) {
var estimatedVRAM uint64
if opts.NumCtx > int(ggml.KV().ContextLength()) {
slog.Warn("requested context length is greater than model max context length", "requested", opts.NumCtx, "model", ggml.KV().ContextLength())
opts.NumCtx = int(ggml.KV().ContextLength())
}
if opts.NumCtx < 4 {
opts.NumCtx = 4
}
// Split up the GPUs by type and try them
for _, gpus := range allGpus.ByLibrary() {
var layerCount int
layerCount, estimatedVRAM = EstimateGPULayers(gpus, ggml, projectors, opts)
if opts.NumGPU < 0 {
if layerCount > 0 && layerCount >= int(ggml.KV().BlockCount()+1) {
return true, estimatedVRAM
}
} else {
if layerCount > 0 && layerCount >= opts.NumGPU {
return true, estimatedVRAM
}
}
}
return false, estimatedVRAM
}
// Given a model and one or more GPU targets, predict how many layers and bytes we can load
// The GPUs provided must all be the same Library
func EstimateGPULayers(gpus []gpu.GpuInfo, ggml *GGML, projectors []string, opts api.Options) (int, uint64) {
if gpus[0].Library == "cpu" {
return 0, 0
}
var memoryAvailable uint64
for _, info := range gpus {
memoryAvailable += info.FreeMemory
}
slog.Debug("evaluating", "library", gpus[0].Library, "gpu_count", len(gpus), "available", format.HumanBytes2(memoryAvailable))
// TODO - this is probably wrong, first GPU vs secondaries will have different overheads
memoryMinimum := gpus[0].MinimumMemory
for _, projector := range projectors {
memoryMinimum += projectorMemoryRequirements(projector)
// multimodal models require at least 2048 context
opts.NumCtx = max(opts.NumCtx, 2048)
}
// fp16 k,v = (1 (k) + 1 (v)) * sizeof(float16) * n_ctx * n_layer * n_embd / n_head * n_head_kv
var kv uint64 = 2 * 2 * uint64(opts.NumCtx) * ggml.KV().BlockCount() * ggml.KV().EmbeddingLength() / ggml.KV().HeadCount() * ggml.KV().HeadCountKV()
graphPartialOffload, graphFullOffload := ggml.GraphSize(uint64(opts.NumCtx), uint64(min(opts.NumCtx, opts.NumBatch)))
if graphPartialOffload == 0 {
graphPartialOffload = ggml.KV().GQA() * kv / 6
}
if graphFullOffload == 0 {
graphFullOffload = graphPartialOffload
}
graphFullOffload *= uint64(len(gpus))
graphPartialOffload *= uint64(len(gpus))
// memoryRequiredTotal represents the memory required for full GPU offloading (all layers)
memoryRequiredTotal := memoryMinimum + graphFullOffload
// memoryRequiredPartial represents the memory required for partial GPU offloading (n > 0, n < layers)
memoryRequiredPartial := memoryMinimum + graphPartialOffload
if memoryRequiredPartial > memoryAvailable {
slog.Debug("insufficient VRAM to load any model layers")
return 0, 0
}
var layerCount int
layers := ggml.Tensors().Layers()
for i := 0; i < int(ggml.KV().BlockCount()); i++ {
memoryLayer := layers[fmt.Sprintf("blk.%d", i)].size()
// KV is proportional to the number of layers
memoryLayer += kv / ggml.KV().BlockCount()
memoryRequiredTotal += memoryLayer
if memoryAvailable > memoryRequiredPartial+memoryLayer {
memoryRequiredPartial += memoryLayer
layerCount++
}
}
var memoryLayerOutput uint64
for k, v := range layers {
if !strings.HasPrefix(k, "blk.") {
memoryLayerOutput += v.size()
}
}
memoryRequiredTotal += memoryLayerOutput
if memoryAvailable > memoryRequiredTotal {
layerCount = int(ggml.KV().BlockCount()) + 1
memoryRequiredPartial = memoryRequiredTotal
}
memoryWeights := memoryRequiredTotal - memoryMinimum - graphFullOffload - kv
slog.Info(
"offload to gpu",
slog.Group(
"layers",
// actual number of layers offloaded
"real", opts.NumGPU,
// estimated number of layers that can be offloaded
"estimate", layerCount,
),
slog.Group(
"memory",
// memory available for offloading
"available", format.HumanBytes2(memoryAvailable),
slog.Group(
"required",
// memory required for full offloading
"full", format.HumanBytes2(memoryRequiredTotal),
// memory required to offload layers.estimate layers
"partial", format.HumanBytes2(memoryRequiredPartial),
// memory of KV cache
"kv", format.HumanBytes2(kv),
),
slog.Group(
"weights",
// memory of the weights
"total", format.HumanBytes2(memoryWeights),
// memory of repeating layers
"repeating", format.HumanBytes2(memoryWeights-memoryLayerOutput),
// memory of non-repeating layers
"nonrepeating", format.HumanBytes2(memoryLayerOutput),
),
slog.Group(
"graph",
// memory of graph when fully offloaded
"full", format.HumanBytes2(graphFullOffload),
// memory of graph when not fully offloaded
"partial", format.HumanBytes2(graphPartialOffload),
),
),
)
return layerCount, uint64(memoryRequiredPartial)
}

View file

@ -9,6 +9,7 @@ import (
"log/slog"
"os"
"path/filepath"
"runtime"
"strings"
"golang.org/x/exp/slices"
@ -138,6 +139,23 @@ func serversForGpu(info gpu.GpuInfo) []string {
return servers
}
// Return the optimal server for this CPU architecture
func serverForCpu() string {
if runtime.GOOS == "darwin" && runtime.GOARCH == "arm64" {
return "metal"
}
variant := gpu.GetCPUVariant()
availableServers := availableServers()
if variant != "" {
for cmp := range availableServers {
if cmp == "cpu_"+variant {
return cmp
}
}
}
return "cpu"
}
// extract extracts the embedded files to the target directory
func extractFiles(targetDir string, glob string) error {
files, err := fs.Glob(libEmbed, glob)

View file

@ -21,21 +21,43 @@ import (
"strings"
"time"
"golang.org/x/sync/semaphore"
"github.com/ollama/ollama/api"
"github.com/ollama/ollama/format"
"github.com/ollama/ollama/gpu"
)
// LlamaServer is an instance of the llama.cpp server
type LlamaServer struct {
type LlamaServer interface {
Ping(ctx context.Context) error
WaitUntilRunning(ctx context.Context) error
Completion(ctx context.Context, req CompletionRequest, fn func(CompletionResponse)) error
Embedding(ctx context.Context, prompt string) ([]float64, error)
Tokenize(ctx context.Context, content string) ([]int, error)
Detokenize(ctx context.Context, tokens []int) (string, error)
Close() error
EstimatedVRAM() uint64
}
// llmServer is an instance of the llama.cpp server
type llmServer struct {
port int
cmd *exec.Cmd
done chan error // Channel to signal when the process exits
status *StatusWriter
options api.Options
// TODO - this should be broken down by GPU
estimatedVRAM uint64 // Estimated usage of VRAM by the loaded model
sem *semaphore.Weighted
}
func NewLlamaServer(model string, adapters, projectors []string, opts api.Options) (*LlamaServer, error) {
func LoadModel(model string) (*GGML, error) {
if _, err := os.Stat(model); err != nil {
return nil, err
}
f, err := os.Open(model)
if err != nil {
return nil, err
@ -43,10 +65,13 @@ func NewLlamaServer(model string, adapters, projectors []string, opts api.Option
defer f.Close()
ggml, _, err := DecodeGGML(f)
if err != nil {
return nil, err
}
return ggml, err
}
// NewLlamaServer will run a server for the given GPUs
// The gpu list must be a single family.
func NewLlamaServer(gpus gpu.GpuInfoList, model string, ggml *GGML, adapters, projectors []string, opts api.Options) (LlamaServer, error) {
var err error
if opts.NumCtx > int(ggml.KV().ContextLength()) {
slog.Warn("requested context length is greater than model max context length", "requested", opts.NumCtx, "model", ggml.KV().ContextLength())
opts.NumCtx = int(ggml.KV().ContextLength())
@ -56,130 +81,50 @@ func NewLlamaServer(model string, adapters, projectors []string, opts api.Option
opts.NumCtx = 4
}
memoryAvailable, _ := gpu.CheckVRAM()
info := gpu.GetGPUInfo()
cpuRunner := ""
var estimatedVRAM uint64
var systemMemory uint64
if (len(gpus) == 1 && gpus[0].Library == "cpu") || opts.NumGPU == 0 {
memoryMinimum := info.MinimumMemory
for _, projector := range projectors {
memoryMinimum += projectorMemoryRequirements(projector)
// TODO evaluate system memory to see if we should block the load, or force an unload of another CPU runner
// multimodal models require at least 2048 context
opts.NumCtx = max(opts.NumCtx, 2048)
}
cpuRunner = serverForCpu()
} else {
if gpus[0].Library == "metal" {
memInfo, err := gpu.GetCPUMem()
if err != nil {
slog.Error("failed to lookup system memory", "error", err)
} else {
systemMemory = memInfo.TotalMemory
slog.Debug("system memory", "total", format.HumanBytes2(systemMemory))
}
}
var layers int
layers, estimatedVRAM = EstimateGPULayers(gpus, ggml, projectors, opts)
// fp16 k,v = (1 (k) + 1 (v)) * sizeof(float16) * n_ctx * n_layer * n_embd / n_head * n_head_kv
var kv uint64 = 2 * 2 * uint64(opts.NumCtx) * ggml.KV().BlockCount() * ggml.KV().EmbeddingLength() / ggml.KV().HeadCount() * ggml.KV().HeadCountKV()
graphPartialOffload, graphFullOffload := ggml.GraphSize(uint64(opts.NumCtx), uint64(min(opts.NumCtx, opts.NumBatch)))
if graphPartialOffload == 0 {
graphPartialOffload = ggml.KV().GQA() * kv / 6
}
if graphFullOffload == 0 {
graphFullOffload = graphPartialOffload
}
graphFullOffload *= uint64(info.DeviceCount)
graphPartialOffload *= uint64(info.DeviceCount)
// memoryRequiredTotal represents the memory required for full GPU offloading (all layers)
memoryRequiredTotal := memoryMinimum + graphFullOffload
// memoryRequiredPartial represents the memory required for partial GPU offloading (n > 0, n < layers)
memoryRequiredPartial := memoryMinimum + graphPartialOffload
if info.Library != "metal" {
if memoryRequiredPartial > memoryAvailable {
info.Library = "cpu"
if gpus[0].Library == "metal" && estimatedVRAM > systemMemory {
// disable partial offloading when model is greater than total system memory as this
// can lead to locking up the system
opts.NumGPU = 0
} else if opts.NumGPU < 0 && layers > 0 && gpus[0].Library != "cpu" {
opts.NumGPU = layers
}
}
var layerCount int
layers := ggml.Tensors().Layers()
for i := 0; i < int(ggml.KV().BlockCount()); i++ {
memoryLayer := layers[fmt.Sprintf("blk.%d", i)].size()
// KV is proportional to the number of layers
memoryLayer += kv / ggml.KV().BlockCount()
memoryRequiredTotal += memoryLayer
if memoryAvailable > memoryRequiredPartial+memoryLayer {
memoryRequiredPartial += memoryLayer
layerCount++
}
}
var memoryLayerOutput uint64
for k, v := range layers {
if !strings.HasPrefix(k, "blk.") {
memoryLayerOutput += v.size()
}
}
memoryRequiredTotal += memoryLayerOutput
if info.Library == "metal" && memoryRequiredTotal > info.TotalMemory {
// disable partial offloading when model is greater than total system memory
opts.NumGPU = 0
} else if memoryAvailable > memoryRequiredTotal {
layerCount = int(ggml.KV().BlockCount()) + 1
memoryRequiredPartial = memoryRequiredTotal
}
if opts.NumGPU < 0 {
opts.NumGPU = layerCount
}
memoryWeights := memoryRequiredTotal - memoryMinimum - graphFullOffload - kv
slog.Info(
"offload to gpu",
slog.Group(
"layers",
// actual number of layers offloaded
"real", opts.NumGPU,
// estimated number of layers that can be offloaded
"estimate", layerCount,
),
slog.Group(
"memory",
// memory available for offloading
"available", format.HumanBytes2(memoryAvailable),
slog.Group(
"required",
// memory required for full offloading
"full", format.HumanBytes2(memoryRequiredTotal),
// memory required to offload layers.estimate layers
"partial", format.HumanBytes2(memoryRequiredPartial),
// memory of KV cache
"kv", format.HumanBytes2(kv),
),
slog.Group(
"weights",
// memory of the weights
"total", format.HumanBytes2(memoryWeights),
// memory of repeating layers
"repeating", format.HumanBytes2(memoryWeights-memoryLayerOutput),
// memory of non-repeating layers
"nonrepeating", format.HumanBytes2(memoryLayerOutput),
),
slog.Group(
"graph",
// memory of graph when fully offloaded
"full", format.HumanBytes2(graphFullOffload),
// memory of graph when not fully offloaded
"partial", format.HumanBytes2(graphPartialOffload),
),
),
)
// Loop through potential servers
finalErr := fmt.Errorf("no suitable llama servers found")
if len(adapters) > 1 {
return nil, errors.New("ollama supports only one lora adapter, but multiple were provided")
}
availableServers := availableServers()
servers := serversForGpu(info)
var servers []string
if cpuRunner != "" {
servers = []string{cpuRunner}
} else {
servers = serversForGpu(gpus[0]) // All GPUs in the list are matching Library and Variant
}
demandLib := strings.Trim(os.Getenv("OLLAMA_LLM_LIBRARY"), "\"' ")
if demandLib != "" {
serverPath := availableServers[demandLib]
@ -192,7 +137,7 @@ func NewLlamaServer(model string, adapters, projectors []string, opts api.Option
}
if len(servers) == 0 {
return nil, fmt.Errorf("no servers found for %v", info)
return nil, fmt.Errorf("no servers found for %v", gpus)
}
params := []string{
@ -249,8 +194,18 @@ func NewLlamaServer(model string, adapters, projectors []string, opts api.Option
params = append(params, "--numa")
}
// Loop through potential servers
var finalErr error
// "--cont-batching", // TODO - doesn't seem to have any noticeable perf change for multiple requests
numParallel := 1
if onp := os.Getenv("OLLAMA_NUM_PARALLEL"); onp != "" {
numParallel, err = strconv.Atoi(onp)
if err != nil || numParallel <= 0 {
err = fmt.Errorf("invalid OLLAMA_NUM_PARALLEL=%s must be greater than zero - %w", onp, err)
slog.Error("misconfiguration", "error", err)
return nil, err
}
}
params = append(params, "--parallel", fmt.Sprintf("%d", numParallel))
for i := 0; i < len(servers); i++ {
dir := availableServers[servers[i]]
@ -275,30 +230,49 @@ func NewLlamaServer(model string, adapters, projectors []string, opts api.Option
}
// append the server directory to LD_LIBRARY_PATH/PATH
libraryPaths := []string{dir}
if libraryPath, ok := os.LookupEnv(pathEnv); ok {
// Append our runner directory to the path
// This will favor system libraries over our bundled library dependencies
libraryPaths = append(filepath.SplitList(libraryPath), libraryPaths...)
}
// Note: we always put the dependency path first
// since this was the exact version we verified for AMD GPUs
// and we favor what the user had in their path
if gpus[0].DependencyPath != "" {
// TODO refine for multi-gpu support
libraryPaths = append([]string{gpus[0].DependencyPath}, libraryPaths...)
}
server := filepath.Join(dir, "ollama_llama_server")
if runtime.GOOS == "windows" {
server = server + ".exe"
}
s := &LlamaServer{
port: port,
cmd: exec.Command(server, finalParams...),
status: NewStatusWriter(os.Stderr),
options: opts,
s := &llmServer{
port: port,
cmd: exec.Command(server, finalParams...),
status: NewStatusWriter(os.Stderr),
options: opts,
estimatedVRAM: estimatedVRAM,
sem: semaphore.NewWeighted(int64(numParallel)),
}
libEnv := fmt.Sprintf("%s=%s", pathEnv, strings.Join(libraryPaths, string(filepath.ListSeparator)))
slog.Debug(libEnv)
s.cmd.Env = append(os.Environ(), libEnv)
s.cmd.Stdout = os.Stdout
s.cmd.Stderr = s.status
// TODO - multiple GPU selection logic...
key, val := gpu.GpuInfoList(gpus).GetVisibleDevicesEnv()
if key != "" {
s.cmd.Env = append(s.cmd.Env, key+"="+val)
}
slog.Info("starting llama server", "cmd", s.cmd.String())
// Log at debug as the environment is inherited and might contain sensitive information
slog.Debug("subprocess", "environment", s.cmd.Env)
if err = s.cmd.Start(); err != nil {
msg := ""
@ -316,6 +290,13 @@ func NewLlamaServer(model string, adapters, projectors []string, opts api.Option
_ = s.cmd.Wait()
}()
// TODO - make sure this is all wired up correctly
// if err = s.WaitUntilRunning(); err != nil {
// slog.Error("error starting llama server", "server", servers[i], "error", err)
// s.Close()
// finalErr = err
// continue
// }
return s, nil
}
@ -353,6 +334,21 @@ const ( // iota is reset to 0
ServerStatusError
)
func (s ServerStatus) ToString() string {
switch s {
case ServerStatusReady:
return "llm server ready"
case ServerStatusNoSlotsAvaialble:
return "llm busy - no slots available"
case ServerStatusLoadingModel:
return "llm server loading model"
case ServerStatusNotResponding:
return "llm server not responding"
default:
return "llm server error"
}
}
type ServerStatusResp struct {
Status string `json:"status"`
SlotsIdle int `json:"slots_idle"`
@ -360,7 +356,7 @@ type ServerStatusResp struct {
Error string `json:"error"`
}
func (s *LlamaServer) getServerStatus(ctx context.Context) (ServerStatus, error) {
func (s *llmServer) getServerStatus(ctx context.Context) (ServerStatus, error) {
// Fail fast if its exited
if s.cmd.ProcessState != nil {
msg := ""
@ -407,7 +403,7 @@ func (s *LlamaServer) getServerStatus(ctx context.Context) (ServerStatus, error)
}
}
func (s *LlamaServer) Ping(ctx context.Context) error {
func (s *llmServer) Ping(ctx context.Context) error {
_, err := s.getServerStatus(ctx)
if err != nil {
slog.Debug("server unhealthy", "error", err)
@ -416,7 +412,7 @@ func (s *LlamaServer) Ping(ctx context.Context) error {
return nil
}
func (s *LlamaServer) WaitUntilRunning() error {
func (s *llmServer) WaitUntilRunning(ctx context.Context) error {
start := time.Now()
// TODO we need to wire up a better way to detect hangs during model load and startup of the server
expiresAt := time.Now().Add(10 * time.Minute) // be generous with timeout, large models can take a while to load
@ -427,6 +423,9 @@ func (s *LlamaServer) WaitUntilRunning() error {
var lastStatus ServerStatus = -1
for {
select {
case <-ctx.Done():
slog.Info("context expired before server started")
return fmt.Errorf("timed out waiting for llama runner to start")
case err := <-s.done:
msg := ""
if s.status != nil && s.status.LastErrMsg != "" {
@ -450,9 +449,9 @@ func (s *LlamaServer) WaitUntilRunning() error {
return fmt.Errorf("llama runner process no longer running: %d %s", s.cmd.ProcessState.ExitCode(), msg)
}
ctx, cancel := context.WithTimeout(context.Background(), 200*time.Millisecond)
c, cancel := context.WithTimeout(ctx, 200*time.Millisecond)
defer cancel()
status, err := s.getServerStatus(ctx)
status, err := s.getServerStatus(c)
if err != nil && lastStatus != status {
slog.Debug("server not yet available", "error", err)
lastStatus = status
@ -538,7 +537,12 @@ type CompletionResponse struct {
EvalDuration time.Duration
}
func (s *LlamaServer) Completion(ctx context.Context, req CompletionRequest, fn func(CompletionResponse)) error {
func (s *llmServer) Completion(ctx context.Context, req CompletionRequest, fn func(CompletionResponse)) error {
if err := s.sem.Acquire(ctx, 1); err != nil {
slog.Error("Failed to acquire semaphore", "error", err)
return err
}
defer s.sem.Release(1)
request := map[string]any{
"prompt": req.Prompt,
"stream": true,
@ -569,7 +573,7 @@ func (s *LlamaServer) Completion(ctx context.Context, req CompletionRequest, fn
if err != nil {
return err
} else if status != ServerStatusReady {
return fmt.Errorf("unexpected server status: %d", status)
return fmt.Errorf("unexpected server status: %s", status.ToString())
}
if req.Format == "json" {
@ -716,13 +720,18 @@ type EmbeddingResponse struct {
Embedding []float64 `json:"embedding"`
}
func (s *LlamaServer) Embedding(ctx context.Context, prompt string) ([]float64, error) {
func (s *llmServer) Embedding(ctx context.Context, prompt string) ([]float64, error) {
if err := s.sem.Acquire(ctx, 1); err != nil {
slog.Error("Failed to acquire semaphore", "error", err)
return nil, err
}
defer s.sem.Release(1)
// Make sure the server is ready
status, err := s.getServerStatus(ctx)
if err != nil {
return nil, err
} else if status != ServerStatusReady {
return nil, fmt.Errorf("unexpected server status: %d", status)
return nil, fmt.Errorf("unexpected server status: %s", status.ToString())
}
data, err := json.Marshal(TokenizeRequest{Content: prompt})
@ -768,13 +777,13 @@ type TokenizeResponse struct {
Tokens []int `json:"tokens"`
}
func (s *LlamaServer) Tokenize(ctx context.Context, content string) ([]int, error) {
func (s *llmServer) Tokenize(ctx context.Context, content string) ([]int, error) {
// Make sure the server is ready
status, err := s.getServerStatus(ctx)
if err != nil {
return nil, err
} else if status != ServerStatusReady {
return nil, fmt.Errorf("unexpected server status: %d", status)
} else if status != ServerStatusReady && status != ServerStatusNoSlotsAvaialble {
return nil, fmt.Errorf("unexpected server status: %s", status.ToString())
}
data, err := json.Marshal(TokenizeRequest{Content: content})
@ -820,13 +829,13 @@ type DetokenizeResponse struct {
Content string `json:"content"`
}
func (s *LlamaServer) Detokenize(ctx context.Context, tokens []int) (string, error) {
func (s *llmServer) Detokenize(ctx context.Context, tokens []int) (string, error) {
// Make sure the server is ready
status, err := s.getServerStatus(ctx)
if err != nil {
return "", err
} else if status != ServerStatusReady {
return "", fmt.Errorf("unexpected server status: %d", status)
} else if status != ServerStatusReady && status != ServerStatusNoSlotsAvaialble {
return "", fmt.Errorf("unexpected server status: %s", status.ToString())
}
data, err := json.Marshal(DetokenizeRequest{Tokens: tokens})
@ -864,7 +873,7 @@ func (s *LlamaServer) Detokenize(ctx context.Context, tokens []int) (string, err
return decoded.Content, nil
}
func (s *LlamaServer) Close() error {
func (s *llmServer) Close() error {
if s.cmd != nil {
slog.Debug("stopping llama server")
return s.cmd.Process.Kill()
@ -873,6 +882,10 @@ func (s *LlamaServer) Close() error {
return nil
}
func (s *llmServer) EstimatedVRAM() uint64 {
return s.estimatedVRAM
}
func parseDurationMs(ms float64) time.Duration {
dur, err := time.ParseDuration(fmt.Sprintf("%fms", ms))
if err != nil {

View file

@ -15,11 +15,8 @@ import (
"os"
"os/signal"
"path/filepath"
"reflect"
"runtime"
"strconv"
"strings"
"sync"
"syscall"
"time"
@ -38,7 +35,8 @@ import (
var mode string = gin.DebugMode
type Server struct {
addr net.Addr
addr net.Addr
sched *Scheduler
}
func init() {
@ -53,88 +51,8 @@ func init() {
gin.SetMode(mode)
}
var loaded struct {
mu sync.Mutex
llama *llm.LlamaServer
expireTimer *time.Timer
model string
adapters []string
projectors []string
*api.Options
}
var defaultSessionDuration = 5 * time.Minute
func unload() {
if loaded.llama != nil {
loaded.llama.Close()
}
loaded.llama = nil
loaded.model = ""
loaded.adapters = nil
loaded.projectors = nil
loaded.Options = nil
}
// load a model into memory if it is not already loaded, it is up to the caller to lock loaded.mu before calling this function
func load(c *gin.Context, model *Model, opts api.Options, sessionDuration time.Duration) error {
ctx, cancel := context.WithTimeout(c, 10*time.Second)
defer cancel()
needLoad := loaded.llama == nil || // is there a model loaded?
loaded.model != model.ModelPath || // has the base model changed?
!reflect.DeepEqual(loaded.adapters, model.AdapterPaths) || // have the adapters changed?
!reflect.DeepEqual(loaded.projectors, model.ProjectorPaths) || // have the adapters changed?
!reflect.DeepEqual(loaded.Options.Runner, opts.Runner) || // have the runner options changed?
loaded.llama.Ping(ctx) != nil
if needLoad {
if loaded.llama != nil {
slog.Info("changing loaded model")
unload()
}
llama, err := llm.NewLlamaServer(model.ModelPath, model.AdapterPaths, model.ProjectorPaths, opts)
if err != nil {
// some older models are not compatible with newer versions of llama.cpp
// show a generalized compatibility error until there is a better way to
// check for model compatibility
if errors.Is(llm.ErrUnsupportedFormat, err) || strings.Contains(err.Error(), "failed to load model") {
err = fmt.Errorf("%v: this model may be incompatible with your version of Ollama. If you previously pulled this model, try updating it by running `ollama pull %s`", err, model.ShortName)
}
return err
}
loaded.model = model.ModelPath
loaded.adapters = model.AdapterPaths
loaded.projectors = model.ProjectorPaths
loaded.llama = llama
loaded.Options = &opts
if err = llama.WaitUntilRunning(); err != nil {
slog.Error("error loading llama server", "error", err)
unload()
return err
}
}
if loaded.expireTimer == nil {
loaded.expireTimer = time.AfterFunc(sessionDuration, func() {
loaded.mu.Lock()
defer loaded.mu.Unlock()
unload()
})
}
loaded.expireTimer.Reset(sessionDuration)
return nil
}
func modelOptions(model *Model, requestOpts map[string]interface{}) (api.Options, error) {
opts := api.DefaultOptions()
if err := opts.FromMap(model.Options); err != nil {
@ -154,9 +72,7 @@ func isSupportedImageType(image []byte) bool {
return slices.Contains(allowedTypes, contentType)
}
func GenerateHandler(c *gin.Context) {
loaded.mu.Lock()
defer loaded.mu.Unlock()
func (s *Server) GenerateHandler(c *gin.Context) {
checkpointStart := time.Now()
var req api.GenerateRequest
@ -224,7 +140,11 @@ func GenerateHandler(c *gin.Context) {
sessionDuration = req.KeepAlive.Duration
}
if err := load(c, model, opts, sessionDuration); err != nil {
rCh, eCh := s.sched.GetRunner(c.Request.Context(), model, opts, sessionDuration)
var runner *runnerRef
select {
case runner = <-rCh:
case err = <-eCh:
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
return
}
@ -275,7 +195,7 @@ func GenerateHandler(c *gin.Context) {
sb.Reset()
if req.Context != nil {
prev, err := loaded.llama.Detokenize(c.Request.Context(), req.Context)
prev, err := runner.llama.Detokenize(c.Request.Context(), req.Context)
if err != nil {
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
return
@ -297,9 +217,6 @@ func GenerateHandler(c *gin.Context) {
defer close(ch)
fn := func(r llm.CompletionResponse) {
// Update model expiration
loaded.expireTimer.Reset(sessionDuration)
// Build up the full response
if _, err := generated.WriteString(r.Content); err != nil {
ch <- gin.H{"error": err.Error()}
@ -331,7 +248,7 @@ func GenerateHandler(c *gin.Context) {
}
// TODO (jmorganca): encode() should not strip special tokens
tokens, err := loaded.llama.Tokenize(c.Request.Context(), p)
tokens, err := runner.llama.Tokenize(c.Request.Context(), p)
if err != nil {
ch <- gin.H{"error": err.Error()}
return
@ -359,7 +276,7 @@ func GenerateHandler(c *gin.Context) {
Images: images,
Options: opts,
}
if err := loaded.llama.Completion(c.Request.Context(), req, fn); err != nil {
if err := runner.llama.Completion(c.Request.Context(), req, fn); err != nil {
ch <- gin.H{"error": err.Error()}
}
}()
@ -421,10 +338,7 @@ func getDefaultSessionDuration() time.Duration {
return defaultSessionDuration
}
func EmbeddingsHandler(c *gin.Context) {
loaded.mu.Lock()
defer loaded.mu.Unlock()
func (s *Server) EmbeddingsHandler(c *gin.Context) {
var req api.EmbeddingRequest
err := c.ShouldBindJSON(&req)
switch {
@ -469,7 +383,11 @@ func EmbeddingsHandler(c *gin.Context) {
sessionDuration = req.KeepAlive.Duration
}
if err := load(c, model, opts, sessionDuration); err != nil {
rCh, eCh := s.sched.GetRunner(c.Request.Context(), model, opts, sessionDuration)
var runner *runnerRef
select {
case runner = <-rCh:
case err = <-eCh:
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
return
}
@ -480,7 +398,7 @@ func EmbeddingsHandler(c *gin.Context) {
return
}
embedding, err := loaded.llama.Embedding(c.Request.Context(), req.Prompt)
embedding, err := runner.llama.Embedding(c.Request.Context(), req.Prompt)
if err != nil {
slog.Info(fmt.Sprintf("embedding generation failed: %v", err))
c.JSON(http.StatusInternalServerError, gin.H{"error": "failed to generate embedding"})
@ -493,7 +411,7 @@ func EmbeddingsHandler(c *gin.Context) {
c.JSON(http.StatusOK, resp)
}
func PullModelHandler(c *gin.Context) {
func (s *Server) PullModelHandler(c *gin.Context) {
var req api.PullRequest
err := c.ShouldBindJSON(&req)
switch {
@ -542,7 +460,7 @@ func PullModelHandler(c *gin.Context) {
streamResponse(c, ch)
}
func PushModelHandler(c *gin.Context) {
func (s *Server) PushModelHandler(c *gin.Context) {
var req api.PushRequest
err := c.ShouldBindJSON(&req)
switch {
@ -591,7 +509,7 @@ func PushModelHandler(c *gin.Context) {
streamResponse(c, ch)
}
func CreateModelHandler(c *gin.Context) {
func (s *Server) CreateModelHandler(c *gin.Context) {
var req api.CreateRequest
err := c.ShouldBindJSON(&req)
switch {
@ -664,7 +582,7 @@ func CreateModelHandler(c *gin.Context) {
streamResponse(c, ch)
}
func DeleteModelHandler(c *gin.Context) {
func (s *Server) DeleteModelHandler(c *gin.Context) {
var req api.DeleteRequest
err := c.ShouldBindJSON(&req)
switch {
@ -709,7 +627,7 @@ func DeleteModelHandler(c *gin.Context) {
c.JSON(http.StatusOK, nil)
}
func ShowModelHandler(c *gin.Context) {
func (s *Server) ShowModelHandler(c *gin.Context) {
var req api.ShowRequest
err := c.ShouldBindJSON(&req)
switch {
@ -809,7 +727,7 @@ func GetModelInfo(req api.ShowRequest) (*api.ShowResponse, error) {
return resp, nil
}
func ListModelsHandler(c *gin.Context) {
func (s *Server) ListModelsHandler(c *gin.Context) {
models := make([]api.ModelResponse, 0)
manifestsPath, err := GetManifestPath()
if err != nil {
@ -869,7 +787,7 @@ func ListModelsHandler(c *gin.Context) {
c.JSON(http.StatusOK, api.ListResponse{Models: models})
}
func CopyModelHandler(c *gin.Context) {
func (s *Server) CopyModelHandler(c *gin.Context) {
var req api.CopyRequest
err := c.ShouldBindJSON(&req)
switch {
@ -901,7 +819,7 @@ func CopyModelHandler(c *gin.Context) {
}
}
func HeadBlobHandler(c *gin.Context) {
func (s *Server) HeadBlobHandler(c *gin.Context) {
path, err := GetBlobsPath(c.Param("digest"))
if err != nil {
c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": err.Error()})
@ -916,7 +834,7 @@ func HeadBlobHandler(c *gin.Context) {
c.Status(http.StatusOK)
}
func CreateBlobHandler(c *gin.Context) {
func (s *Server) CreateBlobHandler(c *gin.Context) {
path, err := GetBlobsPath(c.Param("digest"))
if err != nil {
c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": err.Error()})
@ -1063,27 +981,27 @@ func (s *Server) GenerateRoutes() http.Handler {
allowedHostsMiddleware(s.addr),
)
r.POST("/api/pull", PullModelHandler)
r.POST("/api/generate", GenerateHandler)
r.POST("/api/chat", ChatHandler)
r.POST("/api/embeddings", EmbeddingsHandler)
r.POST("/api/create", CreateModelHandler)
r.POST("/api/push", PushModelHandler)
r.POST("/api/copy", CopyModelHandler)
r.DELETE("/api/delete", DeleteModelHandler)
r.POST("/api/show", ShowModelHandler)
r.POST("/api/blobs/:digest", CreateBlobHandler)
r.HEAD("/api/blobs/:digest", HeadBlobHandler)
r.POST("/api/pull", s.PullModelHandler)
r.POST("/api/generate", s.GenerateHandler)
r.POST("/api/chat", s.ChatHandler)
r.POST("/api/embeddings", s.EmbeddingsHandler)
r.POST("/api/create", s.CreateModelHandler)
r.POST("/api/push", s.PushModelHandler)
r.POST("/api/copy", s.CopyModelHandler)
r.DELETE("/api/delete", s.DeleteModelHandler)
r.POST("/api/show", s.ShowModelHandler)
r.POST("/api/blobs/:digest", s.CreateBlobHandler)
r.HEAD("/api/blobs/:digest", s.HeadBlobHandler)
// Compatibility endpoints
r.POST("/v1/chat/completions", openai.Middleware(), ChatHandler)
r.POST("/v1/chat/completions", openai.Middleware(), s.ChatHandler)
for _, method := range []string{http.MethodGet, http.MethodHead} {
r.Handle(method, "/", func(c *gin.Context) {
c.String(http.StatusOK, "Ollama is running")
})
r.Handle(method, "/api/tags", ListModelsHandler)
r.Handle(method, "/api/tags", s.ListModelsHandler)
r.Handle(method, "/api/version", func(c *gin.Context) {
c.JSON(http.StatusOK, gin.H{"version": version.Version})
})
@ -1137,7 +1055,9 @@ func Serve(ln net.Listener) error {
}
}
s := &Server{addr: ln.Addr()}
ctx, done := context.WithCancel(context.Background())
sched := InitScheduler(ctx)
s := &Server{addr: ln.Addr(), sched: sched}
r := s.GenerateRoutes()
slog.Info(fmt.Sprintf("Listening on %s (version %s)", ln.Addr(), version.Version))
@ -1150,7 +1070,8 @@ func Serve(ln net.Listener) error {
signal.Notify(signals, syscall.SIGINT, syscall.SIGTERM)
go func() {
<-signals
unload()
done()
sched.unloadAllRunners()
gpu.Cleanup()
os.Exit(0)
}()
@ -1158,12 +1079,12 @@ func Serve(ln net.Listener) error {
if err := llm.Init(); err != nil {
return fmt.Errorf("unable to initialize llm library %w", err)
}
if runtime.GOOS == "linux" { // TODO - windows too
// check compatibility to log warnings
if _, err := gpu.CheckVRAM(); err != nil {
slog.Info(err.Error())
}
}
s.sched.Run(ctx)
// At startup we retrieve GPU information so we can get log messages before loading a model
// This will log warnings to the log in case we have problems with detected GPUs
_ = gpu.GetGPUInfo()
return srvr.Serve(ln)
}
@ -1219,9 +1140,9 @@ func streamResponse(c *gin.Context, ch chan any) {
}
// ChatPrompt builds up a prompt from a series of messages for the currently `loaded` model
func chatPrompt(ctx context.Context, template string, messages []api.Message, numCtx int) (string, error) {
func chatPrompt(ctx context.Context, runner *runnerRef, template string, messages []api.Message, numCtx int) (string, error) {
encode := func(s string) ([]int, error) {
return loaded.llama.Tokenize(ctx, s)
return runner.llama.Tokenize(ctx, s)
}
prompt, err := ChatPrompt(template, messages, numCtx, encode)
@ -1232,10 +1153,7 @@ func chatPrompt(ctx context.Context, template string, messages []api.Message, nu
return prompt, nil
}
func ChatHandler(c *gin.Context) {
loaded.mu.Lock()
defer loaded.mu.Unlock()
func (s *Server) ChatHandler(c *gin.Context) {
checkpointStart := time.Now()
var req api.ChatRequest
@ -1292,7 +1210,11 @@ func ChatHandler(c *gin.Context) {
sessionDuration = req.KeepAlive.Duration
}
if err := load(c, model, opts, sessionDuration); err != nil {
rCh, eCh := s.sched.GetRunner(c.Request.Context(), model, opts, sessionDuration)
var runner *runnerRef
select {
case runner = <-rCh:
case err = <-eCh:
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
return
}
@ -1309,7 +1231,7 @@ func ChatHandler(c *gin.Context) {
}, req.Messages...)
}
prompt, err := chatPrompt(c.Request.Context(), model.Template, req.Messages, opts.NumCtx)
prompt, err := chatPrompt(c.Request.Context(), runner, model.Template, req.Messages, opts.NumCtx)
if err != nil {
c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()})
return
@ -1352,8 +1274,6 @@ func ChatHandler(c *gin.Context) {
defer close(ch)
fn := func(r llm.CompletionResponse) {
// Update model expiration
loaded.expireTimer.Reset(sessionDuration)
resp := api.ChatResponse{
Model: req.Model,
@ -1376,7 +1296,7 @@ func ChatHandler(c *gin.Context) {
ch <- resp
}
if err := loaded.llama.Completion(c.Request.Context(), llm.CompletionRequest{
if err := runner.llama.Completion(c.Request.Context(), llm.CompletionRequest{
Prompt: prompt,
Format: req.Format,
Images: images,

525
server/sched.go Normal file
View file

@ -0,0 +1,525 @@
package server
import (
"context"
"errors"
"fmt"
"log/slog"
"os"
"reflect"
"sort"
"strconv"
"strings"
"sync"
"time"
"github.com/ollama/ollama/api"
"github.com/ollama/ollama/format"
"github.com/ollama/ollama/gpu"
"github.com/ollama/ollama/llm"
"golang.org/x/exp/slices"
)
type LlmRequest struct {
ctx context.Context //nolint:containedctx
model *Model
ggml *llm.GGML // TODO - how large is this, and do we need to free it after we've finished loading?
opts api.Options
sessionDuration time.Duration
successCh chan *runnerRef
errCh chan error
}
type Scheduler struct {
pendingReqCh chan *LlmRequest
finishedReqCh chan *LlmRequest
expiredCh chan *runnerRef
unloadedCh chan interface{}
loaded map[string]*runnerRef
loadedMu sync.Mutex
loadFn func(req *LlmRequest, gpus gpu.GpuInfoList)
newServerFn func(gpus gpu.GpuInfoList, model string, ggml *llm.GGML, adapters []string, projectors []string, opts api.Options) (llm.LlamaServer, error)
getGpuFn func() gpu.GpuInfoList
}
// TODO set this to zero after a release or two, to enable multiple models by default
var loadedMax = 1 // Maximum runners; < 1 maps to as many as will fit in VRAM (unlimited for CPU runners)
var maxQueuedRequests = 10 // TODO configurable
func InitScheduler(ctx context.Context) *Scheduler {
maxRunners := os.Getenv("OLLAMA_MAX_LOADED_MODELS")
if maxRunners != "" {
m, err := strconv.Atoi(maxRunners)
if err != nil {
slog.Error("invalid setting", "OLLAMA_MAX_LOADED_MODELS", maxRunners, "error", err)
} else {
loadedMax = m
}
}
sched := &Scheduler{
pendingReqCh: make(chan *LlmRequest, maxQueuedRequests),
finishedReqCh: make(chan *LlmRequest, maxQueuedRequests),
expiredCh: make(chan *runnerRef, maxQueuedRequests),
unloadedCh: make(chan interface{}, maxQueuedRequests),
loaded: make(map[string]*runnerRef),
newServerFn: llm.NewLlamaServer,
getGpuFn: gpu.GetGPUInfo,
}
sched.loadFn = sched.load
return sched
}
// context must be canceled to decrement ref count and release the runner
func (s *Scheduler) GetRunner(c context.Context, model *Model, opts api.Options, sessionDuration time.Duration) (chan *runnerRef, chan error) {
ggml, err := llm.LoadModel(model.ModelPath)
req := &LlmRequest{
ctx: c,
model: model,
ggml: ggml,
opts: opts,
sessionDuration: sessionDuration,
successCh: make(chan *runnerRef),
errCh: make(chan error, 1),
}
if err != nil {
req.errCh <- err
return req.successCh, req.errCh
}
select {
case s.pendingReqCh <- req:
default:
req.errCh <- fmt.Errorf("server busy, please try again. maximum pending requests exceeded")
}
return req.successCh, req.errCh
}
// Returns immediately, spawns go routines for the scheduler which will shutdown when ctx is done
func (s *Scheduler) Run(ctx context.Context) {
slog.Debug("starting llm scheduler")
go func() {
s.processPending(ctx)
}()
go func() {
s.processCompleted(ctx)
}()
}
func (s *Scheduler) processPending(ctx context.Context) {
for {
select {
case <-ctx.Done():
slog.Debug("shutting down scheduler pending loop")
return
case pending := <-s.pendingReqCh:
// Block other requests until we get this pending request running
for {
var runnerToExpire *runnerRef
s.loadedMu.Lock()
runner := s.loaded[pending.model.ModelPath]
loadedCount := len(s.loaded)
s.loadedMu.Unlock()
if runner != nil {
if runner.needsReload(ctx, pending) {
runnerToExpire = runner
} else {
// Runner is usable, return it
pending.useLoadedRunner(runner, s.finishedReqCh)
break
}
} else if loadedCount == 0 {
slog.Debug("loading first model", "model", pending.model.ModelPath)
gpus := s.getGpuFn()
g := pickBestFitGPUs(pending, gpus)
if g != nil {
gpus = g
}
s.loadFn(pending, gpus)
break
} else if loadedMax > 0 && loadedCount >= loadedMax {
slog.Debug("max runners achieved, unloading one to make room", "runner_count", loadedCount)
runnerToExpire = s.findRunnerToUnload(pending)
} else {
// More than one loaded model, so we have to see if the new one fits
// Get a refreshed GPU list
gpus := s.getGpuFn()
// Update free memory from currently loaded models
s.updateFreeSpace(gpus)
gpus = pickBestFitGPUs(pending, gpus)
if gpus != nil {
slog.Debug("new model fits with existing models, loading")
s.loadFn(pending, gpus)
break
}
runnerToExpire = s.findRunnerToUnload(pending)
}
if runnerToExpire == nil {
// Shouildn't happen
slog.Error("runner to expire was nil!")
continue
}
// Trigger an expiration to unload once it's done
runnerToExpire.refMu.Lock()
slog.Debug("resetting model to expire immediately to make room", "model", runnerToExpire.model, "refCount", runnerToExpire.refCount)
if runnerToExpire.expireTimer != nil {
runnerToExpire.expireTimer.Stop()
runnerToExpire.expireTimer = nil
}
runnerToExpire.sessionDuration = 0
if runnerToExpire.refCount <= 0 {
s.expiredCh <- runnerToExpire
}
runnerToExpire.refMu.Unlock()
// Wait for the unload to happen
// Note: at this point we're queueing up all incoming requests, even if they were for
// a different model that's loaded and not scheduled to be removed.
slog.Debug("waiting for pending requests to complete and unload to occur", "model", runnerToExpire.model)
select {
case <-ctx.Done():
slog.Debug("shutting down scheduler pending loop")
return
case <-s.unloadedCh:
slog.Debug("unload completed", "model", runnerToExpire.model)
continue
}
}
case <-s.unloadedCh:
// An unload request when there are no pending request can be ignored
slog.Debug("ignoring unload event with no pending requests")
}
}
}
func (s *Scheduler) processCompleted(ctx context.Context) {
// Process completed requests, expired timers, and unloading models
for {
select {
case <-ctx.Done():
slog.Debug("shutting down scheduler completed loop")
return
case finished := <-s.finishedReqCh:
s.loadedMu.Lock()
runner := s.loaded[finished.model.ModelPath]
s.loadedMu.Unlock()
if runner == nil {
slog.Error("finished requeset signal received after model unloaded", "model", finished.model.ModelPath)
continue
}
runner.refMu.Lock()
runner.refCount--
if runner.refCount <= 0 {
if runner.sessionDuration <= 0 {
slog.Debug("runner with zero duration has gone idle, expiring to unload", "model", runner.model)
if runner.expireTimer != nil {
runner.expireTimer.Stop()
runner.expireTimer = nil
}
s.expiredCh <- runner
} else if runner.expireTimer == nil {
slog.Debug("runner with non-zero duration has gone idle, adding timer", "model", runner.model, "duration", runner.sessionDuration)
runner.expireTimer = time.AfterFunc(runner.sessionDuration, func() {
slog.Debug("timer expired, expiring to unload", "model", runner.model)
runner.refMu.Lock()
defer runner.refMu.Unlock()
if runner.expireTimer != nil {
runner.expireTimer.Stop()
}
s.expiredCh <- runner
})
} else {
slog.Debug("runner with non-zero duration has gone idle, resetting timer", "model", runner.model, "duration", runner.sessionDuration)
runner.expireTimer.Reset(runner.sessionDuration)
}
}
slog.Debug("after processing request finished event", "model", runner.model, "refCount", runner.refCount)
runner.refMu.Unlock()
case runner := <-s.expiredCh:
slog.Debug("runner expired event received", "model", runner.model)
runner.refMu.Lock()
if runner.refCount > 0 {
// Shouldn't happen, but safeguard to ensure no leaked runners
slog.Debug("expired event with positive ref count, retrying", "model", runner.model, "refCount", runner.refCount)
go func(runner *runnerRef) {
// We can't unload yet, but want to as soon as the current request completes
// So queue up another expired event
time.Sleep(10 * time.Millisecond)
s.expiredCh <- runner
}(runner)
runner.refMu.Unlock()
continue
}
slog.Debug("got lock to unload", "model", runner.model)
runner.unload()
s.loadedMu.Lock()
delete(s.loaded, runner.model)
s.loadedMu.Unlock()
slog.Debug("runner released", "model", runner.model)
runner.refMu.Unlock()
slog.Debug("sending an unloaded event", "model", runner.model)
s.unloadedCh <- struct{}{}
}
}
}
// Complete the pending request and send the runner back to the requester
// Wires up a finished event after the request context is completed
// Updates session duration, and resets expiration timer
func (pending *LlmRequest) useLoadedRunner(runner *runnerRef, finished chan *LlmRequest) {
runner.refMu.Lock()
defer runner.refMu.Unlock()
runner.refCount++
runner.sessionDuration = pending.sessionDuration
pending.successCh <- runner
go func() {
<-pending.ctx.Done()
slog.Debug("context for request finished")
finished <- pending
}()
}
func (s *Scheduler) load(req *LlmRequest, gpus gpu.GpuInfoList) {
llama, err := s.newServerFn(gpus, req.model.ModelPath, req.ggml, req.model.AdapterPaths, req.model.ProjectorPaths, req.opts)
if err != nil {
// some older models are not compatible with newer versions of llama.cpp
// show a generalized compatibility error until there is a better way to
// check for model compatibility
if errors.Is(llm.ErrUnsupportedFormat, err) || strings.Contains(err.Error(), "failed to load model") {
err = fmt.Errorf("%v: this model may be incompatible with your version of Ollama. If you previously pulled this model, try updating it by running `ollama pull %s`", err, req.model.ShortName)
}
slog.Info("NewLlamaServer failed", "model", req.model.ModelPath, "error", err)
req.errCh <- err
return
}
runner := &runnerRef{}
runner.model = req.model.ModelPath
runner.adapters = req.model.AdapterPaths
runner.projectors = req.model.ProjectorPaths
runner.llama = llama
runner.Options = &req.opts
runner.sessionDuration = req.sessionDuration
runner.gpus = gpus
runner.estimatedVRAM = llama.EstimatedVRAM()
runner.loading = true
runner.refCount = 1
runner.refMu.Lock()
s.loadedMu.Lock()
s.loaded[req.model.ModelPath] = runner
slog.Info("loaded runners", "count", len(s.loaded))
s.loadedMu.Unlock()
go func() {
defer runner.refMu.Unlock()
if err = llama.WaitUntilRunning(req.ctx); err != nil {
slog.Error("error loading llama server", "error", err)
runner.refCount--
req.errCh <- err
slog.Debug("triggering expiration for failed load", "model", runner.model)
s.expiredCh <- runner
return
}
slog.Debug("finished setting up runner", "model", req.model.ModelPath)
runner.loading = false
go func() {
<-req.ctx.Done()
slog.Debug("context for request finished")
s.finishedReqCh <- req
}()
req.successCh <- runner
}()
}
func (s *Scheduler) updateFreeSpace(allGpus gpu.GpuInfoList) {
type predKey struct {
Library string
ID string
}
predMap := map[predKey]uint64{} // Sum up the total predicted usage per GPU for all runners
s.loadedMu.Lock()
for _, r := range s.loaded {
r.refMu.Lock()
gpuIDs := make([]string, 0, len(r.gpus))
if r.llama != nil {
// TODO this should be broken down by GPU instead of assuming uniform spread
estimatedVRAMPerGPU := r.llama.EstimatedVRAM() / uint64(len(r.gpus))
for _, gpu := range r.gpus {
gpuIDs = append(gpuIDs, gpu.ID)
}
for _, gpu := range allGpus {
if slices.Contains(gpuIDs, gpu.ID) {
predMap[predKey{gpu.Library, gpu.ID}] += estimatedVRAMPerGPU
}
}
} else {
slog.Warn("unexpected nil runner reference, memory prediction may be incorrect")
}
r.refMu.Unlock()
}
s.loadedMu.Unlock()
// Now that we've summed up all the GPU usage predictions across all the loaded runners, update the gpu list
for i := range allGpus {
if p, ok := predMap[predKey{allGpus[i].Library, allGpus[i].ID}]; ok {
slog.Debug("gpu reported", "gpu", allGpus[i].ID, "library", allGpus[i].Library, "available", format.HumanBytes2(allGpus[i].FreeMemory))
if p > allGpus[i].TotalMemory {
// Shouldn't happen
slog.Warn("predicted usage exceeds VRAM", "gpu", allGpus[i].ID, "totalMemory", allGpus[i].TotalMemory, "predicted", p)
allGpus[i].FreeMemory = 0
} else if (allGpus[i].TotalMemory - p) < allGpus[i].FreeMemory { // predicted free is smaller than reported free, use it
// TODO maybe we should just always trust our numbers, since cuda's free memory reporting is laggy
// and we might unload models we didn't actually need to. The risk is if some other GPU intensive app is loaded
// after we start our first runner, then we'll never acount for that, so picking the smallest free value seems prudent.
allGpus[i].FreeMemory = allGpus[i].TotalMemory - p
}
slog.Info("updated VRAM", "gpu", allGpus[i].ID, "library", allGpus[i].Library, "total", format.HumanBytes2(allGpus[i].TotalMemory), "available", format.HumanBytes2(allGpus[i].FreeMemory))
}
}
}
type runnerRef struct {
refMu sync.Mutex
// refCond sync.Cond // Signaled on transition from 1 -> 0 refCount
refCount uint // prevent unloading if > 0
// unloading bool // set to true when we are trying to unload the runner
llama llm.LlamaServer
loading bool // True only during initial load, then false forever
gpus gpu.GpuInfoList // Recorded at time of provisioning
estimatedVRAM uint64
sessionDuration time.Duration
expireTimer *time.Timer
model string
adapters []string
projectors []string
*api.Options
}
// The refMu must already be held when calling unload
func (runner *runnerRef) unload() {
if runner.llama != nil {
runner.llama.Close()
}
runner.llama = nil
runner.adapters = nil
runner.projectors = nil
runner.Options = nil
runner.gpus = nil
}
func (runner *runnerRef) needsReload(ctx context.Context, req *LlmRequest) bool {
slog.Debug("evaluating already loaded", "model", req.model.ModelPath)
runner.refMu.Lock()
defer runner.refMu.Unlock()
// Ignore the NumGPU settings for comparison
optsExisting := runner.Options.Runner
optsExisting.NumGPU = -1
optsNew := req.opts.Runner
optsNew.NumGPU = -1
timeout := 10 * time.Second
if runner.loading {
timeout = 2 * time.Minute // Initial load can take a long time for big models on slow systems...
}
ctx, cancel := context.WithTimeout(ctx, timeout) // BUG -
defer cancel()
if !reflect.DeepEqual(runner.adapters, req.model.AdapterPaths) || // have the adapters changed?
!reflect.DeepEqual(runner.projectors, req.model.ProjectorPaths) || // have the projectors changed?
!reflect.DeepEqual(optsExisting, optsNew) || // have the runner options changed?
runner.llama.Ping(ctx) != nil {
return true
}
return false
}
type ByDuration []*runnerRef
func (a ByDuration) Len() int { return len(a) }
func (a ByDuration) Swap(i, j int) { a[i], a[j] = a[j], a[i] }
func (a ByDuration) Less(i, j int) bool {
// uint64 to turn negative time (never unload) to largest
return uint64(a[i].sessionDuration) < uint64(a[j].sessionDuration)
}
// TODO - future consideration to pick runners based on size
// type BySize []*runnerRef
// func (a BySize) Len() int { return len(a) }
// func (a BySize) Swap(i, j int) { a[i], a[j] = a[j], a[i] }
// func (a BySize) Less(i, j int) bool { return a[i].estimatedVRAM < a[j].estimatedVRAM }
// pickBestFitGPUs will try to find the optimal placement of the model in the available GPUs where the model fully fits
// If the model can not be fit fully within the available GPU(s) nil is returned
func pickBestFitGPUs(req *LlmRequest, gpus gpu.GpuInfoList) gpu.GpuInfoList {
var estimatedVRAM uint64
for _, gl := range gpus.ByLibrary() {
var ok bool
sgl := append(make(gpu.GpuInfoList, 0, len(gl)), gl...)
// TODO - potentially sort by performance capability, existing models loaded, etc.
// Note: at present, this will favor more VRAM over faster GPU speed in mixed setups
sort.Sort(sort.Reverse(gpu.ByFreeMemory(sgl)))
// First attempt to fit the model into a single GPU
for _, g := range sgl {
if ok, estimatedVRAM = llm.PredictServerFit([]gpu.GpuInfo{g}, req.ggml, req.model.AdapterPaths, req.model.ProjectorPaths, req.opts); ok {
slog.Debug("new model will fit in available VRAM in single GPU, loading", "model", req.model.ModelPath, "gpu", g.ID, "available", g.FreeMemory, "required", format.HumanBytes2(estimatedVRAM))
return []gpu.GpuInfo{g}
}
}
// TODO future refinements
// - if multiple Libraries, see if any single GPU in any Library will fit
// - try subsets of GPUs instead of just falling back to 1 or all in a family
// Now try all the GPUs
if ok, estimatedVRAM = llm.PredictServerFit(gl, req.ggml, req.model.AdapterPaths, req.model.ProjectorPaths, req.opts); ok {
slog.Debug("new model will fit in available VRAM, loading", "model", req.model.ModelPath, "library", gl[0].Library, "required", format.HumanBytes2(estimatedVRAM))
return gl
}
}
return nil
}
// findRunnerToUnload finds a runner to unload to make room for a new model
func (s *Scheduler) findRunnerToUnload(req *LlmRequest) *runnerRef {
s.loadedMu.Lock()
runnerList := make([]*runnerRef, 0, len(s.loaded))
for _, r := range s.loaded {
runnerList = append(runnerList, r)
}
s.loadedMu.Unlock()
// In the future we can enhance the algorithm to be smarter about picking the optimal runner to unload
// e.g., if we have multiple options, will one make room for the request?
sort.Sort(ByDuration(runnerList))
// First try to find a runner that's already idle
for _, runner := range runnerList {
runner.refMu.Lock()
rc := runner.refCount
runner.refMu.Unlock()
if rc == 0 {
slog.Debug("found an idle runner to unload")
return runner
}
}
// None appear idle, just wait for the one with the shortest duration
slog.Debug("no idle runners, picking the shortest duration", "count", len(runnerList))
return runnerList[0]
}
func (s *Scheduler) unloadAllRunners() {
s.loadedMu.Lock()
defer s.loadedMu.Unlock()
for model, runner := range s.loaded {
if runner.llama != nil {
slog.Debug("shutting down runner", "model", model)
runner.llama.Close()
}
}
}

553
server/sched_test.go Normal file
View file

@ -0,0 +1,553 @@
package server
import (
"bytes"
"context"
"encoding/binary"
"fmt"
"log/slog"
"os"
"testing"
"time"
"github.com/ollama/ollama/api"
"github.com/ollama/ollama/app/lifecycle"
"github.com/ollama/ollama/format"
"github.com/ollama/ollama/gpu"
"github.com/ollama/ollama/llm"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)
func init() {
os.Setenv("OLLAMA_DEBUG", "1")
lifecycle.InitLogging()
}
func TestInitScheduler(t *testing.T) {
ctx, done := context.WithCancel(context.Background())
defer done()
initialMax := loadedMax
s := InitScheduler(ctx)
require.Equal(t, initialMax, loadedMax)
require.NotNil(t, s.loaded)
os.Setenv("OLLAMA_MAX_LOADED_MODELS", "blue")
s = InitScheduler(ctx)
require.Equal(t, initialMax, loadedMax)
require.NotNil(t, s.loaded)
os.Setenv("OLLAMA_MAX_LOADED_MODELS", "0")
s = InitScheduler(ctx)
require.Equal(t, 0, loadedMax)
require.NotNil(t, s.loaded)
}
func TestLoad(t *testing.T) {
ctx, done := context.WithTimeout(context.Background(), 5*time.Millisecond)
defer done()
s := InitScheduler(ctx)
req := &LlmRequest{
ctx: ctx,
model: &Model{ModelPath: "foo"},
successCh: make(chan *runnerRef, 1),
errCh: make(chan error, 1),
sessionDuration: 2,
}
// Fail to load model first
s.newServerFn = func(gpus gpu.GpuInfoList, model string, ggml *llm.GGML, adapters []string, projectors []string, opts api.Options) (llm.LlamaServer, error) {
return nil, fmt.Errorf("something failed to load model blah")
}
gpus := gpu.GpuInfoList{}
s.load(req, gpus)
require.Len(t, req.successCh, 0)
require.Len(t, req.errCh, 1)
require.Len(t, s.loaded, 0)
err := <-req.errCh
require.Contains(t, err.Error(), "this model may be incompatible")
server := &mockLlm{estimatedVRAM: 10}
s.newServerFn = func(gpus gpu.GpuInfoList, model string, ggml *llm.GGML, adapters []string, projectors []string, opts api.Options) (llm.LlamaServer, error) {
return server, nil
}
s.load(req, gpus)
select {
case err := <-req.errCh:
require.NoError(t, err)
case resp := <-req.successCh:
require.Equal(t, uint64(10), resp.estimatedVRAM)
require.Equal(t, uint(1), resp.refCount)
require.Len(t, s.loaded, 1)
}
req.model.ModelPath = "dummy_model_path"
server.waitResp = fmt.Errorf("wait failure")
s.load(req, gpus)
select {
case err := <-req.errCh:
require.Contains(t, err.Error(), "wait failure")
case resp := <-req.successCh:
t.Errorf("unexpected success %v", resp)
}
runner := s.loaded["dummy_model_path"]
require.NotNil(t, runner)
require.Equal(t, uint(0), runner.refCount)
require.Len(t, s.expiredCh, 1)
}
type bundle struct {
ctx context.Context //nolint:containedctx
ctxDone func()
srv *mockLlm
req *LlmRequest
}
func (scenario *bundle) newServer(gpus gpu.GpuInfoList, model string, ggml *llm.GGML, adapters []string, projectors []string, opts api.Options) (llm.LlamaServer, error) {
return scenario.srv, nil
}
func newScenario(t *testing.T, ctx context.Context, modelName string, estimatedVRAM uint64) *bundle {
scenario := &bundle{}
scenario.ctx, scenario.ctxDone = context.WithCancel(ctx)
t.Helper()
f, err := os.CreateTemp(t.TempDir(), modelName)
assert.Nil(t, err)
defer f.Close()
gguf := llm.NewGGUFV3(binary.LittleEndian)
err = gguf.Encode(f, llm.KV{
"general.architecture": "llama",
"general.name": "name",
"llama.context_length": uint32(32),
"llama.embedding_length": uint32(4096),
"llama.block_count": uint32(1),
"llama.attention.head_count": uint32(32),
"llama.attention.head_count_kv": uint32(32),
"tokenizer.ggml.tokens": []string{" "},
"tokenizer.ggml.scores": []float32{0},
"tokenizer.ggml.token_type": []int32{0},
}, []llm.Tensor{
{Name: "blk.0.attn.weight", Kind: uint32(0), Offset: uint64(0), Shape: []uint64{1, 1, 1, 1}, WriterTo: &bytes.Reader{}},
})
assert.Nil(t, err)
fname := f.Name()
model := &Model{Name: modelName, ModelPath: fname}
ggml, err := llm.LoadModel(model.ModelPath)
require.NoError(t, err)
scenario.req = &LlmRequest{
ctx: scenario.ctx,
model: model,
ggml: ggml,
sessionDuration: 5 * time.Millisecond,
successCh: make(chan *runnerRef, 1),
errCh: make(chan error, 1),
}
scenario.srv = &mockLlm{estimatedVRAM: estimatedVRAM}
return scenario
}
func TestRequests(t *testing.T) {
ctx, done := context.WithTimeout(context.Background(), 100*time.Millisecond)
defer done()
// Same model, same request
scenario1a := newScenario(t, ctx, "ollama-model-1", 10)
scenario1a.req.sessionDuration = 0
scenario1b := newScenario(t, ctx, "ollama-model-1", 11)
scenario1b.req.model = scenario1a.req.model
scenario1b.req.ggml = scenario1a.req.ggml
scenario1b.req.sessionDuration = 0
// simple reload of same model
scenario2a := newScenario(t, ctx, "ollama-model-1", 20)
scenario2a.req.model = scenario1a.req.model
scenario2a.req.ggml = scenario1a.req.ggml
// Multiple loaded models
scenario3a := newScenario(t, ctx, "ollama-model-3a", 1*format.GigaByte)
scenario3b := newScenario(t, ctx, "ollama-model-3b", 24*format.GigaByte)
scenario3c := newScenario(t, ctx, "ollama-model-3c", 30) // Needs prior unloaded
s := InitScheduler(ctx)
s.getGpuFn = func() gpu.GpuInfoList {
g := gpu.GpuInfo{Library: "metal"}
g.TotalMemory = 24 * format.GigaByte
g.FreeMemory = 12 * format.GigaByte
return []gpu.GpuInfo{g}
}
s.newServerFn = scenario1a.newServer
slog.Info("scenario1a")
s.pendingReqCh <- scenario1a.req
require.Len(t, s.pendingReqCh, 1)
s.Run(ctx)
select {
case resp := <-scenario1a.req.successCh:
require.Equal(t, resp.llama, scenario1a.srv)
require.Len(t, s.pendingReqCh, 0)
require.Len(t, scenario1a.req.errCh, 0)
case <-ctx.Done():
t.Errorf("timeout")
}
// Same runner as first request due to not needing a reload
s.newServerFn = scenario1b.newServer
slog.Info("scenario1b")
s.pendingReqCh <- scenario1b.req
select {
case resp := <-scenario1b.req.successCh:
require.Equal(t, resp.llama, scenario1a.srv)
require.Len(t, s.pendingReqCh, 0)
require.Len(t, scenario1b.req.errCh, 0)
case <-ctx.Done():
t.Errorf("timeout")
}
// Trigger a reload
s.newServerFn = scenario2a.newServer
scenario2a.req.model.AdapterPaths = []string{"new"}
slog.Info("scenario2a")
s.pendingReqCh <- scenario2a.req
// finish first two requests, so model can reload
time.Sleep(1 * time.Millisecond)
scenario1a.ctxDone()
scenario1b.ctxDone()
select {
case resp := <-scenario2a.req.successCh:
require.Equal(t, resp.llama, scenario2a.srv)
require.Len(t, s.pendingReqCh, 0)
require.Len(t, scenario2a.req.errCh, 0)
case <-ctx.Done():
t.Errorf("timeout")
}
loadedMax = 1
s.newServerFn = scenario3a.newServer
slog.Info("scenario3a")
s.pendingReqCh <- scenario3a.req
// finish prior request, so new model can load
time.Sleep(1 * time.Millisecond)
scenario2a.ctxDone()
select {
case resp := <-scenario3a.req.successCh:
require.Equal(t, resp.llama, scenario3a.srv)
require.Len(t, s.pendingReqCh, 0)
require.Len(t, scenario3a.req.errCh, 0)
case <-ctx.Done():
t.Errorf("timeout")
}
require.Len(t, s.loaded, 1)
loadedMax = 0
s.newServerFn = scenario3b.newServer
slog.Info("scenario3b")
s.pendingReqCh <- scenario3b.req
select {
case resp := <-scenario3b.req.successCh:
require.Equal(t, resp.llama, scenario3b.srv)
require.Len(t, s.pendingReqCh, 0)
require.Len(t, scenario3b.req.errCh, 0)
case <-ctx.Done():
t.Errorf("timeout")
}
require.Len(t, s.loaded, 2)
// Try to load a model that wont fit
s.newServerFn = scenario3c.newServer
slog.Info("scenario3c")
require.Len(t, s.loaded, 2)
scenario3a.ctxDone() // Won't help since this one isn't big enough to make room
time.Sleep(2 * time.Millisecond)
s.pendingReqCh <- scenario3c.req
// finish prior request, so new model can load
time.Sleep(6 * time.Millisecond)
require.Len(t, s.loaded, 1)
scenario3b.ctxDone()
select {
case resp := <-scenario3c.req.successCh:
require.Equal(t, resp.llama, scenario3c.srv)
require.Len(t, s.pendingReqCh, 0)
require.Len(t, scenario3c.req.errCh, 0)
case <-ctx.Done():
t.Errorf("timeout")
}
require.Len(t, s.loaded, 1)
}
func TestGetRunner(t *testing.T) {
ctx, done := context.WithTimeout(context.Background(), 20*time.Millisecond)
defer done()
// Same model, same request
scenario1a := newScenario(t, ctx, "ollama-model-1a", 10)
scenario1a.req.sessionDuration = 0
scenario1b := newScenario(t, ctx, "ollama-model-1b", 10)
scenario1b.req.sessionDuration = 0
scenario1c := newScenario(t, ctx, "ollama-model-1c", 10)
scenario1c.req.sessionDuration = 0
maxQueuedRequests = 1
s := InitScheduler(ctx)
s.getGpuFn = func() gpu.GpuInfoList {
g := gpu.GpuInfo{Library: "metal"}
g.TotalMemory = 24 * format.GigaByte
g.FreeMemory = 12 * format.GigaByte
return []gpu.GpuInfo{g}
}
s.newServerFn = scenario1a.newServer
slog.Info("scenario1a")
successCh1a, errCh1a := s.GetRunner(scenario1a.ctx, scenario1a.req.model, scenario1a.req.opts, scenario1a.req.sessionDuration)
require.Len(t, s.pendingReqCh, 1)
slog.Info("scenario1b")
successCh1b, errCh1b := s.GetRunner(scenario1b.ctx, scenario1b.req.model, scenario1b.req.opts, scenario1b.req.sessionDuration)
require.Len(t, s.pendingReqCh, 1)
require.Len(t, successCh1b, 0)
require.Len(t, errCh1b, 1)
err := <-errCh1b
require.Contains(t, err.Error(), "server busy")
s.Run(ctx)
select {
case resp := <-successCh1a:
require.Equal(t, resp.llama, scenario1a.srv)
require.Len(t, s.pendingReqCh, 0)
require.Len(t, errCh1a, 0)
case <-ctx.Done():
t.Errorf("timeout")
}
scenario1a.ctxDone()
require.Len(t, s.loaded, 1)
scenario1c.req.model.ModelPath = "bad path"
slog.Info("scenario1c")
successCh1c, errCh1c := s.GetRunner(scenario1c.ctx, scenario1c.req.model, scenario1c.req.opts, scenario1c.req.sessionDuration)
require.Len(t, s.pendingReqCh, 0)
require.Len(t, successCh1c, 0)
require.Len(t, errCh1c, 1)
err = <-errCh1c
require.Contains(t, err.Error(), "bad path")
scenario1b.ctxDone()
time.Sleep(5 * time.Millisecond)
require.Len(t, s.loaded, 0)
}
// TODO - add one scenario that triggers the bogus finished event with positive ref count
func TestPrematureExpired(t *testing.T) {
ctx, done := context.WithTimeout(context.Background(), 100*time.Millisecond)
defer done()
// Same model, same request
scenario1a := newScenario(t, ctx, "ollama-model-1a", 10)
s := InitScheduler(ctx)
s.getGpuFn = func() gpu.GpuInfoList {
g := gpu.GpuInfo{Library: "metal"}
g.TotalMemory = 24 * format.GigaByte
g.FreeMemory = 12 * format.GigaByte
return []gpu.GpuInfo{g}
}
s.newServerFn = scenario1a.newServer
successCh1a, errCh1a := s.GetRunner(scenario1a.ctx, scenario1a.req.model, scenario1a.req.opts, scenario1a.req.sessionDuration)
require.Len(t, s.pendingReqCh, 1)
s.Run(ctx)
select {
case resp := <-successCh1a:
require.Equal(t, resp.llama, scenario1a.srv)
require.Len(t, s.pendingReqCh, 0)
require.Len(t, errCh1a, 0)
require.Len(t, s.loaded, 1)
slog.Info("sending premature expired event now")
s.expiredCh <- resp // Shouldn't happen in real life, but make sure its safe
case <-ctx.Done():
t.Errorf("timeout")
}
time.Sleep(scenario1a.req.sessionDuration)
scenario1a.ctxDone()
time.Sleep(20 * time.Millisecond)
require.LessOrEqual(t, len(s.finishedReqCh), 1)
time.Sleep(10 * time.Millisecond)
require.Len(t, s.finishedReqCh, 0)
require.Len(t, s.loaded, 0)
// also shouldn't happen in real life
s.finishedReqCh <- scenario1a.req
time.Sleep(5 * time.Millisecond)
}
func TestUseLoadedRunner(t *testing.T) {
ctx, done := context.WithTimeout(context.Background(), 5*time.Millisecond)
req := &LlmRequest{
ctx: ctx,
successCh: make(chan *runnerRef, 1),
sessionDuration: 2,
}
finished := make(chan *LlmRequest)
llm1 := &mockLlm{}
r1 := &runnerRef{llama: llm1, sessionDuration: 1}
req.useLoadedRunner(r1, finished)
require.Equal(t, uint(1), r1.refCount)
require.Equal(t, time.Duration(2), r1.sessionDuration)
select {
case success := <-req.successCh:
require.Equal(t, r1, success)
case <-ctx.Done():
t.Errorf("timeout")
}
done()
fin := <-finished
require.Equal(t, req, fin)
}
func TestUpdateFreeSpace(t *testing.T) {
ctx, done := context.WithTimeout(context.Background(), 5*time.Millisecond)
defer done()
gpus := gpu.GpuInfoList{
{
Library: "a",
ID: "1",
},
{
Library: "a",
ID: "2",
},
}
gpus[0].TotalMemory = 1000
gpus[0].FreeMemory = 900
gpus[1].TotalMemory = 2000
gpus[1].FreeMemory = 1900
llm1 := &mockLlm{estimatedVRAM: 100}
llm2 := &mockLlm{estimatedVRAM: 200}
r1 := &runnerRef{llama: llm1, gpus: gpus}
r2 := &runnerRef{llama: llm2, gpus: gpus}
s := InitScheduler(ctx)
s.loaded["a"] = r1
s.loaded["b"] = r2
s.updateFreeSpace(gpus)
require.Equal(t, uint64(850), gpus[0].FreeMemory)
require.Equal(t, uint64(1850), gpus[1].FreeMemory)
}
func TestFindRunnerToUnload(t *testing.T) {
ctx, done := context.WithTimeout(context.Background(), 5*time.Millisecond)
defer done()
req := &LlmRequest{ctx: ctx}
r1 := &runnerRef{refCount: 1, sessionDuration: 1}
r2 := &runnerRef{sessionDuration: 2}
s := InitScheduler(ctx)
s.loaded["a"] = r1
s.loaded["b"] = r2
resp := s.findRunnerToUnload(req)
require.Equal(t, r2, resp)
r2.refCount = 1
resp = s.findRunnerToUnload(req)
require.Equal(t, r1, resp)
}
func TestNeedsReload(t *testing.T) {
ctx, done := context.WithTimeout(context.Background(), 5*time.Millisecond)
defer done()
llm := &mockLlm{}
runner := &runnerRef{
adapters: []string{"adapter1"},
projectors: []string{"projector1"},
Options: &api.Options{},
llama: llm,
}
req := &LlmRequest{
model: &Model{
AdapterPaths: []string{"adapter2"},
ProjectorPaths: []string{"projector2"},
},
opts: api.Options{},
}
resp := runner.needsReload(ctx, req)
require.True(t, resp)
req.model.AdapterPaths = runner.adapters
resp = runner.needsReload(ctx, req)
require.True(t, resp)
req.model.ProjectorPaths = runner.projectors
runner.loading = true
req.opts.NumBatch = 1234
resp = runner.needsReload(ctx, req)
require.True(t, resp)
req.opts.NumBatch = runner.Options.NumBatch
llm.pingResp = fmt.Errorf("foo")
resp = runner.needsReload(ctx, req)
require.True(t, resp)
llm.pingResp = nil
resp = runner.needsReload(ctx, req)
require.False(t, resp)
req.opts.NumGPU = 99
resp = runner.needsReload(ctx, req)
require.False(t, resp)
}
func TestUnloadAllRunners(t *testing.T) {
ctx, done := context.WithTimeout(context.Background(), 5*time.Millisecond)
defer done()
llm1 := &mockLlm{}
llm2 := &mockLlm{}
s := InitScheduler(ctx)
s.unloadAllRunners()
r1 := &runnerRef{llama: llm1}
r2 := &runnerRef{llama: llm2}
s.loaded["a"] = r1
s.loaded["b"] = r2
s.unloadAllRunners()
require.True(t, llm1.closeCalled)
require.True(t, llm2.closeCalled)
}
func TestUnload(t *testing.T) {
llm1 := &mockLlm{}
r1 := &runnerRef{llama: llm1}
r2 := &runnerRef{adapters: []string{"A"}}
r1.unload()
require.True(t, llm1.closeCalled)
r2.unload()
require.Nil(t, r2.adapters)
}
type mockLlm struct {
pingResp error
waitResp error
completionResp error
embeddingResp []float64
embeddingRespErr error
tokenizeResp []int
tokenizeRespErr error
detokenizeResp string
detonekizeRespErr error
closeResp error
closeCalled bool
estimatedVRAM uint64
}
func (s *mockLlm) Ping(ctx context.Context) error { return s.pingResp }
func (s *mockLlm) WaitUntilRunning(ctx context.Context) error { return s.waitResp }
func (s *mockLlm) Completion(ctx context.Context, req llm.CompletionRequest, fn func(llm.CompletionResponse)) error {
return s.completionResp
}
func (s *mockLlm) Embedding(ctx context.Context, prompt string) ([]float64, error) {
return s.embeddingResp, s.embeddingRespErr
}
func (s *mockLlm) Tokenize(ctx context.Context, content string) ([]int, error) {
return s.tokenizeResp, s.tokenizeRespErr
}
func (s *mockLlm) Detokenize(ctx context.Context, tokens []int) (string, error) {
return s.detokenizeResp, s.detonekizeRespErr
}
func (s *mockLlm) Close() error {
s.closeCalled = true
return s.closeResp
}
func (s *mockLlm) EstimatedVRAM() uint64 { return s.estimatedVRAM }