Merge pull request #3418 from dhiltgen/concurrency
Request and model concurrency
This commit is contained in:
commit
5690e5ce99
30 changed files with 2615 additions and 1387 deletions
|
@ -91,6 +91,13 @@ func ClientFromEnvironment() (*Client, error) {
|
||||||
}, nil
|
}, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func NewClient(base *url.URL, http *http.Client) *Client {
|
||||||
|
return &Client{
|
||||||
|
base: base,
|
||||||
|
http: http,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
func (c *Client) do(ctx context.Context, method, path string, reqData, respData any) error {
|
func (c *Client) do(ctx context.Context, method, path string, reqData, respData any) error {
|
||||||
var reqBody io.Reader
|
var reqBody io.Reader
|
||||||
var data []byte
|
var data []byte
|
||||||
|
|
|
@ -15,6 +15,7 @@ const (
|
||||||
|
|
||||||
KibiByte = Byte * 1024
|
KibiByte = Byte * 1024
|
||||||
MebiByte = KibiByte * 1024
|
MebiByte = KibiByte * 1024
|
||||||
|
GibiByte = MebiByte * 1024
|
||||||
)
|
)
|
||||||
|
|
||||||
func HumanBytes(b int64) string {
|
func HumanBytes(b int64) string {
|
||||||
|
|
|
@ -7,7 +7,7 @@ import (
|
||||||
"log/slog"
|
"log/slog"
|
||||||
"os"
|
"os"
|
||||||
"path/filepath"
|
"path/filepath"
|
||||||
"strconv"
|
"runtime"
|
||||||
"strings"
|
"strings"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@ -35,22 +35,64 @@ func GetSupportedGFX(libDir string) ([]string, error) {
|
||||||
return ret, nil
|
return ret, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func amdSetVisibleDevices(ids []int, skip map[int]interface{}) {
|
func rocmGetVisibleDevicesEnv(gpuInfo []GpuInfo) (string, string) {
|
||||||
// Set the visible devices if not already set
|
ids := []string{}
|
||||||
// TODO - does sort order matter?
|
for _, info := range gpuInfo {
|
||||||
devices := []string{}
|
if info.Library != "rocm" {
|
||||||
for i := range ids {
|
// TODO shouldn't happen if things are wired correctly...
|
||||||
if _, skipped := skip[i]; skipped {
|
slog.Debug("rocmGetVisibleDevicesEnv skipping over non-rocm device", "library", info.Library)
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
devices = append(devices, strconv.Itoa(i))
|
ids = append(ids, info.ID)
|
||||||
|
}
|
||||||
|
return "HIP_VISIBLE_DEVICES", strings.Join(ids, ",")
|
||||||
|
}
|
||||||
|
|
||||||
|
func commonAMDValidateLibDir() (string, error) {
|
||||||
|
// We try to favor system paths first, so that we can wire up the subprocess to use
|
||||||
|
// the system version. Only use our bundled version if the system version doesn't work
|
||||||
|
// This gives users a more recovery options if versions have subtle problems at runtime
|
||||||
|
|
||||||
|
// Prefer explicit HIP env var
|
||||||
|
hipPath := os.Getenv("HIP_PATH")
|
||||||
|
if hipPath != "" {
|
||||||
|
hipLibDir := filepath.Join(hipPath, "bin")
|
||||||
|
if rocmLibUsable(hipLibDir) {
|
||||||
|
slog.Debug("detected ROCM via HIP_PATH=" + hipPath)
|
||||||
|
return hipLibDir, nil
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
val := strings.Join(devices, ",")
|
// Scan the LD_LIBRARY_PATH or PATH
|
||||||
err := os.Setenv("HIP_VISIBLE_DEVICES", val)
|
pathEnv := "LD_LIBRARY_PATH"
|
||||||
if err != nil {
|
if runtime.GOOS == "windows" {
|
||||||
slog.Warn(fmt.Sprintf("failed to set env: %s", err))
|
pathEnv = "PATH"
|
||||||
} else {
|
|
||||||
slog.Info("Setting HIP_VISIBLE_DEVICES=" + val)
|
|
||||||
}
|
}
|
||||||
|
|
||||||
|
paths := os.Getenv(pathEnv)
|
||||||
|
for _, path := range filepath.SplitList(paths) {
|
||||||
|
d, err := filepath.Abs(path)
|
||||||
|
if err != nil {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
if rocmLibUsable(d) {
|
||||||
|
return d, nil
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Well known location(s)
|
||||||
|
if rocmLibUsable(RocmStandardLocation) {
|
||||||
|
return RocmStandardLocation, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// Installer payload location if we're running the installed binary
|
||||||
|
exe, err := os.Executable()
|
||||||
|
if err == nil {
|
||||||
|
rocmTargetDir := filepath.Join(filepath.Dir(exe), "rocm")
|
||||||
|
if rocmLibUsable(rocmTargetDir) {
|
||||||
|
slog.Debug("detected ROCM next to ollama executable " + rocmTargetDir)
|
||||||
|
return rocmTargetDir, nil
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return "", fmt.Errorf("no suitable rocm found, falling back to CPU")
|
||||||
}
|
}
|
||||||
|
|
|
@ -69,7 +69,7 @@ func NewHipLib() (*HipLib, error) {
|
||||||
func (hl *HipLib) Release() {
|
func (hl *HipLib) Release() {
|
||||||
err := windows.FreeLibrary(hl.dll)
|
err := windows.FreeLibrary(hl.dll)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
slog.Warn(fmt.Sprintf("failed to unload amdhip64.dll: %s", err))
|
slog.Warn("failed to unload amdhip64.dll", "error", err)
|
||||||
}
|
}
|
||||||
hl.dll = 0
|
hl.dll = 0
|
||||||
}
|
}
|
||||||
|
@ -98,7 +98,7 @@ func (hl *HipLib) HipGetDeviceCount() int {
|
||||||
return 0
|
return 0
|
||||||
}
|
}
|
||||||
if status != hipSuccess {
|
if status != hipSuccess {
|
||||||
slog.Warn(fmt.Sprintf("failed call to hipGetDeviceCount: %d %s", status, err))
|
slog.Warn("failed call to hipGetDeviceCount", "status", status, "error", err)
|
||||||
}
|
}
|
||||||
return count
|
return count
|
||||||
}
|
}
|
||||||
|
|
483
gpu/amd_linux.go
483
gpu/amd_linux.go
|
@ -11,6 +11,8 @@ import (
|
||||||
"slices"
|
"slices"
|
||||||
"strconv"
|
"strconv"
|
||||||
"strings"
|
"strings"
|
||||||
|
|
||||||
|
"github.com/ollama/ollama/format"
|
||||||
)
|
)
|
||||||
|
|
||||||
// Discovery logic for AMD/ROCm GPUs
|
// Discovery logic for AMD/ROCm GPUs
|
||||||
|
@ -24,9 +26,6 @@ const (
|
||||||
GPUTotalMemoryFileGlob = "mem_banks/*/properties" // size_in_bytes line
|
GPUTotalMemoryFileGlob = "mem_banks/*/properties" // size_in_bytes line
|
||||||
GPUUsedMemoryFileGlob = "mem_banks/*/used_memory"
|
GPUUsedMemoryFileGlob = "mem_banks/*/used_memory"
|
||||||
RocmStandardLocation = "/opt/rocm/lib"
|
RocmStandardLocation = "/opt/rocm/lib"
|
||||||
|
|
||||||
// TODO find a better way to detect iGPU instead of minimum memory
|
|
||||||
IGPUMemLimit = 1024 * 1024 * 1024 // 512G is what they typically report, so anything less than 1G must be iGPU
|
|
||||||
)
|
)
|
||||||
|
|
||||||
var (
|
var (
|
||||||
|
@ -35,14 +34,11 @@ var (
|
||||||
)
|
)
|
||||||
|
|
||||||
// Gather GPU information from the amdgpu driver if any supported GPUs are detected
|
// Gather GPU information from the amdgpu driver if any supported GPUs are detected
|
||||||
// HIP_VISIBLE_DEVICES will be set if we detect a mix of unsupported and supported devices
|
func AMDGetGPUInfo() []GpuInfo {
|
||||||
// and the user hasn't already set this variable
|
resp := []GpuInfo{}
|
||||||
func AMDGetGPUInfo(resp *GpuInfo) {
|
|
||||||
// TODO - DRY this out with windows
|
|
||||||
if !AMDDetected() {
|
if !AMDDetected() {
|
||||||
return
|
return resp
|
||||||
}
|
}
|
||||||
skip := map[int]interface{}{}
|
|
||||||
|
|
||||||
// Opportunistic logging of driver version to aid in troubleshooting
|
// Opportunistic logging of driver version to aid in troubleshooting
|
||||||
ver, err := AMDDriverVersion()
|
ver, err := AMDDriverVersion()
|
||||||
|
@ -50,160 +46,117 @@ func AMDGetGPUInfo(resp *GpuInfo) {
|
||||||
slog.Info("AMD Driver: " + ver)
|
slog.Info("AMD Driver: " + ver)
|
||||||
} else {
|
} else {
|
||||||
// TODO - if we see users crash and burn with the upstreamed kernel this can be adjusted to hard-fail rocm support and fallback to CPU
|
// TODO - if we see users crash and burn with the upstreamed kernel this can be adjusted to hard-fail rocm support and fallback to CPU
|
||||||
slog.Warn(fmt.Sprintf("ollama recommends running the https://www.amd.com/en/support/linux-drivers: %s", err))
|
slog.Warn("ollama recommends running the https://www.amd.com/en/support/linux-drivers", "error", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
// If the user has specified exactly which GPUs to use, look up their memory
|
// Determine if the user has already pre-selected which GPUs to look at, then ignore the others
|
||||||
visibleDevices := os.Getenv("HIP_VISIBLE_DEVICES")
|
var visibleDevices []string
|
||||||
if visibleDevices != "" {
|
hipVD := os.Getenv("HIP_VISIBLE_DEVICES") // zero based index only
|
||||||
ids := []int{}
|
rocrVD := os.Getenv("ROCR_VISIBLE_DEVICES") // zero based index or UUID, but consumer cards seem to not support UUID
|
||||||
for _, idStr := range strings.Split(visibleDevices, ",") {
|
gpuDO := os.Getenv("GPU_DEVICE_ORDINAL") // zero based index
|
||||||
id, err := strconv.Atoi(idStr)
|
switch {
|
||||||
if err != nil {
|
// TODO is this priorty order right?
|
||||||
slog.Warn(fmt.Sprintf("malformed HIP_VISIBLE_DEVICES=%s %s", visibleDevices, err))
|
case hipVD != "":
|
||||||
} else {
|
visibleDevices = strings.Split(hipVD, ",")
|
||||||
ids = append(ids, id)
|
case rocrVD != "":
|
||||||
}
|
visibleDevices = strings.Split(rocrVD, ",")
|
||||||
}
|
// TODO - since we don't yet support UUIDs, consider detecting and reporting here
|
||||||
amdProcMemLookup(resp, nil, ids)
|
// all our test systems show GPU-XX indicating UUID is not supported
|
||||||
return
|
case gpuDO != "":
|
||||||
|
visibleDevices = strings.Split(gpuDO, ",")
|
||||||
}
|
}
|
||||||
|
|
||||||
// Gather GFX version information from all detected cards
|
|
||||||
gfx := AMDGFXVersions()
|
|
||||||
verStrings := []string{}
|
|
||||||
for i, v := range gfx {
|
|
||||||
verStrings = append(verStrings, v.ToGFXString())
|
|
||||||
if v.Major == 0 {
|
|
||||||
// Silently skip CPUs
|
|
||||||
skip[i] = struct{}{}
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
if v.Major < 9 {
|
|
||||||
// TODO consider this a build-time setting if we can support 8xx family GPUs
|
|
||||||
slog.Warn(fmt.Sprintf("amdgpu [%d] too old %s", i, v.ToGFXString()))
|
|
||||||
skip[i] = struct{}{}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
slog.Info(fmt.Sprintf("detected amdgpu versions %v", verStrings))
|
|
||||||
|
|
||||||
// Abort if all GPUs are skipped
|
|
||||||
if len(skip) >= len(gfx) {
|
|
||||||
slog.Info("all detected amdgpus are skipped, falling back to CPU")
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
// If we got this far, then we have at least 1 GPU that's a ROCm candidate, so make sure we have a lib
|
|
||||||
libDir, err := AMDValidateLibDir()
|
|
||||||
if err != nil {
|
|
||||||
slog.Warn(fmt.Sprintf("unable to verify rocm library, will use cpu: %s", err))
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
updateLibPath(libDir)
|
|
||||||
|
|
||||||
gfxOverride := os.Getenv("HSA_OVERRIDE_GFX_VERSION")
|
gfxOverride := os.Getenv("HSA_OVERRIDE_GFX_VERSION")
|
||||||
if gfxOverride == "" {
|
var supported []string
|
||||||
supported, err := GetSupportedGFX(libDir)
|
libDir := ""
|
||||||
|
|
||||||
|
// The amdgpu driver always exposes the host CPU(s) first, but we have to skip them and subtract
|
||||||
|
// from the other IDs to get alignment with the HIP libraries expectations (zero is the first GPU, not the CPU)
|
||||||
|
matches, _ := filepath.Glob(GPUPropertiesFileGlob)
|
||||||
|
cpuCount := 0
|
||||||
|
for _, match := range matches {
|
||||||
|
slog.Debug("evaluating amdgpu node " + match)
|
||||||
|
fp, err := os.Open(match)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
slog.Warn(fmt.Sprintf("failed to lookup supported GFX types, falling back to CPU mode: %s", err))
|
slog.Debug("failed to open sysfs node", "file", match, "error", err)
|
||||||
return
|
|
||||||
}
|
|
||||||
slog.Debug(fmt.Sprintf("rocm supported GPU types %v", supported))
|
|
||||||
|
|
||||||
for i, v := range gfx {
|
|
||||||
if !slices.Contains[[]string, string](supported, v.ToGFXString()) {
|
|
||||||
slog.Warn(fmt.Sprintf("amdgpu [%d] %s is not supported by %s %v", i, v.ToGFXString(), libDir, supported))
|
|
||||||
// TODO - consider discrete markdown just for ROCM troubleshooting?
|
|
||||||
slog.Warn("See https://github.com/ollama/ollama/blob/main/docs/gpu.md#overrides for HSA_OVERRIDE_GFX_VERSION usage")
|
|
||||||
skip[i] = struct{}{}
|
|
||||||
} else {
|
|
||||||
slog.Info(fmt.Sprintf("amdgpu [%d] %s is supported", i, v.ToGFXString()))
|
|
||||||
}
|
|
||||||
}
|
|
||||||
} else {
|
|
||||||
slog.Debug("skipping rocm gfx compatibility check with HSA_OVERRIDE_GFX_VERSION=" + gfxOverride)
|
|
||||||
}
|
|
||||||
|
|
||||||
if len(skip) >= len(gfx) {
|
|
||||||
slog.Info("all detected amdgpus are skipped, falling back to CPU")
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
ids := make([]int, len(gfx))
|
|
||||||
i := 0
|
|
||||||
for k := range gfx {
|
|
||||||
ids[i] = k
|
|
||||||
i++
|
|
||||||
}
|
|
||||||
amdProcMemLookup(resp, skip, ids)
|
|
||||||
if resp.memInfo.DeviceCount == 0 {
|
|
||||||
return
|
|
||||||
}
|
|
||||||
if len(skip) > 0 {
|
|
||||||
amdSetVisibleDevices(ids, skip)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func updateLibPath(libDir string) {
|
|
||||||
ldPaths := []string{}
|
|
||||||
if val, ok := os.LookupEnv("LD_LIBRARY_PATH"); ok {
|
|
||||||
ldPaths = strings.Split(val, ":")
|
|
||||||
}
|
|
||||||
for _, d := range ldPaths {
|
|
||||||
if d == libDir {
|
|
||||||
return
|
|
||||||
}
|
|
||||||
}
|
|
||||||
val := strings.Join(append(ldPaths, libDir), ":")
|
|
||||||
slog.Debug("updated lib path", "LD_LIBRARY_PATH", val)
|
|
||||||
os.Setenv("LD_LIBRARY_PATH", val)
|
|
||||||
}
|
|
||||||
|
|
||||||
// Walk the sysfs nodes for the available GPUs and gather information from them
|
|
||||||
// skipping over any devices in the skip map
|
|
||||||
func amdProcMemLookup(resp *GpuInfo, skip map[int]interface{}, ids []int) {
|
|
||||||
resp.memInfo.DeviceCount = 0
|
|
||||||
resp.memInfo.TotalMemory = 0
|
|
||||||
resp.memInfo.FreeMemory = 0
|
|
||||||
slog.Debug("discovering VRAM for amdgpu devices")
|
|
||||||
if len(ids) == 0 {
|
|
||||||
entries, err := os.ReadDir(AMDNodesSysfsDir)
|
|
||||||
if err != nil {
|
|
||||||
slog.Warn(fmt.Sprintf("failed to read amdgpu sysfs %s - %s", AMDNodesSysfsDir, err))
|
|
||||||
return
|
|
||||||
}
|
|
||||||
for _, node := range entries {
|
|
||||||
if !node.IsDir() {
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
id, err := strconv.Atoi(node.Name())
|
|
||||||
if err != nil {
|
|
||||||
slog.Warn("malformed amdgpu sysfs node id " + node.Name())
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
ids = append(ids, id)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
slog.Debug(fmt.Sprintf("amdgpu devices %v", ids))
|
|
||||||
|
|
||||||
for _, id := range ids {
|
|
||||||
if _, skipped := skip[id]; skipped {
|
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
|
defer fp.Close()
|
||||||
|
nodeID, err := strconv.Atoi(filepath.Base(filepath.Dir(match)))
|
||||||
|
if err != nil {
|
||||||
|
slog.Debug("failed to parse node ID", "error", err)
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
scanner := bufio.NewScanner(fp)
|
||||||
|
isCPU := false
|
||||||
|
var major, minor, patch uint64
|
||||||
|
for scanner.Scan() {
|
||||||
|
line := strings.TrimSpace(scanner.Text())
|
||||||
|
// Note: we could also use "cpu_cores_count X" where X is greater than zero to detect CPUs
|
||||||
|
if strings.HasPrefix(line, "gfx_target_version") {
|
||||||
|
ver := strings.Fields(line)
|
||||||
|
|
||||||
|
// Detect CPUs
|
||||||
|
if len(ver) == 2 && ver[1] == "0" {
|
||||||
|
slog.Debug("detected CPU " + match)
|
||||||
|
isCPU = true
|
||||||
|
break
|
||||||
|
}
|
||||||
|
|
||||||
|
if len(ver) != 2 || len(ver[1]) < 5 {
|
||||||
|
slog.Warn("malformed "+match, "gfx_target_version", line)
|
||||||
|
// If this winds up being a CPU, our offsets may be wrong
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
l := len(ver[1])
|
||||||
|
var err1, err2, err3 error
|
||||||
|
patch, err1 = strconv.ParseUint(ver[1][l-2:l], 10, 32)
|
||||||
|
minor, err2 = strconv.ParseUint(ver[1][l-4:l-2], 10, 32)
|
||||||
|
major, err3 = strconv.ParseUint(ver[1][:l-4], 10, 32)
|
||||||
|
if err1 != nil || err2 != nil || err3 != nil {
|
||||||
|
slog.Debug("malformed int " + line)
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// TODO - any other properties we want to extract and record?
|
||||||
|
// vendor_id + device_id -> pci lookup for "Name"
|
||||||
|
// Other metrics that may help us understand relative performance between multiple GPUs
|
||||||
|
}
|
||||||
|
|
||||||
|
if isCPU {
|
||||||
|
cpuCount++
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
// CPUs are always first in the list
|
||||||
|
gpuID := nodeID - cpuCount
|
||||||
|
|
||||||
|
// Shouldn't happen, but just in case...
|
||||||
|
if gpuID < 0 {
|
||||||
|
slog.Error("unexpected amdgpu sysfs data resulted in negative GPU ID, please set OLLAMA_DEBUG=1 and report an issue")
|
||||||
|
return []GpuInfo{}
|
||||||
|
}
|
||||||
|
|
||||||
|
if int(major) < RocmComputeMin {
|
||||||
|
slog.Warn(fmt.Sprintf("amdgpu too old gfx%d%d%d", major, minor, patch), "gpu", gpuID)
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
// Look up the memory for the current node
|
||||||
totalMemory := uint64(0)
|
totalMemory := uint64(0)
|
||||||
usedMemory := uint64(0)
|
usedMemory := uint64(0)
|
||||||
// Adjust for sysfs vs HIP ids
|
propGlob := filepath.Join(AMDNodesSysfsDir, strconv.Itoa(nodeID), GPUTotalMemoryFileGlob)
|
||||||
propGlob := filepath.Join(AMDNodesSysfsDir, strconv.Itoa(id+1), GPUTotalMemoryFileGlob)
|
|
||||||
propFiles, err := filepath.Glob(propGlob)
|
propFiles, err := filepath.Glob(propGlob)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
slog.Warn(fmt.Sprintf("error looking up total GPU memory: %s %s", propGlob, err))
|
slog.Warn("error looking up total GPU memory", "glob", propGlob, "error", err)
|
||||||
}
|
}
|
||||||
// 1 or more memory banks - sum the values of all of them
|
// 1 or more memory banks - sum the values of all of them
|
||||||
for _, propFile := range propFiles {
|
for _, propFile := range propFiles {
|
||||||
fp, err := os.Open(propFile)
|
fp, err := os.Open(propFile)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
slog.Warn(fmt.Sprintf("failed to open sysfs node file %s: %s", propFile, err))
|
slog.Warn("failed to open sysfs node", "file", propFile, "erroir", err)
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
defer fp.Close()
|
defer fp.Close()
|
||||||
|
@ -226,49 +179,113 @@ func amdProcMemLookup(resp *GpuInfo, skip map[int]interface{}, ids []int) {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
if totalMemory == 0 {
|
if totalMemory == 0 {
|
||||||
slog.Warn(fmt.Sprintf("amdgpu [%d] reports zero total memory, skipping", id))
|
slog.Warn("amdgpu reports zero total memory", "gpu", gpuID)
|
||||||
skip[id] = struct{}{}
|
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
if totalMemory < IGPUMemLimit {
|
usedGlob := filepath.Join(AMDNodesSysfsDir, strconv.Itoa(nodeID), GPUUsedMemoryFileGlob)
|
||||||
slog.Info(fmt.Sprintf("amdgpu [%d] appears to be an iGPU with %dM reported total memory, skipping", id, totalMemory/1024/1024))
|
|
||||||
skip[id] = struct{}{}
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
usedGlob := filepath.Join(AMDNodesSysfsDir, strconv.Itoa(id), GPUUsedMemoryFileGlob)
|
|
||||||
usedFiles, err := filepath.Glob(usedGlob)
|
usedFiles, err := filepath.Glob(usedGlob)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
slog.Warn(fmt.Sprintf("error looking up used GPU memory: %s %s", usedGlob, err))
|
slog.Warn("error looking up used GPU memory", "glob", usedGlob, "error", err)
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
for _, usedFile := range usedFiles {
|
for _, usedFile := range usedFiles {
|
||||||
fp, err := os.Open(usedFile)
|
fp, err := os.Open(usedFile)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
slog.Warn(fmt.Sprintf("failed to open sysfs node file %s: %s", usedFile, err))
|
slog.Warn("failed to open sysfs node", "file", usedFile, "error", err)
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
defer fp.Close()
|
defer fp.Close()
|
||||||
data, err := io.ReadAll(fp)
|
data, err := io.ReadAll(fp)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
slog.Warn(fmt.Sprintf("failed to read sysfs node file %s: %s", usedFile, err))
|
slog.Warn("failed to read sysfs node", "file", usedFile, "error", err)
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
used, err := strconv.ParseUint(strings.TrimSpace(string(data)), 10, 64)
|
used, err := strconv.ParseUint(strings.TrimSpace(string(data)), 10, 64)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
slog.Warn(fmt.Sprintf("malformed used memory %s: %s", string(data), err))
|
slog.Warn("malformed used memory", "data", string(data), "error", err)
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
usedMemory += used
|
usedMemory += used
|
||||||
}
|
}
|
||||||
slog.Info(fmt.Sprintf("[%d] amdgpu totalMemory %dM", id, totalMemory/1024/1024))
|
|
||||||
slog.Info(fmt.Sprintf("[%d] amdgpu freeMemory %dM", id, (totalMemory-usedMemory)/1024/1024))
|
// iGPU detection, remove this check once we can support an iGPU variant of the rocm library
|
||||||
resp.memInfo.DeviceCount++
|
if totalMemory < IGPUMemLimit {
|
||||||
resp.memInfo.TotalMemory += totalMemory
|
slog.Info("amdgpu appears to be an iGPU, skipping", "gpu", gpuID, "total", format.HumanBytes2(totalMemory))
|
||||||
resp.memInfo.FreeMemory += (totalMemory - usedMemory)
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
slog.Info("amdgpu memory", "gpu", gpuID, "total", format.HumanBytes2(totalMemory))
|
||||||
|
slog.Info("amdgpu memory", "gpu", gpuID, "available", format.HumanBytes2(totalMemory-usedMemory))
|
||||||
|
gpuInfo := GpuInfo{
|
||||||
|
Library: "rocm",
|
||||||
|
memInfo: memInfo{
|
||||||
|
TotalMemory: totalMemory,
|
||||||
|
FreeMemory: (totalMemory - usedMemory),
|
||||||
|
},
|
||||||
|
ID: fmt.Sprintf("%d", gpuID),
|
||||||
|
// Name: not exposed in sysfs directly, would require pci device id lookup
|
||||||
|
Major: int(major),
|
||||||
|
Minor: int(minor),
|
||||||
|
Patch: int(patch),
|
||||||
|
MinimumMemory: rocmMinimumMemory,
|
||||||
|
}
|
||||||
|
|
||||||
|
// If the user wants to filter to a subset of devices, filter out if we aren't a match
|
||||||
|
if len(visibleDevices) > 0 {
|
||||||
|
include := false
|
||||||
|
for _, visible := range visibleDevices {
|
||||||
|
if visible == gpuInfo.ID {
|
||||||
|
include = true
|
||||||
|
break
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if !include {
|
||||||
|
slog.Info("filtering out device per user request", "id", gpuInfo.ID, "visible_devices", visibleDevices)
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Final validation is gfx compatibility - load the library if we haven't already loaded it
|
||||||
|
// even if the user overrides, we still need to validate the library
|
||||||
|
if libDir == "" {
|
||||||
|
libDir, err = AMDValidateLibDir()
|
||||||
|
if err != nil {
|
||||||
|
slog.Warn("unable to verify rocm library, will use cpu", "error", err)
|
||||||
|
return []GpuInfo{}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
gpuInfo.DependencyPath = libDir
|
||||||
|
|
||||||
|
if gfxOverride == "" {
|
||||||
|
// Only load supported list once
|
||||||
|
if len(supported) == 0 {
|
||||||
|
supported, err = GetSupportedGFX(libDir)
|
||||||
|
if err != nil {
|
||||||
|
slog.Warn("failed to lookup supported GFX types, falling back to CPU mode", "error", err)
|
||||||
|
return []GpuInfo{}
|
||||||
|
}
|
||||||
|
slog.Debug("rocm supported GPUs", "types", supported)
|
||||||
|
}
|
||||||
|
gfx := fmt.Sprintf("gfx%d%d%d", gpuInfo.Major, gpuInfo.Minor, gpuInfo.Patch)
|
||||||
|
if !slices.Contains[[]string, string](supported, gfx) {
|
||||||
|
slog.Warn("amdgpu is not supported", "gpu", gpuInfo.ID, "gpu_type", gfx, "library", libDir, "supported_types", supported)
|
||||||
|
// TODO - consider discrete markdown just for ROCM troubleshooting?
|
||||||
|
slog.Warn("See https://github.com/ollama/ollama/blob/main/docs/gpu.md#overrides for HSA_OVERRIDE_GFX_VERSION usage")
|
||||||
|
continue
|
||||||
|
} else {
|
||||||
|
slog.Info("amdgpu is supported", "gpu", gpuInfo.ID, "gpu_type", gfx)
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
slog.Debug("skipping rocm gfx compatibility check with HSA_OVERRIDE_GFX_VERSION=" + gfxOverride)
|
||||||
|
}
|
||||||
|
|
||||||
|
// The GPU has passed all the verification steps and is supported
|
||||||
|
resp = append(resp, gpuInfo)
|
||||||
}
|
}
|
||||||
if resp.memInfo.DeviceCount > 0 {
|
if len(resp) == 0 {
|
||||||
resp.Library = "rocm"
|
slog.Info("no compatible amdgpu devices detected")
|
||||||
}
|
}
|
||||||
|
return resp
|
||||||
}
|
}
|
||||||
|
|
||||||
// Quick check for AMD driver so we can skip amdgpu discovery if not present
|
// Quick check for AMD driver so we can skip amdgpu discovery if not present
|
||||||
|
@ -280,87 +297,24 @@ func AMDDetected() bool {
|
||||||
slog.Debug("amdgpu driver not detected " + sysfsDir)
|
slog.Debug("amdgpu driver not detected " + sysfsDir)
|
||||||
return false
|
return false
|
||||||
} else if err != nil {
|
} else if err != nil {
|
||||||
slog.Debug(fmt.Sprintf("error looking up amd driver %s %s", sysfsDir, err))
|
slog.Debug("error looking up amd driver", "path", sysfsDir, "error", err)
|
||||||
return false
|
return false
|
||||||
}
|
}
|
||||||
return true
|
return true
|
||||||
}
|
}
|
||||||
|
|
||||||
func setupLink(source, target string) error {
|
|
||||||
if err := os.RemoveAll(target); err != nil {
|
|
||||||
return fmt.Errorf("failed to remove old rocm directory %s %w", target, err)
|
|
||||||
}
|
|
||||||
if err := os.Symlink(source, target); err != nil {
|
|
||||||
return fmt.Errorf("failed to create link %s => %s %w", source, target, err)
|
|
||||||
}
|
|
||||||
slog.Debug(fmt.Sprintf("host rocm linked %s => %s", source, target))
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// Ensure the AMD rocm lib dir is wired up
|
|
||||||
// Prefer to use host installed ROCm, as long as it meets our minimum requirements
|
// Prefer to use host installed ROCm, as long as it meets our minimum requirements
|
||||||
// failing that, tell the user how to download it on their own
|
// failing that, tell the user how to download it on their own
|
||||||
func AMDValidateLibDir() (string, error) {
|
func AMDValidateLibDir() (string, error) {
|
||||||
// We rely on the rpath compiled into our library to find rocm
|
libDir, err := commonAMDValidateLibDir()
|
||||||
// so we establish a symlink to wherever we find it on the system
|
|
||||||
// to <payloads>/rocm
|
|
||||||
payloadsDir, err := PayloadsDir()
|
|
||||||
if err != nil {
|
|
||||||
return "", err
|
|
||||||
}
|
|
||||||
|
|
||||||
// If we already have a rocm dependency wired, nothing more to do
|
|
||||||
rocmTargetDir := filepath.Clean(filepath.Join(payloadsDir, "..", "rocm"))
|
|
||||||
if rocmLibUsable(rocmTargetDir) {
|
|
||||||
return rocmTargetDir, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// next to the running binary
|
|
||||||
exe, err := os.Executable()
|
|
||||||
if err == nil {
|
if err == nil {
|
||||||
peerDir := filepath.Dir(exe)
|
return libDir, nil
|
||||||
if rocmLibUsable(peerDir) {
|
|
||||||
slog.Debug("detected ROCM next to ollama executable " + peerDir)
|
|
||||||
return rocmTargetDir, setupLink(peerDir, rocmTargetDir)
|
|
||||||
}
|
|
||||||
peerDir = filepath.Join(filepath.Dir(exe), "rocm")
|
|
||||||
if rocmLibUsable(peerDir) {
|
|
||||||
slog.Debug("detected ROCM next to ollama executable " + peerDir)
|
|
||||||
return rocmTargetDir, setupLink(peerDir, rocmTargetDir)
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// Well known ollama installer path
|
// Well known ollama installer path
|
||||||
installedRocmDir := "/usr/share/ollama/lib/rocm"
|
installedRocmDir := "/usr/share/ollama/lib/rocm"
|
||||||
if rocmLibUsable(installedRocmDir) {
|
if rocmLibUsable(installedRocmDir) {
|
||||||
return rocmTargetDir, setupLink(installedRocmDir, rocmTargetDir)
|
return installedRocmDir, nil
|
||||||
}
|
|
||||||
|
|
||||||
// Prefer explicit HIP env var
|
|
||||||
hipPath := os.Getenv("HIP_PATH")
|
|
||||||
if hipPath != "" {
|
|
||||||
hipLibDir := filepath.Join(hipPath, "lib")
|
|
||||||
if rocmLibUsable(hipLibDir) {
|
|
||||||
slog.Debug("detected ROCM via HIP_PATH=" + hipPath)
|
|
||||||
return rocmTargetDir, setupLink(hipLibDir, rocmTargetDir)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// Scan the library path for potential matches
|
|
||||||
ldPaths := strings.Split(os.Getenv("LD_LIBRARY_PATH"), ":")
|
|
||||||
for _, ldPath := range ldPaths {
|
|
||||||
d, err := filepath.Abs(ldPath)
|
|
||||||
if err != nil {
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
if rocmLibUsable(d) {
|
|
||||||
return rocmTargetDir, setupLink(d, rocmTargetDir)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// Well known location(s)
|
|
||||||
if rocmLibUsable("/opt/rocm/lib") {
|
|
||||||
return rocmTargetDir, setupLink("/opt/rocm/lib", rocmTargetDir)
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// If we still haven't found a usable rocm, the user will have to install it on their own
|
// If we still haven't found a usable rocm, the user will have to install it on their own
|
||||||
|
@ -384,68 +338,3 @@ func AMDDriverVersion() (string, error) {
|
||||||
}
|
}
|
||||||
return strings.TrimSpace(string(verString)), nil
|
return strings.TrimSpace(string(verString)), nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func AMDGFXVersions() map[int]Version {
|
|
||||||
// The amdgpu driver always exposes the host CPU as node 0, but we have to skip that and subtract one
|
|
||||||
// from the other IDs to get alignment with the HIP libraries expectations (zero is the first GPU, not the CPU)
|
|
||||||
res := map[int]Version{}
|
|
||||||
matches, _ := filepath.Glob(GPUPropertiesFileGlob)
|
|
||||||
for _, match := range matches {
|
|
||||||
fp, err := os.Open(match)
|
|
||||||
if err != nil {
|
|
||||||
slog.Debug(fmt.Sprintf("failed to open sysfs node file %s: %s", match, err))
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
defer fp.Close()
|
|
||||||
i, err := strconv.Atoi(filepath.Base(filepath.Dir(match)))
|
|
||||||
if err != nil {
|
|
||||||
slog.Debug(fmt.Sprintf("failed to parse node ID %s", err))
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
|
|
||||||
if i == 0 {
|
|
||||||
// Skipping the CPU
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
// Align with HIP IDs (zero is first GPU, not CPU)
|
|
||||||
i -= 1
|
|
||||||
|
|
||||||
scanner := bufio.NewScanner(fp)
|
|
||||||
for scanner.Scan() {
|
|
||||||
line := strings.TrimSpace(scanner.Text())
|
|
||||||
if strings.HasPrefix(line, "gfx_target_version") {
|
|
||||||
ver := strings.Fields(line)
|
|
||||||
if len(ver) != 2 || len(ver[1]) < 5 {
|
|
||||||
if ver[1] != "0" {
|
|
||||||
slog.Debug("malformed " + line)
|
|
||||||
}
|
|
||||||
res[i] = Version{
|
|
||||||
Major: 0,
|
|
||||||
Minor: 0,
|
|
||||||
Patch: 0,
|
|
||||||
}
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
l := len(ver[1])
|
|
||||||
patch, err1 := strconv.ParseUint(ver[1][l-2:l], 10, 32)
|
|
||||||
minor, err2 := strconv.ParseUint(ver[1][l-4:l-2], 10, 32)
|
|
||||||
major, err3 := strconv.ParseUint(ver[1][:l-4], 10, 32)
|
|
||||||
if err1 != nil || err2 != nil || err3 != nil {
|
|
||||||
slog.Debug("malformed int " + line)
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
|
|
||||||
res[i] = Version{
|
|
||||||
Major: uint(major),
|
|
||||||
Minor: uint(minor),
|
|
||||||
Patch: uint(patch),
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
return res
|
|
||||||
}
|
|
||||||
|
|
||||||
func (v Version) ToGFXString() string {
|
|
||||||
return fmt.Sprintf("gfx%d%d%d", v.Major, v.Minor, v.Patch)
|
|
||||||
}
|
|
||||||
|
|
|
@ -7,7 +7,10 @@ import (
|
||||||
"os"
|
"os"
|
||||||
"path/filepath"
|
"path/filepath"
|
||||||
"slices"
|
"slices"
|
||||||
|
"strconv"
|
||||||
"strings"
|
"strings"
|
||||||
|
|
||||||
|
"github.com/ollama/ollama/format"
|
||||||
)
|
)
|
||||||
|
|
||||||
const (
|
const (
|
||||||
|
@ -22,36 +25,32 @@ var (
|
||||||
ROCmLibGlobs = []string{"hipblas.dll", "rocblas"} // TODO - probably include more coverage of files here...
|
ROCmLibGlobs = []string{"hipblas.dll", "rocblas"} // TODO - probably include more coverage of files here...
|
||||||
)
|
)
|
||||||
|
|
||||||
func AMDGetGPUInfo(resp *GpuInfo) {
|
func AMDGetGPUInfo() []GpuInfo {
|
||||||
|
resp := []GpuInfo{}
|
||||||
hl, err := NewHipLib()
|
hl, err := NewHipLib()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
slog.Debug(err.Error())
|
slog.Debug(err.Error())
|
||||||
return
|
return nil
|
||||||
}
|
}
|
||||||
defer hl.Release()
|
defer hl.Release()
|
||||||
skip := map[int]interface{}{}
|
|
||||||
ids := []int{}
|
|
||||||
resp.memInfo.DeviceCount = 0
|
|
||||||
resp.memInfo.TotalMemory = 0
|
|
||||||
resp.memInfo.FreeMemory = 0
|
|
||||||
|
|
||||||
ver, err := hl.AMDDriverVersion()
|
ver, err := hl.AMDDriverVersion()
|
||||||
if err == nil {
|
if err == nil {
|
||||||
slog.Info("AMD Driver: " + ver)
|
slog.Info("AMD Driver: " + ver)
|
||||||
} else {
|
} else {
|
||||||
// For now this is benign, but we may eventually need to fail compatibility checks
|
// For now this is benign, but we may eventually need to fail compatibility checks
|
||||||
slog.Debug(fmt.Sprintf("error looking up amd driver version: %s", err))
|
slog.Debug("error looking up amd driver version", "error", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
// Note: the HIP library automatically handles HIP_VISIBLE_DEVICES
|
// Note: the HIP library automatically handles subsetting to any HIP_VISIBLE_DEVICES the user specified
|
||||||
count := hl.HipGetDeviceCount()
|
count := hl.HipGetDeviceCount()
|
||||||
if count == 0 {
|
if count == 0 {
|
||||||
return
|
return nil
|
||||||
}
|
}
|
||||||
libDir, err := AMDValidateLibDir()
|
libDir, err := AMDValidateLibDir()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
slog.Warn(fmt.Sprintf("unable to verify rocm library, will use cpu: %s", err))
|
slog.Warn("unable to verify rocm library, will use cpu", "error", err)
|
||||||
return
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
var supported []string
|
var supported []string
|
||||||
|
@ -59,95 +58,117 @@ func AMDGetGPUInfo(resp *GpuInfo) {
|
||||||
if gfxOverride == "" {
|
if gfxOverride == "" {
|
||||||
supported, err = GetSupportedGFX(libDir)
|
supported, err = GetSupportedGFX(libDir)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
slog.Warn(fmt.Sprintf("failed to lookup supported GFX types, falling back to CPU mode: %s", err))
|
slog.Warn("failed to lookup supported GFX types, falling back to CPU mode", "error", err)
|
||||||
return
|
return nil
|
||||||
}
|
}
|
||||||
} else {
|
} else {
|
||||||
slog.Debug("skipping rocm gfx compatibility check with HSA_OVERRIDE_GFX_VERSION=" + gfxOverride)
|
slog.Debug("skipping rocm gfx compatibility check with HSA_OVERRIDE_GFX_VERSION=" + gfxOverride)
|
||||||
}
|
}
|
||||||
|
|
||||||
slog.Info(fmt.Sprintf("detected %d hip devices", count))
|
slog.Info("detected hip devices", "count", count)
|
||||||
|
// TODO how to determine the underlying device ID when visible devices is causing this to subset?
|
||||||
for i := 0; i < count; i++ {
|
for i := 0; i < count; i++ {
|
||||||
ids = append(ids, i)
|
|
||||||
err = hl.HipSetDevice(i)
|
err = hl.HipSetDevice(i)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
slog.Warn(fmt.Sprintf("[%d] %s", i, err))
|
slog.Warn("set device", "id", i, "error", err)
|
||||||
skip[i] = struct{}{}
|
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
|
|
||||||
props, err := hl.HipGetDeviceProperties(i)
|
props, err := hl.HipGetDeviceProperties(i)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
slog.Warn(fmt.Sprintf("[%d] %s", i, err))
|
slog.Warn("get properties", "id", i, "error", err)
|
||||||
skip[i] = struct{}{}
|
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
n := bytes.IndexByte(props.Name[:], 0)
|
n := bytes.IndexByte(props.Name[:], 0)
|
||||||
name := string(props.Name[:n])
|
name := string(props.Name[:n])
|
||||||
slog.Info(fmt.Sprintf("[%d] Name: %s", i, name))
|
// TODO is UUID actually populated on windows?
|
||||||
|
// Can luid be used on windows for setting visible devices (and is it actually set?)
|
||||||
n = bytes.IndexByte(props.GcnArchName[:], 0)
|
n = bytes.IndexByte(props.GcnArchName[:], 0)
|
||||||
gfx := string(props.GcnArchName[:n])
|
gfx := string(props.GcnArchName[:n])
|
||||||
slog.Info(fmt.Sprintf("[%d] GcnArchName: %s", i, gfx))
|
slog.Info("hip device", "id", i, "name", name, "gfx", gfx)
|
||||||
|
var major, minor, patch string
|
||||||
|
switch len(gfx) {
|
||||||
|
case 6:
|
||||||
|
major, minor, patch = gfx[3:4], gfx[4:5], gfx[5:]
|
||||||
|
case 7:
|
||||||
|
major, minor, patch = gfx[3:5], gfx[5:6], gfx[6:]
|
||||||
|
}
|
||||||
//slog.Info(fmt.Sprintf("[%d] Integrated: %d", i, props.iGPU)) // DOESN'T REPORT CORRECTLY! Always 0
|
//slog.Info(fmt.Sprintf("[%d] Integrated: %d", i, props.iGPU)) // DOESN'T REPORT CORRECTLY! Always 0
|
||||||
// TODO Why isn't props.iGPU accurate!?
|
// TODO Why isn't props.iGPU accurate!?
|
||||||
if strings.EqualFold(name, iGPUName) {
|
if strings.EqualFold(name, iGPUName) {
|
||||||
slog.Info(fmt.Sprintf("iGPU detected [%d] skipping", i))
|
slog.Info("iGPU detected skipping", "id", i)
|
||||||
skip[i] = struct{}{}
|
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
if gfxOverride == "" {
|
if gfxOverride == "" {
|
||||||
if !slices.Contains[[]string, string](supported, gfx) {
|
if !slices.Contains[[]string, string](supported, gfx) {
|
||||||
slog.Warn(fmt.Sprintf("amdgpu [%d] %s is not supported by %s %v", i, gfx, libDir, supported))
|
slog.Warn("amdgpu is not supported", "gpu", i, "gpu_type", gfx, "library", libDir, "supported_types", supported)
|
||||||
// TODO - consider discrete markdown just for ROCM troubleshooting?
|
// TODO - consider discrete markdown just for ROCM troubleshooting?
|
||||||
slog.Warn("See https://github.com/ollama/ollama/blob/main/docs/troubleshooting.md for HSA_OVERRIDE_GFX_VERSION usage")
|
slog.Warn("See https://github.com/ollama/ollama/blob/main/docs/troubleshooting.md for HSA_OVERRIDE_GFX_VERSION usage")
|
||||||
skip[i] = struct{}{}
|
|
||||||
continue
|
continue
|
||||||
} else {
|
} else {
|
||||||
slog.Info(fmt.Sprintf("amdgpu [%d] %s is supported", i, gfx))
|
slog.Info("amdgpu is supported", "gpu", i, "gpu_type", gfx)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
totalMemory, freeMemory, err := hl.HipMemGetInfo()
|
freeMemory, totalMemory, err := hl.HipMemGetInfo()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
slog.Warn(fmt.Sprintf("[%d] %s", i, err))
|
slog.Warn("get mem info", "id", i, "error", err)
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
|
|
||||||
// TODO according to docs, freeMem may lie on windows!
|
// iGPU detection, remove this check once we can support an iGPU variant of the rocm library
|
||||||
slog.Info(fmt.Sprintf("[%d] Total Mem: %d", i, totalMemory))
|
if totalMemory < IGPUMemLimit {
|
||||||
slog.Info(fmt.Sprintf("[%d] Free Mem: %d", i, freeMemory))
|
slog.Info("amdgpu appears to be an iGPU, skipping", "gpu", i, "total", format.HumanBytes2(totalMemory))
|
||||||
resp.memInfo.DeviceCount++
|
continue
|
||||||
resp.memInfo.TotalMemory += totalMemory
|
}
|
||||||
resp.memInfo.FreeMemory += freeMemory
|
|
||||||
|
// TODO revisit this once ROCm v6 is available on windows.
|
||||||
|
// v5.7 only reports VRAM used by this process, so it's completely wrong and unusable
|
||||||
|
slog.Info("amdgpu memory", "gpu", i, "total", format.HumanBytes2(totalMemory))
|
||||||
|
slog.Info("amdgpu memory", "gpu", i, "available", format.HumanBytes2(freeMemory))
|
||||||
|
gpuInfo := GpuInfo{
|
||||||
|
Library: "rocm",
|
||||||
|
memInfo: memInfo{
|
||||||
|
TotalMemory: totalMemory,
|
||||||
|
FreeMemory: freeMemory,
|
||||||
|
},
|
||||||
|
ID: fmt.Sprintf("%d", i), // TODO this is probably wrong if we specify visible devices
|
||||||
|
DependencyPath: libDir,
|
||||||
|
MinimumMemory: rocmMinimumMemory,
|
||||||
|
}
|
||||||
|
if major != "" {
|
||||||
|
gpuInfo.Major, err = strconv.Atoi(major)
|
||||||
|
if err != nil {
|
||||||
|
slog.Info("failed to parse version", "version", gfx, "error", err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if minor != "" {
|
||||||
|
gpuInfo.Minor, err = strconv.Atoi(minor)
|
||||||
|
if err != nil {
|
||||||
|
slog.Info("failed to parse version", "version", gfx, "error", err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if patch != "" {
|
||||||
|
gpuInfo.Patch, err = strconv.Atoi(patch)
|
||||||
|
if err != nil {
|
||||||
|
slog.Info("failed to parse version", "version", gfx, "error", err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if gpuInfo.Major < RocmComputeMin {
|
||||||
|
slog.Warn(fmt.Sprintf("amdgpu [%s] too old gfx%d%d%d", gpuInfo.ID, gpuInfo.Major, gpuInfo.Minor, gpuInfo.Patch))
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
resp = append(resp, gpuInfo)
|
||||||
}
|
}
|
||||||
if resp.memInfo.DeviceCount > 0 {
|
|
||||||
resp.Library = "rocm"
|
return resp
|
||||||
}
|
|
||||||
// Abort if all GPUs are skipped
|
|
||||||
if len(skip) >= count {
|
|
||||||
slog.Info("all detected amdgpus are skipped, falling back to CPU")
|
|
||||||
return
|
|
||||||
}
|
|
||||||
if len(skip) > 0 {
|
|
||||||
amdSetVisibleDevices(ids, skip)
|
|
||||||
}
|
|
||||||
UpdatePath(libDir)
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func AMDValidateLibDir() (string, error) {
|
func AMDValidateLibDir() (string, error) {
|
||||||
// On windows non-admins typically can't create links
|
libDir, err := commonAMDValidateLibDir()
|
||||||
// so instead of trying to rely on rpath and a link in
|
|
||||||
// $LibDir/rocm, we instead rely on setting PATH to point
|
|
||||||
// to the location of the ROCm library
|
|
||||||
|
|
||||||
// Installer payload location if we're running the installed binary
|
|
||||||
exe, err := os.Executable()
|
|
||||||
if err == nil {
|
if err == nil {
|
||||||
rocmTargetDir := filepath.Join(filepath.Dir(exe), "rocm")
|
return libDir, nil
|
||||||
if rocmLibUsable(rocmTargetDir) {
|
|
||||||
slog.Debug("detected ROCM next to ollama executable " + rocmTargetDir)
|
|
||||||
return rocmTargetDir, nil
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// Installer payload (if we're running from some other location)
|
// Installer payload (if we're running from some other location)
|
||||||
|
@ -159,21 +180,6 @@ func AMDValidateLibDir() (string, error) {
|
||||||
return rocmTargetDir, nil
|
return rocmTargetDir, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// Prefer explicit HIP env var
|
|
||||||
hipPath := os.Getenv("HIP_PATH")
|
|
||||||
if hipPath != "" {
|
|
||||||
hipLibDir := filepath.Join(hipPath, "bin")
|
|
||||||
if rocmLibUsable(hipLibDir) {
|
|
||||||
slog.Debug("detected ROCM via HIP_PATH=" + hipPath)
|
|
||||||
return hipLibDir, nil
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// Well known location(s)
|
|
||||||
if rocmLibUsable(RocmStandardLocation) {
|
|
||||||
return RocmStandardLocation, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// Should not happen on windows since we include it in the installer, but stand-alone binary might hit this
|
// Should not happen on windows since we include it in the installer, but stand-alone binary might hit this
|
||||||
slog.Warn("amdgpu detected, but no compatible rocm library found. Please install ROCm")
|
slog.Warn("amdgpu detected, but no compatible rocm library found. Please install ROCm")
|
||||||
return "", fmt.Errorf("no suitable rocm found, falling back to CPU")
|
return "", fmt.Errorf("no suitable rocm found, falling back to CPU")
|
||||||
|
|
|
@ -80,7 +80,7 @@ func cleanupTmpDirs() {
|
||||||
}
|
}
|
||||||
err = os.RemoveAll(d)
|
err = os.RemoveAll(d)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
slog.Debug(fmt.Sprintf("unable to cleanup stale tmpdir %s: %s", d, err))
|
slog.Debug("unable to cleanup stale tmpdir", "path", d, "error", err)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -120,7 +120,7 @@ func UpdatePath(dir string) {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
newPath := strings.Join(append([]string{dir}, pathComponents...), ";")
|
newPath := strings.Join(append([]string{dir}, pathComponents...), ";")
|
||||||
slog.Info(fmt.Sprintf("Updating PATH to %s", newPath))
|
slog.Info("updating", "PATH", newPath)
|
||||||
os.Setenv("PATH", newPath)
|
os.Setenv("PATH", newPath)
|
||||||
}
|
}
|
||||||
// linux and darwin rely on rpath
|
// linux and darwin rely on rpath
|
||||||
|
|
22
gpu/cuda_common.go
Normal file
22
gpu/cuda_common.go
Normal file
|
@ -0,0 +1,22 @@
|
||||||
|
//go:build linux || windows
|
||||||
|
|
||||||
|
package gpu
|
||||||
|
|
||||||
|
import (
|
||||||
|
"log/slog"
|
||||||
|
"strings"
|
||||||
|
)
|
||||||
|
|
||||||
|
func cudaGetVisibleDevicesEnv(gpuInfo []GpuInfo) (string, string) {
|
||||||
|
ids := []string{}
|
||||||
|
for _, info := range gpuInfo {
|
||||||
|
if info.Library != "cuda" {
|
||||||
|
// TODO shouldn't happen if things are wired correctly...
|
||||||
|
slog.Debug("cudaGetVisibleDevicesEnv skipping over non-cuda device", "library", info.Library)
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
ids = append(ids, info.ID)
|
||||||
|
}
|
||||||
|
return "CUDA_VISIBLE_DEVICES", strings.Join(ids, ",")
|
||||||
|
|
||||||
|
}
|
233
gpu/gpu.go
233
gpu/gpu.go
|
@ -16,7 +16,6 @@ import (
|
||||||
"os"
|
"os"
|
||||||
"path/filepath"
|
"path/filepath"
|
||||||
"runtime"
|
"runtime"
|
||||||
"strconv"
|
|
||||||
"strings"
|
"strings"
|
||||||
"sync"
|
"sync"
|
||||||
"unsafe"
|
"unsafe"
|
||||||
|
@ -25,8 +24,8 @@ import (
|
||||||
)
|
)
|
||||||
|
|
||||||
type handles struct {
|
type handles struct {
|
||||||
nvml *C.nvml_handle_t
|
deviceCount int
|
||||||
cudart *C.cudart_handle_t
|
cudart *C.cudart_handle_t
|
||||||
}
|
}
|
||||||
|
|
||||||
const (
|
const (
|
||||||
|
@ -39,26 +38,10 @@ var gpuMutex sync.Mutex
|
||||||
// With our current CUDA compile flags, older than 5.0 will not work properly
|
// With our current CUDA compile flags, older than 5.0 will not work properly
|
||||||
var CudaComputeMin = [2]C.int{5, 0}
|
var CudaComputeMin = [2]C.int{5, 0}
|
||||||
|
|
||||||
// Possible locations for the nvidia-ml library
|
var RocmComputeMin = 9
|
||||||
var NvmlLinuxGlobs = []string{
|
|
||||||
"/usr/local/cuda/lib64/libnvidia-ml.so*",
|
|
||||||
"/usr/lib/x86_64-linux-gnu/nvidia/current/libnvidia-ml.so*",
|
|
||||||
"/usr/lib/x86_64-linux-gnu/libnvidia-ml.so*",
|
|
||||||
"/usr/lib/wsl/lib/libnvidia-ml.so*",
|
|
||||||
"/usr/lib/wsl/drivers/*/libnvidia-ml.so*",
|
|
||||||
"/opt/cuda/lib64/libnvidia-ml.so*",
|
|
||||||
"/usr/lib*/libnvidia-ml.so*",
|
|
||||||
"/usr/lib/aarch64-linux-gnu/nvidia/current/libnvidia-ml.so*",
|
|
||||||
"/usr/lib/aarch64-linux-gnu/libnvidia-ml.so*",
|
|
||||||
"/usr/local/lib*/libnvidia-ml.so*",
|
|
||||||
|
|
||||||
// TODO: are these stubs ever valid?
|
// TODO find a better way to detect iGPU instead of minimum memory
|
||||||
"/opt/cuda/targets/x86_64-linux/lib/stubs/libnvidia-ml.so*",
|
const IGPUMemLimit = 1 * format.GibiByte // 512G is what they typically report, so anything less than 1G must be iGPU
|
||||||
}
|
|
||||||
|
|
||||||
var NvmlWindowsGlobs = []string{
|
|
||||||
"c:\\Windows\\System32\\nvml.dll",
|
|
||||||
}
|
|
||||||
|
|
||||||
var CudartLinuxGlobs = []string{
|
var CudartLinuxGlobs = []string{
|
||||||
"/usr/local/cuda/lib64/libcudart.so*",
|
"/usr/local/cuda/lib64/libcudart.so*",
|
||||||
|
@ -88,26 +71,18 @@ func initGPUHandles() *handles {
|
||||||
|
|
||||||
// TODO - if the ollama build is CPU only, don't do these checks as they're irrelevant and confusing
|
// TODO - if the ollama build is CPU only, don't do these checks as they're irrelevant and confusing
|
||||||
|
|
||||||
gpuHandles := &handles{nil, nil}
|
gpuHandles := &handles{}
|
||||||
var nvmlMgmtName string
|
|
||||||
var nvmlMgmtPatterns []string
|
|
||||||
var cudartMgmtName string
|
var cudartMgmtName string
|
||||||
var cudartMgmtPatterns []string
|
var cudartMgmtPatterns []string
|
||||||
|
|
||||||
tmpDir, _ := PayloadsDir()
|
tmpDir, _ := PayloadsDir()
|
||||||
switch runtime.GOOS {
|
switch runtime.GOOS {
|
||||||
case "windows":
|
case "windows":
|
||||||
nvmlMgmtName = "nvml.dll"
|
|
||||||
nvmlMgmtPatterns = make([]string, len(NvmlWindowsGlobs))
|
|
||||||
copy(nvmlMgmtPatterns, NvmlWindowsGlobs)
|
|
||||||
cudartMgmtName = "cudart64_*.dll"
|
cudartMgmtName = "cudart64_*.dll"
|
||||||
localAppData := os.Getenv("LOCALAPPDATA")
|
localAppData := os.Getenv("LOCALAPPDATA")
|
||||||
cudartMgmtPatterns = []string{filepath.Join(localAppData, "Programs", "Ollama", cudartMgmtName)}
|
cudartMgmtPatterns = []string{filepath.Join(localAppData, "Programs", "Ollama", cudartMgmtName)}
|
||||||
cudartMgmtPatterns = append(cudartMgmtPatterns, CudartWindowsGlobs...)
|
cudartMgmtPatterns = append(cudartMgmtPatterns, CudartWindowsGlobs...)
|
||||||
case "linux":
|
case "linux":
|
||||||
nvmlMgmtName = "libnvidia-ml.so"
|
|
||||||
nvmlMgmtPatterns = make([]string, len(NvmlLinuxGlobs))
|
|
||||||
copy(nvmlMgmtPatterns, NvmlLinuxGlobs)
|
|
||||||
cudartMgmtName = "libcudart.so*"
|
cudartMgmtName = "libcudart.so*"
|
||||||
if tmpDir != "" {
|
if tmpDir != "" {
|
||||||
// TODO - add "payloads" for subprocess
|
// TODO - add "payloads" for subprocess
|
||||||
|
@ -118,31 +93,21 @@ func initGPUHandles() *handles {
|
||||||
return gpuHandles
|
return gpuHandles
|
||||||
}
|
}
|
||||||
|
|
||||||
slog.Info("Detecting GPU type")
|
slog.Info("Detecting GPUs")
|
||||||
cudartLibPaths := FindGPULibs(cudartMgmtName, cudartMgmtPatterns)
|
cudartLibPaths := FindGPULibs(cudartMgmtName, cudartMgmtPatterns)
|
||||||
if len(cudartLibPaths) > 0 {
|
if len(cudartLibPaths) > 0 {
|
||||||
cudart := LoadCUDARTMgmt(cudartLibPaths)
|
deviceCount, cudart, libPath := LoadCUDARTMgmt(cudartLibPaths)
|
||||||
if cudart != nil {
|
if cudart != nil {
|
||||||
slog.Info("Nvidia GPU detected via cudart")
|
slog.Info("detected GPUs", "library", libPath, "count", deviceCount)
|
||||||
gpuHandles.cudart = cudart
|
gpuHandles.cudart = cudart
|
||||||
return gpuHandles
|
gpuHandles.deviceCount = deviceCount
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// TODO once we build confidence, remove this and the gpu_info_nvml.[ch] files
|
|
||||||
nvmlLibPaths := FindGPULibs(nvmlMgmtName, nvmlMgmtPatterns)
|
|
||||||
if len(nvmlLibPaths) > 0 {
|
|
||||||
nvml := LoadNVMLMgmt(nvmlLibPaths)
|
|
||||||
if nvml != nil {
|
|
||||||
slog.Info("Nvidia GPU detected via nvidia-ml")
|
|
||||||
gpuHandles.nvml = nvml
|
|
||||||
return gpuHandles
|
return gpuHandles
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
return gpuHandles
|
return gpuHandles
|
||||||
}
|
}
|
||||||
|
|
||||||
func GetGPUInfo() GpuInfo {
|
func GetGPUInfo() GpuInfoList {
|
||||||
// TODO - consider exploring lspci (and equivalent on windows) to check for
|
// TODO - consider exploring lspci (and equivalent on windows) to check for
|
||||||
// GPUs so we can report warnings if we see Nvidia/AMD but fail to load the libraries
|
// GPUs so we can report warnings if we see Nvidia/AMD but fail to load the libraries
|
||||||
gpuMutex.Lock()
|
gpuMutex.Lock()
|
||||||
|
@ -150,9 +115,6 @@ func GetGPUInfo() GpuInfo {
|
||||||
|
|
||||||
gpuHandles := initGPUHandles()
|
gpuHandles := initGPUHandles()
|
||||||
defer func() {
|
defer func() {
|
||||||
if gpuHandles.nvml != nil {
|
|
||||||
C.nvml_release(*gpuHandles.nvml)
|
|
||||||
}
|
|
||||||
if gpuHandles.cudart != nil {
|
if gpuHandles.cudart != nil {
|
||||||
C.cudart_release(*gpuHandles.cudart)
|
C.cudart_release(*gpuHandles.cudart)
|
||||||
}
|
}
|
||||||
|
@ -165,72 +127,63 @@ func GetGPUInfo() GpuInfo {
|
||||||
}
|
}
|
||||||
|
|
||||||
var memInfo C.mem_info_t
|
var memInfo C.mem_info_t
|
||||||
resp := GpuInfo{}
|
resp := []GpuInfo{}
|
||||||
if gpuHandles.nvml != nil && (cpuVariant != "" || runtime.GOARCH != "amd64") {
|
|
||||||
C.nvml_check_vram(*gpuHandles.nvml, &memInfo)
|
// NVIDIA first
|
||||||
|
for i := 0; i < gpuHandles.deviceCount; i++ {
|
||||||
|
// TODO once we support CPU compilation variants of GPU libraries refine this...
|
||||||
|
if cpuVariant == "" && runtime.GOARCH == "amd64" {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
gpuInfo := GpuInfo{
|
||||||
|
Library: "cuda",
|
||||||
|
}
|
||||||
|
C.cudart_check_vram(*gpuHandles.cudart, C.int(i), &memInfo)
|
||||||
if memInfo.err != nil {
|
if memInfo.err != nil {
|
||||||
slog.Info(fmt.Sprintf("[nvidia-ml] error looking up NVML GPU memory: %s", C.GoString(memInfo.err)))
|
slog.Info("error looking up nvidia GPU memory", "error", C.GoString(memInfo.err))
|
||||||
C.free(unsafe.Pointer(memInfo.err))
|
C.free(unsafe.Pointer(memInfo.err))
|
||||||
} else if memInfo.count > 0 {
|
continue
|
||||||
// Verify minimum compute capability
|
|
||||||
var cc C.nvml_compute_capability_t
|
|
||||||
C.nvml_compute_capability(*gpuHandles.nvml, &cc)
|
|
||||||
if cc.err != nil {
|
|
||||||
slog.Info(fmt.Sprintf("[nvidia-ml] error looking up NVML GPU compute capability: %s", C.GoString(cc.err)))
|
|
||||||
C.free(unsafe.Pointer(cc.err))
|
|
||||||
} else if cc.major > CudaComputeMin[0] || (cc.major == CudaComputeMin[0] && cc.minor >= CudaComputeMin[1]) {
|
|
||||||
slog.Info(fmt.Sprintf("[nvidia-ml] NVML CUDA Compute Capability detected: %d.%d", cc.major, cc.minor))
|
|
||||||
resp.Library = "cuda"
|
|
||||||
resp.MinimumMemory = cudaMinimumMemory
|
|
||||||
} else {
|
|
||||||
slog.Info(fmt.Sprintf("[nvidia-ml] CUDA GPU is too old. Falling back to CPU mode. Compute Capability detected: %d.%d", cc.major, cc.minor))
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
} else if gpuHandles.cudart != nil && (cpuVariant != "" || runtime.GOARCH != "amd64") {
|
if memInfo.major < CudaComputeMin[0] || (memInfo.major == CudaComputeMin[0] && memInfo.minor < CudaComputeMin[1]) {
|
||||||
C.cudart_check_vram(*gpuHandles.cudart, &memInfo)
|
slog.Info(fmt.Sprintf("[%d] CUDA GPU is too old. Compute Capability detected: %d.%d", i, memInfo.major, memInfo.minor))
|
||||||
if memInfo.err != nil {
|
continue
|
||||||
slog.Info(fmt.Sprintf("[cudart] error looking up CUDART GPU memory: %s", C.GoString(memInfo.err)))
|
|
||||||
C.free(unsafe.Pointer(memInfo.err))
|
|
||||||
} else if memInfo.count > 0 {
|
|
||||||
// Verify minimum compute capability
|
|
||||||
var cc C.cudart_compute_capability_t
|
|
||||||
C.cudart_compute_capability(*gpuHandles.cudart, &cc)
|
|
||||||
if cc.err != nil {
|
|
||||||
slog.Info(fmt.Sprintf("[cudart] error looking up CUDA compute capability: %s", C.GoString(cc.err)))
|
|
||||||
C.free(unsafe.Pointer(cc.err))
|
|
||||||
} else if cc.major > CudaComputeMin[0] || (cc.major == CudaComputeMin[0] && cc.minor >= CudaComputeMin[1]) {
|
|
||||||
slog.Info(fmt.Sprintf("[cudart] CUDART CUDA Compute Capability detected: %d.%d", cc.major, cc.minor))
|
|
||||||
resp.Library = "cuda"
|
|
||||||
resp.MinimumMemory = cudaMinimumMemory
|
|
||||||
} else {
|
|
||||||
slog.Info(fmt.Sprintf("[cudart] CUDA GPU is too old. Falling back to CPU mode. Compute Capability detected: %d.%d", cc.major, cc.minor))
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
} else {
|
gpuInfo.TotalMemory = uint64(memInfo.total)
|
||||||
AMDGetGPUInfo(&resp)
|
gpuInfo.FreeMemory = uint64(memInfo.free)
|
||||||
if resp.Library != "" {
|
gpuInfo.ID = C.GoString(&memInfo.gpu_id[0])
|
||||||
resp.MinimumMemory = rocmMinimumMemory
|
gpuInfo.Major = int(memInfo.major)
|
||||||
return resp
|
gpuInfo.Minor = int(memInfo.minor)
|
||||||
}
|
gpuInfo.MinimumMemory = cudaMinimumMemory
|
||||||
}
|
|
||||||
if resp.Library == "" {
|
// TODO potentially sort on our own algorithm instead of what the underlying GPU library does...
|
||||||
C.cpu_check_ram(&memInfo)
|
resp = append(resp, gpuInfo)
|
||||||
resp.Library = "cpu"
|
}
|
||||||
resp.Variant = cpuVariant
|
|
||||||
}
|
// Then AMD
|
||||||
if memInfo.err != nil {
|
resp = append(resp, AMDGetGPUInfo()...)
|
||||||
slog.Info(fmt.Sprintf("error looking up CPU memory: %s", C.GoString(memInfo.err)))
|
|
||||||
C.free(unsafe.Pointer(memInfo.err))
|
if len(resp) == 0 {
|
||||||
return resp
|
C.cpu_check_ram(&memInfo)
|
||||||
|
if memInfo.err != nil {
|
||||||
|
slog.Info("error looking up CPU memory", "error", C.GoString(memInfo.err))
|
||||||
|
C.free(unsafe.Pointer(memInfo.err))
|
||||||
|
return resp
|
||||||
|
}
|
||||||
|
gpuInfo := GpuInfo{
|
||||||
|
Library: "cpu",
|
||||||
|
Variant: cpuVariant,
|
||||||
|
}
|
||||||
|
gpuInfo.TotalMemory = uint64(memInfo.total)
|
||||||
|
gpuInfo.FreeMemory = uint64(memInfo.free)
|
||||||
|
gpuInfo.ID = C.GoString(&memInfo.gpu_id[0])
|
||||||
|
|
||||||
|
resp = append(resp, gpuInfo)
|
||||||
}
|
}
|
||||||
|
|
||||||
resp.DeviceCount = uint32(memInfo.count)
|
|
||||||
resp.FreeMemory = uint64(memInfo.free)
|
|
||||||
resp.TotalMemory = uint64(memInfo.total)
|
|
||||||
return resp
|
return resp
|
||||||
}
|
}
|
||||||
|
|
||||||
func getCPUMem() (memInfo, error) {
|
func GetCPUMem() (memInfo, error) {
|
||||||
var ret memInfo
|
var ret memInfo
|
||||||
var info C.mem_info_t
|
var info C.mem_info_t
|
||||||
C.cpu_check_ram(&info)
|
C.cpu_check_ram(&info)
|
||||||
|
@ -243,29 +196,11 @@ func getCPUMem() (memInfo, error) {
|
||||||
return ret, nil
|
return ret, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func CheckVRAM() (uint64, error) {
|
|
||||||
userLimit := os.Getenv("OLLAMA_MAX_VRAM")
|
|
||||||
if userLimit != "" {
|
|
||||||
avail, err := strconv.ParseInt(userLimit, 10, 64)
|
|
||||||
if err != nil {
|
|
||||||
return 0, fmt.Errorf("Invalid OLLAMA_MAX_VRAM setting %s: %s", userLimit, err)
|
|
||||||
}
|
|
||||||
slog.Info(fmt.Sprintf("user override OLLAMA_MAX_VRAM=%d", avail))
|
|
||||||
return uint64(avail), nil
|
|
||||||
}
|
|
||||||
gpuInfo := GetGPUInfo()
|
|
||||||
if gpuInfo.FreeMemory > 0 && (gpuInfo.Library == "cuda" || gpuInfo.Library == "rocm") {
|
|
||||||
return gpuInfo.FreeMemory, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
return 0, fmt.Errorf("no GPU detected") // TODO - better handling of CPU based memory determiniation
|
|
||||||
}
|
|
||||||
|
|
||||||
func FindGPULibs(baseLibName string, patterns []string) []string {
|
func FindGPULibs(baseLibName string, patterns []string) []string {
|
||||||
// Multiple GPU libraries may exist, and some may not work, so keep trying until we exhaust them
|
// Multiple GPU libraries may exist, and some may not work, so keep trying until we exhaust them
|
||||||
var ldPaths []string
|
var ldPaths []string
|
||||||
gpuLibPaths := []string{}
|
gpuLibPaths := []string{}
|
||||||
slog.Info(fmt.Sprintf("Searching for GPU management library %s", baseLibName))
|
slog.Debug("Searching for GPU library", "name", baseLibName)
|
||||||
|
|
||||||
switch runtime.GOOS {
|
switch runtime.GOOS {
|
||||||
case "windows":
|
case "windows":
|
||||||
|
@ -283,7 +218,7 @@ func FindGPULibs(baseLibName string, patterns []string) []string {
|
||||||
}
|
}
|
||||||
patterns = append(patterns, filepath.Join(d, baseLibName+"*"))
|
patterns = append(patterns, filepath.Join(d, baseLibName+"*"))
|
||||||
}
|
}
|
||||||
slog.Debug(fmt.Sprintf("gpu management search paths: %v", patterns))
|
slog.Debug("gpu library search", "globs", patterns)
|
||||||
for _, pattern := range patterns {
|
for _, pattern := range patterns {
|
||||||
// Ignore glob discovery errors
|
// Ignore glob discovery errors
|
||||||
matches, _ := filepath.Glob(pattern)
|
matches, _ := filepath.Glob(pattern)
|
||||||
|
@ -311,28 +246,11 @@ func FindGPULibs(baseLibName string, patterns []string) []string {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
slog.Info(fmt.Sprintf("Discovered GPU libraries: %v", gpuLibPaths))
|
slog.Debug("discovered GPU libraries", "paths", gpuLibPaths)
|
||||||
return gpuLibPaths
|
return gpuLibPaths
|
||||||
}
|
}
|
||||||
|
|
||||||
func LoadNVMLMgmt(nvmlLibPaths []string) *C.nvml_handle_t {
|
func LoadCUDARTMgmt(cudartLibPaths []string) (int, *C.cudart_handle_t, string) {
|
||||||
var resp C.nvml_init_resp_t
|
|
||||||
resp.ch.verbose = getVerboseState()
|
|
||||||
for _, libPath := range nvmlLibPaths {
|
|
||||||
lib := C.CString(libPath)
|
|
||||||
defer C.free(unsafe.Pointer(lib))
|
|
||||||
C.nvml_init(lib, &resp)
|
|
||||||
if resp.err != nil {
|
|
||||||
slog.Info(fmt.Sprintf("Unable to load NVML management library %s: %s", libPath, C.GoString(resp.err)))
|
|
||||||
C.free(unsafe.Pointer(resp.err))
|
|
||||||
} else {
|
|
||||||
return &resp.ch
|
|
||||||
}
|
|
||||||
}
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func LoadCUDARTMgmt(cudartLibPaths []string) *C.cudart_handle_t {
|
|
||||||
var resp C.cudart_init_resp_t
|
var resp C.cudart_init_resp_t
|
||||||
resp.ch.verbose = getVerboseState()
|
resp.ch.verbose = getVerboseState()
|
||||||
for _, libPath := range cudartLibPaths {
|
for _, libPath := range cudartLibPaths {
|
||||||
|
@ -340,13 +258,13 @@ func LoadCUDARTMgmt(cudartLibPaths []string) *C.cudart_handle_t {
|
||||||
defer C.free(unsafe.Pointer(lib))
|
defer C.free(unsafe.Pointer(lib))
|
||||||
C.cudart_init(lib, &resp)
|
C.cudart_init(lib, &resp)
|
||||||
if resp.err != nil {
|
if resp.err != nil {
|
||||||
slog.Info(fmt.Sprintf("Unable to load cudart CUDA management library %s: %s", libPath, C.GoString(resp.err)))
|
slog.Debug("Unable to load cudart", "library", libPath, "error", C.GoString(resp.err))
|
||||||
C.free(unsafe.Pointer(resp.err))
|
C.free(unsafe.Pointer(resp.err))
|
||||||
} else {
|
} else {
|
||||||
return &resp.ch
|
return int(resp.num_devices), &resp.ch, libPath
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
return nil
|
return 0, nil, ""
|
||||||
}
|
}
|
||||||
|
|
||||||
func getVerboseState() C.uint16_t {
|
func getVerboseState() C.uint16_t {
|
||||||
|
@ -355,3 +273,22 @@ func getVerboseState() C.uint16_t {
|
||||||
}
|
}
|
||||||
return C.uint16_t(0)
|
return C.uint16_t(0)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Given the list of GPUs this instantiation is targeted for,
|
||||||
|
// figure out the visible devices environment variable
|
||||||
|
//
|
||||||
|
// If different libraries are detected, the first one is what we use
|
||||||
|
func (l GpuInfoList) GetVisibleDevicesEnv() (string, string) {
|
||||||
|
if len(l) == 0 {
|
||||||
|
return "", ""
|
||||||
|
}
|
||||||
|
switch l[0].Library {
|
||||||
|
case "cuda":
|
||||||
|
return cudaGetVisibleDevicesEnv(l)
|
||||||
|
case "rocm":
|
||||||
|
return rocmGetVisibleDevicesEnv(l)
|
||||||
|
default:
|
||||||
|
slog.Debug("no filter required for library " + l[0].Library)
|
||||||
|
return "", ""
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
|
@ -9,52 +9,41 @@ package gpu
|
||||||
*/
|
*/
|
||||||
import "C"
|
import "C"
|
||||||
import (
|
import (
|
||||||
"fmt"
|
|
||||||
"log/slog"
|
|
||||||
"os"
|
|
||||||
"runtime"
|
"runtime"
|
||||||
"strconv"
|
|
||||||
)
|
)
|
||||||
|
|
||||||
// CheckVRAM returns the free VRAM in bytes on Linux machines with NVIDIA GPUs
|
func GetGPUInfo() GpuInfoList {
|
||||||
func CheckVRAM() (uint64, error) {
|
mem, _ := GetCPUMem()
|
||||||
userLimit := os.Getenv("OLLAMA_MAX_VRAM")
|
|
||||||
if userLimit != "" {
|
|
||||||
avail, err := strconv.ParseInt(userLimit, 10, 64)
|
|
||||||
if err != nil {
|
|
||||||
return 0, fmt.Errorf("Invalid OLLAMA_MAX_VRAM setting %s: %s", userLimit, err)
|
|
||||||
}
|
|
||||||
slog.Info(fmt.Sprintf("user override OLLAMA_MAX_VRAM=%d", avail))
|
|
||||||
return uint64(avail), nil
|
|
||||||
}
|
|
||||||
|
|
||||||
if runtime.GOARCH == "amd64" {
|
if runtime.GOARCH == "amd64" {
|
||||||
// gpu not supported, this may not be metal
|
return []GpuInfo{
|
||||||
return 0, nil
|
{
|
||||||
}
|
Library: "cpu",
|
||||||
|
Variant: GetCPUVariant(),
|
||||||
return uint64(C.getRecommendedMaxVRAM()), nil
|
memInfo: mem,
|
||||||
}
|
},
|
||||||
|
|
||||||
func GetGPUInfo() GpuInfo {
|
|
||||||
mem, _ := getCPUMem()
|
|
||||||
if runtime.GOARCH == "amd64" {
|
|
||||||
return GpuInfo{
|
|
||||||
Library: "cpu",
|
|
||||||
Variant: GetCPUVariant(),
|
|
||||||
memInfo: mem,
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
return GpuInfo{
|
info := GpuInfo{
|
||||||
Library: "metal",
|
Library: "metal",
|
||||||
memInfo: mem,
|
ID: "0",
|
||||||
}
|
}
|
||||||
|
info.TotalMemory = uint64(C.getRecommendedMaxVRAM())
|
||||||
|
|
||||||
|
// TODO is there a way to gather actual allocated video memory? (currentAllocatedSize doesn't work)
|
||||||
|
info.FreeMemory = info.TotalMemory
|
||||||
|
|
||||||
|
info.MinimumMemory = 0
|
||||||
|
return []GpuInfo{info}
|
||||||
}
|
}
|
||||||
|
|
||||||
func getCPUMem() (memInfo, error) {
|
func GetCPUMem() (memInfo, error) {
|
||||||
return memInfo{
|
return memInfo{
|
||||||
TotalMemory: uint64(C.getPhysicalMemory()),
|
TotalMemory: uint64(C.getPhysicalMemory()),
|
||||||
FreeMemory: 0,
|
FreeMemory: 0,
|
||||||
DeviceCount: 1,
|
|
||||||
}, nil
|
}, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (l GpuInfoList) GetVisibleDevicesEnv() (string, string) {
|
||||||
|
// No-op on darwin
|
||||||
|
return "", ""
|
||||||
|
}
|
||||||
|
|
|
@ -38,12 +38,17 @@
|
||||||
extern "C" {
|
extern "C" {
|
||||||
#endif
|
#endif
|
||||||
|
|
||||||
|
#define GPU_ID_LEN 64
|
||||||
|
|
||||||
typedef struct mem_info {
|
typedef struct mem_info {
|
||||||
|
char *err; // If non-nill, caller responsible for freeing
|
||||||
|
char gpu_id[GPU_ID_LEN];
|
||||||
uint64_t total;
|
uint64_t total;
|
||||||
uint64_t free;
|
uint64_t free;
|
||||||
unsigned int count;
|
|
||||||
int igpu_index; // If >= 0, we detected an integrated GPU to ignore
|
// Compute Capability
|
||||||
char *err; // If non-nill, caller responsible for freeing
|
int major;
|
||||||
|
int minor;
|
||||||
} mem_info_t;
|
} mem_info_t;
|
||||||
|
|
||||||
void cpu_check_ram(mem_info_t *resp);
|
void cpu_check_ram(mem_info_t *resp);
|
||||||
|
@ -52,7 +57,6 @@ void cpu_check_ram(mem_info_t *resp);
|
||||||
}
|
}
|
||||||
#endif
|
#endif
|
||||||
|
|
||||||
#include "gpu_info_nvml.h"
|
|
||||||
#include "gpu_info_cudart.h"
|
#include "gpu_info_cudart.h"
|
||||||
|
|
||||||
#endif // __GPU_INFO_H__
|
#endif // __GPU_INFO_H__
|
||||||
|
|
|
@ -8,9 +8,11 @@ void cpu_check_ram(mem_info_t *resp) {
|
||||||
MEMORYSTATUSEX info;
|
MEMORYSTATUSEX info;
|
||||||
info.dwLength = sizeof(info);
|
info.dwLength = sizeof(info);
|
||||||
if (GlobalMemoryStatusEx(&info) != 0) {
|
if (GlobalMemoryStatusEx(&info) != 0) {
|
||||||
resp->count = 1;
|
|
||||||
resp->total = info.ullTotalPhys;
|
resp->total = info.ullTotalPhys;
|
||||||
resp->free = info.ullAvailPhys;
|
resp->free = info.ullAvailPhys;
|
||||||
|
resp->major = 0;
|
||||||
|
resp->minor = 0;
|
||||||
|
snprintf(&resp->gpu_id[0], GPU_ID_LEN, "0");
|
||||||
} else {
|
} else {
|
||||||
resp->err = LOAD_ERR();
|
resp->err = LOAD_ERR();
|
||||||
}
|
}
|
||||||
|
@ -27,9 +29,11 @@ void cpu_check_ram(mem_info_t *resp) {
|
||||||
if (sysinfo(&info) != 0) {
|
if (sysinfo(&info) != 0) {
|
||||||
resp->err = strdup(strerror(errno));
|
resp->err = strdup(strerror(errno));
|
||||||
} else {
|
} else {
|
||||||
resp->count = 1;
|
|
||||||
resp->total = info.totalram * info.mem_unit;
|
resp->total = info.totalram * info.mem_unit;
|
||||||
resp->free = info.freeram * info.mem_unit;
|
resp->free = info.freeram * info.mem_unit;
|
||||||
|
resp->major = 0;
|
||||||
|
resp->minor = 0;
|
||||||
|
snprintf(&resp->gpu_id[0], GPU_ID_LEN, "0");
|
||||||
}
|
}
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
|
|
|
@ -6,6 +6,7 @@
|
||||||
void cudart_init(char *cudart_lib_path, cudart_init_resp_t *resp) {
|
void cudart_init(char *cudart_lib_path, cudart_init_resp_t *resp) {
|
||||||
cudartReturn_t ret;
|
cudartReturn_t ret;
|
||||||
resp->err = NULL;
|
resp->err = NULL;
|
||||||
|
resp->num_devices = 0;
|
||||||
const int buflen = 256;
|
const int buflen = 256;
|
||||||
char buf[buflen + 1];
|
char buf[buflen + 1];
|
||||||
int i;
|
int i;
|
||||||
|
@ -21,6 +22,7 @@ void cudart_init(char *cudart_lib_path, cudart_init_resp_t *resp) {
|
||||||
{"cudaGetDeviceCount", (void *)&resp->ch.cudaGetDeviceCount},
|
{"cudaGetDeviceCount", (void *)&resp->ch.cudaGetDeviceCount},
|
||||||
{"cudaDeviceGetAttribute", (void *)&resp->ch.cudaDeviceGetAttribute},
|
{"cudaDeviceGetAttribute", (void *)&resp->ch.cudaDeviceGetAttribute},
|
||||||
{"cudaDriverGetVersion", (void *)&resp->ch.cudaDriverGetVersion},
|
{"cudaDriverGetVersion", (void *)&resp->ch.cudaDriverGetVersion},
|
||||||
|
{"cudaGetDeviceProperties", (void *)&resp->ch.cudaGetDeviceProperties},
|
||||||
{NULL, NULL},
|
{NULL, NULL},
|
||||||
};
|
};
|
||||||
|
|
||||||
|
@ -36,13 +38,7 @@ void cudart_init(char *cudart_lib_path, cudart_init_resp_t *resp) {
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
|
|
||||||
// TODO once we've squashed the remaining corner cases remove this log
|
|
||||||
LOG(resp->ch.verbose, "wiring cudart library functions in %s\n", cudart_lib_path);
|
|
||||||
|
|
||||||
for (i = 0; l[i].s != NULL; i++) {
|
for (i = 0; l[i].s != NULL; i++) {
|
||||||
// TODO once we've squashed the remaining corner cases remove this log
|
|
||||||
LOG(resp->ch.verbose, "dlsym: %s\n", l[i].s);
|
|
||||||
|
|
||||||
*l[i].p = LOAD_SYMBOL(resp->ch.handle, l[i].s);
|
*l[i].p = LOAD_SYMBOL(resp->ch.handle, l[i].s);
|
||||||
if (!l[i].p) {
|
if (!l[i].p) {
|
||||||
char *msg = LOAD_ERR();
|
char *msg = LOAD_ERR();
|
||||||
|
@ -63,7 +59,7 @@ void cudart_init(char *cudart_lib_path, cudart_init_resp_t *resp) {
|
||||||
UNLOAD_LIBRARY(resp->ch.handle);
|
UNLOAD_LIBRARY(resp->ch.handle);
|
||||||
resp->ch.handle = NULL;
|
resp->ch.handle = NULL;
|
||||||
if (ret == CUDA_ERROR_INSUFFICIENT_DRIVER) {
|
if (ret == CUDA_ERROR_INSUFFICIENT_DRIVER) {
|
||||||
resp->err = strdup("your nvidia driver is too old or missing, please upgrade to run ollama");
|
resp->err = strdup("your nvidia driver is too old or missing. If you have a CUDA GPU please upgrade to run ollama");
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
snprintf(buf, buflen, "cudart init failure: %d", ret);
|
snprintf(buf, buflen, "cudart init failure: %d", ret);
|
||||||
|
@ -85,110 +81,95 @@ void cudart_init(char *cudart_lib_path, cudart_init_resp_t *resp) {
|
||||||
driverVersion.minor = (version - (driverVersion.major * 1000)) / 10;
|
driverVersion.minor = (version - (driverVersion.major * 1000)) / 10;
|
||||||
LOG(resp->ch.verbose, "CUDA driver version: %d-%d\n", driverVersion.major, driverVersion.minor);
|
LOG(resp->ch.verbose, "CUDA driver version: %d-%d\n", driverVersion.major, driverVersion.minor);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
ret = (*resp->ch.cudaGetDeviceCount)(&resp->num_devices);
|
||||||
|
if (ret != CUDART_SUCCESS) {
|
||||||
|
LOG(resp->ch.verbose, "cudaGetDeviceCount err: %d\n", ret);
|
||||||
|
UNLOAD_LIBRARY(resp->ch.handle);
|
||||||
|
resp->ch.handle = NULL;
|
||||||
|
snprintf(buf, buflen, "unable to get device count: %d", ret);
|
||||||
|
resp->err = strdup(buf);
|
||||||
|
return;
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
void cudart_check_vram(cudart_handle_t h, mem_info_t *resp) {
|
void cudart_check_vram(cudart_handle_t h, int i, mem_info_t *resp) {
|
||||||
resp->err = NULL;
|
resp->err = NULL;
|
||||||
cudartMemory_t memInfo = {0,0,0};
|
cudartMemory_t memInfo = {0,0,0};
|
||||||
cudartReturn_t ret;
|
cudartReturn_t ret;
|
||||||
const int buflen = 256;
|
const int buflen = 256;
|
||||||
char buf[buflen + 1];
|
char buf[buflen + 1];
|
||||||
int i;
|
|
||||||
|
|
||||||
if (h.handle == NULL) {
|
if (h.handle == NULL) {
|
||||||
resp->err = strdup("cudart handle isn't initialized");
|
resp->err = strdup("cudart handle isn't initialized");
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
|
|
||||||
// cudaGetDeviceCount takes int type, resp-> count is uint
|
ret = (*h.cudaSetDevice)(i);
|
||||||
int deviceCount;
|
|
||||||
ret = (*h.cudaGetDeviceCount)(&deviceCount);
|
|
||||||
if (ret != CUDART_SUCCESS) {
|
if (ret != CUDART_SUCCESS) {
|
||||||
snprintf(buf, buflen, "unable to get device count: %d", ret);
|
snprintf(buf, buflen, "cudart device failed to initialize");
|
||||||
resp->err = strdup(buf);
|
resp->err = strdup(buf);
|
||||||
return;
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
|
cudaDeviceProp_t props;
|
||||||
|
ret = (*h.cudaGetDeviceProperties)(&props, i);
|
||||||
|
if (ret != CUDART_SUCCESS) {
|
||||||
|
LOG(h.verbose, "[%d] device properties lookup failure: %d\n", i, ret);
|
||||||
|
snprintf(&resp->gpu_id[0], GPU_ID_LEN, "%d", i);
|
||||||
|
resp->major = 0;
|
||||||
|
resp->minor = 0;
|
||||||
} else {
|
} else {
|
||||||
resp->count = (unsigned int)deviceCount;
|
int allNull = 1;
|
||||||
}
|
for (int j = 0; j < 16; j++) {
|
||||||
|
if (props.uuid.bytes[j] != 0) {
|
||||||
resp->total = 0;
|
allNull = 0;
|
||||||
resp->free = 0;
|
break;
|
||||||
for (i = 0; i < resp-> count; i++) {
|
}
|
||||||
ret = (*h.cudaSetDevice)(i);
|
|
||||||
if (ret != CUDART_SUCCESS) {
|
|
||||||
snprintf(buf, buflen, "cudart device failed to initialize");
|
|
||||||
resp->err = strdup(buf);
|
|
||||||
return;
|
|
||||||
}
|
}
|
||||||
ret = (*h.cudaMemGetInfo)(&memInfo.free, &memInfo.total);
|
if (allNull != 0) {
|
||||||
if (ret != CUDART_SUCCESS) {
|
snprintf(&resp->gpu_id[0], GPU_ID_LEN, "%d", i);
|
||||||
snprintf(buf, buflen, "cudart device memory info lookup failure %d", ret);
|
} else {
|
||||||
resp->err = strdup(buf);
|
// GPU-d110a105-ac29-1d54-7b49-9c90440f215b
|
||||||
return;
|
snprintf(&resp->gpu_id[0], GPU_ID_LEN,
|
||||||
|
"GPU-%02x%02x%02x%02x-%02x%02x-%02x%02x-%02x%02x-%02x%02x%02x%02x%02x%02x",
|
||||||
|
props.uuid.bytes[0],
|
||||||
|
props.uuid.bytes[1],
|
||||||
|
props.uuid.bytes[2],
|
||||||
|
props.uuid.bytes[3],
|
||||||
|
props.uuid.bytes[4],
|
||||||
|
props.uuid.bytes[5],
|
||||||
|
props.uuid.bytes[6],
|
||||||
|
props.uuid.bytes[7],
|
||||||
|
props.uuid.bytes[8],
|
||||||
|
props.uuid.bytes[9],
|
||||||
|
props.uuid.bytes[10],
|
||||||
|
props.uuid.bytes[11],
|
||||||
|
props.uuid.bytes[12],
|
||||||
|
props.uuid.bytes[13],
|
||||||
|
props.uuid.bytes[14],
|
||||||
|
props.uuid.bytes[15]
|
||||||
|
);
|
||||||
}
|
}
|
||||||
|
resp->major = props.major;
|
||||||
|
resp->minor = props.minor;
|
||||||
|
|
||||||
LOG(h.verbose, "[%d] CUDA totalMem %lu\n", i, memInfo.total);
|
// TODO add other useful properties from props
|
||||||
LOG(h.verbose, "[%d] CUDA freeMem %lu\n", i, memInfo.free);
|
|
||||||
|
|
||||||
resp->total += memInfo.total;
|
|
||||||
resp->free += memInfo.free;
|
|
||||||
}
|
}
|
||||||
}
|
ret = (*h.cudaMemGetInfo)(&memInfo.free, &memInfo.total);
|
||||||
|
|
||||||
void cudart_compute_capability(cudart_handle_t h, cudart_compute_capability_t *resp) {
|
|
||||||
resp->err = NULL;
|
|
||||||
resp->major = 0;
|
|
||||||
resp->minor = 0;
|
|
||||||
int major = 0;
|
|
||||||
int minor = 0;
|
|
||||||
cudartReturn_t ret;
|
|
||||||
const int buflen = 256;
|
|
||||||
char buf[buflen + 1];
|
|
||||||
int i;
|
|
||||||
|
|
||||||
if (h.handle == NULL) {
|
|
||||||
resp->err = strdup("cudart handle not initialized");
|
|
||||||
return;
|
|
||||||
}
|
|
||||||
|
|
||||||
int devices;
|
|
||||||
ret = (*h.cudaGetDeviceCount)(&devices);
|
|
||||||
if (ret != CUDART_SUCCESS) {
|
if (ret != CUDART_SUCCESS) {
|
||||||
snprintf(buf, buflen, "unable to get cudart device count: %d", ret);
|
snprintf(buf, buflen, "cudart device memory info lookup failure %d", ret);
|
||||||
resp->err = strdup(buf);
|
resp->err = strdup(buf);
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
|
|
||||||
for (i = 0; i < devices; i++) {
|
resp->total = memInfo.total;
|
||||||
ret = (*h.cudaSetDevice)(i);
|
resp->free = memInfo.free;
|
||||||
if (ret != CUDART_SUCCESS) {
|
|
||||||
snprintf(buf, buflen, "cudart device failed to initialize");
|
|
||||||
resp->err = strdup(buf);
|
|
||||||
return;
|
|
||||||
}
|
|
||||||
|
|
||||||
ret = (*h.cudaDeviceGetAttribute)(&major, cudartDevAttrComputeCapabilityMajor, i);
|
LOG(h.verbose, "[%s] CUDA totalMem %lu\n", resp->gpu_id, resp->total);
|
||||||
if (ret != CUDART_SUCCESS) {
|
LOG(h.verbose, "[%s] CUDA freeMem %lu\n", resp->gpu_id, resp->free);
|
||||||
snprintf(buf, buflen, "device compute capability lookup failure %d: %d", i, ret);
|
LOG(h.verbose, "[%s] Compute Capability %d.%d\n", resp->gpu_id, resp->major, resp->minor);
|
||||||
resp->err = strdup(buf);
|
|
||||||
return;
|
|
||||||
}
|
|
||||||
ret = (*h.cudaDeviceGetAttribute)(&minor, cudartDevAttrComputeCapabilityMinor, i);
|
|
||||||
if (ret != CUDART_SUCCESS) {
|
|
||||||
snprintf(buf, buflen, "device compute capability lookup failure %d: %d", i, ret);
|
|
||||||
resp->err = strdup(buf);
|
|
||||||
return;
|
|
||||||
}
|
|
||||||
|
|
||||||
// Report the lowest major.minor we detect as that limits our compatibility
|
|
||||||
if (resp->major == 0 || resp->major > major ) {
|
|
||||||
resp->major = major;
|
|
||||||
resp->minor = minor;
|
|
||||||
} else if ( resp->major == major && resp->minor > minor ) {
|
|
||||||
resp->minor = minor;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
void cudart_release(cudart_handle_t h) {
|
void cudart_release(cudart_handle_t h) {
|
||||||
|
|
|
@ -6,7 +6,8 @@
|
||||||
// Just enough typedef's to dlopen/dlsym for memory information
|
// Just enough typedef's to dlopen/dlsym for memory information
|
||||||
typedef enum cudartReturn_enum {
|
typedef enum cudartReturn_enum {
|
||||||
CUDART_SUCCESS = 0,
|
CUDART_SUCCESS = 0,
|
||||||
CUDART_UNSUPPORTED = 1,
|
CUDA_ERROR_INVALID_VALUE = 1,
|
||||||
|
CUDA_ERROR_MEMORY_ALLOCATION = 2,
|
||||||
CUDA_ERROR_INSUFFICIENT_DRIVER = 35,
|
CUDA_ERROR_INSUFFICIENT_DRIVER = 35,
|
||||||
// Other values omitted for now...
|
// Other values omitted for now...
|
||||||
} cudartReturn_t;
|
} cudartReturn_t;
|
||||||
|
@ -14,6 +15,11 @@ typedef enum cudartReturn_enum {
|
||||||
typedef enum cudartDeviceAttr_enum {
|
typedef enum cudartDeviceAttr_enum {
|
||||||
cudartDevAttrComputeCapabilityMajor = 75,
|
cudartDevAttrComputeCapabilityMajor = 75,
|
||||||
cudartDevAttrComputeCapabilityMinor = 76,
|
cudartDevAttrComputeCapabilityMinor = 76,
|
||||||
|
|
||||||
|
// TODO - not yet wired up but may be useful for Jetson or other
|
||||||
|
// integrated GPU scenarios with shared memory
|
||||||
|
cudaDevAttrIntegrated = 18
|
||||||
|
|
||||||
} cudartDeviceAttr_t;
|
} cudartDeviceAttr_t;
|
||||||
|
|
||||||
typedef void *cudartDevice_t; // Opaque is sufficient
|
typedef void *cudartDevice_t; // Opaque is sufficient
|
||||||
|
@ -28,6 +34,92 @@ typedef struct cudartDriverVersion {
|
||||||
int minor;
|
int minor;
|
||||||
} cudartDriverVersion_t;
|
} cudartDriverVersion_t;
|
||||||
|
|
||||||
|
typedef struct cudaUUID {
|
||||||
|
unsigned char bytes[16];
|
||||||
|
} cudaUUID_t;
|
||||||
|
typedef struct cudaDeviceProp {
|
||||||
|
char name[256]; /**< ASCII string identifying device */
|
||||||
|
cudaUUID_t uuid; /**< 16-byte unique identifier */
|
||||||
|
char luid[8]; /**< 8-byte locally unique identifier. Value is undefined on TCC and non-Windows platforms */
|
||||||
|
unsigned int luidDeviceNodeMask; /**< LUID device node mask. Value is undefined on TCC and non-Windows platforms */
|
||||||
|
size_t totalGlobalMem; /**< Global memory available on device in bytes */
|
||||||
|
size_t sharedMemPerBlock; /**< Shared memory available per block in bytes */
|
||||||
|
int regsPerBlock; /**< 32-bit registers available per block */
|
||||||
|
int warpSize; /**< Warp size in threads */
|
||||||
|
size_t memPitch; /**< Maximum pitch in bytes allowed by memory copies */
|
||||||
|
int maxThreadsPerBlock; /**< Maximum number of threads per block */
|
||||||
|
int maxThreadsDim[3]; /**< Maximum size of each dimension of a block */
|
||||||
|
int maxGridSize[3]; /**< Maximum size of each dimension of a grid */
|
||||||
|
int clockRate; /**< Clock frequency in kilohertz */
|
||||||
|
size_t totalConstMem; /**< Constant memory available on device in bytes */
|
||||||
|
int major; /**< Major compute capability */
|
||||||
|
int minor; /**< Minor compute capability */
|
||||||
|
size_t textureAlignment; /**< Alignment requirement for textures */
|
||||||
|
size_t texturePitchAlignment; /**< Pitch alignment requirement for texture references bound to pitched memory */
|
||||||
|
int deviceOverlap; /**< Device can concurrently copy memory and execute a kernel. Deprecated. Use instead asyncEngineCount. */
|
||||||
|
int multiProcessorCount; /**< Number of multiprocessors on device */
|
||||||
|
int kernelExecTimeoutEnabled; /**< Specified whether there is a run time limit on kernels */
|
||||||
|
int integrated; /**< Device is integrated as opposed to discrete */
|
||||||
|
int canMapHostMemory; /**< Device can map host memory with cudaHostAlloc/cudaHostGetDevicePointer */
|
||||||
|
int computeMode; /**< Compute mode (See ::cudaComputeMode) */
|
||||||
|
int maxTexture1D; /**< Maximum 1D texture size */
|
||||||
|
int maxTexture1DMipmap; /**< Maximum 1D mipmapped texture size */
|
||||||
|
int maxTexture1DLinear; /**< Deprecated, do not use. Use cudaDeviceGetTexture1DLinearMaxWidth() or cuDeviceGetTexture1DLinearMaxWidth() instead. */
|
||||||
|
int maxTexture2D[2]; /**< Maximum 2D texture dimensions */
|
||||||
|
int maxTexture2DMipmap[2]; /**< Maximum 2D mipmapped texture dimensions */
|
||||||
|
int maxTexture2DLinear[3]; /**< Maximum dimensions (width, height, pitch) for 2D textures bound to pitched memory */
|
||||||
|
int maxTexture2DGather[2]; /**< Maximum 2D texture dimensions if texture gather operations have to be performed */
|
||||||
|
int maxTexture3D[3]; /**< Maximum 3D texture dimensions */
|
||||||
|
int maxTexture3DAlt[3]; /**< Maximum alternate 3D texture dimensions */
|
||||||
|
int maxTextureCubemap; /**< Maximum Cubemap texture dimensions */
|
||||||
|
int maxTexture1DLayered[2]; /**< Maximum 1D layered texture dimensions */
|
||||||
|
int maxTexture2DLayered[3]; /**< Maximum 2D layered texture dimensions */
|
||||||
|
int maxTextureCubemapLayered[2];/**< Maximum Cubemap layered texture dimensions */
|
||||||
|
int maxSurface1D; /**< Maximum 1D surface size */
|
||||||
|
int maxSurface2D[2]; /**< Maximum 2D surface dimensions */
|
||||||
|
int maxSurface3D[3]; /**< Maximum 3D surface dimensions */
|
||||||
|
int maxSurface1DLayered[2]; /**< Maximum 1D layered surface dimensions */
|
||||||
|
int maxSurface2DLayered[3]; /**< Maximum 2D layered surface dimensions */
|
||||||
|
int maxSurfaceCubemap; /**< Maximum Cubemap surface dimensions */
|
||||||
|
int maxSurfaceCubemapLayered[2];/**< Maximum Cubemap layered surface dimensions */
|
||||||
|
size_t surfaceAlignment; /**< Alignment requirements for surfaces */
|
||||||
|
int concurrentKernels; /**< Device can possibly execute multiple kernels concurrently */
|
||||||
|
int ECCEnabled; /**< Device has ECC support enabled */
|
||||||
|
int pciBusID; /**< PCI bus ID of the device */
|
||||||
|
int pciDeviceID; /**< PCI device ID of the device */
|
||||||
|
int pciDomainID; /**< PCI domain ID of the device */
|
||||||
|
int tccDriver; /**< 1 if device is a Tesla device using TCC driver, 0 otherwise */
|
||||||
|
int asyncEngineCount; /**< Number of asynchronous engines */
|
||||||
|
int unifiedAddressing; /**< Device shares a unified address space with the host */
|
||||||
|
int memoryClockRate; /**< Peak memory clock frequency in kilohertz */
|
||||||
|
int memoryBusWidth; /**< Global memory bus width in bits */
|
||||||
|
int l2CacheSize; /**< Size of L2 cache in bytes */
|
||||||
|
int persistingL2CacheMaxSize; /**< Device's maximum l2 persisting lines capacity setting in bytes */
|
||||||
|
int maxThreadsPerMultiProcessor;/**< Maximum resident threads per multiprocessor */
|
||||||
|
int streamPrioritiesSupported; /**< Device supports stream priorities */
|
||||||
|
int globalL1CacheSupported; /**< Device supports caching globals in L1 */
|
||||||
|
int localL1CacheSupported; /**< Device supports caching locals in L1 */
|
||||||
|
size_t sharedMemPerMultiprocessor; /**< Shared memory available per multiprocessor in bytes */
|
||||||
|
int regsPerMultiprocessor; /**< 32-bit registers available per multiprocessor */
|
||||||
|
int managedMemory; /**< Device supports allocating managed memory on this system */
|
||||||
|
int isMultiGpuBoard; /**< Device is on a multi-GPU board */
|
||||||
|
int multiGpuBoardGroupID; /**< Unique identifier for a group of devices on the same multi-GPU board */
|
||||||
|
int hostNativeAtomicSupported; /**< Link between the device and the host supports native atomic operations */
|
||||||
|
int singleToDoublePrecisionPerfRatio; /**< Ratio of single precision performance (in floating-point operations per second) to double precision performance */
|
||||||
|
int pageableMemoryAccess; /**< Device supports coherently accessing pageable memory without calling cudaHostRegister on it */
|
||||||
|
int concurrentManagedAccess; /**< Device can coherently access managed memory concurrently with the CPU */
|
||||||
|
int computePreemptionSupported; /**< Device supports Compute Preemption */
|
||||||
|
int canUseHostPointerForRegisteredMem; /**< Device can access host registered memory at the same virtual address as the CPU */
|
||||||
|
int cooperativeLaunch; /**< Device supports launching cooperative kernels via ::cudaLaunchCooperativeKernel */
|
||||||
|
int cooperativeMultiDeviceLaunch; /**< Deprecated, cudaLaunchCooperativeKernelMultiDevice is deprecated. */
|
||||||
|
size_t sharedMemPerBlockOptin; /**< Per device maximum shared memory per block usable by special opt in */
|
||||||
|
int pageableMemoryAccessUsesHostPageTables; /**< Device accesses pageable memory via the host's page tables */
|
||||||
|
int directManagedMemAccessFromHost; /**< Host can directly access managed memory on the device without migration. */
|
||||||
|
int maxBlocksPerMultiProcessor; /**< Maximum number of resident blocks per multiprocessor */
|
||||||
|
int accessPolicyMaxWindowSize; /**< The maximum value of ::cudaAccessPolicyWindow::num_bytes. */
|
||||||
|
size_t reservedSharedMemPerBlock; /**< Shared memory reserved by CUDA driver per block in bytes */
|
||||||
|
} cudaDeviceProp_t;
|
||||||
|
|
||||||
typedef struct cudart_handle {
|
typedef struct cudart_handle {
|
||||||
void *handle;
|
void *handle;
|
||||||
uint16_t verbose;
|
uint16_t verbose;
|
||||||
|
@ -38,23 +130,17 @@ typedef struct cudart_handle {
|
||||||
cudartReturn_t (*cudaGetDeviceCount)(int *);
|
cudartReturn_t (*cudaGetDeviceCount)(int *);
|
||||||
cudartReturn_t (*cudaDeviceGetAttribute)(int* value, cudartDeviceAttr_t attr, int device);
|
cudartReturn_t (*cudaDeviceGetAttribute)(int* value, cudartDeviceAttr_t attr, int device);
|
||||||
cudartReturn_t (*cudaDriverGetVersion) (int *driverVersion);
|
cudartReturn_t (*cudaDriverGetVersion) (int *driverVersion);
|
||||||
|
cudartReturn_t (*cudaGetDeviceProperties) (cudaDeviceProp_t* prop, int device);
|
||||||
} cudart_handle_t;
|
} cudart_handle_t;
|
||||||
|
|
||||||
typedef struct cudart_init_resp {
|
typedef struct cudart_init_resp {
|
||||||
char *err; // If err is non-null handle is invalid
|
char *err; // If err is non-null handle is invalid
|
||||||
cudart_handle_t ch;
|
cudart_handle_t ch;
|
||||||
|
int num_devices;
|
||||||
} cudart_init_resp_t;
|
} cudart_init_resp_t;
|
||||||
|
|
||||||
typedef struct cudart_compute_capability {
|
|
||||||
char *err;
|
|
||||||
int major;
|
|
||||||
int minor;
|
|
||||||
} cudart_compute_capability_t;
|
|
||||||
|
|
||||||
|
|
||||||
void cudart_init(char *cudart_lib_path, cudart_init_resp_t *resp);
|
void cudart_init(char *cudart_lib_path, cudart_init_resp_t *resp);
|
||||||
void cudart_check_vram(cudart_handle_t ch, mem_info_t *resp);
|
void cudart_check_vram(cudart_handle_t ch, int device_id, mem_info_t *resp);
|
||||||
void cudart_compute_capability(cudart_handle_t th, cudart_compute_capability_t *cc);
|
|
||||||
void cudart_release(cudart_handle_t ch);
|
void cudart_release(cudart_handle_t ch);
|
||||||
|
|
||||||
#endif // __GPU_INFO_CUDART_H__
|
#endif // __GPU_INFO_CUDART_H__
|
||||||
|
|
|
@ -1,221 +0,0 @@
|
||||||
#ifndef __APPLE__ // TODO - maybe consider nvidia support on intel macs?
|
|
||||||
|
|
||||||
#include <string.h>
|
|
||||||
|
|
||||||
#include "gpu_info_nvml.h"
|
|
||||||
|
|
||||||
void nvml_init(char *nvml_lib_path, nvml_init_resp_t *resp) {
|
|
||||||
nvmlReturn_t ret;
|
|
||||||
resp->err = NULL;
|
|
||||||
const int buflen = 256;
|
|
||||||
char buf[buflen + 1];
|
|
||||||
int i;
|
|
||||||
|
|
||||||
struct lookup {
|
|
||||||
char *s;
|
|
||||||
void **p;
|
|
||||||
} l[] = {
|
|
||||||
{"nvmlInit_v2", (void *)&resp->ch.nvmlInit_v2},
|
|
||||||
{"nvmlShutdown", (void *)&resp->ch.nvmlShutdown},
|
|
||||||
{"nvmlDeviceGetHandleByIndex", (void *)&resp->ch.nvmlDeviceGetHandleByIndex},
|
|
||||||
{"nvmlDeviceGetMemoryInfo", (void *)&resp->ch.nvmlDeviceGetMemoryInfo},
|
|
||||||
{"nvmlDeviceGetCount_v2", (void *)&resp->ch.nvmlDeviceGetCount_v2},
|
|
||||||
{"nvmlDeviceGetCudaComputeCapability", (void *)&resp->ch.nvmlDeviceGetCudaComputeCapability},
|
|
||||||
{"nvmlSystemGetDriverVersion", (void *)&resp->ch.nvmlSystemGetDriverVersion},
|
|
||||||
{"nvmlDeviceGetName", (void *)&resp->ch.nvmlDeviceGetName},
|
|
||||||
{"nvmlDeviceGetSerial", (void *)&resp->ch.nvmlDeviceGetSerial},
|
|
||||||
{"nvmlDeviceGetVbiosVersion", (void *)&resp->ch.nvmlDeviceGetVbiosVersion},
|
|
||||||
{"nvmlDeviceGetBoardPartNumber", (void *)&resp->ch.nvmlDeviceGetBoardPartNumber},
|
|
||||||
{"nvmlDeviceGetBrand", (void *)&resp->ch.nvmlDeviceGetBrand},
|
|
||||||
{NULL, NULL},
|
|
||||||
};
|
|
||||||
|
|
||||||
resp->ch.handle = LOAD_LIBRARY(nvml_lib_path, RTLD_LAZY);
|
|
||||||
if (!resp->ch.handle) {
|
|
||||||
char *msg = LOAD_ERR();
|
|
||||||
LOG(resp->ch.verbose, "library %s load err: %s\n", nvml_lib_path, msg);
|
|
||||||
snprintf(buf, buflen,
|
|
||||||
"Unable to load %s library to query for Nvidia GPUs: %s",
|
|
||||||
nvml_lib_path, msg);
|
|
||||||
free(msg);
|
|
||||||
resp->err = strdup(buf);
|
|
||||||
return;
|
|
||||||
}
|
|
||||||
|
|
||||||
// TODO once we've squashed the remaining corner cases remove this log
|
|
||||||
LOG(resp->ch.verbose, "wiring nvidia management library functions in %s\n", nvml_lib_path);
|
|
||||||
|
|
||||||
for (i = 0; l[i].s != NULL; i++) {
|
|
||||||
// TODO once we've squashed the remaining corner cases remove this log
|
|
||||||
LOG(resp->ch.verbose, "dlsym: %s\n", l[i].s);
|
|
||||||
|
|
||||||
*l[i].p = LOAD_SYMBOL(resp->ch.handle, l[i].s);
|
|
||||||
if (!l[i].p) {
|
|
||||||
resp->ch.handle = NULL;
|
|
||||||
char *msg = LOAD_ERR();
|
|
||||||
LOG(resp->ch.verbose, "dlerr: %s\n", msg);
|
|
||||||
UNLOAD_LIBRARY(resp->ch.handle);
|
|
||||||
snprintf(buf, buflen, "symbol lookup for %s failed: %s", l[i].s,
|
|
||||||
msg);
|
|
||||||
free(msg);
|
|
||||||
resp->err = strdup(buf);
|
|
||||||
return;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
ret = (*resp->ch.nvmlInit_v2)();
|
|
||||||
if (ret != NVML_SUCCESS) {
|
|
||||||
LOG(resp->ch.verbose, "nvmlInit_v2 err: %d\n", ret);
|
|
||||||
UNLOAD_LIBRARY(resp->ch.handle);
|
|
||||||
resp->ch.handle = NULL;
|
|
||||||
snprintf(buf, buflen, "nvml vram init failure: %d", ret);
|
|
||||||
resp->err = strdup(buf);
|
|
||||||
return;
|
|
||||||
}
|
|
||||||
|
|
||||||
// Report driver version if we're in verbose mode, ignore errors
|
|
||||||
ret = (*resp->ch.nvmlSystemGetDriverVersion)(buf, buflen);
|
|
||||||
if (ret != NVML_SUCCESS) {
|
|
||||||
LOG(resp->ch.verbose, "nvmlSystemGetDriverVersion failed: %d\n", ret);
|
|
||||||
} else {
|
|
||||||
LOG(resp->ch.verbose, "CUDA driver version: %s\n", buf);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
void nvml_check_vram(nvml_handle_t h, mem_info_t *resp) {
|
|
||||||
resp->err = NULL;
|
|
||||||
nvmlDevice_t device;
|
|
||||||
nvmlMemory_t memInfo = {0};
|
|
||||||
nvmlReturn_t ret;
|
|
||||||
const int buflen = 256;
|
|
||||||
char buf[buflen + 1];
|
|
||||||
int i;
|
|
||||||
|
|
||||||
if (h.handle == NULL) {
|
|
||||||
resp->err = strdup("nvml handle isn't initialized");
|
|
||||||
return;
|
|
||||||
}
|
|
||||||
|
|
||||||
ret = (*h.nvmlDeviceGetCount_v2)(&resp->count);
|
|
||||||
if (ret != NVML_SUCCESS) {
|
|
||||||
snprintf(buf, buflen, "unable to get device count: %d", ret);
|
|
||||||
resp->err = strdup(buf);
|
|
||||||
return;
|
|
||||||
}
|
|
||||||
|
|
||||||
resp->total = 0;
|
|
||||||
resp->free = 0;
|
|
||||||
for (i = 0; i < resp->count; i++) {
|
|
||||||
ret = (*h.nvmlDeviceGetHandleByIndex)(i, &device);
|
|
||||||
if (ret != NVML_SUCCESS) {
|
|
||||||
snprintf(buf, buflen, "unable to get device handle %d: %d", i, ret);
|
|
||||||
resp->err = strdup(buf);
|
|
||||||
return;
|
|
||||||
}
|
|
||||||
|
|
||||||
ret = (*h.nvmlDeviceGetMemoryInfo)(device, &memInfo);
|
|
||||||
if (ret != NVML_SUCCESS) {
|
|
||||||
snprintf(buf, buflen, "device memory info lookup failure %d: %d", i, ret);
|
|
||||||
resp->err = strdup(buf);
|
|
||||||
return;
|
|
||||||
}
|
|
||||||
if (h.verbose) {
|
|
||||||
nvmlBrandType_t brand = 0;
|
|
||||||
// When in verbose mode, report more information about
|
|
||||||
// the card we discover, but don't fail on error
|
|
||||||
ret = (*h.nvmlDeviceGetName)(device, buf, buflen);
|
|
||||||
if (ret != NVML_SUCCESS) {
|
|
||||||
LOG(h.verbose, "nvmlDeviceGetName failed: %d\n", ret);
|
|
||||||
} else {
|
|
||||||
LOG(h.verbose, "[%d] CUDA device name: %s\n", i, buf);
|
|
||||||
}
|
|
||||||
ret = (*h.nvmlDeviceGetBoardPartNumber)(device, buf, buflen);
|
|
||||||
if (ret != NVML_SUCCESS) {
|
|
||||||
LOG(h.verbose, "nvmlDeviceGetBoardPartNumber failed: %d\n", ret);
|
|
||||||
} else {
|
|
||||||
LOG(h.verbose, "[%d] CUDA part number: %s\n", i, buf);
|
|
||||||
}
|
|
||||||
ret = (*h.nvmlDeviceGetSerial)(device, buf, buflen);
|
|
||||||
if (ret != NVML_SUCCESS) {
|
|
||||||
LOG(h.verbose, "nvmlDeviceGetSerial failed: %d\n", ret);
|
|
||||||
} else {
|
|
||||||
LOG(h.verbose, "[%d] CUDA S/N: %s\n", i, buf);
|
|
||||||
}
|
|
||||||
ret = (*h.nvmlDeviceGetVbiosVersion)(device, buf, buflen);
|
|
||||||
if (ret != NVML_SUCCESS) {
|
|
||||||
LOG(h.verbose, "nvmlDeviceGetVbiosVersion failed: %d\n", ret);
|
|
||||||
} else {
|
|
||||||
LOG(h.verbose, "[%d] CUDA vbios version: %s\n", i, buf);
|
|
||||||
}
|
|
||||||
ret = (*h.nvmlDeviceGetBrand)(device, &brand);
|
|
||||||
if (ret != NVML_SUCCESS) {
|
|
||||||
LOG(h.verbose, "nvmlDeviceGetBrand failed: %d\n", ret);
|
|
||||||
} else {
|
|
||||||
LOG(h.verbose, "[%d] CUDA brand: %d\n", i, brand);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
LOG(h.verbose, "[%d] CUDA totalMem %ld\n", i, memInfo.total);
|
|
||||||
LOG(h.verbose, "[%d] CUDA freeMem %ld\n", i, memInfo.free);
|
|
||||||
|
|
||||||
resp->total += memInfo.total;
|
|
||||||
resp->free += memInfo.free;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
void nvml_compute_capability(nvml_handle_t h, nvml_compute_capability_t *resp) {
|
|
||||||
resp->err = NULL;
|
|
||||||
resp->major = 0;
|
|
||||||
resp->minor = 0;
|
|
||||||
nvmlDevice_t device;
|
|
||||||
int major = 0;
|
|
||||||
int minor = 0;
|
|
||||||
nvmlReturn_t ret;
|
|
||||||
const int buflen = 256;
|
|
||||||
char buf[buflen + 1];
|
|
||||||
int i;
|
|
||||||
|
|
||||||
if (h.handle == NULL) {
|
|
||||||
resp->err = strdup("nvml handle not initialized");
|
|
||||||
return;
|
|
||||||
}
|
|
||||||
|
|
||||||
unsigned int devices;
|
|
||||||
ret = (*h.nvmlDeviceGetCount_v2)(&devices);
|
|
||||||
if (ret != NVML_SUCCESS) {
|
|
||||||
snprintf(buf, buflen, "unable to get device count: %d", ret);
|
|
||||||
resp->err = strdup(buf);
|
|
||||||
return;
|
|
||||||
}
|
|
||||||
|
|
||||||
for (i = 0; i < devices; i++) {
|
|
||||||
ret = (*h.nvmlDeviceGetHandleByIndex)(i, &device);
|
|
||||||
if (ret != NVML_SUCCESS) {
|
|
||||||
snprintf(buf, buflen, "unable to get device handle %d: %d", i, ret);
|
|
||||||
resp->err = strdup(buf);
|
|
||||||
return;
|
|
||||||
}
|
|
||||||
|
|
||||||
ret = (*h.nvmlDeviceGetCudaComputeCapability)(device, &major, &minor);
|
|
||||||
if (ret != NVML_SUCCESS) {
|
|
||||||
snprintf(buf, buflen, "device compute capability lookup failure %d: %d", i, ret);
|
|
||||||
resp->err = strdup(buf);
|
|
||||||
return;
|
|
||||||
}
|
|
||||||
// Report the lowest major.minor we detect as that limits our compatibility
|
|
||||||
if (resp->major == 0 || resp->major > major ) {
|
|
||||||
resp->major = major;
|
|
||||||
resp->minor = minor;
|
|
||||||
} else if ( resp->major == major && resp->minor > minor ) {
|
|
||||||
resp->minor = minor;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
void nvml_release(nvml_handle_t h) {
|
|
||||||
LOG(h.verbose, "releasing nvml library\n");
|
|
||||||
UNLOAD_LIBRARY(h.handle);
|
|
||||||
h.handle = NULL;
|
|
||||||
}
|
|
||||||
|
|
||||||
#endif // __APPLE__
|
|
|
@ -1,57 +0,0 @@
|
||||||
#ifndef __APPLE__
|
|
||||||
#ifndef __GPU_INFO_NVML_H__
|
|
||||||
#define __GPU_INFO_NVML_H__
|
|
||||||
#include "gpu_info.h"
|
|
||||||
|
|
||||||
// Just enough typedef's to dlopen/dlsym for memory information
|
|
||||||
typedef enum nvmlReturn_enum {
|
|
||||||
NVML_SUCCESS = 0,
|
|
||||||
// Other values omitted for now...
|
|
||||||
} nvmlReturn_t;
|
|
||||||
typedef void *nvmlDevice_t; // Opaque is sufficient
|
|
||||||
typedef struct nvmlMemory_st {
|
|
||||||
unsigned long long total;
|
|
||||||
unsigned long long free;
|
|
||||||
unsigned long long used;
|
|
||||||
} nvmlMemory_t;
|
|
||||||
|
|
||||||
typedef enum nvmlBrandType_enum
|
|
||||||
{
|
|
||||||
NVML_BRAND_UNKNOWN = 0,
|
|
||||||
} nvmlBrandType_t;
|
|
||||||
|
|
||||||
typedef struct nvml_handle {
|
|
||||||
void *handle;
|
|
||||||
uint16_t verbose;
|
|
||||||
nvmlReturn_t (*nvmlInit_v2)(void);
|
|
||||||
nvmlReturn_t (*nvmlShutdown)(void);
|
|
||||||
nvmlReturn_t (*nvmlDeviceGetHandleByIndex)(unsigned int, nvmlDevice_t *);
|
|
||||||
nvmlReturn_t (*nvmlDeviceGetMemoryInfo)(nvmlDevice_t, nvmlMemory_t *);
|
|
||||||
nvmlReturn_t (*nvmlDeviceGetCount_v2)(unsigned int *);
|
|
||||||
nvmlReturn_t (*nvmlDeviceGetCudaComputeCapability)(nvmlDevice_t, int* major, int* minor);
|
|
||||||
nvmlReturn_t (*nvmlSystemGetDriverVersion) (char* version, unsigned int length);
|
|
||||||
nvmlReturn_t (*nvmlDeviceGetName) (nvmlDevice_t device, char* name, unsigned int length);
|
|
||||||
nvmlReturn_t (*nvmlDeviceGetSerial) (nvmlDevice_t device, char* serial, unsigned int length);
|
|
||||||
nvmlReturn_t (*nvmlDeviceGetVbiosVersion) (nvmlDevice_t device, char* version, unsigned int length);
|
|
||||||
nvmlReturn_t (*nvmlDeviceGetBoardPartNumber) (nvmlDevice_t device, char* partNumber, unsigned int length);
|
|
||||||
nvmlReturn_t (*nvmlDeviceGetBrand) (nvmlDevice_t device, nvmlBrandType_t* type);
|
|
||||||
} nvml_handle_t;
|
|
||||||
|
|
||||||
typedef struct nvml_init_resp {
|
|
||||||
char *err; // If err is non-null handle is invalid
|
|
||||||
nvml_handle_t ch;
|
|
||||||
} nvml_init_resp_t;
|
|
||||||
|
|
||||||
typedef struct nvml_compute_capability {
|
|
||||||
char *err;
|
|
||||||
int major;
|
|
||||||
int minor;
|
|
||||||
} nvml_compute_capability_t;
|
|
||||||
|
|
||||||
void nvml_init(char *nvml_lib_path, nvml_init_resp_t *resp);
|
|
||||||
void nvml_check_vram(nvml_handle_t ch, mem_info_t *resp);
|
|
||||||
void nvml_compute_capability(nvml_handle_t ch, nvml_compute_capability_t *cc);
|
|
||||||
void nvml_release(nvml_handle_t ch);
|
|
||||||
|
|
||||||
#endif // __GPU_INFO_NVML_H__
|
|
||||||
#endif // __APPLE__
|
|
|
@ -9,23 +9,16 @@ import (
|
||||||
|
|
||||||
func TestBasicGetGPUInfo(t *testing.T) {
|
func TestBasicGetGPUInfo(t *testing.T) {
|
||||||
info := GetGPUInfo()
|
info := GetGPUInfo()
|
||||||
assert.Contains(t, "cuda rocm cpu metal", info.Library)
|
assert.Greater(t, len(info), 0)
|
||||||
|
assert.Contains(t, "cuda rocm cpu metal", info[0].Library)
|
||||||
switch runtime.GOOS {
|
if info[0].Library != "cpu" {
|
||||||
case "darwin":
|
assert.Greater(t, info[0].TotalMemory, uint64(0))
|
||||||
// TODO - remove this once MacOS returns some size for CPU
|
assert.Greater(t, info[0].FreeMemory, uint64(0))
|
||||||
return
|
|
||||||
case "linux", "windows":
|
|
||||||
assert.Greater(t, info.TotalMemory, uint64(0))
|
|
||||||
assert.Greater(t, info.FreeMemory, uint64(0))
|
|
||||||
assert.Greater(t, info.DeviceCount, uint32(0))
|
|
||||||
default:
|
|
||||||
return
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestCPUMemInfo(t *testing.T) {
|
func TestCPUMemInfo(t *testing.T) {
|
||||||
info, err := getCPUMem()
|
info, err := GetCPUMem()
|
||||||
assert.NoError(t, err)
|
assert.NoError(t, err)
|
||||||
switch runtime.GOOS {
|
switch runtime.GOOS {
|
||||||
case "darwin":
|
case "darwin":
|
||||||
|
|
49
gpu/types.go
49
gpu/types.go
|
@ -3,7 +3,6 @@ package gpu
|
||||||
type memInfo struct {
|
type memInfo struct {
|
||||||
TotalMemory uint64 `json:"total_memory,omitempty"`
|
TotalMemory uint64 `json:"total_memory,omitempty"`
|
||||||
FreeMemory uint64 `json:"free_memory,omitempty"`
|
FreeMemory uint64 `json:"free_memory,omitempty"`
|
||||||
DeviceCount uint32 `json:"device_count,omitempty"`
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// Beginning of an `ollama info` command
|
// Beginning of an `ollama info` command
|
||||||
|
@ -17,11 +16,49 @@ type GpuInfo struct {
|
||||||
// MinimumMemory represents the minimum memory required to use the GPU
|
// MinimumMemory represents the minimum memory required to use the GPU
|
||||||
MinimumMemory uint64 `json:"-"`
|
MinimumMemory uint64 `json:"-"`
|
||||||
|
|
||||||
// TODO add other useful attributes about the card here for discovery information
|
// Any extra PATH/LD_LIBRARY_PATH dependencies required for the Library to operate properly
|
||||||
|
DependencyPath string `json:"lib_path,omitempty"`
|
||||||
|
|
||||||
|
// GPU information
|
||||||
|
ID string `json:"gpu_id"` // string to use for selection of this specific GPU
|
||||||
|
Name string `json:"name"` // user friendly name if available
|
||||||
|
Major int `json:"major,omitempty"` // Major compatibility version (CC or gfx)
|
||||||
|
Minor int `json:"minor,omitempty"` // Minor compatibility version (CC or gfx)
|
||||||
|
Patch int `json:"patch,omitempty"` // Patch compatibility only matters on AMD
|
||||||
|
|
||||||
|
// TODO other performance capability info to help in scheduling decisions
|
||||||
}
|
}
|
||||||
|
|
||||||
type Version struct {
|
type GpuInfoList []GpuInfo
|
||||||
Major uint
|
|
||||||
Minor uint
|
// Split up the set of gpu info's by Library and variant
|
||||||
Patch uint
|
func (l GpuInfoList) ByLibrary() []GpuInfoList {
|
||||||
|
resp := []GpuInfoList{}
|
||||||
|
libs := []string{}
|
||||||
|
for _, info := range l {
|
||||||
|
found := false
|
||||||
|
requested := info.Library
|
||||||
|
if info.Variant != "" {
|
||||||
|
requested += "_" + info.Variant
|
||||||
|
}
|
||||||
|
for i, lib := range libs {
|
||||||
|
if lib == requested {
|
||||||
|
resp[i] = append(resp[i], info)
|
||||||
|
found = true
|
||||||
|
break
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if !found {
|
||||||
|
libs = append(libs, info.Library)
|
||||||
|
resp = append(resp, []GpuInfo{info})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return resp
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Sort by Free Space
|
||||||
|
type ByFreeMemory []GpuInfo
|
||||||
|
|
||||||
|
func (a ByFreeMemory) Len() int { return len(a) }
|
||||||
|
func (a ByFreeMemory) Swap(i, j int) { a[i], a[j] = a[j], a[i] }
|
||||||
|
func (a ByFreeMemory) Less(i, j int) bool { return a[i].FreeMemory < a[j].FreeMemory }
|
||||||
|
|
|
@ -4,11 +4,14 @@ package integration
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"context"
|
"context"
|
||||||
"net/http"
|
"log/slog"
|
||||||
|
"os"
|
||||||
|
"runtime"
|
||||||
"testing"
|
"testing"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
"github.com/ollama/ollama/api"
|
"github.com/ollama/ollama/api"
|
||||||
|
"github.com/stretchr/testify/require"
|
||||||
)
|
)
|
||||||
|
|
||||||
func TestOrcaMiniBlueSky(t *testing.T) {
|
func TestOrcaMiniBlueSky(t *testing.T) {
|
||||||
|
@ -24,5 +27,44 @@ func TestOrcaMiniBlueSky(t *testing.T) {
|
||||||
"seed": 123,
|
"seed": 123,
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
GenerateTestHelper(ctx, t, &http.Client{}, req, []string{"rayleigh", "scattering"})
|
GenerateTestHelper(ctx, t, req, []string{"rayleigh", "scattering"})
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestUnicodeModelDir(t *testing.T) {
|
||||||
|
// This is only useful for Windows with utf-16 characters, so skip this test for other platforms
|
||||||
|
if runtime.GOOS != "windows" {
|
||||||
|
t.Skip("Unicode test only applicable to windows")
|
||||||
|
}
|
||||||
|
// Only works for local testing
|
||||||
|
if os.Getenv("OLLAMA_TEST_EXISTING") != "" {
|
||||||
|
t.Skip("TestUnicodeModelDir only works for local testing, skipping")
|
||||||
|
}
|
||||||
|
|
||||||
|
modelDir, err := os.MkdirTemp("", "ollama_埃")
|
||||||
|
require.NoError(t, err)
|
||||||
|
defer os.RemoveAll(modelDir)
|
||||||
|
slog.Info("unicode", "OLLAMA_MODELS", modelDir)
|
||||||
|
|
||||||
|
oldModelsDir := os.Getenv("OLLAMA_MODELS")
|
||||||
|
if oldModelsDir == "" {
|
||||||
|
defer os.Unsetenv("OLLAMA_MODELS")
|
||||||
|
} else {
|
||||||
|
defer os.Setenv("OLLAMA_MODELS", oldModelsDir)
|
||||||
|
}
|
||||||
|
err = os.Setenv("OLLAMA_MODELS", modelDir)
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
ctx, cancel := context.WithTimeout(context.Background(), 2*time.Minute)
|
||||||
|
defer cancel()
|
||||||
|
|
||||||
|
req := api.GenerateRequest{
|
||||||
|
Model: "orca-mini",
|
||||||
|
Prompt: "why is the sky blue?",
|
||||||
|
Stream: &stream,
|
||||||
|
Options: map[string]interface{}{
|
||||||
|
"temperature": 0,
|
||||||
|
"seed": 123,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
GenerateTestHelper(ctx, t, req, []string{"rayleigh", "scattering"})
|
||||||
}
|
}
|
||||||
|
|
225
integration/concurrency_test.go
Normal file
225
integration/concurrency_test.go
Normal file
|
@ -0,0 +1,225 @@
|
||||||
|
//go:build integration
|
||||||
|
|
||||||
|
package integration
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"log/slog"
|
||||||
|
"os"
|
||||||
|
"strconv"
|
||||||
|
"sync"
|
||||||
|
"testing"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"github.com/ollama/ollama/api"
|
||||||
|
"github.com/stretchr/testify/require"
|
||||||
|
)
|
||||||
|
|
||||||
|
func TestMultiModelConcurrency(t *testing.T) {
|
||||||
|
var (
|
||||||
|
req = [2]api.GenerateRequest{
|
||||||
|
{
|
||||||
|
Model: "orca-mini",
|
||||||
|
Prompt: "why is the ocean blue?",
|
||||||
|
Stream: &stream,
|
||||||
|
Options: map[string]interface{}{
|
||||||
|
"seed": 42,
|
||||||
|
"temperature": 0.0,
|
||||||
|
},
|
||||||
|
}, {
|
||||||
|
Model: "tinydolphin",
|
||||||
|
Prompt: "what is the origin of the us thanksgiving holiday?",
|
||||||
|
Stream: &stream,
|
||||||
|
Options: map[string]interface{}{
|
||||||
|
"seed": 42,
|
||||||
|
"temperature": 0.0,
|
||||||
|
},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
resp = [2][]string{
|
||||||
|
[]string{"sunlight"},
|
||||||
|
[]string{"england", "english", "massachusetts", "pilgrims"},
|
||||||
|
}
|
||||||
|
)
|
||||||
|
var wg sync.WaitGroup
|
||||||
|
wg.Add(len(req))
|
||||||
|
ctx, cancel := context.WithTimeout(context.Background(), time.Second*120)
|
||||||
|
defer cancel()
|
||||||
|
for i := 0; i < len(req); i++ {
|
||||||
|
go func(i int) {
|
||||||
|
defer wg.Done()
|
||||||
|
GenerateTestHelper(ctx, t, req[i], resp[i])
|
||||||
|
}(i)
|
||||||
|
}
|
||||||
|
wg.Wait()
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestIntegrationConcurrentPredictOrcaMini(t *testing.T) {
|
||||||
|
ctx, cancel := context.WithTimeout(context.Background(), 10*time.Minute) // GTX 750 2G card takes ~9 minutes
|
||||||
|
defer cancel()
|
||||||
|
client, _, cleanup := InitServerConnection(ctx, t)
|
||||||
|
defer cleanup()
|
||||||
|
|
||||||
|
req, resp := GenerateRequests()
|
||||||
|
// Get the server running (if applicable) warm the model up with a single initial request
|
||||||
|
DoGenerate(ctx, t, client, req[0], resp[0], 60*time.Second, 5*time.Second)
|
||||||
|
|
||||||
|
var wg sync.WaitGroup
|
||||||
|
wg.Add(len(req))
|
||||||
|
for i := 0; i < len(req); i++ {
|
||||||
|
go func(i int) {
|
||||||
|
defer wg.Done()
|
||||||
|
for j := 0; j < 5; j++ {
|
||||||
|
slog.Info("Starting", "req", i, "iter", j)
|
||||||
|
// On slower GPUs it can take a while to process the 4 concurrent requests
|
||||||
|
// so we allow a much longer initial timeout
|
||||||
|
DoGenerate(ctx, t, client, req[i], resp[i], 90*time.Second, 5*time.Second)
|
||||||
|
}
|
||||||
|
}(i)
|
||||||
|
}
|
||||||
|
wg.Wait()
|
||||||
|
}
|
||||||
|
|
||||||
|
// Stress the system if we know how much VRAM it has, and attempt to load more models than will fit
|
||||||
|
func TestMultiModelStress(t *testing.T) {
|
||||||
|
vram := os.Getenv("OLLAMA_MAX_VRAM")
|
||||||
|
if vram == "" {
|
||||||
|
t.Skip("OLLAMA_MAX_VRAM not specified, can't pick the right models for the stress test")
|
||||||
|
}
|
||||||
|
max, err := strconv.ParseUint(vram, 10, 64)
|
||||||
|
require.NoError(t, err)
|
||||||
|
const MB = uint64(1024 * 1024)
|
||||||
|
type model struct {
|
||||||
|
name string
|
||||||
|
size uint64 // Approximate amount of VRAM they typically use when fully loaded in VRAM
|
||||||
|
}
|
||||||
|
|
||||||
|
smallModels := []model{
|
||||||
|
{
|
||||||
|
name: "orca-mini",
|
||||||
|
size: 2992 * MB,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "phi",
|
||||||
|
size: 2616 * MB,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "gemma:2b",
|
||||||
|
size: 2364 * MB,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "stable-code:3b",
|
||||||
|
size: 2608 * MB,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "starcoder2:3b",
|
||||||
|
size: 2166 * MB,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
mediumModels := []model{
|
||||||
|
{
|
||||||
|
name: "llama2",
|
||||||
|
size: 5118 * MB,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "mistral",
|
||||||
|
size: 4620 * MB,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "orca-mini:7b",
|
||||||
|
size: 5118 * MB,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "dolphin-mistral",
|
||||||
|
size: 4620 * MB,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "gemma:7b",
|
||||||
|
size: 5000 * MB,
|
||||||
|
},
|
||||||
|
// TODO - uncomment this once #3565 is merged and this is rebased on it
|
||||||
|
// {
|
||||||
|
// name: "codellama:7b",
|
||||||
|
// size: 5118 * MB,
|
||||||
|
// },
|
||||||
|
}
|
||||||
|
|
||||||
|
// These seem to be too slow to be useful...
|
||||||
|
// largeModels := []model{
|
||||||
|
// {
|
||||||
|
// name: "llama2:13b",
|
||||||
|
// size: 7400 * MB,
|
||||||
|
// },
|
||||||
|
// {
|
||||||
|
// name: "codellama:13b",
|
||||||
|
// size: 7400 * MB,
|
||||||
|
// },
|
||||||
|
// {
|
||||||
|
// name: "orca-mini:13b",
|
||||||
|
// size: 7400 * MB,
|
||||||
|
// },
|
||||||
|
// {
|
||||||
|
// name: "gemma:7b",
|
||||||
|
// size: 5000 * MB,
|
||||||
|
// },
|
||||||
|
// {
|
||||||
|
// name: "starcoder2:15b",
|
||||||
|
// size: 9100 * MB,
|
||||||
|
// },
|
||||||
|
// }
|
||||||
|
|
||||||
|
var chosenModels []model
|
||||||
|
switch {
|
||||||
|
case max < 10000*MB:
|
||||||
|
slog.Info("selecting small models")
|
||||||
|
chosenModels = smallModels
|
||||||
|
// case max < 30000*MB:
|
||||||
|
default:
|
||||||
|
slog.Info("selecting medium models")
|
||||||
|
chosenModels = mediumModels
|
||||||
|
// default:
|
||||||
|
// slog.Info("selecting large models")
|
||||||
|
// chosenModels = largModels
|
||||||
|
}
|
||||||
|
|
||||||
|
req, resp := GenerateRequests()
|
||||||
|
|
||||||
|
for i := range req {
|
||||||
|
if i > len(chosenModels) {
|
||||||
|
break
|
||||||
|
}
|
||||||
|
req[i].Model = chosenModels[i].name
|
||||||
|
}
|
||||||
|
|
||||||
|
ctx, cancel := context.WithTimeout(context.Background(), 15*time.Minute) // TODO baseline -- 10m too short
|
||||||
|
defer cancel()
|
||||||
|
client, _, cleanup := InitServerConnection(ctx, t)
|
||||||
|
defer cleanup()
|
||||||
|
|
||||||
|
// Make sure all the models are pulled before we get started
|
||||||
|
for _, r := range req {
|
||||||
|
require.NoError(t, PullIfMissing(ctx, client, r.Model))
|
||||||
|
}
|
||||||
|
|
||||||
|
var wg sync.WaitGroup
|
||||||
|
consumed := uint64(256 * MB) // Assume some baseline usage
|
||||||
|
for i := 0; i < len(req); i++ {
|
||||||
|
// Always get at least 2 models, but dont' overshoot VRAM too much or we'll take too long
|
||||||
|
if i > 1 && consumed > max {
|
||||||
|
slog.Info("achieved target vram exhaustion", "count", i, "vramMB", max/1024/1024, "modelsMB", consumed/1024/1024)
|
||||||
|
break
|
||||||
|
}
|
||||||
|
consumed += chosenModels[i].size
|
||||||
|
slog.Info("target vram", "count", i, "vramMB", max/1024/1024, "modelsMB", consumed/1024/1024)
|
||||||
|
|
||||||
|
wg.Add(1)
|
||||||
|
go func(i int) {
|
||||||
|
defer wg.Done()
|
||||||
|
for j := 0; j < 3; j++ {
|
||||||
|
slog.Info("Starting", "req", i, "iter", j, "model", req[i].Model)
|
||||||
|
DoGenerate(ctx, t, client, req[i], resp[i], 90*time.Second, 5*time.Second)
|
||||||
|
}
|
||||||
|
}(i)
|
||||||
|
}
|
||||||
|
wg.Wait()
|
||||||
|
}
|
|
@ -4,7 +4,6 @@ package integration
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"context"
|
"context"
|
||||||
"net/http"
|
|
||||||
"testing"
|
"testing"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
|
@ -25,5 +24,5 @@ func TestContextExhaustion(t *testing.T) {
|
||||||
"num_ctx": 128,
|
"num_ctx": 128,
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
GenerateTestHelper(ctx, t, &http.Client{}, req, []string{"once", "upon", "lived"})
|
GenerateTestHelper(ctx, t, req, []string{"once", "upon", "lived"})
|
||||||
}
|
}
|
||||||
|
|
|
@ -5,7 +5,6 @@ package integration
|
||||||
import (
|
import (
|
||||||
"context"
|
"context"
|
||||||
"encoding/base64"
|
"encoding/base64"
|
||||||
"net/http"
|
|
||||||
"testing"
|
"testing"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
|
@ -29,10 +28,11 @@ func TestIntegrationMultimodal(t *testing.T) {
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
|
|
||||||
resp := "the ollamas"
|
// Note: sometimes it returns "the ollamas" sometimes "the ollams"
|
||||||
|
resp := "the ollam"
|
||||||
ctx, cancel := context.WithTimeout(context.Background(), 3*time.Minute)
|
ctx, cancel := context.WithTimeout(context.Background(), 3*time.Minute)
|
||||||
defer cancel()
|
defer cancel()
|
||||||
GenerateTestHelper(ctx, t, &http.Client{}, req, []string{resp})
|
GenerateTestHelper(ctx, t, req, []string{resp})
|
||||||
}
|
}
|
||||||
|
|
||||||
const imageEncoding = `iVBORw0KGgoAAAANSUhEUgAAANIAAAB4CAYAAACHHqzKAAAAAXNSR0IArs4c6QAAAIRlWElmTU0AKgAAAAgABQESAAMAAAABAAEAAAEaAAUAAAABAAAASgEb
|
const imageEncoding = `iVBORw0KGgoAAAANSUhEUgAAANIAAAB4CAYAAACHHqzKAAAAAXNSR0IArs4c6QAAAIRlWElmTU0AKgAAAAgABQESAAMAAAABAAEAAAEaAAUAAAABAAAASgEb
|
||||||
|
|
|
@ -4,8 +4,6 @@ package integration
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"context"
|
"context"
|
||||||
"net/http"
|
|
||||||
"sync"
|
|
||||||
"testing"
|
"testing"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
|
@ -45,25 +43,5 @@ var (
|
||||||
func TestIntegrationSimpleOrcaMini(t *testing.T) {
|
func TestIntegrationSimpleOrcaMini(t *testing.T) {
|
||||||
ctx, cancel := context.WithTimeout(context.Background(), time.Second*120)
|
ctx, cancel := context.WithTimeout(context.Background(), time.Second*120)
|
||||||
defer cancel()
|
defer cancel()
|
||||||
GenerateTestHelper(ctx, t, &http.Client{}, req[0], resp[0])
|
GenerateTestHelper(ctx, t, req[0], resp[0])
|
||||||
}
|
}
|
||||||
|
|
||||||
// TODO
|
|
||||||
// The server always loads a new runner and closes the old one, which forces serial execution
|
|
||||||
// At present this test case fails with concurrency problems. Eventually we should try to
|
|
||||||
// get true concurrency working with n_parallel support in the backend
|
|
||||||
func TestIntegrationConcurrentPredictOrcaMini(t *testing.T) {
|
|
||||||
var wg sync.WaitGroup
|
|
||||||
wg.Add(len(req))
|
|
||||||
ctx, cancel := context.WithTimeout(context.Background(), time.Second*120)
|
|
||||||
defer cancel()
|
|
||||||
for i := 0; i < len(req); i++ {
|
|
||||||
go func(i int) {
|
|
||||||
defer wg.Done()
|
|
||||||
GenerateTestHelper(ctx, t, &http.Client{}, req[i], resp[i])
|
|
||||||
}(i)
|
|
||||||
}
|
|
||||||
wg.Wait()
|
|
||||||
}
|
|
||||||
|
|
||||||
// TODO - create a parallel test with 2 different models once we support concurrency
|
|
||||||
|
|
|
@ -5,13 +5,14 @@ package integration
|
||||||
import (
|
import (
|
||||||
"bytes"
|
"bytes"
|
||||||
"context"
|
"context"
|
||||||
"encoding/json"
|
"errors"
|
||||||
"fmt"
|
"fmt"
|
||||||
"io"
|
"io"
|
||||||
"log/slog"
|
"log/slog"
|
||||||
"math/rand"
|
"math/rand"
|
||||||
"net"
|
"net"
|
||||||
"net/http"
|
"net/http"
|
||||||
|
"net/url"
|
||||||
"os"
|
"os"
|
||||||
"path/filepath"
|
"path/filepath"
|
||||||
"runtime"
|
"runtime"
|
||||||
|
@ -23,9 +24,13 @@ import (
|
||||||
|
|
||||||
"github.com/ollama/ollama/api"
|
"github.com/ollama/ollama/api"
|
||||||
"github.com/ollama/ollama/app/lifecycle"
|
"github.com/ollama/ollama/app/lifecycle"
|
||||||
"github.com/stretchr/testify/assert"
|
"github.com/stretchr/testify/require"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
func Init() {
|
||||||
|
lifecycle.InitLogging()
|
||||||
|
}
|
||||||
|
|
||||||
func FindPort() string {
|
func FindPort() string {
|
||||||
port := 0
|
port := 0
|
||||||
if a, err := net.ResolveTCPAddr("tcp", "localhost:0"); err == nil {
|
if a, err := net.ResolveTCPAddr("tcp", "localhost:0"); err == nil {
|
||||||
|
@ -41,7 +46,7 @@ func FindPort() string {
|
||||||
return strconv.Itoa(port)
|
return strconv.Itoa(port)
|
||||||
}
|
}
|
||||||
|
|
||||||
func GetTestEndpoint() (string, string) {
|
func GetTestEndpoint() (*api.Client, string) {
|
||||||
defaultPort := "11434"
|
defaultPort := "11434"
|
||||||
ollamaHost := os.Getenv("OLLAMA_HOST")
|
ollamaHost := os.Getenv("OLLAMA_HOST")
|
||||||
|
|
||||||
|
@ -67,16 +72,20 @@ func GetTestEndpoint() (string, string) {
|
||||||
port = FindPort()
|
port = FindPort()
|
||||||
}
|
}
|
||||||
|
|
||||||
url := fmt.Sprintf("%s:%s", host, port)
|
slog.Info("server connection", "host", host, "port", port)
|
||||||
slog.Info("server connection", "url", url)
|
|
||||||
return scheme, url
|
return api.NewClient(
|
||||||
|
&url.URL{
|
||||||
|
Scheme: scheme,
|
||||||
|
Host: net.JoinHostPort(host, port),
|
||||||
|
},
|
||||||
|
http.DefaultClient), fmt.Sprintf("%s:%s", host, port)
|
||||||
}
|
}
|
||||||
|
|
||||||
// TODO make fanicier, grab logs, etc.
|
|
||||||
var serverMutex sync.Mutex
|
var serverMutex sync.Mutex
|
||||||
var serverReady bool
|
var serverReady bool
|
||||||
|
|
||||||
func StartServer(ctx context.Context, ollamaHost string) error {
|
func startServer(ctx context.Context, ollamaHost string) error {
|
||||||
// Make sure the server has been built
|
// Make sure the server has been built
|
||||||
CLIName, err := filepath.Abs("../ollama")
|
CLIName, err := filepath.Abs("../ollama")
|
||||||
if err != nil {
|
if err != nil {
|
||||||
|
@ -125,67 +134,76 @@ func StartServer(ctx context.Context, ollamaHost string) error {
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func PullIfMissing(ctx context.Context, client *http.Client, scheme, testEndpoint, modelName string) error {
|
func PullIfMissing(ctx context.Context, client *api.Client, modelName string) error {
|
||||||
slog.Info("checking status of model", "model", modelName)
|
slog.Info("checking status of model", "model", modelName)
|
||||||
showReq := &api.ShowRequest{Name: modelName}
|
showReq := &api.ShowRequest{Name: modelName}
|
||||||
requestJSON, err := json.Marshal(showReq)
|
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
|
|
||||||
req, err := http.NewRequest("POST", scheme+"://"+testEndpoint+"/api/show", bytes.NewReader(requestJSON))
|
showCtx, cancel := context.WithDeadlineCause(
|
||||||
if err != nil {
|
ctx,
|
||||||
|
time.Now().Add(5*time.Second),
|
||||||
|
fmt.Errorf("show for existing model %s took too long", modelName),
|
||||||
|
)
|
||||||
|
defer cancel()
|
||||||
|
_, err := client.Show(showCtx, showReq)
|
||||||
|
var statusError api.StatusError
|
||||||
|
switch {
|
||||||
|
case errors.As(err, &statusError) && statusError.StatusCode == http.StatusNotFound:
|
||||||
|
break
|
||||||
|
case err != nil:
|
||||||
return err
|
return err
|
||||||
}
|
default:
|
||||||
|
|
||||||
// Make the request with the HTTP client
|
|
||||||
response, err := client.Do(req.WithContext(ctx))
|
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
defer response.Body.Close()
|
|
||||||
if response.StatusCode == 200 {
|
|
||||||
slog.Info("model already present", "model", modelName)
|
slog.Info("model already present", "model", modelName)
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
slog.Info("model missing", "status", response.StatusCode)
|
slog.Info("model missing", "model", modelName)
|
||||||
|
|
||||||
|
stallDuration := 30 * time.Second // This includes checksum verification, which can take a while on larger models
|
||||||
|
stallTimer := time.NewTimer(stallDuration)
|
||||||
|
fn := func(resp api.ProgressResponse) error {
|
||||||
|
// fmt.Print(".")
|
||||||
|
if !stallTimer.Reset(stallDuration) {
|
||||||
|
return fmt.Errorf("stall was detected, aborting status reporting")
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
stream := true
|
||||||
pullReq := &api.PullRequest{Name: modelName, Stream: &stream}
|
pullReq := &api.PullRequest{Name: modelName, Stream: &stream}
|
||||||
requestJSON, err = json.Marshal(pullReq)
|
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
|
|
||||||
req, err = http.NewRequest("POST", scheme+"://"+testEndpoint+"/api/pull", bytes.NewReader(requestJSON))
|
var pullError error
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
slog.Info("pulling", "model", modelName)
|
|
||||||
|
|
||||||
response, err = client.Do(req.WithContext(ctx))
|
done := make(chan int)
|
||||||
if err != nil {
|
go func() {
|
||||||
return err
|
pullError = client.Pull(ctx, pullReq, fn)
|
||||||
|
done <- 0
|
||||||
|
}()
|
||||||
|
|
||||||
|
select {
|
||||||
|
case <-stallTimer.C:
|
||||||
|
return fmt.Errorf("download stalled")
|
||||||
|
case <-done:
|
||||||
|
return pullError
|
||||||
}
|
}
|
||||||
defer response.Body.Close()
|
|
||||||
if response.StatusCode != 200 {
|
|
||||||
return fmt.Errorf("failed to pull model") // TODO more details perhaps
|
|
||||||
}
|
|
||||||
slog.Info("model pulled", "model", modelName)
|
|
||||||
return nil
|
|
||||||
}
|
}
|
||||||
|
|
||||||
var serverProcMutex sync.Mutex
|
var serverProcMutex sync.Mutex
|
||||||
|
|
||||||
func GenerateTestHelper(ctx context.Context, t *testing.T, client *http.Client, genReq api.GenerateRequest, anyResp []string) {
|
// Returns an Client, the testEndpoint, and a cleanup function, fails the test on errors
|
||||||
|
// Starts the server if needed
|
||||||
// TODO maybe stuff in an init routine?
|
func InitServerConnection(ctx context.Context, t *testing.T) (*api.Client, string, func()) {
|
||||||
lifecycle.InitLogging()
|
client, testEndpoint := GetTestEndpoint()
|
||||||
|
if os.Getenv("OLLAMA_TEST_EXISTING") == "" {
|
||||||
requestJSON, err := json.Marshal(genReq)
|
serverProcMutex.Lock()
|
||||||
if err != nil {
|
fp, err := os.CreateTemp("", "ollama-server-*.log")
|
||||||
t.Fatalf("Error serializing request: %v", err)
|
if err != nil {
|
||||||
|
t.Fatalf("failed to generate log file: %s", err)
|
||||||
|
}
|
||||||
|
lifecycle.ServerLogFile = fp.Name()
|
||||||
|
fp.Close()
|
||||||
|
require.NoError(t, startServer(ctx, testEndpoint))
|
||||||
}
|
}
|
||||||
defer func() {
|
|
||||||
|
return client, testEndpoint, func() {
|
||||||
if os.Getenv("OLLAMA_TEST_EXISTING") == "" {
|
if os.Getenv("OLLAMA_TEST_EXISTING") == "" {
|
||||||
defer serverProcMutex.Unlock()
|
defer serverProcMutex.Unlock()
|
||||||
if t.Failed() {
|
if t.Failed() {
|
||||||
|
@ -203,63 +221,118 @@ func GenerateTestHelper(ctx context.Context, t *testing.T, client *http.Client,
|
||||||
os.Stderr.Write(data)
|
os.Stderr.Write(data)
|
||||||
slog.Warn("END OF SERVER")
|
slog.Warn("END OF SERVER")
|
||||||
}
|
}
|
||||||
err = os.Remove(lifecycle.ServerLogFile)
|
err := os.Remove(lifecycle.ServerLogFile)
|
||||||
if err != nil && !os.IsNotExist(err) {
|
if err != nil && !os.IsNotExist(err) {
|
||||||
slog.Warn("failed to cleanup", "logfile", lifecycle.ServerLogFile, "error", err)
|
slog.Warn("failed to cleanup", "logfile", lifecycle.ServerLogFile, "error", err)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}()
|
|
||||||
scheme, testEndpoint := GetTestEndpoint()
|
|
||||||
|
|
||||||
if os.Getenv("OLLAMA_TEST_EXISTING") == "" {
|
|
||||||
serverProcMutex.Lock()
|
|
||||||
fp, err := os.CreateTemp("", "ollama-server-*.log")
|
|
||||||
if err != nil {
|
|
||||||
t.Fatalf("failed to generate log file: %s", err)
|
|
||||||
}
|
|
||||||
lifecycle.ServerLogFile = fp.Name()
|
|
||||||
fp.Close()
|
|
||||||
assert.NoError(t, StartServer(ctx, testEndpoint))
|
|
||||||
}
|
}
|
||||||
|
}
|
||||||
err = PullIfMissing(ctx, client, scheme, testEndpoint, genReq.Model)
|
|
||||||
if err != nil {
|
func GenerateTestHelper(ctx context.Context, t *testing.T, genReq api.GenerateRequest, anyResp []string) {
|
||||||
t.Fatalf("Error pulling model: %v", err)
|
client, _, cleanup := InitServerConnection(ctx, t)
|
||||||
}
|
defer cleanup()
|
||||||
|
require.NoError(t, PullIfMissing(ctx, client, genReq.Model))
|
||||||
// Make the request and get the response
|
DoGenerate(ctx, t, client, genReq, anyResp, 30*time.Second, 10*time.Second)
|
||||||
req, err := http.NewRequest("POST", scheme+"://"+testEndpoint+"/api/generate", bytes.NewReader(requestJSON))
|
}
|
||||||
if err != nil {
|
|
||||||
t.Fatalf("Error creating request: %v", err)
|
func DoGenerate(ctx context.Context, t *testing.T, client *api.Client, genReq api.GenerateRequest, anyResp []string, initialTimeout, streamTimeout time.Duration) {
|
||||||
}
|
stallTimer := time.NewTimer(initialTimeout)
|
||||||
|
var buf bytes.Buffer
|
||||||
// Set the content type for the request
|
fn := func(response api.GenerateResponse) error {
|
||||||
req.Header.Set("Content-Type", "application/json")
|
// fmt.Print(".")
|
||||||
|
buf.Write([]byte(response.Response))
|
||||||
// Make the request with the HTTP client
|
if !stallTimer.Reset(streamTimeout) {
|
||||||
response, err := client.Do(req.WithContext(ctx))
|
return fmt.Errorf("stall was detected while streaming response, aborting")
|
||||||
if err != nil {
|
}
|
||||||
t.Fatalf("Error making request: %v", err)
|
return nil
|
||||||
}
|
}
|
||||||
defer response.Body.Close()
|
|
||||||
body, err := io.ReadAll(response.Body)
|
stream := true
|
||||||
assert.NoError(t, err)
|
genReq.Stream = &stream
|
||||||
assert.Equal(t, response.StatusCode, 200, string(body))
|
done := make(chan int)
|
||||||
|
var genErr error
|
||||||
// Verify the response is valid JSON
|
go func() {
|
||||||
var payload api.GenerateResponse
|
genErr = client.Generate(ctx, &genReq, fn)
|
||||||
err = json.Unmarshal(body, &payload)
|
done <- 0
|
||||||
if err != nil {
|
}()
|
||||||
assert.NoError(t, err, body)
|
|
||||||
}
|
select {
|
||||||
|
case <-stallTimer.C:
|
||||||
// Verify the response contains the expected data
|
if buf.Len() == 0 {
|
||||||
atLeastOne := false
|
t.Errorf("generate never started. Timed out after :%s", initialTimeout.String())
|
||||||
for _, resp := range anyResp {
|
} else {
|
||||||
if strings.Contains(strings.ToLower(payload.Response), resp) {
|
t.Errorf("generate stalled. Response so far:%s", buf.String())
|
||||||
atLeastOne = true
|
}
|
||||||
break
|
case <-done:
|
||||||
}
|
require.NoError(t, genErr, "failed with %s request prompt %s ", genReq.Model, genReq.Prompt)
|
||||||
}
|
// Verify the response contains the expected data
|
||||||
assert.True(t, atLeastOne, "none of %v found in %s", anyResp, payload.Response)
|
response := buf.String()
|
||||||
|
atLeastOne := false
|
||||||
|
for _, resp := range anyResp {
|
||||||
|
if strings.Contains(strings.ToLower(response), resp) {
|
||||||
|
atLeastOne = true
|
||||||
|
break
|
||||||
|
}
|
||||||
|
}
|
||||||
|
require.True(t, atLeastOne, "none of %v found in %s", anyResp, response)
|
||||||
|
slog.Info("test pass", "model", genReq.Model, "prompt", genReq.Prompt, "contains", anyResp, "response", response)
|
||||||
|
case <-ctx.Done():
|
||||||
|
t.Error("outer test context done while waiting for generate")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Generate a set of requests
|
||||||
|
// By default each request uses orca-mini as the model
|
||||||
|
func GenerateRequests() ([]api.GenerateRequest, [][]string) {
|
||||||
|
return []api.GenerateRequest{
|
||||||
|
{
|
||||||
|
Model: "orca-mini",
|
||||||
|
Prompt: "why is the ocean blue?",
|
||||||
|
Stream: &stream,
|
||||||
|
Options: map[string]interface{}{
|
||||||
|
"seed": 42,
|
||||||
|
"temperature": 0.0,
|
||||||
|
},
|
||||||
|
}, {
|
||||||
|
Model: "orca-mini",
|
||||||
|
Prompt: "why is the color of dirt brown?",
|
||||||
|
Stream: &stream,
|
||||||
|
Options: map[string]interface{}{
|
||||||
|
"seed": 42,
|
||||||
|
"temperature": 0.0,
|
||||||
|
},
|
||||||
|
}, {
|
||||||
|
Model: "orca-mini",
|
||||||
|
Prompt: "what is the origin of the us thanksgiving holiday?",
|
||||||
|
Stream: &stream,
|
||||||
|
Options: map[string]interface{}{
|
||||||
|
"seed": 42,
|
||||||
|
"temperature": 0.0,
|
||||||
|
},
|
||||||
|
}, {
|
||||||
|
Model: "orca-mini",
|
||||||
|
Prompt: "what is the origin of independence day?",
|
||||||
|
Stream: &stream,
|
||||||
|
Options: map[string]interface{}{
|
||||||
|
"seed": 42,
|
||||||
|
"temperature": 0.0,
|
||||||
|
},
|
||||||
|
}, {
|
||||||
|
Model: "orca-mini",
|
||||||
|
Prompt: "what is the composition of air?",
|
||||||
|
Stream: &stream,
|
||||||
|
Options: map[string]interface{}{
|
||||||
|
"seed": 42,
|
||||||
|
"temperature": 0.0,
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
[][]string{
|
||||||
|
[]string{"sunlight"},
|
||||||
|
[]string{"soil", "organic", "earth", "black", "tan"},
|
||||||
|
[]string{"england", "english", "massachusetts", "pilgrims"},
|
||||||
|
[]string{"fourth", "july", "declaration", "independence"},
|
||||||
|
[]string{"nitrogen", "oxygen", "carbon", "dioxide"},
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
162
llm/memory.go
Normal file
162
llm/memory.go
Normal file
|
@ -0,0 +1,162 @@
|
||||||
|
package llm
|
||||||
|
|
||||||
|
import (
|
||||||
|
"fmt"
|
||||||
|
"log/slog"
|
||||||
|
"strings"
|
||||||
|
|
||||||
|
"github.com/ollama/ollama/api"
|
||||||
|
"github.com/ollama/ollama/format"
|
||||||
|
"github.com/ollama/ollama/gpu"
|
||||||
|
)
|
||||||
|
|
||||||
|
// This algorithm looks for a complete fit to determine if we need to unload other models
|
||||||
|
func PredictServerFit(allGpus gpu.GpuInfoList, ggml *GGML, adapters, projectors []string, opts api.Options) (bool, uint64) {
|
||||||
|
var estimatedVRAM uint64
|
||||||
|
if opts.NumCtx > int(ggml.KV().ContextLength()) {
|
||||||
|
slog.Warn("requested context length is greater than model max context length", "requested", opts.NumCtx, "model", ggml.KV().ContextLength())
|
||||||
|
opts.NumCtx = int(ggml.KV().ContextLength())
|
||||||
|
}
|
||||||
|
|
||||||
|
if opts.NumCtx < 4 {
|
||||||
|
opts.NumCtx = 4
|
||||||
|
}
|
||||||
|
|
||||||
|
// Split up the GPUs by type and try them
|
||||||
|
for _, gpus := range allGpus.ByLibrary() {
|
||||||
|
var layerCount int
|
||||||
|
layerCount, estimatedVRAM = EstimateGPULayers(gpus, ggml, projectors, opts)
|
||||||
|
if opts.NumGPU < 0 {
|
||||||
|
if layerCount > 0 && layerCount >= int(ggml.KV().BlockCount()+1) {
|
||||||
|
return true, estimatedVRAM
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
if layerCount > 0 && layerCount >= opts.NumGPU {
|
||||||
|
return true, estimatedVRAM
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return false, estimatedVRAM
|
||||||
|
}
|
||||||
|
|
||||||
|
// Given a model and one or more GPU targets, predict how many layers and bytes we can load
|
||||||
|
// The GPUs provided must all be the same Library
|
||||||
|
func EstimateGPULayers(gpus []gpu.GpuInfo, ggml *GGML, projectors []string, opts api.Options) (int, uint64) {
|
||||||
|
if gpus[0].Library == "cpu" {
|
||||||
|
return 0, 0
|
||||||
|
}
|
||||||
|
var memoryAvailable uint64
|
||||||
|
for _, info := range gpus {
|
||||||
|
memoryAvailable += info.FreeMemory
|
||||||
|
}
|
||||||
|
slog.Debug("evaluating", "library", gpus[0].Library, "gpu_count", len(gpus), "available", format.HumanBytes2(memoryAvailable))
|
||||||
|
|
||||||
|
// TODO - this is probably wrong, first GPU vs secondaries will have different overheads
|
||||||
|
memoryMinimum := gpus[0].MinimumMemory
|
||||||
|
|
||||||
|
for _, projector := range projectors {
|
||||||
|
memoryMinimum += projectorMemoryRequirements(projector)
|
||||||
|
|
||||||
|
// multimodal models require at least 2048 context
|
||||||
|
opts.NumCtx = max(opts.NumCtx, 2048)
|
||||||
|
}
|
||||||
|
|
||||||
|
// fp16 k,v = (1 (k) + 1 (v)) * sizeof(float16) * n_ctx * n_layer * n_embd / n_head * n_head_kv
|
||||||
|
var kv uint64 = 2 * 2 * uint64(opts.NumCtx) * ggml.KV().BlockCount() * ggml.KV().EmbeddingLength() / ggml.KV().HeadCount() * ggml.KV().HeadCountKV()
|
||||||
|
|
||||||
|
graphPartialOffload, graphFullOffload := ggml.GraphSize(uint64(opts.NumCtx), uint64(min(opts.NumCtx, opts.NumBatch)))
|
||||||
|
if graphPartialOffload == 0 {
|
||||||
|
graphPartialOffload = ggml.KV().GQA() * kv / 6
|
||||||
|
}
|
||||||
|
|
||||||
|
if graphFullOffload == 0 {
|
||||||
|
graphFullOffload = graphPartialOffload
|
||||||
|
}
|
||||||
|
|
||||||
|
graphFullOffload *= uint64(len(gpus))
|
||||||
|
graphPartialOffload *= uint64(len(gpus))
|
||||||
|
|
||||||
|
// memoryRequiredTotal represents the memory required for full GPU offloading (all layers)
|
||||||
|
memoryRequiredTotal := memoryMinimum + graphFullOffload
|
||||||
|
|
||||||
|
// memoryRequiredPartial represents the memory required for partial GPU offloading (n > 0, n < layers)
|
||||||
|
memoryRequiredPartial := memoryMinimum + graphPartialOffload
|
||||||
|
|
||||||
|
if memoryRequiredPartial > memoryAvailable {
|
||||||
|
slog.Debug("insufficient VRAM to load any model layers")
|
||||||
|
return 0, 0
|
||||||
|
}
|
||||||
|
|
||||||
|
var layerCount int
|
||||||
|
layers := ggml.Tensors().Layers()
|
||||||
|
for i := 0; i < int(ggml.KV().BlockCount()); i++ {
|
||||||
|
memoryLayer := layers[fmt.Sprintf("blk.%d", i)].size()
|
||||||
|
|
||||||
|
// KV is proportional to the number of layers
|
||||||
|
memoryLayer += kv / ggml.KV().BlockCount()
|
||||||
|
|
||||||
|
memoryRequiredTotal += memoryLayer
|
||||||
|
if memoryAvailable > memoryRequiredPartial+memoryLayer {
|
||||||
|
memoryRequiredPartial += memoryLayer
|
||||||
|
layerCount++
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
var memoryLayerOutput uint64
|
||||||
|
for k, v := range layers {
|
||||||
|
if !strings.HasPrefix(k, "blk.") {
|
||||||
|
memoryLayerOutput += v.size()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
memoryRequiredTotal += memoryLayerOutput
|
||||||
|
|
||||||
|
if memoryAvailable > memoryRequiredTotal {
|
||||||
|
layerCount = int(ggml.KV().BlockCount()) + 1
|
||||||
|
memoryRequiredPartial = memoryRequiredTotal
|
||||||
|
}
|
||||||
|
|
||||||
|
memoryWeights := memoryRequiredTotal - memoryMinimum - graphFullOffload - kv
|
||||||
|
|
||||||
|
slog.Info(
|
||||||
|
"offload to gpu",
|
||||||
|
slog.Group(
|
||||||
|
"layers",
|
||||||
|
// actual number of layers offloaded
|
||||||
|
"real", opts.NumGPU,
|
||||||
|
// estimated number of layers that can be offloaded
|
||||||
|
"estimate", layerCount,
|
||||||
|
),
|
||||||
|
slog.Group(
|
||||||
|
"memory",
|
||||||
|
// memory available for offloading
|
||||||
|
"available", format.HumanBytes2(memoryAvailable),
|
||||||
|
slog.Group(
|
||||||
|
"required",
|
||||||
|
// memory required for full offloading
|
||||||
|
"full", format.HumanBytes2(memoryRequiredTotal),
|
||||||
|
// memory required to offload layers.estimate layers
|
||||||
|
"partial", format.HumanBytes2(memoryRequiredPartial),
|
||||||
|
// memory of KV cache
|
||||||
|
"kv", format.HumanBytes2(kv),
|
||||||
|
),
|
||||||
|
slog.Group(
|
||||||
|
"weights",
|
||||||
|
// memory of the weights
|
||||||
|
"total", format.HumanBytes2(memoryWeights),
|
||||||
|
// memory of repeating layers
|
||||||
|
"repeating", format.HumanBytes2(memoryWeights-memoryLayerOutput),
|
||||||
|
// memory of non-repeating layers
|
||||||
|
"nonrepeating", format.HumanBytes2(memoryLayerOutput),
|
||||||
|
),
|
||||||
|
slog.Group(
|
||||||
|
"graph",
|
||||||
|
// memory of graph when fully offloaded
|
||||||
|
"full", format.HumanBytes2(graphFullOffload),
|
||||||
|
// memory of graph when not fully offloaded
|
||||||
|
"partial", format.HumanBytes2(graphPartialOffload),
|
||||||
|
),
|
||||||
|
),
|
||||||
|
)
|
||||||
|
return layerCount, uint64(memoryRequiredPartial)
|
||||||
|
}
|
|
@ -9,6 +9,7 @@ import (
|
||||||
"log/slog"
|
"log/slog"
|
||||||
"os"
|
"os"
|
||||||
"path/filepath"
|
"path/filepath"
|
||||||
|
"runtime"
|
||||||
"strings"
|
"strings"
|
||||||
|
|
||||||
"golang.org/x/exp/slices"
|
"golang.org/x/exp/slices"
|
||||||
|
@ -138,6 +139,23 @@ func serversForGpu(info gpu.GpuInfo) []string {
|
||||||
return servers
|
return servers
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Return the optimal server for this CPU architecture
|
||||||
|
func serverForCpu() string {
|
||||||
|
if runtime.GOOS == "darwin" && runtime.GOARCH == "arm64" {
|
||||||
|
return "metal"
|
||||||
|
}
|
||||||
|
variant := gpu.GetCPUVariant()
|
||||||
|
availableServers := availableServers()
|
||||||
|
if variant != "" {
|
||||||
|
for cmp := range availableServers {
|
||||||
|
if cmp == "cpu_"+variant {
|
||||||
|
return cmp
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return "cpu"
|
||||||
|
}
|
||||||
|
|
||||||
// extract extracts the embedded files to the target directory
|
// extract extracts the embedded files to the target directory
|
||||||
func extractFiles(targetDir string, glob string) error {
|
func extractFiles(targetDir string, glob string) error {
|
||||||
files, err := fs.Glob(libEmbed, glob)
|
files, err := fs.Glob(libEmbed, glob)
|
||||||
|
|
299
llm/server.go
299
llm/server.go
|
@ -21,21 +21,43 @@ import (
|
||||||
"strings"
|
"strings"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
|
"golang.org/x/sync/semaphore"
|
||||||
|
|
||||||
"github.com/ollama/ollama/api"
|
"github.com/ollama/ollama/api"
|
||||||
"github.com/ollama/ollama/format"
|
"github.com/ollama/ollama/format"
|
||||||
"github.com/ollama/ollama/gpu"
|
"github.com/ollama/ollama/gpu"
|
||||||
)
|
)
|
||||||
|
|
||||||
// LlamaServer is an instance of the llama.cpp server
|
type LlamaServer interface {
|
||||||
type LlamaServer struct {
|
Ping(ctx context.Context) error
|
||||||
|
WaitUntilRunning(ctx context.Context) error
|
||||||
|
Completion(ctx context.Context, req CompletionRequest, fn func(CompletionResponse)) error
|
||||||
|
Embedding(ctx context.Context, prompt string) ([]float64, error)
|
||||||
|
Tokenize(ctx context.Context, content string) ([]int, error)
|
||||||
|
Detokenize(ctx context.Context, tokens []int) (string, error)
|
||||||
|
Close() error
|
||||||
|
EstimatedVRAM() uint64
|
||||||
|
}
|
||||||
|
|
||||||
|
// llmServer is an instance of the llama.cpp server
|
||||||
|
type llmServer struct {
|
||||||
port int
|
port int
|
||||||
cmd *exec.Cmd
|
cmd *exec.Cmd
|
||||||
done chan error // Channel to signal when the process exits
|
done chan error // Channel to signal when the process exits
|
||||||
status *StatusWriter
|
status *StatusWriter
|
||||||
options api.Options
|
options api.Options
|
||||||
|
|
||||||
|
// TODO - this should be broken down by GPU
|
||||||
|
estimatedVRAM uint64 // Estimated usage of VRAM by the loaded model
|
||||||
|
|
||||||
|
sem *semaphore.Weighted
|
||||||
}
|
}
|
||||||
|
|
||||||
func NewLlamaServer(model string, adapters, projectors []string, opts api.Options) (*LlamaServer, error) {
|
func LoadModel(model string) (*GGML, error) {
|
||||||
|
if _, err := os.Stat(model); err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
f, err := os.Open(model)
|
f, err := os.Open(model)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
|
@ -43,10 +65,13 @@ func NewLlamaServer(model string, adapters, projectors []string, opts api.Option
|
||||||
defer f.Close()
|
defer f.Close()
|
||||||
|
|
||||||
ggml, _, err := DecodeGGML(f)
|
ggml, _, err := DecodeGGML(f)
|
||||||
if err != nil {
|
return ggml, err
|
||||||
return nil, err
|
}
|
||||||
}
|
|
||||||
|
|
||||||
|
// NewLlamaServer will run a server for the given GPUs
|
||||||
|
// The gpu list must be a single family.
|
||||||
|
func NewLlamaServer(gpus gpu.GpuInfoList, model string, ggml *GGML, adapters, projectors []string, opts api.Options) (LlamaServer, error) {
|
||||||
|
var err error
|
||||||
if opts.NumCtx > int(ggml.KV().ContextLength()) {
|
if opts.NumCtx > int(ggml.KV().ContextLength()) {
|
||||||
slog.Warn("requested context length is greater than model max context length", "requested", opts.NumCtx, "model", ggml.KV().ContextLength())
|
slog.Warn("requested context length is greater than model max context length", "requested", opts.NumCtx, "model", ggml.KV().ContextLength())
|
||||||
opts.NumCtx = int(ggml.KV().ContextLength())
|
opts.NumCtx = int(ggml.KV().ContextLength())
|
||||||
|
@ -56,130 +81,50 @@ func NewLlamaServer(model string, adapters, projectors []string, opts api.Option
|
||||||
opts.NumCtx = 4
|
opts.NumCtx = 4
|
||||||
}
|
}
|
||||||
|
|
||||||
memoryAvailable, _ := gpu.CheckVRAM()
|
cpuRunner := ""
|
||||||
info := gpu.GetGPUInfo()
|
var estimatedVRAM uint64
|
||||||
|
var systemMemory uint64
|
||||||
|
if (len(gpus) == 1 && gpus[0].Library == "cpu") || opts.NumGPU == 0 {
|
||||||
|
|
||||||
memoryMinimum := info.MinimumMemory
|
// TODO evaluate system memory to see if we should block the load, or force an unload of another CPU runner
|
||||||
for _, projector := range projectors {
|
|
||||||
memoryMinimum += projectorMemoryRequirements(projector)
|
|
||||||
|
|
||||||
// multimodal models require at least 2048 context
|
cpuRunner = serverForCpu()
|
||||||
opts.NumCtx = max(opts.NumCtx, 2048)
|
} else {
|
||||||
}
|
if gpus[0].Library == "metal" {
|
||||||
|
memInfo, err := gpu.GetCPUMem()
|
||||||
|
if err != nil {
|
||||||
|
slog.Error("failed to lookup system memory", "error", err)
|
||||||
|
} else {
|
||||||
|
systemMemory = memInfo.TotalMemory
|
||||||
|
slog.Debug("system memory", "total", format.HumanBytes2(systemMemory))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
var layers int
|
||||||
|
layers, estimatedVRAM = EstimateGPULayers(gpus, ggml, projectors, opts)
|
||||||
|
|
||||||
// fp16 k,v = (1 (k) + 1 (v)) * sizeof(float16) * n_ctx * n_layer * n_embd / n_head * n_head_kv
|
if gpus[0].Library == "metal" && estimatedVRAM > systemMemory {
|
||||||
var kv uint64 = 2 * 2 * uint64(opts.NumCtx) * ggml.KV().BlockCount() * ggml.KV().EmbeddingLength() / ggml.KV().HeadCount() * ggml.KV().HeadCountKV()
|
// disable partial offloading when model is greater than total system memory as this
|
||||||
|
// can lead to locking up the system
|
||||||
graphPartialOffload, graphFullOffload := ggml.GraphSize(uint64(opts.NumCtx), uint64(min(opts.NumCtx, opts.NumBatch)))
|
opts.NumGPU = 0
|
||||||
if graphPartialOffload == 0 {
|
} else if opts.NumGPU < 0 && layers > 0 && gpus[0].Library != "cpu" {
|
||||||
graphPartialOffload = ggml.KV().GQA() * kv / 6
|
opts.NumGPU = layers
|
||||||
}
|
|
||||||
|
|
||||||
if graphFullOffload == 0 {
|
|
||||||
graphFullOffload = graphPartialOffload
|
|
||||||
}
|
|
||||||
|
|
||||||
graphFullOffload *= uint64(info.DeviceCount)
|
|
||||||
graphPartialOffload *= uint64(info.DeviceCount)
|
|
||||||
|
|
||||||
// memoryRequiredTotal represents the memory required for full GPU offloading (all layers)
|
|
||||||
memoryRequiredTotal := memoryMinimum + graphFullOffload
|
|
||||||
|
|
||||||
// memoryRequiredPartial represents the memory required for partial GPU offloading (n > 0, n < layers)
|
|
||||||
memoryRequiredPartial := memoryMinimum + graphPartialOffload
|
|
||||||
|
|
||||||
if info.Library != "metal" {
|
|
||||||
if memoryRequiredPartial > memoryAvailable {
|
|
||||||
info.Library = "cpu"
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
var layerCount int
|
// Loop through potential servers
|
||||||
layers := ggml.Tensors().Layers()
|
finalErr := fmt.Errorf("no suitable llama servers found")
|
||||||
for i := 0; i < int(ggml.KV().BlockCount()); i++ {
|
|
||||||
memoryLayer := layers[fmt.Sprintf("blk.%d", i)].size()
|
|
||||||
|
|
||||||
// KV is proportional to the number of layers
|
|
||||||
memoryLayer += kv / ggml.KV().BlockCount()
|
|
||||||
|
|
||||||
memoryRequiredTotal += memoryLayer
|
|
||||||
if memoryAvailable > memoryRequiredPartial+memoryLayer {
|
|
||||||
memoryRequiredPartial += memoryLayer
|
|
||||||
layerCount++
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
var memoryLayerOutput uint64
|
|
||||||
for k, v := range layers {
|
|
||||||
if !strings.HasPrefix(k, "blk.") {
|
|
||||||
memoryLayerOutput += v.size()
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
memoryRequiredTotal += memoryLayerOutput
|
|
||||||
|
|
||||||
if info.Library == "metal" && memoryRequiredTotal > info.TotalMemory {
|
|
||||||
// disable partial offloading when model is greater than total system memory
|
|
||||||
opts.NumGPU = 0
|
|
||||||
} else if memoryAvailable > memoryRequiredTotal {
|
|
||||||
layerCount = int(ggml.KV().BlockCount()) + 1
|
|
||||||
memoryRequiredPartial = memoryRequiredTotal
|
|
||||||
}
|
|
||||||
|
|
||||||
if opts.NumGPU < 0 {
|
|
||||||
opts.NumGPU = layerCount
|
|
||||||
}
|
|
||||||
|
|
||||||
memoryWeights := memoryRequiredTotal - memoryMinimum - graphFullOffload - kv
|
|
||||||
|
|
||||||
slog.Info(
|
|
||||||
"offload to gpu",
|
|
||||||
slog.Group(
|
|
||||||
"layers",
|
|
||||||
// actual number of layers offloaded
|
|
||||||
"real", opts.NumGPU,
|
|
||||||
// estimated number of layers that can be offloaded
|
|
||||||
"estimate", layerCount,
|
|
||||||
),
|
|
||||||
slog.Group(
|
|
||||||
"memory",
|
|
||||||
// memory available for offloading
|
|
||||||
"available", format.HumanBytes2(memoryAvailable),
|
|
||||||
slog.Group(
|
|
||||||
"required",
|
|
||||||
// memory required for full offloading
|
|
||||||
"full", format.HumanBytes2(memoryRequiredTotal),
|
|
||||||
// memory required to offload layers.estimate layers
|
|
||||||
"partial", format.HumanBytes2(memoryRequiredPartial),
|
|
||||||
// memory of KV cache
|
|
||||||
"kv", format.HumanBytes2(kv),
|
|
||||||
),
|
|
||||||
slog.Group(
|
|
||||||
"weights",
|
|
||||||
// memory of the weights
|
|
||||||
"total", format.HumanBytes2(memoryWeights),
|
|
||||||
// memory of repeating layers
|
|
||||||
"repeating", format.HumanBytes2(memoryWeights-memoryLayerOutput),
|
|
||||||
// memory of non-repeating layers
|
|
||||||
"nonrepeating", format.HumanBytes2(memoryLayerOutput),
|
|
||||||
),
|
|
||||||
slog.Group(
|
|
||||||
"graph",
|
|
||||||
// memory of graph when fully offloaded
|
|
||||||
"full", format.HumanBytes2(graphFullOffload),
|
|
||||||
// memory of graph when not fully offloaded
|
|
||||||
"partial", format.HumanBytes2(graphPartialOffload),
|
|
||||||
),
|
|
||||||
),
|
|
||||||
)
|
|
||||||
|
|
||||||
if len(adapters) > 1 {
|
if len(adapters) > 1 {
|
||||||
return nil, errors.New("ollama supports only one lora adapter, but multiple were provided")
|
return nil, errors.New("ollama supports only one lora adapter, but multiple were provided")
|
||||||
}
|
}
|
||||||
|
|
||||||
availableServers := availableServers()
|
availableServers := availableServers()
|
||||||
servers := serversForGpu(info)
|
var servers []string
|
||||||
|
if cpuRunner != "" {
|
||||||
|
servers = []string{cpuRunner}
|
||||||
|
} else {
|
||||||
|
servers = serversForGpu(gpus[0]) // All GPUs in the list are matching Library and Variant
|
||||||
|
}
|
||||||
demandLib := strings.Trim(os.Getenv("OLLAMA_LLM_LIBRARY"), "\"' ")
|
demandLib := strings.Trim(os.Getenv("OLLAMA_LLM_LIBRARY"), "\"' ")
|
||||||
if demandLib != "" {
|
if demandLib != "" {
|
||||||
serverPath := availableServers[demandLib]
|
serverPath := availableServers[demandLib]
|
||||||
|
@ -192,7 +137,7 @@ func NewLlamaServer(model string, adapters, projectors []string, opts api.Option
|
||||||
}
|
}
|
||||||
|
|
||||||
if len(servers) == 0 {
|
if len(servers) == 0 {
|
||||||
return nil, fmt.Errorf("no servers found for %v", info)
|
return nil, fmt.Errorf("no servers found for %v", gpus)
|
||||||
}
|
}
|
||||||
|
|
||||||
params := []string{
|
params := []string{
|
||||||
|
@ -249,8 +194,18 @@ func NewLlamaServer(model string, adapters, projectors []string, opts api.Option
|
||||||
params = append(params, "--numa")
|
params = append(params, "--numa")
|
||||||
}
|
}
|
||||||
|
|
||||||
// Loop through potential servers
|
// "--cont-batching", // TODO - doesn't seem to have any noticeable perf change for multiple requests
|
||||||
var finalErr error
|
numParallel := 1
|
||||||
|
if onp := os.Getenv("OLLAMA_NUM_PARALLEL"); onp != "" {
|
||||||
|
numParallel, err = strconv.Atoi(onp)
|
||||||
|
if err != nil || numParallel <= 0 {
|
||||||
|
err = fmt.Errorf("invalid OLLAMA_NUM_PARALLEL=%s must be greater than zero - %w", onp, err)
|
||||||
|
slog.Error("misconfiguration", "error", err)
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
}
|
||||||
|
params = append(params, "--parallel", fmt.Sprintf("%d", numParallel))
|
||||||
|
|
||||||
for i := 0; i < len(servers); i++ {
|
for i := 0; i < len(servers); i++ {
|
||||||
dir := availableServers[servers[i]]
|
dir := availableServers[servers[i]]
|
||||||
|
|
||||||
|
@ -275,30 +230,49 @@ func NewLlamaServer(model string, adapters, projectors []string, opts api.Option
|
||||||
}
|
}
|
||||||
// append the server directory to LD_LIBRARY_PATH/PATH
|
// append the server directory to LD_LIBRARY_PATH/PATH
|
||||||
libraryPaths := []string{dir}
|
libraryPaths := []string{dir}
|
||||||
|
|
||||||
if libraryPath, ok := os.LookupEnv(pathEnv); ok {
|
if libraryPath, ok := os.LookupEnv(pathEnv); ok {
|
||||||
// Append our runner directory to the path
|
// Append our runner directory to the path
|
||||||
// This will favor system libraries over our bundled library dependencies
|
// This will favor system libraries over our bundled library dependencies
|
||||||
libraryPaths = append(filepath.SplitList(libraryPath), libraryPaths...)
|
libraryPaths = append(filepath.SplitList(libraryPath), libraryPaths...)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Note: we always put the dependency path first
|
||||||
|
// since this was the exact version we verified for AMD GPUs
|
||||||
|
// and we favor what the user had in their path
|
||||||
|
if gpus[0].DependencyPath != "" {
|
||||||
|
// TODO refine for multi-gpu support
|
||||||
|
libraryPaths = append([]string{gpus[0].DependencyPath}, libraryPaths...)
|
||||||
|
}
|
||||||
|
|
||||||
server := filepath.Join(dir, "ollama_llama_server")
|
server := filepath.Join(dir, "ollama_llama_server")
|
||||||
if runtime.GOOS == "windows" {
|
if runtime.GOOS == "windows" {
|
||||||
server = server + ".exe"
|
server = server + ".exe"
|
||||||
}
|
}
|
||||||
|
|
||||||
s := &LlamaServer{
|
s := &llmServer{
|
||||||
port: port,
|
port: port,
|
||||||
cmd: exec.Command(server, finalParams...),
|
cmd: exec.Command(server, finalParams...),
|
||||||
status: NewStatusWriter(os.Stderr),
|
status: NewStatusWriter(os.Stderr),
|
||||||
options: opts,
|
options: opts,
|
||||||
|
estimatedVRAM: estimatedVRAM,
|
||||||
|
sem: semaphore.NewWeighted(int64(numParallel)),
|
||||||
}
|
}
|
||||||
|
|
||||||
libEnv := fmt.Sprintf("%s=%s", pathEnv, strings.Join(libraryPaths, string(filepath.ListSeparator)))
|
libEnv := fmt.Sprintf("%s=%s", pathEnv, strings.Join(libraryPaths, string(filepath.ListSeparator)))
|
||||||
slog.Debug(libEnv)
|
|
||||||
s.cmd.Env = append(os.Environ(), libEnv)
|
s.cmd.Env = append(os.Environ(), libEnv)
|
||||||
s.cmd.Stdout = os.Stdout
|
s.cmd.Stdout = os.Stdout
|
||||||
s.cmd.Stderr = s.status
|
s.cmd.Stderr = s.status
|
||||||
|
|
||||||
|
// TODO - multiple GPU selection logic...
|
||||||
|
key, val := gpu.GpuInfoList(gpus).GetVisibleDevicesEnv()
|
||||||
|
if key != "" {
|
||||||
|
s.cmd.Env = append(s.cmd.Env, key+"="+val)
|
||||||
|
}
|
||||||
|
|
||||||
slog.Info("starting llama server", "cmd", s.cmd.String())
|
slog.Info("starting llama server", "cmd", s.cmd.String())
|
||||||
|
// Log at debug as the environment is inherited and might contain sensitive information
|
||||||
|
slog.Debug("subprocess", "environment", s.cmd.Env)
|
||||||
|
|
||||||
if err = s.cmd.Start(); err != nil {
|
if err = s.cmd.Start(); err != nil {
|
||||||
msg := ""
|
msg := ""
|
||||||
|
@ -316,6 +290,13 @@ func NewLlamaServer(model string, adapters, projectors []string, opts api.Option
|
||||||
_ = s.cmd.Wait()
|
_ = s.cmd.Wait()
|
||||||
}()
|
}()
|
||||||
|
|
||||||
|
// TODO - make sure this is all wired up correctly
|
||||||
|
// if err = s.WaitUntilRunning(); err != nil {
|
||||||
|
// slog.Error("error starting llama server", "server", servers[i], "error", err)
|
||||||
|
// s.Close()
|
||||||
|
// finalErr = err
|
||||||
|
// continue
|
||||||
|
// }
|
||||||
return s, nil
|
return s, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -353,6 +334,21 @@ const ( // iota is reset to 0
|
||||||
ServerStatusError
|
ServerStatusError
|
||||||
)
|
)
|
||||||
|
|
||||||
|
func (s ServerStatus) ToString() string {
|
||||||
|
switch s {
|
||||||
|
case ServerStatusReady:
|
||||||
|
return "llm server ready"
|
||||||
|
case ServerStatusNoSlotsAvaialble:
|
||||||
|
return "llm busy - no slots available"
|
||||||
|
case ServerStatusLoadingModel:
|
||||||
|
return "llm server loading model"
|
||||||
|
case ServerStatusNotResponding:
|
||||||
|
return "llm server not responding"
|
||||||
|
default:
|
||||||
|
return "llm server error"
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
type ServerStatusResp struct {
|
type ServerStatusResp struct {
|
||||||
Status string `json:"status"`
|
Status string `json:"status"`
|
||||||
SlotsIdle int `json:"slots_idle"`
|
SlotsIdle int `json:"slots_idle"`
|
||||||
|
@ -360,7 +356,7 @@ type ServerStatusResp struct {
|
||||||
Error string `json:"error"`
|
Error string `json:"error"`
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *LlamaServer) getServerStatus(ctx context.Context) (ServerStatus, error) {
|
func (s *llmServer) getServerStatus(ctx context.Context) (ServerStatus, error) {
|
||||||
// Fail fast if its exited
|
// Fail fast if its exited
|
||||||
if s.cmd.ProcessState != nil {
|
if s.cmd.ProcessState != nil {
|
||||||
msg := ""
|
msg := ""
|
||||||
|
@ -407,7 +403,7 @@ func (s *LlamaServer) getServerStatus(ctx context.Context) (ServerStatus, error)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *LlamaServer) Ping(ctx context.Context) error {
|
func (s *llmServer) Ping(ctx context.Context) error {
|
||||||
_, err := s.getServerStatus(ctx)
|
_, err := s.getServerStatus(ctx)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
slog.Debug("server unhealthy", "error", err)
|
slog.Debug("server unhealthy", "error", err)
|
||||||
|
@ -416,7 +412,7 @@ func (s *LlamaServer) Ping(ctx context.Context) error {
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *LlamaServer) WaitUntilRunning() error {
|
func (s *llmServer) WaitUntilRunning(ctx context.Context) error {
|
||||||
start := time.Now()
|
start := time.Now()
|
||||||
// TODO we need to wire up a better way to detect hangs during model load and startup of the server
|
// TODO we need to wire up a better way to detect hangs during model load and startup of the server
|
||||||
expiresAt := time.Now().Add(10 * time.Minute) // be generous with timeout, large models can take a while to load
|
expiresAt := time.Now().Add(10 * time.Minute) // be generous with timeout, large models can take a while to load
|
||||||
|
@ -427,6 +423,9 @@ func (s *LlamaServer) WaitUntilRunning() error {
|
||||||
var lastStatus ServerStatus = -1
|
var lastStatus ServerStatus = -1
|
||||||
for {
|
for {
|
||||||
select {
|
select {
|
||||||
|
case <-ctx.Done():
|
||||||
|
slog.Info("context expired before server started")
|
||||||
|
return fmt.Errorf("timed out waiting for llama runner to start")
|
||||||
case err := <-s.done:
|
case err := <-s.done:
|
||||||
msg := ""
|
msg := ""
|
||||||
if s.status != nil && s.status.LastErrMsg != "" {
|
if s.status != nil && s.status.LastErrMsg != "" {
|
||||||
|
@ -450,9 +449,9 @@ func (s *LlamaServer) WaitUntilRunning() error {
|
||||||
return fmt.Errorf("llama runner process no longer running: %d %s", s.cmd.ProcessState.ExitCode(), msg)
|
return fmt.Errorf("llama runner process no longer running: %d %s", s.cmd.ProcessState.ExitCode(), msg)
|
||||||
}
|
}
|
||||||
|
|
||||||
ctx, cancel := context.WithTimeout(context.Background(), 200*time.Millisecond)
|
c, cancel := context.WithTimeout(ctx, 200*time.Millisecond)
|
||||||
defer cancel()
|
defer cancel()
|
||||||
status, err := s.getServerStatus(ctx)
|
status, err := s.getServerStatus(c)
|
||||||
if err != nil && lastStatus != status {
|
if err != nil && lastStatus != status {
|
||||||
slog.Debug("server not yet available", "error", err)
|
slog.Debug("server not yet available", "error", err)
|
||||||
lastStatus = status
|
lastStatus = status
|
||||||
|
@ -538,7 +537,12 @@ type CompletionResponse struct {
|
||||||
EvalDuration time.Duration
|
EvalDuration time.Duration
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *LlamaServer) Completion(ctx context.Context, req CompletionRequest, fn func(CompletionResponse)) error {
|
func (s *llmServer) Completion(ctx context.Context, req CompletionRequest, fn func(CompletionResponse)) error {
|
||||||
|
if err := s.sem.Acquire(ctx, 1); err != nil {
|
||||||
|
slog.Error("Failed to acquire semaphore", "error", err)
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
defer s.sem.Release(1)
|
||||||
request := map[string]any{
|
request := map[string]any{
|
||||||
"prompt": req.Prompt,
|
"prompt": req.Prompt,
|
||||||
"stream": true,
|
"stream": true,
|
||||||
|
@ -569,7 +573,7 @@ func (s *LlamaServer) Completion(ctx context.Context, req CompletionRequest, fn
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
} else if status != ServerStatusReady {
|
} else if status != ServerStatusReady {
|
||||||
return fmt.Errorf("unexpected server status: %d", status)
|
return fmt.Errorf("unexpected server status: %s", status.ToString())
|
||||||
}
|
}
|
||||||
|
|
||||||
if req.Format == "json" {
|
if req.Format == "json" {
|
||||||
|
@ -716,13 +720,18 @@ type EmbeddingResponse struct {
|
||||||
Embedding []float64 `json:"embedding"`
|
Embedding []float64 `json:"embedding"`
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *LlamaServer) Embedding(ctx context.Context, prompt string) ([]float64, error) {
|
func (s *llmServer) Embedding(ctx context.Context, prompt string) ([]float64, error) {
|
||||||
|
if err := s.sem.Acquire(ctx, 1); err != nil {
|
||||||
|
slog.Error("Failed to acquire semaphore", "error", err)
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
defer s.sem.Release(1)
|
||||||
// Make sure the server is ready
|
// Make sure the server is ready
|
||||||
status, err := s.getServerStatus(ctx)
|
status, err := s.getServerStatus(ctx)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
} else if status != ServerStatusReady {
|
} else if status != ServerStatusReady {
|
||||||
return nil, fmt.Errorf("unexpected server status: %d", status)
|
return nil, fmt.Errorf("unexpected server status: %s", status.ToString())
|
||||||
}
|
}
|
||||||
|
|
||||||
data, err := json.Marshal(TokenizeRequest{Content: prompt})
|
data, err := json.Marshal(TokenizeRequest{Content: prompt})
|
||||||
|
@ -768,13 +777,13 @@ type TokenizeResponse struct {
|
||||||
Tokens []int `json:"tokens"`
|
Tokens []int `json:"tokens"`
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *LlamaServer) Tokenize(ctx context.Context, content string) ([]int, error) {
|
func (s *llmServer) Tokenize(ctx context.Context, content string) ([]int, error) {
|
||||||
// Make sure the server is ready
|
// Make sure the server is ready
|
||||||
status, err := s.getServerStatus(ctx)
|
status, err := s.getServerStatus(ctx)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
} else if status != ServerStatusReady {
|
} else if status != ServerStatusReady && status != ServerStatusNoSlotsAvaialble {
|
||||||
return nil, fmt.Errorf("unexpected server status: %d", status)
|
return nil, fmt.Errorf("unexpected server status: %s", status.ToString())
|
||||||
}
|
}
|
||||||
|
|
||||||
data, err := json.Marshal(TokenizeRequest{Content: content})
|
data, err := json.Marshal(TokenizeRequest{Content: content})
|
||||||
|
@ -820,13 +829,13 @@ type DetokenizeResponse struct {
|
||||||
Content string `json:"content"`
|
Content string `json:"content"`
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *LlamaServer) Detokenize(ctx context.Context, tokens []int) (string, error) {
|
func (s *llmServer) Detokenize(ctx context.Context, tokens []int) (string, error) {
|
||||||
// Make sure the server is ready
|
// Make sure the server is ready
|
||||||
status, err := s.getServerStatus(ctx)
|
status, err := s.getServerStatus(ctx)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return "", err
|
return "", err
|
||||||
} else if status != ServerStatusReady {
|
} else if status != ServerStatusReady && status != ServerStatusNoSlotsAvaialble {
|
||||||
return "", fmt.Errorf("unexpected server status: %d", status)
|
return "", fmt.Errorf("unexpected server status: %s", status.ToString())
|
||||||
}
|
}
|
||||||
|
|
||||||
data, err := json.Marshal(DetokenizeRequest{Tokens: tokens})
|
data, err := json.Marshal(DetokenizeRequest{Tokens: tokens})
|
||||||
|
@ -864,7 +873,7 @@ func (s *LlamaServer) Detokenize(ctx context.Context, tokens []int) (string, err
|
||||||
return decoded.Content, nil
|
return decoded.Content, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *LlamaServer) Close() error {
|
func (s *llmServer) Close() error {
|
||||||
if s.cmd != nil {
|
if s.cmd != nil {
|
||||||
slog.Debug("stopping llama server")
|
slog.Debug("stopping llama server")
|
||||||
return s.cmd.Process.Kill()
|
return s.cmd.Process.Kill()
|
||||||
|
@ -873,6 +882,10 @@ func (s *LlamaServer) Close() error {
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (s *llmServer) EstimatedVRAM() uint64 {
|
||||||
|
return s.estimatedVRAM
|
||||||
|
}
|
||||||
|
|
||||||
func parseDurationMs(ms float64) time.Duration {
|
func parseDurationMs(ms float64) time.Duration {
|
||||||
dur, err := time.ParseDuration(fmt.Sprintf("%fms", ms))
|
dur, err := time.ParseDuration(fmt.Sprintf("%fms", ms))
|
||||||
if err != nil {
|
if err != nil {
|
||||||
|
|
202
server/routes.go
202
server/routes.go
|
@ -15,11 +15,8 @@ import (
|
||||||
"os"
|
"os"
|
||||||
"os/signal"
|
"os/signal"
|
||||||
"path/filepath"
|
"path/filepath"
|
||||||
"reflect"
|
|
||||||
"runtime"
|
|
||||||
"strconv"
|
"strconv"
|
||||||
"strings"
|
"strings"
|
||||||
"sync"
|
|
||||||
"syscall"
|
"syscall"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
|
@ -38,7 +35,8 @@ import (
|
||||||
var mode string = gin.DebugMode
|
var mode string = gin.DebugMode
|
||||||
|
|
||||||
type Server struct {
|
type Server struct {
|
||||||
addr net.Addr
|
addr net.Addr
|
||||||
|
sched *Scheduler
|
||||||
}
|
}
|
||||||
|
|
||||||
func init() {
|
func init() {
|
||||||
|
@ -53,88 +51,8 @@ func init() {
|
||||||
gin.SetMode(mode)
|
gin.SetMode(mode)
|
||||||
}
|
}
|
||||||
|
|
||||||
var loaded struct {
|
|
||||||
mu sync.Mutex
|
|
||||||
|
|
||||||
llama *llm.LlamaServer
|
|
||||||
|
|
||||||
expireTimer *time.Timer
|
|
||||||
|
|
||||||
model string
|
|
||||||
adapters []string
|
|
||||||
projectors []string
|
|
||||||
*api.Options
|
|
||||||
}
|
|
||||||
|
|
||||||
var defaultSessionDuration = 5 * time.Minute
|
var defaultSessionDuration = 5 * time.Minute
|
||||||
|
|
||||||
func unload() {
|
|
||||||
if loaded.llama != nil {
|
|
||||||
loaded.llama.Close()
|
|
||||||
}
|
|
||||||
|
|
||||||
loaded.llama = nil
|
|
||||||
loaded.model = ""
|
|
||||||
loaded.adapters = nil
|
|
||||||
loaded.projectors = nil
|
|
||||||
loaded.Options = nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// load a model into memory if it is not already loaded, it is up to the caller to lock loaded.mu before calling this function
|
|
||||||
func load(c *gin.Context, model *Model, opts api.Options, sessionDuration time.Duration) error {
|
|
||||||
ctx, cancel := context.WithTimeout(c, 10*time.Second)
|
|
||||||
defer cancel()
|
|
||||||
|
|
||||||
needLoad := loaded.llama == nil || // is there a model loaded?
|
|
||||||
loaded.model != model.ModelPath || // has the base model changed?
|
|
||||||
!reflect.DeepEqual(loaded.adapters, model.AdapterPaths) || // have the adapters changed?
|
|
||||||
!reflect.DeepEqual(loaded.projectors, model.ProjectorPaths) || // have the adapters changed?
|
|
||||||
!reflect.DeepEqual(loaded.Options.Runner, opts.Runner) || // have the runner options changed?
|
|
||||||
loaded.llama.Ping(ctx) != nil
|
|
||||||
|
|
||||||
if needLoad {
|
|
||||||
if loaded.llama != nil {
|
|
||||||
slog.Info("changing loaded model")
|
|
||||||
unload()
|
|
||||||
}
|
|
||||||
|
|
||||||
llama, err := llm.NewLlamaServer(model.ModelPath, model.AdapterPaths, model.ProjectorPaths, opts)
|
|
||||||
if err != nil {
|
|
||||||
// some older models are not compatible with newer versions of llama.cpp
|
|
||||||
// show a generalized compatibility error until there is a better way to
|
|
||||||
// check for model compatibility
|
|
||||||
if errors.Is(llm.ErrUnsupportedFormat, err) || strings.Contains(err.Error(), "failed to load model") {
|
|
||||||
err = fmt.Errorf("%v: this model may be incompatible with your version of Ollama. If you previously pulled this model, try updating it by running `ollama pull %s`", err, model.ShortName)
|
|
||||||
}
|
|
||||||
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
|
|
||||||
loaded.model = model.ModelPath
|
|
||||||
loaded.adapters = model.AdapterPaths
|
|
||||||
loaded.projectors = model.ProjectorPaths
|
|
||||||
loaded.llama = llama
|
|
||||||
loaded.Options = &opts
|
|
||||||
|
|
||||||
if err = llama.WaitUntilRunning(); err != nil {
|
|
||||||
slog.Error("error loading llama server", "error", err)
|
|
||||||
unload()
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
if loaded.expireTimer == nil {
|
|
||||||
loaded.expireTimer = time.AfterFunc(sessionDuration, func() {
|
|
||||||
loaded.mu.Lock()
|
|
||||||
defer loaded.mu.Unlock()
|
|
||||||
unload()
|
|
||||||
})
|
|
||||||
}
|
|
||||||
|
|
||||||
loaded.expireTimer.Reset(sessionDuration)
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func modelOptions(model *Model, requestOpts map[string]interface{}) (api.Options, error) {
|
func modelOptions(model *Model, requestOpts map[string]interface{}) (api.Options, error) {
|
||||||
opts := api.DefaultOptions()
|
opts := api.DefaultOptions()
|
||||||
if err := opts.FromMap(model.Options); err != nil {
|
if err := opts.FromMap(model.Options); err != nil {
|
||||||
|
@ -154,9 +72,7 @@ func isSupportedImageType(image []byte) bool {
|
||||||
return slices.Contains(allowedTypes, contentType)
|
return slices.Contains(allowedTypes, contentType)
|
||||||
}
|
}
|
||||||
|
|
||||||
func GenerateHandler(c *gin.Context) {
|
func (s *Server) GenerateHandler(c *gin.Context) {
|
||||||
loaded.mu.Lock()
|
|
||||||
defer loaded.mu.Unlock()
|
|
||||||
|
|
||||||
checkpointStart := time.Now()
|
checkpointStart := time.Now()
|
||||||
var req api.GenerateRequest
|
var req api.GenerateRequest
|
||||||
|
@ -224,7 +140,11 @@ func GenerateHandler(c *gin.Context) {
|
||||||
sessionDuration = req.KeepAlive.Duration
|
sessionDuration = req.KeepAlive.Duration
|
||||||
}
|
}
|
||||||
|
|
||||||
if err := load(c, model, opts, sessionDuration); err != nil {
|
rCh, eCh := s.sched.GetRunner(c.Request.Context(), model, opts, sessionDuration)
|
||||||
|
var runner *runnerRef
|
||||||
|
select {
|
||||||
|
case runner = <-rCh:
|
||||||
|
case err = <-eCh:
|
||||||
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
|
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
@ -275,7 +195,7 @@ func GenerateHandler(c *gin.Context) {
|
||||||
|
|
||||||
sb.Reset()
|
sb.Reset()
|
||||||
if req.Context != nil {
|
if req.Context != nil {
|
||||||
prev, err := loaded.llama.Detokenize(c.Request.Context(), req.Context)
|
prev, err := runner.llama.Detokenize(c.Request.Context(), req.Context)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
|
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
|
||||||
return
|
return
|
||||||
|
@ -297,9 +217,6 @@ func GenerateHandler(c *gin.Context) {
|
||||||
defer close(ch)
|
defer close(ch)
|
||||||
|
|
||||||
fn := func(r llm.CompletionResponse) {
|
fn := func(r llm.CompletionResponse) {
|
||||||
// Update model expiration
|
|
||||||
loaded.expireTimer.Reset(sessionDuration)
|
|
||||||
|
|
||||||
// Build up the full response
|
// Build up the full response
|
||||||
if _, err := generated.WriteString(r.Content); err != nil {
|
if _, err := generated.WriteString(r.Content); err != nil {
|
||||||
ch <- gin.H{"error": err.Error()}
|
ch <- gin.H{"error": err.Error()}
|
||||||
|
@ -331,7 +248,7 @@ func GenerateHandler(c *gin.Context) {
|
||||||
}
|
}
|
||||||
|
|
||||||
// TODO (jmorganca): encode() should not strip special tokens
|
// TODO (jmorganca): encode() should not strip special tokens
|
||||||
tokens, err := loaded.llama.Tokenize(c.Request.Context(), p)
|
tokens, err := runner.llama.Tokenize(c.Request.Context(), p)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
ch <- gin.H{"error": err.Error()}
|
ch <- gin.H{"error": err.Error()}
|
||||||
return
|
return
|
||||||
|
@ -359,7 +276,7 @@ func GenerateHandler(c *gin.Context) {
|
||||||
Images: images,
|
Images: images,
|
||||||
Options: opts,
|
Options: opts,
|
||||||
}
|
}
|
||||||
if err := loaded.llama.Completion(c.Request.Context(), req, fn); err != nil {
|
if err := runner.llama.Completion(c.Request.Context(), req, fn); err != nil {
|
||||||
ch <- gin.H{"error": err.Error()}
|
ch <- gin.H{"error": err.Error()}
|
||||||
}
|
}
|
||||||
}()
|
}()
|
||||||
|
@ -421,10 +338,7 @@ func getDefaultSessionDuration() time.Duration {
|
||||||
return defaultSessionDuration
|
return defaultSessionDuration
|
||||||
}
|
}
|
||||||
|
|
||||||
func EmbeddingsHandler(c *gin.Context) {
|
func (s *Server) EmbeddingsHandler(c *gin.Context) {
|
||||||
loaded.mu.Lock()
|
|
||||||
defer loaded.mu.Unlock()
|
|
||||||
|
|
||||||
var req api.EmbeddingRequest
|
var req api.EmbeddingRequest
|
||||||
err := c.ShouldBindJSON(&req)
|
err := c.ShouldBindJSON(&req)
|
||||||
switch {
|
switch {
|
||||||
|
@ -469,7 +383,11 @@ func EmbeddingsHandler(c *gin.Context) {
|
||||||
sessionDuration = req.KeepAlive.Duration
|
sessionDuration = req.KeepAlive.Duration
|
||||||
}
|
}
|
||||||
|
|
||||||
if err := load(c, model, opts, sessionDuration); err != nil {
|
rCh, eCh := s.sched.GetRunner(c.Request.Context(), model, opts, sessionDuration)
|
||||||
|
var runner *runnerRef
|
||||||
|
select {
|
||||||
|
case runner = <-rCh:
|
||||||
|
case err = <-eCh:
|
||||||
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
|
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
@ -480,7 +398,7 @@ func EmbeddingsHandler(c *gin.Context) {
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
embedding, err := loaded.llama.Embedding(c.Request.Context(), req.Prompt)
|
embedding, err := runner.llama.Embedding(c.Request.Context(), req.Prompt)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
slog.Info(fmt.Sprintf("embedding generation failed: %v", err))
|
slog.Info(fmt.Sprintf("embedding generation failed: %v", err))
|
||||||
c.JSON(http.StatusInternalServerError, gin.H{"error": "failed to generate embedding"})
|
c.JSON(http.StatusInternalServerError, gin.H{"error": "failed to generate embedding"})
|
||||||
|
@ -493,7 +411,7 @@ func EmbeddingsHandler(c *gin.Context) {
|
||||||
c.JSON(http.StatusOK, resp)
|
c.JSON(http.StatusOK, resp)
|
||||||
}
|
}
|
||||||
|
|
||||||
func PullModelHandler(c *gin.Context) {
|
func (s *Server) PullModelHandler(c *gin.Context) {
|
||||||
var req api.PullRequest
|
var req api.PullRequest
|
||||||
err := c.ShouldBindJSON(&req)
|
err := c.ShouldBindJSON(&req)
|
||||||
switch {
|
switch {
|
||||||
|
@ -542,7 +460,7 @@ func PullModelHandler(c *gin.Context) {
|
||||||
streamResponse(c, ch)
|
streamResponse(c, ch)
|
||||||
}
|
}
|
||||||
|
|
||||||
func PushModelHandler(c *gin.Context) {
|
func (s *Server) PushModelHandler(c *gin.Context) {
|
||||||
var req api.PushRequest
|
var req api.PushRequest
|
||||||
err := c.ShouldBindJSON(&req)
|
err := c.ShouldBindJSON(&req)
|
||||||
switch {
|
switch {
|
||||||
|
@ -591,7 +509,7 @@ func PushModelHandler(c *gin.Context) {
|
||||||
streamResponse(c, ch)
|
streamResponse(c, ch)
|
||||||
}
|
}
|
||||||
|
|
||||||
func CreateModelHandler(c *gin.Context) {
|
func (s *Server) CreateModelHandler(c *gin.Context) {
|
||||||
var req api.CreateRequest
|
var req api.CreateRequest
|
||||||
err := c.ShouldBindJSON(&req)
|
err := c.ShouldBindJSON(&req)
|
||||||
switch {
|
switch {
|
||||||
|
@ -664,7 +582,7 @@ func CreateModelHandler(c *gin.Context) {
|
||||||
streamResponse(c, ch)
|
streamResponse(c, ch)
|
||||||
}
|
}
|
||||||
|
|
||||||
func DeleteModelHandler(c *gin.Context) {
|
func (s *Server) DeleteModelHandler(c *gin.Context) {
|
||||||
var req api.DeleteRequest
|
var req api.DeleteRequest
|
||||||
err := c.ShouldBindJSON(&req)
|
err := c.ShouldBindJSON(&req)
|
||||||
switch {
|
switch {
|
||||||
|
@ -709,7 +627,7 @@ func DeleteModelHandler(c *gin.Context) {
|
||||||
c.JSON(http.StatusOK, nil)
|
c.JSON(http.StatusOK, nil)
|
||||||
}
|
}
|
||||||
|
|
||||||
func ShowModelHandler(c *gin.Context) {
|
func (s *Server) ShowModelHandler(c *gin.Context) {
|
||||||
var req api.ShowRequest
|
var req api.ShowRequest
|
||||||
err := c.ShouldBindJSON(&req)
|
err := c.ShouldBindJSON(&req)
|
||||||
switch {
|
switch {
|
||||||
|
@ -809,7 +727,7 @@ func GetModelInfo(req api.ShowRequest) (*api.ShowResponse, error) {
|
||||||
return resp, nil
|
return resp, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func ListModelsHandler(c *gin.Context) {
|
func (s *Server) ListModelsHandler(c *gin.Context) {
|
||||||
models := make([]api.ModelResponse, 0)
|
models := make([]api.ModelResponse, 0)
|
||||||
manifestsPath, err := GetManifestPath()
|
manifestsPath, err := GetManifestPath()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
|
@ -869,7 +787,7 @@ func ListModelsHandler(c *gin.Context) {
|
||||||
c.JSON(http.StatusOK, api.ListResponse{Models: models})
|
c.JSON(http.StatusOK, api.ListResponse{Models: models})
|
||||||
}
|
}
|
||||||
|
|
||||||
func CopyModelHandler(c *gin.Context) {
|
func (s *Server) CopyModelHandler(c *gin.Context) {
|
||||||
var req api.CopyRequest
|
var req api.CopyRequest
|
||||||
err := c.ShouldBindJSON(&req)
|
err := c.ShouldBindJSON(&req)
|
||||||
switch {
|
switch {
|
||||||
|
@ -901,7 +819,7 @@ func CopyModelHandler(c *gin.Context) {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func HeadBlobHandler(c *gin.Context) {
|
func (s *Server) HeadBlobHandler(c *gin.Context) {
|
||||||
path, err := GetBlobsPath(c.Param("digest"))
|
path, err := GetBlobsPath(c.Param("digest"))
|
||||||
if err != nil {
|
if err != nil {
|
||||||
c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": err.Error()})
|
c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": err.Error()})
|
||||||
|
@ -916,7 +834,7 @@ func HeadBlobHandler(c *gin.Context) {
|
||||||
c.Status(http.StatusOK)
|
c.Status(http.StatusOK)
|
||||||
}
|
}
|
||||||
|
|
||||||
func CreateBlobHandler(c *gin.Context) {
|
func (s *Server) CreateBlobHandler(c *gin.Context) {
|
||||||
path, err := GetBlobsPath(c.Param("digest"))
|
path, err := GetBlobsPath(c.Param("digest"))
|
||||||
if err != nil {
|
if err != nil {
|
||||||
c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": err.Error()})
|
c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": err.Error()})
|
||||||
|
@ -1063,27 +981,27 @@ func (s *Server) GenerateRoutes() http.Handler {
|
||||||
allowedHostsMiddleware(s.addr),
|
allowedHostsMiddleware(s.addr),
|
||||||
)
|
)
|
||||||
|
|
||||||
r.POST("/api/pull", PullModelHandler)
|
r.POST("/api/pull", s.PullModelHandler)
|
||||||
r.POST("/api/generate", GenerateHandler)
|
r.POST("/api/generate", s.GenerateHandler)
|
||||||
r.POST("/api/chat", ChatHandler)
|
r.POST("/api/chat", s.ChatHandler)
|
||||||
r.POST("/api/embeddings", EmbeddingsHandler)
|
r.POST("/api/embeddings", s.EmbeddingsHandler)
|
||||||
r.POST("/api/create", CreateModelHandler)
|
r.POST("/api/create", s.CreateModelHandler)
|
||||||
r.POST("/api/push", PushModelHandler)
|
r.POST("/api/push", s.PushModelHandler)
|
||||||
r.POST("/api/copy", CopyModelHandler)
|
r.POST("/api/copy", s.CopyModelHandler)
|
||||||
r.DELETE("/api/delete", DeleteModelHandler)
|
r.DELETE("/api/delete", s.DeleteModelHandler)
|
||||||
r.POST("/api/show", ShowModelHandler)
|
r.POST("/api/show", s.ShowModelHandler)
|
||||||
r.POST("/api/blobs/:digest", CreateBlobHandler)
|
r.POST("/api/blobs/:digest", s.CreateBlobHandler)
|
||||||
r.HEAD("/api/blobs/:digest", HeadBlobHandler)
|
r.HEAD("/api/blobs/:digest", s.HeadBlobHandler)
|
||||||
|
|
||||||
// Compatibility endpoints
|
// Compatibility endpoints
|
||||||
r.POST("/v1/chat/completions", openai.Middleware(), ChatHandler)
|
r.POST("/v1/chat/completions", openai.Middleware(), s.ChatHandler)
|
||||||
|
|
||||||
for _, method := range []string{http.MethodGet, http.MethodHead} {
|
for _, method := range []string{http.MethodGet, http.MethodHead} {
|
||||||
r.Handle(method, "/", func(c *gin.Context) {
|
r.Handle(method, "/", func(c *gin.Context) {
|
||||||
c.String(http.StatusOK, "Ollama is running")
|
c.String(http.StatusOK, "Ollama is running")
|
||||||
})
|
})
|
||||||
|
|
||||||
r.Handle(method, "/api/tags", ListModelsHandler)
|
r.Handle(method, "/api/tags", s.ListModelsHandler)
|
||||||
r.Handle(method, "/api/version", func(c *gin.Context) {
|
r.Handle(method, "/api/version", func(c *gin.Context) {
|
||||||
c.JSON(http.StatusOK, gin.H{"version": version.Version})
|
c.JSON(http.StatusOK, gin.H{"version": version.Version})
|
||||||
})
|
})
|
||||||
|
@ -1137,7 +1055,9 @@ func Serve(ln net.Listener) error {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
s := &Server{addr: ln.Addr()}
|
ctx, done := context.WithCancel(context.Background())
|
||||||
|
sched := InitScheduler(ctx)
|
||||||
|
s := &Server{addr: ln.Addr(), sched: sched}
|
||||||
r := s.GenerateRoutes()
|
r := s.GenerateRoutes()
|
||||||
|
|
||||||
slog.Info(fmt.Sprintf("Listening on %s (version %s)", ln.Addr(), version.Version))
|
slog.Info(fmt.Sprintf("Listening on %s (version %s)", ln.Addr(), version.Version))
|
||||||
|
@ -1150,7 +1070,8 @@ func Serve(ln net.Listener) error {
|
||||||
signal.Notify(signals, syscall.SIGINT, syscall.SIGTERM)
|
signal.Notify(signals, syscall.SIGINT, syscall.SIGTERM)
|
||||||
go func() {
|
go func() {
|
||||||
<-signals
|
<-signals
|
||||||
unload()
|
done()
|
||||||
|
sched.unloadAllRunners()
|
||||||
gpu.Cleanup()
|
gpu.Cleanup()
|
||||||
os.Exit(0)
|
os.Exit(0)
|
||||||
}()
|
}()
|
||||||
|
@ -1158,12 +1079,12 @@ func Serve(ln net.Listener) error {
|
||||||
if err := llm.Init(); err != nil {
|
if err := llm.Init(); err != nil {
|
||||||
return fmt.Errorf("unable to initialize llm library %w", err)
|
return fmt.Errorf("unable to initialize llm library %w", err)
|
||||||
}
|
}
|
||||||
if runtime.GOOS == "linux" { // TODO - windows too
|
|
||||||
// check compatibility to log warnings
|
s.sched.Run(ctx)
|
||||||
if _, err := gpu.CheckVRAM(); err != nil {
|
|
||||||
slog.Info(err.Error())
|
// At startup we retrieve GPU information so we can get log messages before loading a model
|
||||||
}
|
// This will log warnings to the log in case we have problems with detected GPUs
|
||||||
}
|
_ = gpu.GetGPUInfo()
|
||||||
|
|
||||||
return srvr.Serve(ln)
|
return srvr.Serve(ln)
|
||||||
}
|
}
|
||||||
|
@ -1219,9 +1140,9 @@ func streamResponse(c *gin.Context, ch chan any) {
|
||||||
}
|
}
|
||||||
|
|
||||||
// ChatPrompt builds up a prompt from a series of messages for the currently `loaded` model
|
// ChatPrompt builds up a prompt from a series of messages for the currently `loaded` model
|
||||||
func chatPrompt(ctx context.Context, template string, messages []api.Message, numCtx int) (string, error) {
|
func chatPrompt(ctx context.Context, runner *runnerRef, template string, messages []api.Message, numCtx int) (string, error) {
|
||||||
encode := func(s string) ([]int, error) {
|
encode := func(s string) ([]int, error) {
|
||||||
return loaded.llama.Tokenize(ctx, s)
|
return runner.llama.Tokenize(ctx, s)
|
||||||
}
|
}
|
||||||
|
|
||||||
prompt, err := ChatPrompt(template, messages, numCtx, encode)
|
prompt, err := ChatPrompt(template, messages, numCtx, encode)
|
||||||
|
@ -1232,10 +1153,7 @@ func chatPrompt(ctx context.Context, template string, messages []api.Message, nu
|
||||||
return prompt, nil
|
return prompt, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func ChatHandler(c *gin.Context) {
|
func (s *Server) ChatHandler(c *gin.Context) {
|
||||||
loaded.mu.Lock()
|
|
||||||
defer loaded.mu.Unlock()
|
|
||||||
|
|
||||||
checkpointStart := time.Now()
|
checkpointStart := time.Now()
|
||||||
|
|
||||||
var req api.ChatRequest
|
var req api.ChatRequest
|
||||||
|
@ -1292,7 +1210,11 @@ func ChatHandler(c *gin.Context) {
|
||||||
sessionDuration = req.KeepAlive.Duration
|
sessionDuration = req.KeepAlive.Duration
|
||||||
}
|
}
|
||||||
|
|
||||||
if err := load(c, model, opts, sessionDuration); err != nil {
|
rCh, eCh := s.sched.GetRunner(c.Request.Context(), model, opts, sessionDuration)
|
||||||
|
var runner *runnerRef
|
||||||
|
select {
|
||||||
|
case runner = <-rCh:
|
||||||
|
case err = <-eCh:
|
||||||
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
|
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
@ -1309,7 +1231,7 @@ func ChatHandler(c *gin.Context) {
|
||||||
}, req.Messages...)
|
}, req.Messages...)
|
||||||
}
|
}
|
||||||
|
|
||||||
prompt, err := chatPrompt(c.Request.Context(), model.Template, req.Messages, opts.NumCtx)
|
prompt, err := chatPrompt(c.Request.Context(), runner, model.Template, req.Messages, opts.NumCtx)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()})
|
c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()})
|
||||||
return
|
return
|
||||||
|
@ -1352,8 +1274,6 @@ func ChatHandler(c *gin.Context) {
|
||||||
defer close(ch)
|
defer close(ch)
|
||||||
|
|
||||||
fn := func(r llm.CompletionResponse) {
|
fn := func(r llm.CompletionResponse) {
|
||||||
// Update model expiration
|
|
||||||
loaded.expireTimer.Reset(sessionDuration)
|
|
||||||
|
|
||||||
resp := api.ChatResponse{
|
resp := api.ChatResponse{
|
||||||
Model: req.Model,
|
Model: req.Model,
|
||||||
|
@ -1376,7 +1296,7 @@ func ChatHandler(c *gin.Context) {
|
||||||
ch <- resp
|
ch <- resp
|
||||||
}
|
}
|
||||||
|
|
||||||
if err := loaded.llama.Completion(c.Request.Context(), llm.CompletionRequest{
|
if err := runner.llama.Completion(c.Request.Context(), llm.CompletionRequest{
|
||||||
Prompt: prompt,
|
Prompt: prompt,
|
||||||
Format: req.Format,
|
Format: req.Format,
|
||||||
Images: images,
|
Images: images,
|
||||||
|
|
525
server/sched.go
Normal file
525
server/sched.go
Normal file
|
@ -0,0 +1,525 @@
|
||||||
|
package server
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"errors"
|
||||||
|
"fmt"
|
||||||
|
"log/slog"
|
||||||
|
"os"
|
||||||
|
"reflect"
|
||||||
|
"sort"
|
||||||
|
"strconv"
|
||||||
|
"strings"
|
||||||
|
"sync"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"github.com/ollama/ollama/api"
|
||||||
|
"github.com/ollama/ollama/format"
|
||||||
|
"github.com/ollama/ollama/gpu"
|
||||||
|
"github.com/ollama/ollama/llm"
|
||||||
|
"golang.org/x/exp/slices"
|
||||||
|
)
|
||||||
|
|
||||||
|
type LlmRequest struct {
|
||||||
|
ctx context.Context //nolint:containedctx
|
||||||
|
model *Model
|
||||||
|
ggml *llm.GGML // TODO - how large is this, and do we need to free it after we've finished loading?
|
||||||
|
opts api.Options
|
||||||
|
sessionDuration time.Duration
|
||||||
|
successCh chan *runnerRef
|
||||||
|
errCh chan error
|
||||||
|
}
|
||||||
|
|
||||||
|
type Scheduler struct {
|
||||||
|
pendingReqCh chan *LlmRequest
|
||||||
|
finishedReqCh chan *LlmRequest
|
||||||
|
expiredCh chan *runnerRef
|
||||||
|
unloadedCh chan interface{}
|
||||||
|
|
||||||
|
loaded map[string]*runnerRef
|
||||||
|
loadedMu sync.Mutex
|
||||||
|
|
||||||
|
loadFn func(req *LlmRequest, gpus gpu.GpuInfoList)
|
||||||
|
newServerFn func(gpus gpu.GpuInfoList, model string, ggml *llm.GGML, adapters []string, projectors []string, opts api.Options) (llm.LlamaServer, error)
|
||||||
|
getGpuFn func() gpu.GpuInfoList
|
||||||
|
}
|
||||||
|
|
||||||
|
// TODO set this to zero after a release or two, to enable multiple models by default
|
||||||
|
var loadedMax = 1 // Maximum runners; < 1 maps to as many as will fit in VRAM (unlimited for CPU runners)
|
||||||
|
var maxQueuedRequests = 10 // TODO configurable
|
||||||
|
|
||||||
|
func InitScheduler(ctx context.Context) *Scheduler {
|
||||||
|
maxRunners := os.Getenv("OLLAMA_MAX_LOADED_MODELS")
|
||||||
|
if maxRunners != "" {
|
||||||
|
m, err := strconv.Atoi(maxRunners)
|
||||||
|
if err != nil {
|
||||||
|
slog.Error("invalid setting", "OLLAMA_MAX_LOADED_MODELS", maxRunners, "error", err)
|
||||||
|
} else {
|
||||||
|
loadedMax = m
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
sched := &Scheduler{
|
||||||
|
pendingReqCh: make(chan *LlmRequest, maxQueuedRequests),
|
||||||
|
finishedReqCh: make(chan *LlmRequest, maxQueuedRequests),
|
||||||
|
expiredCh: make(chan *runnerRef, maxQueuedRequests),
|
||||||
|
unloadedCh: make(chan interface{}, maxQueuedRequests),
|
||||||
|
loaded: make(map[string]*runnerRef),
|
||||||
|
newServerFn: llm.NewLlamaServer,
|
||||||
|
getGpuFn: gpu.GetGPUInfo,
|
||||||
|
}
|
||||||
|
sched.loadFn = sched.load
|
||||||
|
return sched
|
||||||
|
}
|
||||||
|
|
||||||
|
// context must be canceled to decrement ref count and release the runner
|
||||||
|
func (s *Scheduler) GetRunner(c context.Context, model *Model, opts api.Options, sessionDuration time.Duration) (chan *runnerRef, chan error) {
|
||||||
|
ggml, err := llm.LoadModel(model.ModelPath)
|
||||||
|
req := &LlmRequest{
|
||||||
|
ctx: c,
|
||||||
|
model: model,
|
||||||
|
ggml: ggml,
|
||||||
|
opts: opts,
|
||||||
|
sessionDuration: sessionDuration,
|
||||||
|
successCh: make(chan *runnerRef),
|
||||||
|
errCh: make(chan error, 1),
|
||||||
|
}
|
||||||
|
if err != nil {
|
||||||
|
req.errCh <- err
|
||||||
|
return req.successCh, req.errCh
|
||||||
|
}
|
||||||
|
select {
|
||||||
|
case s.pendingReqCh <- req:
|
||||||
|
default:
|
||||||
|
req.errCh <- fmt.Errorf("server busy, please try again. maximum pending requests exceeded")
|
||||||
|
}
|
||||||
|
return req.successCh, req.errCh
|
||||||
|
}
|
||||||
|
|
||||||
|
// Returns immediately, spawns go routines for the scheduler which will shutdown when ctx is done
|
||||||
|
func (s *Scheduler) Run(ctx context.Context) {
|
||||||
|
slog.Debug("starting llm scheduler")
|
||||||
|
go func() {
|
||||||
|
s.processPending(ctx)
|
||||||
|
}()
|
||||||
|
|
||||||
|
go func() {
|
||||||
|
s.processCompleted(ctx)
|
||||||
|
}()
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *Scheduler) processPending(ctx context.Context) {
|
||||||
|
for {
|
||||||
|
select {
|
||||||
|
case <-ctx.Done():
|
||||||
|
slog.Debug("shutting down scheduler pending loop")
|
||||||
|
return
|
||||||
|
case pending := <-s.pendingReqCh:
|
||||||
|
// Block other requests until we get this pending request running
|
||||||
|
for {
|
||||||
|
var runnerToExpire *runnerRef
|
||||||
|
s.loadedMu.Lock()
|
||||||
|
runner := s.loaded[pending.model.ModelPath]
|
||||||
|
loadedCount := len(s.loaded)
|
||||||
|
s.loadedMu.Unlock()
|
||||||
|
if runner != nil {
|
||||||
|
if runner.needsReload(ctx, pending) {
|
||||||
|
runnerToExpire = runner
|
||||||
|
} else {
|
||||||
|
// Runner is usable, return it
|
||||||
|
pending.useLoadedRunner(runner, s.finishedReqCh)
|
||||||
|
break
|
||||||
|
}
|
||||||
|
} else if loadedCount == 0 {
|
||||||
|
slog.Debug("loading first model", "model", pending.model.ModelPath)
|
||||||
|
gpus := s.getGpuFn()
|
||||||
|
g := pickBestFitGPUs(pending, gpus)
|
||||||
|
if g != nil {
|
||||||
|
gpus = g
|
||||||
|
}
|
||||||
|
s.loadFn(pending, gpus)
|
||||||
|
break
|
||||||
|
} else if loadedMax > 0 && loadedCount >= loadedMax {
|
||||||
|
slog.Debug("max runners achieved, unloading one to make room", "runner_count", loadedCount)
|
||||||
|
runnerToExpire = s.findRunnerToUnload(pending)
|
||||||
|
} else {
|
||||||
|
// More than one loaded model, so we have to see if the new one fits
|
||||||
|
// Get a refreshed GPU list
|
||||||
|
gpus := s.getGpuFn()
|
||||||
|
// Update free memory from currently loaded models
|
||||||
|
s.updateFreeSpace(gpus)
|
||||||
|
gpus = pickBestFitGPUs(pending, gpus)
|
||||||
|
if gpus != nil {
|
||||||
|
slog.Debug("new model fits with existing models, loading")
|
||||||
|
s.loadFn(pending, gpus)
|
||||||
|
break
|
||||||
|
}
|
||||||
|
runnerToExpire = s.findRunnerToUnload(pending)
|
||||||
|
}
|
||||||
|
|
||||||
|
if runnerToExpire == nil {
|
||||||
|
// Shouildn't happen
|
||||||
|
slog.Error("runner to expire was nil!")
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
// Trigger an expiration to unload once it's done
|
||||||
|
runnerToExpire.refMu.Lock()
|
||||||
|
slog.Debug("resetting model to expire immediately to make room", "model", runnerToExpire.model, "refCount", runnerToExpire.refCount)
|
||||||
|
if runnerToExpire.expireTimer != nil {
|
||||||
|
runnerToExpire.expireTimer.Stop()
|
||||||
|
runnerToExpire.expireTimer = nil
|
||||||
|
}
|
||||||
|
runnerToExpire.sessionDuration = 0
|
||||||
|
if runnerToExpire.refCount <= 0 {
|
||||||
|
s.expiredCh <- runnerToExpire
|
||||||
|
}
|
||||||
|
runnerToExpire.refMu.Unlock()
|
||||||
|
// Wait for the unload to happen
|
||||||
|
// Note: at this point we're queueing up all incoming requests, even if they were for
|
||||||
|
// a different model that's loaded and not scheduled to be removed.
|
||||||
|
slog.Debug("waiting for pending requests to complete and unload to occur", "model", runnerToExpire.model)
|
||||||
|
select {
|
||||||
|
case <-ctx.Done():
|
||||||
|
slog.Debug("shutting down scheduler pending loop")
|
||||||
|
return
|
||||||
|
case <-s.unloadedCh:
|
||||||
|
slog.Debug("unload completed", "model", runnerToExpire.model)
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
}
|
||||||
|
case <-s.unloadedCh:
|
||||||
|
// An unload request when there are no pending request can be ignored
|
||||||
|
slog.Debug("ignoring unload event with no pending requests")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *Scheduler) processCompleted(ctx context.Context) {
|
||||||
|
// Process completed requests, expired timers, and unloading models
|
||||||
|
for {
|
||||||
|
select {
|
||||||
|
case <-ctx.Done():
|
||||||
|
slog.Debug("shutting down scheduler completed loop")
|
||||||
|
return
|
||||||
|
case finished := <-s.finishedReqCh:
|
||||||
|
s.loadedMu.Lock()
|
||||||
|
runner := s.loaded[finished.model.ModelPath]
|
||||||
|
s.loadedMu.Unlock()
|
||||||
|
if runner == nil {
|
||||||
|
slog.Error("finished requeset signal received after model unloaded", "model", finished.model.ModelPath)
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
runner.refMu.Lock()
|
||||||
|
runner.refCount--
|
||||||
|
if runner.refCount <= 0 {
|
||||||
|
if runner.sessionDuration <= 0 {
|
||||||
|
slog.Debug("runner with zero duration has gone idle, expiring to unload", "model", runner.model)
|
||||||
|
if runner.expireTimer != nil {
|
||||||
|
runner.expireTimer.Stop()
|
||||||
|
runner.expireTimer = nil
|
||||||
|
}
|
||||||
|
s.expiredCh <- runner
|
||||||
|
} else if runner.expireTimer == nil {
|
||||||
|
slog.Debug("runner with non-zero duration has gone idle, adding timer", "model", runner.model, "duration", runner.sessionDuration)
|
||||||
|
runner.expireTimer = time.AfterFunc(runner.sessionDuration, func() {
|
||||||
|
slog.Debug("timer expired, expiring to unload", "model", runner.model)
|
||||||
|
runner.refMu.Lock()
|
||||||
|
defer runner.refMu.Unlock()
|
||||||
|
if runner.expireTimer != nil {
|
||||||
|
runner.expireTimer.Stop()
|
||||||
|
}
|
||||||
|
s.expiredCh <- runner
|
||||||
|
})
|
||||||
|
} else {
|
||||||
|
slog.Debug("runner with non-zero duration has gone idle, resetting timer", "model", runner.model, "duration", runner.sessionDuration)
|
||||||
|
runner.expireTimer.Reset(runner.sessionDuration)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
slog.Debug("after processing request finished event", "model", runner.model, "refCount", runner.refCount)
|
||||||
|
runner.refMu.Unlock()
|
||||||
|
case runner := <-s.expiredCh:
|
||||||
|
slog.Debug("runner expired event received", "model", runner.model)
|
||||||
|
runner.refMu.Lock()
|
||||||
|
if runner.refCount > 0 {
|
||||||
|
// Shouldn't happen, but safeguard to ensure no leaked runners
|
||||||
|
slog.Debug("expired event with positive ref count, retrying", "model", runner.model, "refCount", runner.refCount)
|
||||||
|
go func(runner *runnerRef) {
|
||||||
|
// We can't unload yet, but want to as soon as the current request completes
|
||||||
|
// So queue up another expired event
|
||||||
|
time.Sleep(10 * time.Millisecond)
|
||||||
|
s.expiredCh <- runner
|
||||||
|
}(runner)
|
||||||
|
runner.refMu.Unlock()
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
slog.Debug("got lock to unload", "model", runner.model)
|
||||||
|
runner.unload()
|
||||||
|
s.loadedMu.Lock()
|
||||||
|
delete(s.loaded, runner.model)
|
||||||
|
s.loadedMu.Unlock()
|
||||||
|
slog.Debug("runner released", "model", runner.model)
|
||||||
|
runner.refMu.Unlock()
|
||||||
|
slog.Debug("sending an unloaded event", "model", runner.model)
|
||||||
|
s.unloadedCh <- struct{}{}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Complete the pending request and send the runner back to the requester
|
||||||
|
// Wires up a finished event after the request context is completed
|
||||||
|
// Updates session duration, and resets expiration timer
|
||||||
|
func (pending *LlmRequest) useLoadedRunner(runner *runnerRef, finished chan *LlmRequest) {
|
||||||
|
runner.refMu.Lock()
|
||||||
|
defer runner.refMu.Unlock()
|
||||||
|
runner.refCount++
|
||||||
|
runner.sessionDuration = pending.sessionDuration
|
||||||
|
pending.successCh <- runner
|
||||||
|
go func() {
|
||||||
|
<-pending.ctx.Done()
|
||||||
|
slog.Debug("context for request finished")
|
||||||
|
finished <- pending
|
||||||
|
}()
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *Scheduler) load(req *LlmRequest, gpus gpu.GpuInfoList) {
|
||||||
|
llama, err := s.newServerFn(gpus, req.model.ModelPath, req.ggml, req.model.AdapterPaths, req.model.ProjectorPaths, req.opts)
|
||||||
|
if err != nil {
|
||||||
|
// some older models are not compatible with newer versions of llama.cpp
|
||||||
|
// show a generalized compatibility error until there is a better way to
|
||||||
|
// check for model compatibility
|
||||||
|
if errors.Is(llm.ErrUnsupportedFormat, err) || strings.Contains(err.Error(), "failed to load model") {
|
||||||
|
err = fmt.Errorf("%v: this model may be incompatible with your version of Ollama. If you previously pulled this model, try updating it by running `ollama pull %s`", err, req.model.ShortName)
|
||||||
|
}
|
||||||
|
slog.Info("NewLlamaServer failed", "model", req.model.ModelPath, "error", err)
|
||||||
|
req.errCh <- err
|
||||||
|
return
|
||||||
|
}
|
||||||
|
runner := &runnerRef{}
|
||||||
|
runner.model = req.model.ModelPath
|
||||||
|
runner.adapters = req.model.AdapterPaths
|
||||||
|
runner.projectors = req.model.ProjectorPaths
|
||||||
|
runner.llama = llama
|
||||||
|
runner.Options = &req.opts
|
||||||
|
runner.sessionDuration = req.sessionDuration
|
||||||
|
runner.gpus = gpus
|
||||||
|
runner.estimatedVRAM = llama.EstimatedVRAM()
|
||||||
|
runner.loading = true
|
||||||
|
runner.refCount = 1
|
||||||
|
runner.refMu.Lock()
|
||||||
|
s.loadedMu.Lock()
|
||||||
|
s.loaded[req.model.ModelPath] = runner
|
||||||
|
slog.Info("loaded runners", "count", len(s.loaded))
|
||||||
|
s.loadedMu.Unlock()
|
||||||
|
|
||||||
|
go func() {
|
||||||
|
defer runner.refMu.Unlock()
|
||||||
|
if err = llama.WaitUntilRunning(req.ctx); err != nil {
|
||||||
|
slog.Error("error loading llama server", "error", err)
|
||||||
|
runner.refCount--
|
||||||
|
req.errCh <- err
|
||||||
|
slog.Debug("triggering expiration for failed load", "model", runner.model)
|
||||||
|
s.expiredCh <- runner
|
||||||
|
return
|
||||||
|
}
|
||||||
|
slog.Debug("finished setting up runner", "model", req.model.ModelPath)
|
||||||
|
runner.loading = false
|
||||||
|
go func() {
|
||||||
|
<-req.ctx.Done()
|
||||||
|
slog.Debug("context for request finished")
|
||||||
|
s.finishedReqCh <- req
|
||||||
|
}()
|
||||||
|
req.successCh <- runner
|
||||||
|
}()
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *Scheduler) updateFreeSpace(allGpus gpu.GpuInfoList) {
|
||||||
|
type predKey struct {
|
||||||
|
Library string
|
||||||
|
ID string
|
||||||
|
}
|
||||||
|
predMap := map[predKey]uint64{} // Sum up the total predicted usage per GPU for all runners
|
||||||
|
s.loadedMu.Lock()
|
||||||
|
for _, r := range s.loaded {
|
||||||
|
r.refMu.Lock()
|
||||||
|
gpuIDs := make([]string, 0, len(r.gpus))
|
||||||
|
if r.llama != nil {
|
||||||
|
|
||||||
|
// TODO this should be broken down by GPU instead of assuming uniform spread
|
||||||
|
estimatedVRAMPerGPU := r.llama.EstimatedVRAM() / uint64(len(r.gpus))
|
||||||
|
for _, gpu := range r.gpus {
|
||||||
|
gpuIDs = append(gpuIDs, gpu.ID)
|
||||||
|
}
|
||||||
|
for _, gpu := range allGpus {
|
||||||
|
if slices.Contains(gpuIDs, gpu.ID) {
|
||||||
|
predMap[predKey{gpu.Library, gpu.ID}] += estimatedVRAMPerGPU
|
||||||
|
}
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
slog.Warn("unexpected nil runner reference, memory prediction may be incorrect")
|
||||||
|
}
|
||||||
|
r.refMu.Unlock()
|
||||||
|
}
|
||||||
|
s.loadedMu.Unlock()
|
||||||
|
|
||||||
|
// Now that we've summed up all the GPU usage predictions across all the loaded runners, update the gpu list
|
||||||
|
for i := range allGpus {
|
||||||
|
if p, ok := predMap[predKey{allGpus[i].Library, allGpus[i].ID}]; ok {
|
||||||
|
slog.Debug("gpu reported", "gpu", allGpus[i].ID, "library", allGpus[i].Library, "available", format.HumanBytes2(allGpus[i].FreeMemory))
|
||||||
|
if p > allGpus[i].TotalMemory {
|
||||||
|
// Shouldn't happen
|
||||||
|
slog.Warn("predicted usage exceeds VRAM", "gpu", allGpus[i].ID, "totalMemory", allGpus[i].TotalMemory, "predicted", p)
|
||||||
|
allGpus[i].FreeMemory = 0
|
||||||
|
} else if (allGpus[i].TotalMemory - p) < allGpus[i].FreeMemory { // predicted free is smaller than reported free, use it
|
||||||
|
// TODO maybe we should just always trust our numbers, since cuda's free memory reporting is laggy
|
||||||
|
// and we might unload models we didn't actually need to. The risk is if some other GPU intensive app is loaded
|
||||||
|
// after we start our first runner, then we'll never acount for that, so picking the smallest free value seems prudent.
|
||||||
|
allGpus[i].FreeMemory = allGpus[i].TotalMemory - p
|
||||||
|
}
|
||||||
|
slog.Info("updated VRAM", "gpu", allGpus[i].ID, "library", allGpus[i].Library, "total", format.HumanBytes2(allGpus[i].TotalMemory), "available", format.HumanBytes2(allGpus[i].FreeMemory))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
type runnerRef struct {
|
||||||
|
refMu sync.Mutex
|
||||||
|
// refCond sync.Cond // Signaled on transition from 1 -> 0 refCount
|
||||||
|
refCount uint // prevent unloading if > 0
|
||||||
|
// unloading bool // set to true when we are trying to unload the runner
|
||||||
|
|
||||||
|
llama llm.LlamaServer
|
||||||
|
loading bool // True only during initial load, then false forever
|
||||||
|
gpus gpu.GpuInfoList // Recorded at time of provisioning
|
||||||
|
estimatedVRAM uint64
|
||||||
|
|
||||||
|
sessionDuration time.Duration
|
||||||
|
expireTimer *time.Timer
|
||||||
|
|
||||||
|
model string
|
||||||
|
adapters []string
|
||||||
|
projectors []string
|
||||||
|
*api.Options
|
||||||
|
}
|
||||||
|
|
||||||
|
// The refMu must already be held when calling unload
|
||||||
|
func (runner *runnerRef) unload() {
|
||||||
|
if runner.llama != nil {
|
||||||
|
runner.llama.Close()
|
||||||
|
}
|
||||||
|
runner.llama = nil
|
||||||
|
runner.adapters = nil
|
||||||
|
runner.projectors = nil
|
||||||
|
runner.Options = nil
|
||||||
|
runner.gpus = nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (runner *runnerRef) needsReload(ctx context.Context, req *LlmRequest) bool {
|
||||||
|
slog.Debug("evaluating already loaded", "model", req.model.ModelPath)
|
||||||
|
runner.refMu.Lock()
|
||||||
|
defer runner.refMu.Unlock()
|
||||||
|
// Ignore the NumGPU settings for comparison
|
||||||
|
optsExisting := runner.Options.Runner
|
||||||
|
optsExisting.NumGPU = -1
|
||||||
|
optsNew := req.opts.Runner
|
||||||
|
optsNew.NumGPU = -1
|
||||||
|
timeout := 10 * time.Second
|
||||||
|
if runner.loading {
|
||||||
|
timeout = 2 * time.Minute // Initial load can take a long time for big models on slow systems...
|
||||||
|
}
|
||||||
|
ctx, cancel := context.WithTimeout(ctx, timeout) // BUG -
|
||||||
|
defer cancel()
|
||||||
|
if !reflect.DeepEqual(runner.adapters, req.model.AdapterPaths) || // have the adapters changed?
|
||||||
|
!reflect.DeepEqual(runner.projectors, req.model.ProjectorPaths) || // have the projectors changed?
|
||||||
|
!reflect.DeepEqual(optsExisting, optsNew) || // have the runner options changed?
|
||||||
|
runner.llama.Ping(ctx) != nil {
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
|
||||||
|
type ByDuration []*runnerRef
|
||||||
|
|
||||||
|
func (a ByDuration) Len() int { return len(a) }
|
||||||
|
func (a ByDuration) Swap(i, j int) { a[i], a[j] = a[j], a[i] }
|
||||||
|
func (a ByDuration) Less(i, j int) bool {
|
||||||
|
// uint64 to turn negative time (never unload) to largest
|
||||||
|
return uint64(a[i].sessionDuration) < uint64(a[j].sessionDuration)
|
||||||
|
}
|
||||||
|
|
||||||
|
// TODO - future consideration to pick runners based on size
|
||||||
|
// type BySize []*runnerRef
|
||||||
|
// func (a BySize) Len() int { return len(a) }
|
||||||
|
// func (a BySize) Swap(i, j int) { a[i], a[j] = a[j], a[i] }
|
||||||
|
// func (a BySize) Less(i, j int) bool { return a[i].estimatedVRAM < a[j].estimatedVRAM }
|
||||||
|
|
||||||
|
// pickBestFitGPUs will try to find the optimal placement of the model in the available GPUs where the model fully fits
|
||||||
|
// If the model can not be fit fully within the available GPU(s) nil is returned
|
||||||
|
func pickBestFitGPUs(req *LlmRequest, gpus gpu.GpuInfoList) gpu.GpuInfoList {
|
||||||
|
var estimatedVRAM uint64
|
||||||
|
for _, gl := range gpus.ByLibrary() {
|
||||||
|
var ok bool
|
||||||
|
sgl := append(make(gpu.GpuInfoList, 0, len(gl)), gl...)
|
||||||
|
|
||||||
|
// TODO - potentially sort by performance capability, existing models loaded, etc.
|
||||||
|
// Note: at present, this will favor more VRAM over faster GPU speed in mixed setups
|
||||||
|
sort.Sort(sort.Reverse(gpu.ByFreeMemory(sgl)))
|
||||||
|
|
||||||
|
// First attempt to fit the model into a single GPU
|
||||||
|
for _, g := range sgl {
|
||||||
|
if ok, estimatedVRAM = llm.PredictServerFit([]gpu.GpuInfo{g}, req.ggml, req.model.AdapterPaths, req.model.ProjectorPaths, req.opts); ok {
|
||||||
|
slog.Debug("new model will fit in available VRAM in single GPU, loading", "model", req.model.ModelPath, "gpu", g.ID, "available", g.FreeMemory, "required", format.HumanBytes2(estimatedVRAM))
|
||||||
|
return []gpu.GpuInfo{g}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// TODO future refinements
|
||||||
|
// - if multiple Libraries, see if any single GPU in any Library will fit
|
||||||
|
// - try subsets of GPUs instead of just falling back to 1 or all in a family
|
||||||
|
|
||||||
|
// Now try all the GPUs
|
||||||
|
if ok, estimatedVRAM = llm.PredictServerFit(gl, req.ggml, req.model.AdapterPaths, req.model.ProjectorPaths, req.opts); ok {
|
||||||
|
slog.Debug("new model will fit in available VRAM, loading", "model", req.model.ModelPath, "library", gl[0].Library, "required", format.HumanBytes2(estimatedVRAM))
|
||||||
|
return gl
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// findRunnerToUnload finds a runner to unload to make room for a new model
|
||||||
|
func (s *Scheduler) findRunnerToUnload(req *LlmRequest) *runnerRef {
|
||||||
|
s.loadedMu.Lock()
|
||||||
|
runnerList := make([]*runnerRef, 0, len(s.loaded))
|
||||||
|
for _, r := range s.loaded {
|
||||||
|
runnerList = append(runnerList, r)
|
||||||
|
}
|
||||||
|
s.loadedMu.Unlock()
|
||||||
|
|
||||||
|
// In the future we can enhance the algorithm to be smarter about picking the optimal runner to unload
|
||||||
|
// e.g., if we have multiple options, will one make room for the request?
|
||||||
|
sort.Sort(ByDuration(runnerList))
|
||||||
|
|
||||||
|
// First try to find a runner that's already idle
|
||||||
|
for _, runner := range runnerList {
|
||||||
|
runner.refMu.Lock()
|
||||||
|
rc := runner.refCount
|
||||||
|
runner.refMu.Unlock()
|
||||||
|
if rc == 0 {
|
||||||
|
slog.Debug("found an idle runner to unload")
|
||||||
|
return runner
|
||||||
|
}
|
||||||
|
}
|
||||||
|
// None appear idle, just wait for the one with the shortest duration
|
||||||
|
slog.Debug("no idle runners, picking the shortest duration", "count", len(runnerList))
|
||||||
|
return runnerList[0]
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *Scheduler) unloadAllRunners() {
|
||||||
|
s.loadedMu.Lock()
|
||||||
|
defer s.loadedMu.Unlock()
|
||||||
|
for model, runner := range s.loaded {
|
||||||
|
if runner.llama != nil {
|
||||||
|
slog.Debug("shutting down runner", "model", model)
|
||||||
|
runner.llama.Close()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
553
server/sched_test.go
Normal file
553
server/sched_test.go
Normal file
|
@ -0,0 +1,553 @@
|
||||||
|
package server
|
||||||
|
|
||||||
|
import (
|
||||||
|
"bytes"
|
||||||
|
"context"
|
||||||
|
"encoding/binary"
|
||||||
|
"fmt"
|
||||||
|
"log/slog"
|
||||||
|
"os"
|
||||||
|
"testing"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"github.com/ollama/ollama/api"
|
||||||
|
"github.com/ollama/ollama/app/lifecycle"
|
||||||
|
"github.com/ollama/ollama/format"
|
||||||
|
"github.com/ollama/ollama/gpu"
|
||||||
|
"github.com/ollama/ollama/llm"
|
||||||
|
"github.com/stretchr/testify/assert"
|
||||||
|
"github.com/stretchr/testify/require"
|
||||||
|
)
|
||||||
|
|
||||||
|
func init() {
|
||||||
|
os.Setenv("OLLAMA_DEBUG", "1")
|
||||||
|
lifecycle.InitLogging()
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestInitScheduler(t *testing.T) {
|
||||||
|
ctx, done := context.WithCancel(context.Background())
|
||||||
|
defer done()
|
||||||
|
initialMax := loadedMax
|
||||||
|
s := InitScheduler(ctx)
|
||||||
|
require.Equal(t, initialMax, loadedMax)
|
||||||
|
require.NotNil(t, s.loaded)
|
||||||
|
|
||||||
|
os.Setenv("OLLAMA_MAX_LOADED_MODELS", "blue")
|
||||||
|
s = InitScheduler(ctx)
|
||||||
|
require.Equal(t, initialMax, loadedMax)
|
||||||
|
require.NotNil(t, s.loaded)
|
||||||
|
|
||||||
|
os.Setenv("OLLAMA_MAX_LOADED_MODELS", "0")
|
||||||
|
s = InitScheduler(ctx)
|
||||||
|
require.Equal(t, 0, loadedMax)
|
||||||
|
require.NotNil(t, s.loaded)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestLoad(t *testing.T) {
|
||||||
|
ctx, done := context.WithTimeout(context.Background(), 5*time.Millisecond)
|
||||||
|
defer done()
|
||||||
|
s := InitScheduler(ctx)
|
||||||
|
req := &LlmRequest{
|
||||||
|
ctx: ctx,
|
||||||
|
model: &Model{ModelPath: "foo"},
|
||||||
|
successCh: make(chan *runnerRef, 1),
|
||||||
|
errCh: make(chan error, 1),
|
||||||
|
sessionDuration: 2,
|
||||||
|
}
|
||||||
|
// Fail to load model first
|
||||||
|
s.newServerFn = func(gpus gpu.GpuInfoList, model string, ggml *llm.GGML, adapters []string, projectors []string, opts api.Options) (llm.LlamaServer, error) {
|
||||||
|
return nil, fmt.Errorf("something failed to load model blah")
|
||||||
|
}
|
||||||
|
gpus := gpu.GpuInfoList{}
|
||||||
|
s.load(req, gpus)
|
||||||
|
require.Len(t, req.successCh, 0)
|
||||||
|
require.Len(t, req.errCh, 1)
|
||||||
|
require.Len(t, s.loaded, 0)
|
||||||
|
err := <-req.errCh
|
||||||
|
require.Contains(t, err.Error(), "this model may be incompatible")
|
||||||
|
|
||||||
|
server := &mockLlm{estimatedVRAM: 10}
|
||||||
|
s.newServerFn = func(gpus gpu.GpuInfoList, model string, ggml *llm.GGML, adapters []string, projectors []string, opts api.Options) (llm.LlamaServer, error) {
|
||||||
|
return server, nil
|
||||||
|
}
|
||||||
|
s.load(req, gpus)
|
||||||
|
select {
|
||||||
|
case err := <-req.errCh:
|
||||||
|
require.NoError(t, err)
|
||||||
|
case resp := <-req.successCh:
|
||||||
|
require.Equal(t, uint64(10), resp.estimatedVRAM)
|
||||||
|
require.Equal(t, uint(1), resp.refCount)
|
||||||
|
require.Len(t, s.loaded, 1)
|
||||||
|
}
|
||||||
|
|
||||||
|
req.model.ModelPath = "dummy_model_path"
|
||||||
|
server.waitResp = fmt.Errorf("wait failure")
|
||||||
|
s.load(req, gpus)
|
||||||
|
select {
|
||||||
|
case err := <-req.errCh:
|
||||||
|
require.Contains(t, err.Error(), "wait failure")
|
||||||
|
case resp := <-req.successCh:
|
||||||
|
t.Errorf("unexpected success %v", resp)
|
||||||
|
}
|
||||||
|
runner := s.loaded["dummy_model_path"]
|
||||||
|
require.NotNil(t, runner)
|
||||||
|
require.Equal(t, uint(0), runner.refCount)
|
||||||
|
require.Len(t, s.expiredCh, 1)
|
||||||
|
}
|
||||||
|
|
||||||
|
type bundle struct {
|
||||||
|
ctx context.Context //nolint:containedctx
|
||||||
|
ctxDone func()
|
||||||
|
srv *mockLlm
|
||||||
|
req *LlmRequest
|
||||||
|
}
|
||||||
|
|
||||||
|
func (scenario *bundle) newServer(gpus gpu.GpuInfoList, model string, ggml *llm.GGML, adapters []string, projectors []string, opts api.Options) (llm.LlamaServer, error) {
|
||||||
|
return scenario.srv, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func newScenario(t *testing.T, ctx context.Context, modelName string, estimatedVRAM uint64) *bundle {
|
||||||
|
scenario := &bundle{}
|
||||||
|
scenario.ctx, scenario.ctxDone = context.WithCancel(ctx)
|
||||||
|
t.Helper()
|
||||||
|
|
||||||
|
f, err := os.CreateTemp(t.TempDir(), modelName)
|
||||||
|
assert.Nil(t, err)
|
||||||
|
defer f.Close()
|
||||||
|
|
||||||
|
gguf := llm.NewGGUFV3(binary.LittleEndian)
|
||||||
|
err = gguf.Encode(f, llm.KV{
|
||||||
|
"general.architecture": "llama",
|
||||||
|
"general.name": "name",
|
||||||
|
"llama.context_length": uint32(32),
|
||||||
|
"llama.embedding_length": uint32(4096),
|
||||||
|
"llama.block_count": uint32(1),
|
||||||
|
"llama.attention.head_count": uint32(32),
|
||||||
|
"llama.attention.head_count_kv": uint32(32),
|
||||||
|
"tokenizer.ggml.tokens": []string{" "},
|
||||||
|
"tokenizer.ggml.scores": []float32{0},
|
||||||
|
"tokenizer.ggml.token_type": []int32{0},
|
||||||
|
}, []llm.Tensor{
|
||||||
|
{Name: "blk.0.attn.weight", Kind: uint32(0), Offset: uint64(0), Shape: []uint64{1, 1, 1, 1}, WriterTo: &bytes.Reader{}},
|
||||||
|
})
|
||||||
|
assert.Nil(t, err)
|
||||||
|
fname := f.Name()
|
||||||
|
model := &Model{Name: modelName, ModelPath: fname}
|
||||||
|
ggml, err := llm.LoadModel(model.ModelPath)
|
||||||
|
require.NoError(t, err)
|
||||||
|
scenario.req = &LlmRequest{
|
||||||
|
ctx: scenario.ctx,
|
||||||
|
model: model,
|
||||||
|
ggml: ggml,
|
||||||
|
sessionDuration: 5 * time.Millisecond,
|
||||||
|
successCh: make(chan *runnerRef, 1),
|
||||||
|
errCh: make(chan error, 1),
|
||||||
|
}
|
||||||
|
scenario.srv = &mockLlm{estimatedVRAM: estimatedVRAM}
|
||||||
|
return scenario
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestRequests(t *testing.T) {
|
||||||
|
ctx, done := context.WithTimeout(context.Background(), 100*time.Millisecond)
|
||||||
|
defer done()
|
||||||
|
|
||||||
|
// Same model, same request
|
||||||
|
scenario1a := newScenario(t, ctx, "ollama-model-1", 10)
|
||||||
|
scenario1a.req.sessionDuration = 0
|
||||||
|
scenario1b := newScenario(t, ctx, "ollama-model-1", 11)
|
||||||
|
scenario1b.req.model = scenario1a.req.model
|
||||||
|
scenario1b.req.ggml = scenario1a.req.ggml
|
||||||
|
scenario1b.req.sessionDuration = 0
|
||||||
|
|
||||||
|
// simple reload of same model
|
||||||
|
scenario2a := newScenario(t, ctx, "ollama-model-1", 20)
|
||||||
|
scenario2a.req.model = scenario1a.req.model
|
||||||
|
scenario2a.req.ggml = scenario1a.req.ggml
|
||||||
|
|
||||||
|
// Multiple loaded models
|
||||||
|
scenario3a := newScenario(t, ctx, "ollama-model-3a", 1*format.GigaByte)
|
||||||
|
scenario3b := newScenario(t, ctx, "ollama-model-3b", 24*format.GigaByte)
|
||||||
|
scenario3c := newScenario(t, ctx, "ollama-model-3c", 30) // Needs prior unloaded
|
||||||
|
|
||||||
|
s := InitScheduler(ctx)
|
||||||
|
s.getGpuFn = func() gpu.GpuInfoList {
|
||||||
|
g := gpu.GpuInfo{Library: "metal"}
|
||||||
|
g.TotalMemory = 24 * format.GigaByte
|
||||||
|
g.FreeMemory = 12 * format.GigaByte
|
||||||
|
return []gpu.GpuInfo{g}
|
||||||
|
}
|
||||||
|
s.newServerFn = scenario1a.newServer
|
||||||
|
slog.Info("scenario1a")
|
||||||
|
s.pendingReqCh <- scenario1a.req
|
||||||
|
require.Len(t, s.pendingReqCh, 1)
|
||||||
|
s.Run(ctx)
|
||||||
|
select {
|
||||||
|
case resp := <-scenario1a.req.successCh:
|
||||||
|
require.Equal(t, resp.llama, scenario1a.srv)
|
||||||
|
require.Len(t, s.pendingReqCh, 0)
|
||||||
|
require.Len(t, scenario1a.req.errCh, 0)
|
||||||
|
case <-ctx.Done():
|
||||||
|
t.Errorf("timeout")
|
||||||
|
}
|
||||||
|
|
||||||
|
// Same runner as first request due to not needing a reload
|
||||||
|
s.newServerFn = scenario1b.newServer
|
||||||
|
slog.Info("scenario1b")
|
||||||
|
s.pendingReqCh <- scenario1b.req
|
||||||
|
select {
|
||||||
|
case resp := <-scenario1b.req.successCh:
|
||||||
|
require.Equal(t, resp.llama, scenario1a.srv)
|
||||||
|
require.Len(t, s.pendingReqCh, 0)
|
||||||
|
require.Len(t, scenario1b.req.errCh, 0)
|
||||||
|
case <-ctx.Done():
|
||||||
|
t.Errorf("timeout")
|
||||||
|
}
|
||||||
|
|
||||||
|
// Trigger a reload
|
||||||
|
s.newServerFn = scenario2a.newServer
|
||||||
|
scenario2a.req.model.AdapterPaths = []string{"new"}
|
||||||
|
slog.Info("scenario2a")
|
||||||
|
s.pendingReqCh <- scenario2a.req
|
||||||
|
// finish first two requests, so model can reload
|
||||||
|
time.Sleep(1 * time.Millisecond)
|
||||||
|
scenario1a.ctxDone()
|
||||||
|
scenario1b.ctxDone()
|
||||||
|
select {
|
||||||
|
case resp := <-scenario2a.req.successCh:
|
||||||
|
require.Equal(t, resp.llama, scenario2a.srv)
|
||||||
|
require.Len(t, s.pendingReqCh, 0)
|
||||||
|
require.Len(t, scenario2a.req.errCh, 0)
|
||||||
|
case <-ctx.Done():
|
||||||
|
t.Errorf("timeout")
|
||||||
|
}
|
||||||
|
|
||||||
|
loadedMax = 1
|
||||||
|
s.newServerFn = scenario3a.newServer
|
||||||
|
slog.Info("scenario3a")
|
||||||
|
s.pendingReqCh <- scenario3a.req
|
||||||
|
// finish prior request, so new model can load
|
||||||
|
time.Sleep(1 * time.Millisecond)
|
||||||
|
scenario2a.ctxDone()
|
||||||
|
select {
|
||||||
|
case resp := <-scenario3a.req.successCh:
|
||||||
|
require.Equal(t, resp.llama, scenario3a.srv)
|
||||||
|
require.Len(t, s.pendingReqCh, 0)
|
||||||
|
require.Len(t, scenario3a.req.errCh, 0)
|
||||||
|
case <-ctx.Done():
|
||||||
|
t.Errorf("timeout")
|
||||||
|
}
|
||||||
|
require.Len(t, s.loaded, 1)
|
||||||
|
|
||||||
|
loadedMax = 0
|
||||||
|
s.newServerFn = scenario3b.newServer
|
||||||
|
slog.Info("scenario3b")
|
||||||
|
s.pendingReqCh <- scenario3b.req
|
||||||
|
select {
|
||||||
|
case resp := <-scenario3b.req.successCh:
|
||||||
|
require.Equal(t, resp.llama, scenario3b.srv)
|
||||||
|
require.Len(t, s.pendingReqCh, 0)
|
||||||
|
require.Len(t, scenario3b.req.errCh, 0)
|
||||||
|
case <-ctx.Done():
|
||||||
|
t.Errorf("timeout")
|
||||||
|
}
|
||||||
|
require.Len(t, s.loaded, 2)
|
||||||
|
|
||||||
|
// Try to load a model that wont fit
|
||||||
|
s.newServerFn = scenario3c.newServer
|
||||||
|
slog.Info("scenario3c")
|
||||||
|
require.Len(t, s.loaded, 2)
|
||||||
|
scenario3a.ctxDone() // Won't help since this one isn't big enough to make room
|
||||||
|
time.Sleep(2 * time.Millisecond)
|
||||||
|
s.pendingReqCh <- scenario3c.req
|
||||||
|
// finish prior request, so new model can load
|
||||||
|
time.Sleep(6 * time.Millisecond)
|
||||||
|
require.Len(t, s.loaded, 1)
|
||||||
|
scenario3b.ctxDone()
|
||||||
|
select {
|
||||||
|
case resp := <-scenario3c.req.successCh:
|
||||||
|
require.Equal(t, resp.llama, scenario3c.srv)
|
||||||
|
require.Len(t, s.pendingReqCh, 0)
|
||||||
|
require.Len(t, scenario3c.req.errCh, 0)
|
||||||
|
case <-ctx.Done():
|
||||||
|
t.Errorf("timeout")
|
||||||
|
}
|
||||||
|
require.Len(t, s.loaded, 1)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestGetRunner(t *testing.T) {
|
||||||
|
ctx, done := context.WithTimeout(context.Background(), 20*time.Millisecond)
|
||||||
|
defer done()
|
||||||
|
|
||||||
|
// Same model, same request
|
||||||
|
scenario1a := newScenario(t, ctx, "ollama-model-1a", 10)
|
||||||
|
scenario1a.req.sessionDuration = 0
|
||||||
|
scenario1b := newScenario(t, ctx, "ollama-model-1b", 10)
|
||||||
|
scenario1b.req.sessionDuration = 0
|
||||||
|
scenario1c := newScenario(t, ctx, "ollama-model-1c", 10)
|
||||||
|
scenario1c.req.sessionDuration = 0
|
||||||
|
maxQueuedRequests = 1
|
||||||
|
s := InitScheduler(ctx)
|
||||||
|
s.getGpuFn = func() gpu.GpuInfoList {
|
||||||
|
g := gpu.GpuInfo{Library: "metal"}
|
||||||
|
g.TotalMemory = 24 * format.GigaByte
|
||||||
|
g.FreeMemory = 12 * format.GigaByte
|
||||||
|
return []gpu.GpuInfo{g}
|
||||||
|
}
|
||||||
|
s.newServerFn = scenario1a.newServer
|
||||||
|
slog.Info("scenario1a")
|
||||||
|
successCh1a, errCh1a := s.GetRunner(scenario1a.ctx, scenario1a.req.model, scenario1a.req.opts, scenario1a.req.sessionDuration)
|
||||||
|
require.Len(t, s.pendingReqCh, 1)
|
||||||
|
slog.Info("scenario1b")
|
||||||
|
successCh1b, errCh1b := s.GetRunner(scenario1b.ctx, scenario1b.req.model, scenario1b.req.opts, scenario1b.req.sessionDuration)
|
||||||
|
require.Len(t, s.pendingReqCh, 1)
|
||||||
|
require.Len(t, successCh1b, 0)
|
||||||
|
require.Len(t, errCh1b, 1)
|
||||||
|
err := <-errCh1b
|
||||||
|
require.Contains(t, err.Error(), "server busy")
|
||||||
|
s.Run(ctx)
|
||||||
|
select {
|
||||||
|
case resp := <-successCh1a:
|
||||||
|
require.Equal(t, resp.llama, scenario1a.srv)
|
||||||
|
require.Len(t, s.pendingReqCh, 0)
|
||||||
|
require.Len(t, errCh1a, 0)
|
||||||
|
case <-ctx.Done():
|
||||||
|
t.Errorf("timeout")
|
||||||
|
}
|
||||||
|
scenario1a.ctxDone()
|
||||||
|
require.Len(t, s.loaded, 1)
|
||||||
|
|
||||||
|
scenario1c.req.model.ModelPath = "bad path"
|
||||||
|
slog.Info("scenario1c")
|
||||||
|
successCh1c, errCh1c := s.GetRunner(scenario1c.ctx, scenario1c.req.model, scenario1c.req.opts, scenario1c.req.sessionDuration)
|
||||||
|
require.Len(t, s.pendingReqCh, 0)
|
||||||
|
require.Len(t, successCh1c, 0)
|
||||||
|
require.Len(t, errCh1c, 1)
|
||||||
|
err = <-errCh1c
|
||||||
|
require.Contains(t, err.Error(), "bad path")
|
||||||
|
scenario1b.ctxDone()
|
||||||
|
|
||||||
|
time.Sleep(5 * time.Millisecond)
|
||||||
|
require.Len(t, s.loaded, 0)
|
||||||
|
}
|
||||||
|
|
||||||
|
// TODO - add one scenario that triggers the bogus finished event with positive ref count
|
||||||
|
func TestPrematureExpired(t *testing.T) {
|
||||||
|
ctx, done := context.WithTimeout(context.Background(), 100*time.Millisecond)
|
||||||
|
defer done()
|
||||||
|
|
||||||
|
// Same model, same request
|
||||||
|
scenario1a := newScenario(t, ctx, "ollama-model-1a", 10)
|
||||||
|
s := InitScheduler(ctx)
|
||||||
|
s.getGpuFn = func() gpu.GpuInfoList {
|
||||||
|
g := gpu.GpuInfo{Library: "metal"}
|
||||||
|
g.TotalMemory = 24 * format.GigaByte
|
||||||
|
g.FreeMemory = 12 * format.GigaByte
|
||||||
|
return []gpu.GpuInfo{g}
|
||||||
|
}
|
||||||
|
s.newServerFn = scenario1a.newServer
|
||||||
|
successCh1a, errCh1a := s.GetRunner(scenario1a.ctx, scenario1a.req.model, scenario1a.req.opts, scenario1a.req.sessionDuration)
|
||||||
|
require.Len(t, s.pendingReqCh, 1)
|
||||||
|
s.Run(ctx)
|
||||||
|
select {
|
||||||
|
case resp := <-successCh1a:
|
||||||
|
require.Equal(t, resp.llama, scenario1a.srv)
|
||||||
|
require.Len(t, s.pendingReqCh, 0)
|
||||||
|
require.Len(t, errCh1a, 0)
|
||||||
|
require.Len(t, s.loaded, 1)
|
||||||
|
slog.Info("sending premature expired event now")
|
||||||
|
s.expiredCh <- resp // Shouldn't happen in real life, but make sure its safe
|
||||||
|
case <-ctx.Done():
|
||||||
|
t.Errorf("timeout")
|
||||||
|
}
|
||||||
|
time.Sleep(scenario1a.req.sessionDuration)
|
||||||
|
scenario1a.ctxDone()
|
||||||
|
time.Sleep(20 * time.Millisecond)
|
||||||
|
require.LessOrEqual(t, len(s.finishedReqCh), 1)
|
||||||
|
time.Sleep(10 * time.Millisecond)
|
||||||
|
require.Len(t, s.finishedReqCh, 0)
|
||||||
|
require.Len(t, s.loaded, 0)
|
||||||
|
|
||||||
|
// also shouldn't happen in real life
|
||||||
|
s.finishedReqCh <- scenario1a.req
|
||||||
|
time.Sleep(5 * time.Millisecond)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestUseLoadedRunner(t *testing.T) {
|
||||||
|
ctx, done := context.WithTimeout(context.Background(), 5*time.Millisecond)
|
||||||
|
req := &LlmRequest{
|
||||||
|
ctx: ctx,
|
||||||
|
successCh: make(chan *runnerRef, 1),
|
||||||
|
sessionDuration: 2,
|
||||||
|
}
|
||||||
|
finished := make(chan *LlmRequest)
|
||||||
|
llm1 := &mockLlm{}
|
||||||
|
r1 := &runnerRef{llama: llm1, sessionDuration: 1}
|
||||||
|
req.useLoadedRunner(r1, finished)
|
||||||
|
require.Equal(t, uint(1), r1.refCount)
|
||||||
|
require.Equal(t, time.Duration(2), r1.sessionDuration)
|
||||||
|
select {
|
||||||
|
case success := <-req.successCh:
|
||||||
|
require.Equal(t, r1, success)
|
||||||
|
case <-ctx.Done():
|
||||||
|
t.Errorf("timeout")
|
||||||
|
}
|
||||||
|
done()
|
||||||
|
fin := <-finished
|
||||||
|
require.Equal(t, req, fin)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestUpdateFreeSpace(t *testing.T) {
|
||||||
|
ctx, done := context.WithTimeout(context.Background(), 5*time.Millisecond)
|
||||||
|
defer done()
|
||||||
|
gpus := gpu.GpuInfoList{
|
||||||
|
{
|
||||||
|
Library: "a",
|
||||||
|
ID: "1",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
Library: "a",
|
||||||
|
ID: "2",
|
||||||
|
},
|
||||||
|
}
|
||||||
|
gpus[0].TotalMemory = 1000
|
||||||
|
gpus[0].FreeMemory = 900
|
||||||
|
gpus[1].TotalMemory = 2000
|
||||||
|
gpus[1].FreeMemory = 1900
|
||||||
|
llm1 := &mockLlm{estimatedVRAM: 100}
|
||||||
|
llm2 := &mockLlm{estimatedVRAM: 200}
|
||||||
|
r1 := &runnerRef{llama: llm1, gpus: gpus}
|
||||||
|
r2 := &runnerRef{llama: llm2, gpus: gpus}
|
||||||
|
|
||||||
|
s := InitScheduler(ctx)
|
||||||
|
s.loaded["a"] = r1
|
||||||
|
s.loaded["b"] = r2
|
||||||
|
|
||||||
|
s.updateFreeSpace(gpus)
|
||||||
|
require.Equal(t, uint64(850), gpus[0].FreeMemory)
|
||||||
|
require.Equal(t, uint64(1850), gpus[1].FreeMemory)
|
||||||
|
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestFindRunnerToUnload(t *testing.T) {
|
||||||
|
ctx, done := context.WithTimeout(context.Background(), 5*time.Millisecond)
|
||||||
|
defer done()
|
||||||
|
req := &LlmRequest{ctx: ctx}
|
||||||
|
r1 := &runnerRef{refCount: 1, sessionDuration: 1}
|
||||||
|
r2 := &runnerRef{sessionDuration: 2}
|
||||||
|
|
||||||
|
s := InitScheduler(ctx)
|
||||||
|
s.loaded["a"] = r1
|
||||||
|
s.loaded["b"] = r2
|
||||||
|
|
||||||
|
resp := s.findRunnerToUnload(req)
|
||||||
|
require.Equal(t, r2, resp)
|
||||||
|
r2.refCount = 1
|
||||||
|
resp = s.findRunnerToUnload(req)
|
||||||
|
require.Equal(t, r1, resp)
|
||||||
|
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestNeedsReload(t *testing.T) {
|
||||||
|
ctx, done := context.WithTimeout(context.Background(), 5*time.Millisecond)
|
||||||
|
defer done()
|
||||||
|
|
||||||
|
llm := &mockLlm{}
|
||||||
|
runner := &runnerRef{
|
||||||
|
adapters: []string{"adapter1"},
|
||||||
|
projectors: []string{"projector1"},
|
||||||
|
Options: &api.Options{},
|
||||||
|
llama: llm,
|
||||||
|
}
|
||||||
|
req := &LlmRequest{
|
||||||
|
model: &Model{
|
||||||
|
AdapterPaths: []string{"adapter2"},
|
||||||
|
ProjectorPaths: []string{"projector2"},
|
||||||
|
},
|
||||||
|
opts: api.Options{},
|
||||||
|
}
|
||||||
|
resp := runner.needsReload(ctx, req)
|
||||||
|
require.True(t, resp)
|
||||||
|
req.model.AdapterPaths = runner.adapters
|
||||||
|
resp = runner.needsReload(ctx, req)
|
||||||
|
require.True(t, resp)
|
||||||
|
req.model.ProjectorPaths = runner.projectors
|
||||||
|
runner.loading = true
|
||||||
|
req.opts.NumBatch = 1234
|
||||||
|
resp = runner.needsReload(ctx, req)
|
||||||
|
require.True(t, resp)
|
||||||
|
req.opts.NumBatch = runner.Options.NumBatch
|
||||||
|
llm.pingResp = fmt.Errorf("foo")
|
||||||
|
resp = runner.needsReload(ctx, req)
|
||||||
|
require.True(t, resp)
|
||||||
|
llm.pingResp = nil
|
||||||
|
resp = runner.needsReload(ctx, req)
|
||||||
|
require.False(t, resp)
|
||||||
|
req.opts.NumGPU = 99
|
||||||
|
resp = runner.needsReload(ctx, req)
|
||||||
|
require.False(t, resp)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestUnloadAllRunners(t *testing.T) {
|
||||||
|
ctx, done := context.WithTimeout(context.Background(), 5*time.Millisecond)
|
||||||
|
defer done()
|
||||||
|
|
||||||
|
llm1 := &mockLlm{}
|
||||||
|
llm2 := &mockLlm{}
|
||||||
|
s := InitScheduler(ctx)
|
||||||
|
s.unloadAllRunners()
|
||||||
|
|
||||||
|
r1 := &runnerRef{llama: llm1}
|
||||||
|
r2 := &runnerRef{llama: llm2}
|
||||||
|
|
||||||
|
s.loaded["a"] = r1
|
||||||
|
s.loaded["b"] = r2
|
||||||
|
s.unloadAllRunners()
|
||||||
|
|
||||||
|
require.True(t, llm1.closeCalled)
|
||||||
|
require.True(t, llm2.closeCalled)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestUnload(t *testing.T) {
|
||||||
|
llm1 := &mockLlm{}
|
||||||
|
r1 := &runnerRef{llama: llm1}
|
||||||
|
r2 := &runnerRef{adapters: []string{"A"}}
|
||||||
|
r1.unload()
|
||||||
|
require.True(t, llm1.closeCalled)
|
||||||
|
r2.unload()
|
||||||
|
require.Nil(t, r2.adapters)
|
||||||
|
}
|
||||||
|
|
||||||
|
type mockLlm struct {
|
||||||
|
pingResp error
|
||||||
|
waitResp error
|
||||||
|
completionResp error
|
||||||
|
embeddingResp []float64
|
||||||
|
embeddingRespErr error
|
||||||
|
tokenizeResp []int
|
||||||
|
tokenizeRespErr error
|
||||||
|
detokenizeResp string
|
||||||
|
detonekizeRespErr error
|
||||||
|
closeResp error
|
||||||
|
closeCalled bool
|
||||||
|
estimatedVRAM uint64
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *mockLlm) Ping(ctx context.Context) error { return s.pingResp }
|
||||||
|
func (s *mockLlm) WaitUntilRunning(ctx context.Context) error { return s.waitResp }
|
||||||
|
func (s *mockLlm) Completion(ctx context.Context, req llm.CompletionRequest, fn func(llm.CompletionResponse)) error {
|
||||||
|
return s.completionResp
|
||||||
|
}
|
||||||
|
func (s *mockLlm) Embedding(ctx context.Context, prompt string) ([]float64, error) {
|
||||||
|
return s.embeddingResp, s.embeddingRespErr
|
||||||
|
}
|
||||||
|
func (s *mockLlm) Tokenize(ctx context.Context, content string) ([]int, error) {
|
||||||
|
return s.tokenizeResp, s.tokenizeRespErr
|
||||||
|
}
|
||||||
|
func (s *mockLlm) Detokenize(ctx context.Context, tokens []int) (string, error) {
|
||||||
|
return s.detokenizeResp, s.detonekizeRespErr
|
||||||
|
}
|
||||||
|
func (s *mockLlm) Close() error {
|
||||||
|
s.closeCalled = true
|
||||||
|
return s.closeResp
|
||||||
|
}
|
||||||
|
func (s *mockLlm) EstimatedVRAM() uint64 { return s.estimatedVRAM }
|
Loading…
Reference in a new issue