Merge remote-tracking branch 'upstream/main' into pr3702
This commit is contained in:
commit
920a4b0794
132 changed files with 7701 additions and 4766 deletions
20
.github/workflows/release.yaml
vendored
20
.github/workflows/release.yaml
vendored
|
@ -103,6 +103,7 @@ jobs:
|
||||||
path: |
|
path: |
|
||||||
llm/build/**/bin/*
|
llm/build/**/bin/*
|
||||||
llm/build/**/*.a
|
llm/build/**/*.a
|
||||||
|
dist/windows-amd64/**
|
||||||
|
|
||||||
# ROCm generation step
|
# ROCm generation step
|
||||||
generate-windows-rocm:
|
generate-windows-rocm:
|
||||||
|
@ -173,7 +174,9 @@ jobs:
|
||||||
- uses: actions/upload-artifact@v4
|
- uses: actions/upload-artifact@v4
|
||||||
with:
|
with:
|
||||||
name: generate-windows-rocm
|
name: generate-windows-rocm
|
||||||
path: llm/build/**/bin/*
|
path: |
|
||||||
|
llm/build/**/bin/*
|
||||||
|
dist/windows-amd64/**
|
||||||
- uses: actions/upload-artifact@v4
|
- uses: actions/upload-artifact@v4
|
||||||
with:
|
with:
|
||||||
name: windows-rocm-deps
|
name: windows-rocm-deps
|
||||||
|
@ -253,7 +256,9 @@ jobs:
|
||||||
- uses: actions/upload-artifact@v4
|
- uses: actions/upload-artifact@v4
|
||||||
with:
|
with:
|
||||||
name: generate-windows-cuda
|
name: generate-windows-cuda
|
||||||
path: llm/build/**/bin/*
|
path: |
|
||||||
|
llm/build/**/bin/*
|
||||||
|
dist/windows-amd64/**
|
||||||
- uses: actions/upload-artifact@v4
|
- uses: actions/upload-artifact@v4
|
||||||
with:
|
with:
|
||||||
name: windows-cuda-deps
|
name: windows-cuda-deps
|
||||||
|
@ -306,23 +311,18 @@ jobs:
|
||||||
- uses: actions/download-artifact@v4
|
- uses: actions/download-artifact@v4
|
||||||
with:
|
with:
|
||||||
name: generate-windows-cpu
|
name: generate-windows-cpu
|
||||||
path: llm/build
|
|
||||||
- uses: actions/download-artifact@v4
|
- uses: actions/download-artifact@v4
|
||||||
with:
|
with:
|
||||||
name: generate-windows-cuda
|
name: generate-windows-cuda
|
||||||
path: llm/build
|
|
||||||
- uses: actions/download-artifact@v4
|
- uses: actions/download-artifact@v4
|
||||||
with:
|
with:
|
||||||
name: windows-cuda-deps
|
name: windows-cuda-deps
|
||||||
path: dist/deps
|
|
||||||
- uses: actions/download-artifact@v4
|
- uses: actions/download-artifact@v4
|
||||||
with:
|
with:
|
||||||
name: windows-rocm-deps
|
name: windows-rocm-deps
|
||||||
path: dist/deps
|
|
||||||
- uses: actions/download-artifact@v4
|
- uses: actions/download-artifact@v4
|
||||||
with:
|
with:
|
||||||
name: generate-windows-rocm
|
name: generate-windows-rocm
|
||||||
path: llm/build
|
|
||||||
- run: dir llm/build
|
- run: dir llm/build
|
||||||
- run: |
|
- run: |
|
||||||
$gopath=(get-command go).source | split-path -parent
|
$gopath=(get-command go).source | split-path -parent
|
||||||
|
@ -331,13 +331,13 @@ jobs:
|
||||||
$env:CMAKE_SYSTEM_VERSION="10.0.22621.0"
|
$env:CMAKE_SYSTEM_VERSION="10.0.22621.0"
|
||||||
$env:PATH="$gopath;$env:PATH"
|
$env:PATH="$gopath;$env:PATH"
|
||||||
$env:OLLAMA_SKIP_GENERATE="1"
|
$env:OLLAMA_SKIP_GENERATE="1"
|
||||||
$env:NVIDIA_DIR=$(resolve-path ".\dist\deps")
|
|
||||||
$env:HIP_PATH=$(resolve-path ".\dist\deps")
|
|
||||||
& .\scripts\build_windows.ps1
|
& .\scripts\build_windows.ps1
|
||||||
- uses: actions/upload-artifact@v4
|
- uses: actions/upload-artifact@v4
|
||||||
with:
|
with:
|
||||||
name: dist-windows
|
name: dist-windows
|
||||||
path: dist/*.exe
|
path: |
|
||||||
|
dist/OllamaSetup.exe
|
||||||
|
dist/ollama-windows-*.zip
|
||||||
|
|
||||||
# Linux x86 assets built using the container based build
|
# Linux x86 assets built using the container based build
|
||||||
build-linux-amd64:
|
build-linux-amd64:
|
||||||
|
|
34
.github/workflows/test.yaml
vendored
34
.github/workflows/test.yaml
vendored
|
@ -1,5 +1,15 @@
|
||||||
name: test
|
name: test
|
||||||
|
|
||||||
|
concurrency:
|
||||||
|
# For PRs, later CI runs preempt previous ones. e.g. a force push on a PR
|
||||||
|
# cancels running CI jobs and starts all new ones.
|
||||||
|
#
|
||||||
|
# For non-PR pushes, concurrency.group needs to be unique for every distinct
|
||||||
|
# CI run we want to have happen. Use run_id, which in practice means all
|
||||||
|
# non-PR CI runs will be allowed to run without preempting each other.
|
||||||
|
group: ${{ github.workflow }}-$${{ github.pull_request.number || github.run_id }}
|
||||||
|
cancel-in-progress: true
|
||||||
|
|
||||||
on:
|
on:
|
||||||
pull_request:
|
pull_request:
|
||||||
paths:
|
paths:
|
||||||
|
@ -21,7 +31,9 @@ jobs:
|
||||||
- id: changes
|
- id: changes
|
||||||
run: |
|
run: |
|
||||||
changed() {
|
changed() {
|
||||||
git diff-tree -r --no-commit-id --name-only ${{ github.event.pull_request.base.sha }} ${{ github.event.pull_request.head.sha }} \
|
git diff-tree -r --no-commit-id --name-only \
|
||||||
|
$(git merge-base ${{ github.event.pull_request.base.sha }} ${{ github.event.pull_request.head.sha }}) \
|
||||||
|
${{ github.event.pull_request.head.sha }} \
|
||||||
| xargs python3 -c "import sys; print(any([x.startswith('$1') for x in sys.argv[1:]]))"
|
| xargs python3 -c "import sys; print(any([x.startswith('$1') for x in sys.argv[1:]]))"
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -103,7 +115,9 @@ jobs:
|
||||||
- uses: actions/upload-artifact@v4
|
- uses: actions/upload-artifact@v4
|
||||||
with:
|
with:
|
||||||
name: cuda-${{ matrix.cuda-version }}-libraries
|
name: cuda-${{ matrix.cuda-version }}-libraries
|
||||||
path: llm/build/**/bin/*
|
path: |
|
||||||
|
llm/build/**/bin/*
|
||||||
|
dist/windows-amd64/**
|
||||||
generate-rocm:
|
generate-rocm:
|
||||||
needs: [changes]
|
needs: [changes]
|
||||||
if: ${{ needs.changes.outputs.GENERATE_ROCM == 'True' }}
|
if: ${{ needs.changes.outputs.GENERATE_ROCM == 'True' }}
|
||||||
|
@ -134,7 +148,9 @@ jobs:
|
||||||
- uses: actions/upload-artifact@v4
|
- uses: actions/upload-artifact@v4
|
||||||
with:
|
with:
|
||||||
name: rocm-${{ matrix.rocm-version }}-libraries
|
name: rocm-${{ matrix.rocm-version }}-libraries
|
||||||
path: llm/build/**/bin/*
|
path: |
|
||||||
|
llm/build/**/bin/*
|
||||||
|
dist/windows-amd64/**
|
||||||
|
|
||||||
# ROCm generation step
|
# ROCm generation step
|
||||||
generate-windows-rocm:
|
generate-windows-rocm:
|
||||||
|
@ -253,14 +269,9 @@ jobs:
|
||||||
mkdir -p llm/build/darwin/$ARCH/stub/bin
|
mkdir -p llm/build/darwin/$ARCH/stub/bin
|
||||||
touch llm/build/darwin/$ARCH/stub/bin/ollama_llama_server
|
touch llm/build/darwin/$ARCH/stub/bin/ollama_llama_server
|
||||||
if: ${{ startsWith(matrix.os, 'macos-') }}
|
if: ${{ startsWith(matrix.os, 'macos-') }}
|
||||||
- run: |
|
|
||||||
mkdir -p llm/build/windows/$ARCH/stub/bin
|
|
||||||
touch llm/build/windows/$ARCH/stub/bin/ollama_llama_server
|
|
||||||
if: ${{ startsWith(matrix.os, 'windows-') }}
|
|
||||||
shell: bash
|
|
||||||
- uses: golangci/golangci-lint-action@v4
|
- uses: golangci/golangci-lint-action@v4
|
||||||
with:
|
with:
|
||||||
args: --timeout 8m0s
|
args: --timeout 8m0s -v
|
||||||
test:
|
test:
|
||||||
strategy:
|
strategy:
|
||||||
matrix:
|
matrix:
|
||||||
|
@ -284,7 +295,6 @@ jobs:
|
||||||
with:
|
with:
|
||||||
go-version-file: go.mod
|
go-version-file: go.mod
|
||||||
cache: true
|
cache: true
|
||||||
- run: go get
|
|
||||||
- run: |
|
- run: |
|
||||||
case ${{ matrix.arch }} in
|
case ${{ matrix.arch }} in
|
||||||
amd64) echo ARCH=x86_64 ;;
|
amd64) echo ARCH=x86_64 ;;
|
||||||
|
@ -299,10 +309,6 @@ jobs:
|
||||||
mkdir -p llm/build/darwin/$ARCH/stub/bin
|
mkdir -p llm/build/darwin/$ARCH/stub/bin
|
||||||
touch llm/build/darwin/$ARCH/stub/bin/ollama_llama_server
|
touch llm/build/darwin/$ARCH/stub/bin/ollama_llama_server
|
||||||
if: ${{ startsWith(matrix.os, 'macos-') }}
|
if: ${{ startsWith(matrix.os, 'macos-') }}
|
||||||
- run: |
|
|
||||||
mkdir -p llm/build/windows/$ARCH/stub/bin
|
|
||||||
touch llm/build/windows/$ARCH/stub/bin/ollama_llama_server
|
|
||||||
if: ${{ startsWith(matrix.os, 'windows-') }}
|
|
||||||
shell: bash
|
shell: bash
|
||||||
- run: go generate ./...
|
- run: go generate ./...
|
||||||
- run: go build
|
- run: go build
|
||||||
|
|
1
.gitignore
vendored
1
.gitignore
vendored
|
@ -12,3 +12,4 @@ ggml-metal.metal
|
||||||
test_data
|
test_data
|
||||||
*.crt
|
*.crt
|
||||||
llm/build
|
llm/build
|
||||||
|
__debug_bin*
|
14
Dockerfile
14
Dockerfile
|
@ -18,7 +18,7 @@ ENV PATH /opt/rh/devtoolset-10/root/usr/bin:$PATH
|
||||||
COPY --from=llm-code / /go/src/github.com/ollama/ollama/
|
COPY --from=llm-code / /go/src/github.com/ollama/ollama/
|
||||||
WORKDIR /go/src/github.com/ollama/ollama/llm/generate
|
WORKDIR /go/src/github.com/ollama/ollama/llm/generate
|
||||||
ARG CGO_CFLAGS
|
ARG CGO_CFLAGS
|
||||||
RUN OLLAMA_SKIP_CPU_GENERATE=1 sh gen_linux.sh
|
RUN OLLAMA_SKIP_STATIC_GENERATE=1 OLLAMA_SKIP_CPU_GENERATE=1 sh gen_linux.sh
|
||||||
|
|
||||||
FROM --platform=linux/arm64 nvidia/cuda:$CUDA_VERSION-devel-rockylinux8 AS cuda-build-arm64
|
FROM --platform=linux/arm64 nvidia/cuda:$CUDA_VERSION-devel-rockylinux8 AS cuda-build-arm64
|
||||||
ARG CMAKE_VERSION
|
ARG CMAKE_VERSION
|
||||||
|
@ -28,7 +28,7 @@ ENV PATH /opt/rh/gcc-toolset-10/root/usr/bin:$PATH
|
||||||
COPY --from=llm-code / /go/src/github.com/ollama/ollama/
|
COPY --from=llm-code / /go/src/github.com/ollama/ollama/
|
||||||
WORKDIR /go/src/github.com/ollama/ollama/llm/generate
|
WORKDIR /go/src/github.com/ollama/ollama/llm/generate
|
||||||
ARG CGO_CFLAGS
|
ARG CGO_CFLAGS
|
||||||
RUN OLLAMA_SKIP_CPU_GENERATE=1 sh gen_linux.sh
|
RUN OLLAMA_SKIP_STATIC_GENERATE=1 OLLAMA_SKIP_CPU_GENERATE=1 sh gen_linux.sh
|
||||||
|
|
||||||
FROM --platform=linux/amd64 rocm/dev-centos-7:${ROCM_VERSION}-complete AS rocm-build-amd64
|
FROM --platform=linux/amd64 rocm/dev-centos-7:${ROCM_VERSION}-complete AS rocm-build-amd64
|
||||||
ARG CMAKE_VERSION
|
ARG CMAKE_VERSION
|
||||||
|
@ -40,7 +40,7 @@ COPY --from=llm-code / /go/src/github.com/ollama/ollama/
|
||||||
WORKDIR /go/src/github.com/ollama/ollama/llm/generate
|
WORKDIR /go/src/github.com/ollama/ollama/llm/generate
|
||||||
ARG CGO_CFLAGS
|
ARG CGO_CFLAGS
|
||||||
ARG AMDGPU_TARGETS
|
ARG AMDGPU_TARGETS
|
||||||
RUN OLLAMA_SKIP_CPU_GENERATE=1 sh gen_linux.sh
|
RUN OLLAMA_SKIP_STATIC_GENERATE=1 OLLAMA_SKIP_CPU_GENERATE=1 sh gen_linux.sh
|
||||||
RUN mkdir /tmp/scratch && \
|
RUN mkdir /tmp/scratch && \
|
||||||
for dep in $(zcat /go/src/github.com/ollama/ollama/llm/build/linux/x86_64/rocm*/bin/deps.txt.gz) ; do \
|
for dep in $(zcat /go/src/github.com/ollama/ollama/llm/build/linux/x86_64/rocm*/bin/deps.txt.gz) ; do \
|
||||||
cp ${dep} /tmp/scratch/ || exit 1 ; \
|
cp ${dep} /tmp/scratch/ || exit 1 ; \
|
||||||
|
@ -64,11 +64,11 @@ WORKDIR /go/src/github.com/ollama/ollama/llm/generate
|
||||||
FROM --platform=linux/amd64 cpu-builder-amd64 AS static-build-amd64
|
FROM --platform=linux/amd64 cpu-builder-amd64 AS static-build-amd64
|
||||||
RUN OLLAMA_CPU_TARGET="static" sh gen_linux.sh
|
RUN OLLAMA_CPU_TARGET="static" sh gen_linux.sh
|
||||||
FROM --platform=linux/amd64 cpu-builder-amd64 AS cpu-build-amd64
|
FROM --platform=linux/amd64 cpu-builder-amd64 AS cpu-build-amd64
|
||||||
RUN OLLAMA_CPU_TARGET="cpu" sh gen_linux.sh
|
RUN OLLAMA_SKIP_STATIC_GENERATE=1 OLLAMA_CPU_TARGET="cpu" sh gen_linux.sh
|
||||||
FROM --platform=linux/amd64 cpu-builder-amd64 AS cpu_avx-build-amd64
|
FROM --platform=linux/amd64 cpu-builder-amd64 AS cpu_avx-build-amd64
|
||||||
RUN OLLAMA_CPU_TARGET="cpu_avx" sh gen_linux.sh
|
RUN OLLAMA_SKIP_STATIC_GENERATE=1 OLLAMA_CPU_TARGET="cpu_avx" sh gen_linux.sh
|
||||||
FROM --platform=linux/amd64 cpu-builder-amd64 AS cpu_avx2-build-amd64
|
FROM --platform=linux/amd64 cpu-builder-amd64 AS cpu_avx2-build-amd64
|
||||||
RUN OLLAMA_CPU_TARGET="cpu_avx2" sh gen_linux.sh
|
RUN OLLAMA_SKIP_STATIC_GENERATE=1 OLLAMA_CPU_TARGET="cpu_avx2" sh gen_linux.sh
|
||||||
|
|
||||||
FROM --platform=linux/arm64 centos:7 AS cpu-builder-arm64
|
FROM --platform=linux/arm64 centos:7 AS cpu-builder-arm64
|
||||||
ARG CMAKE_VERSION
|
ARG CMAKE_VERSION
|
||||||
|
@ -84,7 +84,7 @@ WORKDIR /go/src/github.com/ollama/ollama/llm/generate
|
||||||
FROM --platform=linux/arm64 cpu-builder-arm64 AS static-build-arm64
|
FROM --platform=linux/arm64 cpu-builder-arm64 AS static-build-arm64
|
||||||
RUN OLLAMA_CPU_TARGET="static" sh gen_linux.sh
|
RUN OLLAMA_CPU_TARGET="static" sh gen_linux.sh
|
||||||
FROM --platform=linux/arm64 cpu-builder-arm64 AS cpu-build-arm64
|
FROM --platform=linux/arm64 cpu-builder-arm64 AS cpu-build-arm64
|
||||||
RUN OLLAMA_CPU_TARGET="cpu" sh gen_linux.sh
|
RUN OLLAMA_SKIP_STATIC_GENERATE=1 OLLAMA_CPU_TARGET="cpu" sh gen_linux.sh
|
||||||
|
|
||||||
|
|
||||||
# Intermediate stage used for ./scripts/build_linux.sh
|
# Intermediate stage used for ./scripts/build_linux.sh
|
||||||
|
|
67
README.md
67
README.md
|
@ -1,5 +1,5 @@
|
||||||
<div align="center">
|
<div align="center">
|
||||||
<img alt="ollama" height="200px" src="https://github.com/ollama/ollama/assets/3325447/0d0b44e2-8f4a-4e99-9b52-a5c1c741c8f7">
|
<img alt="ollama" height="200px" src="https://github.com/ollama/ollama/assets/3325447/0d0b44e2-8f4a-4e99-9b52-a5c1c741c8f7">
|
||||||
</div>
|
</div>
|
||||||
|
|
||||||
# Ollama
|
# Ollama
|
||||||
|
@ -35,10 +35,10 @@ The official [Ollama Docker image](https://hub.docker.com/r/ollama/ollama) `olla
|
||||||
|
|
||||||
## Quickstart
|
## Quickstart
|
||||||
|
|
||||||
To run and chat with [Llama 2](https://ollama.com/library/llama2):
|
To run and chat with [Llama 3](https://ollama.com/library/llama3):
|
||||||
|
|
||||||
```
|
```
|
||||||
ollama run llama2
|
ollama run llama3
|
||||||
```
|
```
|
||||||
|
|
||||||
## Model library
|
## Model library
|
||||||
|
@ -49,17 +49,14 @@ Here are some example models that can be downloaded:
|
||||||
|
|
||||||
| Model | Parameters | Size | Download |
|
| Model | Parameters | Size | Download |
|
||||||
| ------------------ | ---------- | ----- | ------------------------------ |
|
| ------------------ | ---------- | ----- | ------------------------------ |
|
||||||
| Llama 2 | 7B | 3.8GB | `ollama run llama2` |
|
| Llama 3 | 8B | 4.7GB | `ollama run llama3` |
|
||||||
|
| Llama 3 | 70B | 40GB | `ollama run llama3:70b` |
|
||||||
|
| Phi-3 | 3.8B | 2.3GB | `ollama run phi3` |
|
||||||
| Mistral | 7B | 4.1GB | `ollama run mistral` |
|
| Mistral | 7B | 4.1GB | `ollama run mistral` |
|
||||||
| Dolphin Phi | 2.7B | 1.6GB | `ollama run dolphin-phi` |
|
|
||||||
| Phi-2 | 2.7B | 1.7GB | `ollama run phi` |
|
|
||||||
| Neural Chat | 7B | 4.1GB | `ollama run neural-chat` |
|
| Neural Chat | 7B | 4.1GB | `ollama run neural-chat` |
|
||||||
| Starling | 7B | 4.1GB | `ollama run starling-lm` |
|
| Starling | 7B | 4.1GB | `ollama run starling-lm` |
|
||||||
| Code Llama | 7B | 3.8GB | `ollama run codellama` |
|
| Code Llama | 7B | 3.8GB | `ollama run codellama` |
|
||||||
| Llama 2 Uncensored | 7B | 3.8GB | `ollama run llama2-uncensored` |
|
| Llama 2 Uncensored | 7B | 3.8GB | `ollama run llama2-uncensored` |
|
||||||
| Llama 2 13B | 13B | 7.3GB | `ollama run llama2:13b` |
|
|
||||||
| Llama 2 70B | 70B | 39GB | `ollama run llama2:70b` |
|
|
||||||
| Orca Mini | 3B | 1.9GB | `ollama run orca-mini` |
|
|
||||||
| LLaVA | 7B | 4.5GB | `ollama run llava` |
|
| LLaVA | 7B | 4.5GB | `ollama run llava` |
|
||||||
| Gemma | 2B | 1.4GB | `ollama run gemma:2b` |
|
| Gemma | 2B | 1.4GB | `ollama run gemma:2b` |
|
||||||
| Gemma | 7B | 4.8GB | `ollama run gemma:7b` |
|
| Gemma | 7B | 4.8GB | `ollama run gemma:7b` |
|
||||||
|
@ -97,16 +94,16 @@ See the [guide](docs/import.md) on importing models for more information.
|
||||||
|
|
||||||
### Customize a prompt
|
### Customize a prompt
|
||||||
|
|
||||||
Models from the Ollama library can be customized with a prompt. For example, to customize the `llama2` model:
|
Models from the Ollama library can be customized with a prompt. For example, to customize the `llama3` model:
|
||||||
|
|
||||||
```
|
```
|
||||||
ollama pull llama2
|
ollama pull llama3
|
||||||
```
|
```
|
||||||
|
|
||||||
Create a `Modelfile`:
|
Create a `Modelfile`:
|
||||||
|
|
||||||
```
|
```
|
||||||
FROM llama2
|
FROM llama3
|
||||||
|
|
||||||
# set the temperature to 1 [higher is more creative, lower is more coherent]
|
# set the temperature to 1 [higher is more creative, lower is more coherent]
|
||||||
PARAMETER temperature 1
|
PARAMETER temperature 1
|
||||||
|
@ -141,7 +138,7 @@ ollama create mymodel -f ./Modelfile
|
||||||
### Pull a model
|
### Pull a model
|
||||||
|
|
||||||
```
|
```
|
||||||
ollama pull llama2
|
ollama pull llama3
|
||||||
```
|
```
|
||||||
|
|
||||||
> This command can also be used to update a local model. Only the diff will be pulled.
|
> This command can also be used to update a local model. Only the diff will be pulled.
|
||||||
|
@ -149,13 +146,13 @@ ollama pull llama2
|
||||||
### Remove a model
|
### Remove a model
|
||||||
|
|
||||||
```
|
```
|
||||||
ollama rm llama2
|
ollama rm llama3
|
||||||
```
|
```
|
||||||
|
|
||||||
### Copy a model
|
### Copy a model
|
||||||
|
|
||||||
```
|
```
|
||||||
ollama cp llama2 my-llama2
|
ollama cp llama3 my-model
|
||||||
```
|
```
|
||||||
|
|
||||||
### Multiline input
|
### Multiline input
|
||||||
|
@ -176,10 +173,10 @@ I'm a basic program that prints the famous "Hello, world!" message to the consol
|
||||||
The image features a yellow smiley face, which is likely the central focus of the picture.
|
The image features a yellow smiley face, which is likely the central focus of the picture.
|
||||||
```
|
```
|
||||||
|
|
||||||
### Pass in prompt as arguments
|
### Pass the prompt as an argument
|
||||||
|
|
||||||
```
|
```
|
||||||
$ ollama run llama2 "Summarize this file: $(cat README.md)"
|
$ ollama run llama3 "Summarize this file: $(cat README.md)"
|
||||||
Ollama is a lightweight, extensible framework for building and running language models on the local machine. It provides a simple API for creating, running, and managing models, as well as a library of pre-built models that can be easily used in a variety of applications.
|
Ollama is a lightweight, extensible framework for building and running language models on the local machine. It provides a simple API for creating, running, and managing models, as well as a library of pre-built models that can be easily used in a variety of applications.
|
||||||
```
|
```
|
||||||
|
|
||||||
|
@ -226,7 +223,7 @@ Next, start the server:
|
||||||
Finally, in a separate shell, run a model:
|
Finally, in a separate shell, run a model:
|
||||||
|
|
||||||
```
|
```
|
||||||
./ollama run llama2
|
./ollama run llama3
|
||||||
```
|
```
|
||||||
|
|
||||||
## REST API
|
## REST API
|
||||||
|
@ -237,7 +234,7 @@ Ollama has a REST API for running and managing models.
|
||||||
|
|
||||||
```
|
```
|
||||||
curl http://localhost:11434/api/generate -d '{
|
curl http://localhost:11434/api/generate -d '{
|
||||||
"model": "llama2",
|
"model": "llama3",
|
||||||
"prompt":"Why is the sky blue?"
|
"prompt":"Why is the sky blue?"
|
||||||
}'
|
}'
|
||||||
```
|
```
|
||||||
|
@ -246,7 +243,7 @@ curl http://localhost:11434/api/generate -d '{
|
||||||
|
|
||||||
```
|
```
|
||||||
curl http://localhost:11434/api/chat -d '{
|
curl http://localhost:11434/api/chat -d '{
|
||||||
"model": "mistral",
|
"model": "llama3",
|
||||||
"messages": [
|
"messages": [
|
||||||
{ "role": "user", "content": "why is the sky blue?" }
|
{ "role": "user", "content": "why is the sky blue?" }
|
||||||
]
|
]
|
||||||
|
@ -259,16 +256,18 @@ See the [API documentation](./docs/api.md) for all endpoints.
|
||||||
|
|
||||||
### Web & Desktop
|
### Web & Desktop
|
||||||
|
|
||||||
|
- [Open WebUI](https://github.com/open-webui/open-webui)
|
||||||
|
- [Enchanted (macOS native)](https://github.com/AugustDev/enchanted)
|
||||||
|
- [Hollama](https://github.com/fmaclen/hollama)
|
||||||
- [Lollms-Webui](https://github.com/ParisNeo/lollms-webui)
|
- [Lollms-Webui](https://github.com/ParisNeo/lollms-webui)
|
||||||
- [LibreChat](https://github.com/danny-avila/LibreChat)
|
- [LibreChat](https://github.com/danny-avila/LibreChat)
|
||||||
- [Bionic GPT](https://github.com/bionic-gpt/bionic-gpt)
|
- [Bionic GPT](https://github.com/bionic-gpt/bionic-gpt)
|
||||||
- [Enchanted (macOS native)](https://github.com/AugustDev/enchanted)
|
|
||||||
- [HTML UI](https://github.com/rtcfirefly/ollama-ui)
|
- [HTML UI](https://github.com/rtcfirefly/ollama-ui)
|
||||||
- [Saddle](https://github.com/jikkuatwork/saddle)
|
- [Saddle](https://github.com/jikkuatwork/saddle)
|
||||||
- [Chatbot UI](https://github.com/ivanfioravanti/chatbot-ollama)
|
- [Chatbot UI](https://github.com/ivanfioravanti/chatbot-ollama)
|
||||||
|
- [Chatbot UI v2](https://github.com/mckaywrigley/chatbot-ui)
|
||||||
- [Typescript UI](https://github.com/ollama-interface/Ollama-Gui?tab=readme-ov-file)
|
- [Typescript UI](https://github.com/ollama-interface/Ollama-Gui?tab=readme-ov-file)
|
||||||
- [Minimalistic React UI for Ollama Models](https://github.com/richawo/minimal-llm-ui)
|
- [Minimalistic React UI for Ollama Models](https://github.com/richawo/minimal-llm-ui)
|
||||||
- [Open WebUI](https://github.com/open-webui/open-webui)
|
|
||||||
- [Ollamac](https://github.com/kevinhermawan/Ollamac)
|
- [Ollamac](https://github.com/kevinhermawan/Ollamac)
|
||||||
- [big-AGI](https://github.com/enricoros/big-AGI/blob/main/docs/config-local-ollama.md)
|
- [big-AGI](https://github.com/enricoros/big-AGI/blob/main/docs/config-local-ollama.md)
|
||||||
- [Cheshire Cat assistant framework](https://github.com/cheshire-cat-ai/core)
|
- [Cheshire Cat assistant framework](https://github.com/cheshire-cat-ai/core)
|
||||||
|
@ -286,13 +285,20 @@ See the [API documentation](./docs/api.md) for all endpoints.
|
||||||
- [OllamaGUI](https://github.com/enoch1118/ollamaGUI)
|
- [OllamaGUI](https://github.com/enoch1118/ollamaGUI)
|
||||||
- [OpenAOE](https://github.com/InternLM/OpenAOE)
|
- [OpenAOE](https://github.com/InternLM/OpenAOE)
|
||||||
- [Odin Runes](https://github.com/leonid20000/OdinRunes)
|
- [Odin Runes](https://github.com/leonid20000/OdinRunes)
|
||||||
- [LLM-X: Progressive Web App](https://github.com/mrdjohnson/llm-x)
|
- [LLM-X](https://github.com/mrdjohnson/llm-x) (Progressive Web App)
|
||||||
- [AnythingLLM (Docker + MacOs/Windows/Linux native app)](https://github.com/Mintplex-Labs/anything-llm)
|
- [AnythingLLM (Docker + MacOs/Windows/Linux native app)](https://github.com/Mintplex-Labs/anything-llm)
|
||||||
- [Ollama Basic Chat: Uses HyperDiv Reactive UI](https://github.com/rapidarchitect/ollama_basic_chat)
|
- [Ollama Basic Chat: Uses HyperDiv Reactive UI](https://github.com/rapidarchitect/ollama_basic_chat)
|
||||||
- [Ollama-chats RPG](https://github.com/drazdra/ollama-chats)
|
- [Ollama-chats RPG](https://github.com/drazdra/ollama-chats)
|
||||||
- [ChatOllama: Open Source Chatbot based on Ollama with Knowledge Bases](https://github.com/sugarforever/chat-ollama)
|
- [QA-Pilot](https://github.com/reid41/QA-Pilot) (Chat with Code Repository)
|
||||||
- [CRAG Ollama Chat: Simple Web Search with Corrective RAG](https://github.com/Nagi-ovo/CRAG-Ollama-Chat)
|
- [ChatOllama](https://github.com/sugarforever/chat-ollama) (Open Source Chatbot based on Ollama with Knowledge Bases)
|
||||||
- [RAGFlow: Open-source Retrieval-Augmented Generation engine based on deep document understanding](https://github.com/infiniflow/ragflow)
|
- [CRAG Ollama Chat](https://github.com/Nagi-ovo/CRAG-Ollama-Chat) (Simple Web Search with Corrective RAG)
|
||||||
|
- [RAGFlow](https://github.com/infiniflow/ragflow) (Open-source Retrieval-Augmented Generation engine based on deep document understanding)
|
||||||
|
- [StreamDeploy](https://github.com/StreamDeploy-DevRel/streamdeploy-llm-app-scaffold) (LLM Application Scaffold)
|
||||||
|
- [chat](https://github.com/swuecho/chat) (chat web app for teams)
|
||||||
|
- [Lobe Chat](https://github.com/lobehub/lobe-chat) with [Integrating Doc](https://lobehub.com/docs/self-hosting/examples/ollama)
|
||||||
|
- [Ollama RAG Chatbot](https://github.com/datvodinh/rag-chatbot.git) (Local Chat with multiple PDFs using Ollama and RAG)
|
||||||
|
- [BrainSoup](https://www.nurgo-software.com/products/brainsoup) (Flexible native client with RAG & multi-agent automation)
|
||||||
|
- [macai](https://github.com/Renset/macai) (macOS client for Ollama, ChatGPT, and other compatible API back-ends)
|
||||||
|
|
||||||
### Terminal
|
### Terminal
|
||||||
|
|
||||||
|
@ -308,11 +314,13 @@ See the [API documentation](./docs/api.md) for all endpoints.
|
||||||
- [Oatmeal](https://github.com/dustinblackman/oatmeal)
|
- [Oatmeal](https://github.com/dustinblackman/oatmeal)
|
||||||
- [cmdh](https://github.com/pgibler/cmdh)
|
- [cmdh](https://github.com/pgibler/cmdh)
|
||||||
- [ooo](https://github.com/npahlfer/ooo)
|
- [ooo](https://github.com/npahlfer/ooo)
|
||||||
|
- [shell-pilot](https://github.com/reid41/shell-pilot)
|
||||||
- [tenere](https://github.com/pythops/tenere)
|
- [tenere](https://github.com/pythops/tenere)
|
||||||
- [llm-ollama](https://github.com/taketwo/llm-ollama) for [Datasette's LLM CLI](https://llm.datasette.io/en/stable/).
|
- [llm-ollama](https://github.com/taketwo/llm-ollama) for [Datasette's LLM CLI](https://llm.datasette.io/en/stable/).
|
||||||
- [typechat-cli](https://github.com/anaisbetts/typechat-cli)
|
- [typechat-cli](https://github.com/anaisbetts/typechat-cli)
|
||||||
- [ShellOracle](https://github.com/djcopley/ShellOracle)
|
- [ShellOracle](https://github.com/djcopley/ShellOracle)
|
||||||
- [tlm](https://github.com/yusufcanb/tlm)
|
- [tlm](https://github.com/yusufcanb/tlm)
|
||||||
|
- [podman-ollama](https://github.com/ericcurtin/podman-ollama)
|
||||||
|
|
||||||
### Database
|
### Database
|
||||||
|
|
||||||
|
@ -344,9 +352,11 @@ See the [API documentation](./docs/api.md) for all endpoints.
|
||||||
- [Haystack](https://github.com/deepset-ai/haystack-integrations/blob/main/integrations/ollama.md)
|
- [Haystack](https://github.com/deepset-ai/haystack-integrations/blob/main/integrations/ollama.md)
|
||||||
- [Elixir LangChain](https://github.com/brainlid/langchain)
|
- [Elixir LangChain](https://github.com/brainlid/langchain)
|
||||||
- [Ollama for R - rollama](https://github.com/JBGruber/rollama)
|
- [Ollama for R - rollama](https://github.com/JBGruber/rollama)
|
||||||
|
- [Ollama for R - ollama-r](https://github.com/hauselin/ollama-r)
|
||||||
- [Ollama-ex for Elixir](https://github.com/lebrunel/ollama-ex)
|
- [Ollama-ex for Elixir](https://github.com/lebrunel/ollama-ex)
|
||||||
- [Ollama Connector for SAP ABAP](https://github.com/b-tocs/abap_btocs_ollama)
|
- [Ollama Connector for SAP ABAP](https://github.com/b-tocs/abap_btocs_ollama)
|
||||||
- [Testcontainers](https://testcontainers.com/modules/ollama/)
|
- [Testcontainers](https://testcontainers.com/modules/ollama/)
|
||||||
|
- [Portkey](https://portkey.ai/docs/welcome/integration-guides/ollama)
|
||||||
|
|
||||||
### Mobile
|
### Mobile
|
||||||
|
|
||||||
|
@ -366,17 +376,20 @@ See the [API documentation](./docs/api.md) for all endpoints.
|
||||||
- [Ollama Telegram Bot](https://github.com/ruecat/ollama-telegram)
|
- [Ollama Telegram Bot](https://github.com/ruecat/ollama-telegram)
|
||||||
- [Hass Ollama Conversation](https://github.com/ej52/hass-ollama-conversation)
|
- [Hass Ollama Conversation](https://github.com/ej52/hass-ollama-conversation)
|
||||||
- [Rivet plugin](https://github.com/abrenneke/rivet-plugin-ollama)
|
- [Rivet plugin](https://github.com/abrenneke/rivet-plugin-ollama)
|
||||||
- [Llama Coder](https://github.com/ex3ndr/llama-coder) (Copilot alternative using Ollama)
|
|
||||||
- [Obsidian BMO Chatbot plugin](https://github.com/longy2k/obsidian-bmo-chatbot)
|
- [Obsidian BMO Chatbot plugin](https://github.com/longy2k/obsidian-bmo-chatbot)
|
||||||
- [Cliobot](https://github.com/herval/cliobot) (Telegram bot with Ollama support)
|
- [Cliobot](https://github.com/herval/cliobot) (Telegram bot with Ollama support)
|
||||||
- [Copilot for Obsidian plugin](https://github.com/logancyang/obsidian-copilot)
|
- [Copilot for Obsidian plugin](https://github.com/logancyang/obsidian-copilot)
|
||||||
- [Obsidian Local GPT plugin](https://github.com/pfrankov/obsidian-local-gpt)
|
- [Obsidian Local GPT plugin](https://github.com/pfrankov/obsidian-local-gpt)
|
||||||
- [Open Interpreter](https://docs.openinterpreter.com/language-model-setup/local-models/ollama)
|
- [Open Interpreter](https://docs.openinterpreter.com/language-model-setup/local-models/ollama)
|
||||||
|
- [Llama Coder](https://github.com/ex3ndr/llama-coder) (Copilot alternative using Ollama)
|
||||||
|
- [Ollama Copilot](https://github.com/bernardo-bruning/ollama-copilot) (Proxy that allows you to use ollama as a copilot like Github copilot)
|
||||||
- [twinny](https://github.com/rjmacarthy/twinny) (Copilot and Copilot chat alternative using Ollama)
|
- [twinny](https://github.com/rjmacarthy/twinny) (Copilot and Copilot chat alternative using Ollama)
|
||||||
- [Wingman-AI](https://github.com/RussellCanfield/wingman-ai) (Copilot code and chat alternative using Ollama and HuggingFace)
|
- [Wingman-AI](https://github.com/RussellCanfield/wingman-ai) (Copilot code and chat alternative using Ollama and HuggingFace)
|
||||||
- [Page Assist](https://github.com/n4ze3m/page-assist) (Chrome Extension)
|
- [Page Assist](https://github.com/n4ze3m/page-assist) (Chrome Extension)
|
||||||
- [AI Telegram Bot](https://github.com/tusharhero/aitelegrambot) (Telegram bot using Ollama in backend)
|
- [AI Telegram Bot](https://github.com/tusharhero/aitelegrambot) (Telegram bot using Ollama in backend)
|
||||||
- [AI ST Completion](https://github.com/yaroslavyaroslav/OpenAI-sublime-text) (Sublime Text 4 AI assistant plugin with Ollama support)
|
- [AI ST Completion](https://github.com/yaroslavyaroslav/OpenAI-sublime-text) (Sublime Text 4 AI assistant plugin with Ollama support)
|
||||||
|
- [Discord-Ollama Chat Bot](https://github.com/kevinthedang/discord-ollama) (Generalized TypeScript Discord Bot w/ Tuning Documentation)
|
||||||
|
|
||||||
### Supported backends
|
### Supported backends
|
||||||
- [llama.cpp](https://github.com/ggerganov/llama.cpp) project founded by Georgi Gerganov.
|
- [llama.cpp](https://github.com/ggerganov/llama.cpp) project founded by Georgi Gerganov.
|
||||||
|
|
||||||
|
|
|
@ -1,9 +1,16 @@
|
||||||
// Package api implements the client-side API for code wishing to interact
|
// Package api implements the client-side API for code wishing to interact
|
||||||
// with the ollama service. The methods of the [Client] type correspond to
|
// with the ollama service. The methods of the [Client] type correspond to
|
||||||
// the ollama REST API as described in https://github.com/ollama/ollama/blob/main/docs/api.md
|
// the ollama REST API as described in [the API documentation].
|
||||||
//
|
|
||||||
// The ollama command-line client itself uses this package to interact with
|
// The ollama command-line client itself uses this package to interact with
|
||||||
// the backend service.
|
// the backend service.
|
||||||
|
//
|
||||||
|
// # Examples
|
||||||
|
//
|
||||||
|
// Several examples of using this package are available [in the GitHub
|
||||||
|
// repository].
|
||||||
|
//
|
||||||
|
// [the API documentation]: https://github.com/ollama/ollama/blob/main/docs/api.md
|
||||||
|
// [in the GitHub repository]: https://github.com/ollama/ollama/tree/main/examples
|
||||||
package api
|
package api
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
@ -18,6 +25,7 @@ import (
|
||||||
"net/url"
|
"net/url"
|
||||||
"os"
|
"os"
|
||||||
"runtime"
|
"runtime"
|
||||||
|
"strconv"
|
||||||
"strings"
|
"strings"
|
||||||
|
|
||||||
"github.com/ollama/ollama/format"
|
"github.com/ollama/ollama/format"
|
||||||
|
@ -57,12 +65,36 @@ func checkError(resp *http.Response, body []byte) error {
|
||||||
// If the variable is not specified, a default ollama host and port will be
|
// If the variable is not specified, a default ollama host and port will be
|
||||||
// used.
|
// used.
|
||||||
func ClientFromEnvironment() (*Client, error) {
|
func ClientFromEnvironment() (*Client, error) {
|
||||||
|
ollamaHost, err := GetOllamaHost()
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
return &Client{
|
||||||
|
base: &url.URL{
|
||||||
|
Scheme: ollamaHost.Scheme,
|
||||||
|
Host: net.JoinHostPort(ollamaHost.Host, ollamaHost.Port),
|
||||||
|
},
|
||||||
|
http: http.DefaultClient,
|
||||||
|
}, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
type OllamaHost struct {
|
||||||
|
Scheme string
|
||||||
|
Host string
|
||||||
|
Port string
|
||||||
|
}
|
||||||
|
|
||||||
|
func GetOllamaHost() (OllamaHost, error) {
|
||||||
defaultPort := "11434"
|
defaultPort := "11434"
|
||||||
|
|
||||||
scheme, hostport, ok := strings.Cut(os.Getenv("OLLAMA_HOST"), "://")
|
hostVar := os.Getenv("OLLAMA_HOST")
|
||||||
|
hostVar = strings.TrimSpace(strings.Trim(strings.TrimSpace(hostVar), "\"'"))
|
||||||
|
|
||||||
|
scheme, hostport, ok := strings.Cut(hostVar, "://")
|
||||||
switch {
|
switch {
|
||||||
case !ok:
|
case !ok:
|
||||||
scheme, hostport = "http", os.Getenv("OLLAMA_HOST")
|
scheme, hostport = "http", hostVar
|
||||||
case scheme == "http":
|
case scheme == "http":
|
||||||
defaultPort = "80"
|
defaultPort = "80"
|
||||||
case scheme == "https":
|
case scheme == "https":
|
||||||
|
@ -82,15 +114,24 @@ func ClientFromEnvironment() (*Client, error) {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
return &Client{
|
if portNum, err := strconv.ParseInt(port, 10, 32); err != nil || portNum > 65535 || portNum < 0 {
|
||||||
base: &url.URL{
|
return OllamaHost{}, ErrInvalidHostPort
|
||||||
|
}
|
||||||
|
|
||||||
|
return OllamaHost{
|
||||||
Scheme: scheme,
|
Scheme: scheme,
|
||||||
Host: net.JoinHostPort(host, port),
|
Host: host,
|
||||||
},
|
Port: port,
|
||||||
http: http.DefaultClient,
|
|
||||||
}, nil
|
}, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func NewClient(base *url.URL, http *http.Client) *Client {
|
||||||
|
return &Client{
|
||||||
|
base: base,
|
||||||
|
http: http,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
func (c *Client) do(ctx context.Context, method, path string, reqData, respData any) error {
|
func (c *Client) do(ctx context.Context, method, path string, reqData, respData any) error {
|
||||||
var reqBody io.Reader
|
var reqBody io.Reader
|
||||||
var data []byte
|
var data []byte
|
||||||
|
@ -265,8 +306,14 @@ func (c *Client) Pull(ctx context.Context, req *PullRequest, fn PullProgressFunc
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// PushProgressFunc is a function that [Client.Push] invokes when progress is
|
||||||
|
// made.
|
||||||
|
// It's similar to other progress function types like [PullProgressFunc].
|
||||||
type PushProgressFunc func(ProgressResponse) error
|
type PushProgressFunc func(ProgressResponse) error
|
||||||
|
|
||||||
|
// Push uploads a model to the model library; requires registering for ollama.ai
|
||||||
|
// and adding a public key first. fn is called each time progress is made on
|
||||||
|
// the request and can be used to display a progress bar, etc.
|
||||||
func (c *Client) Push(ctx context.Context, req *PushRequest, fn PushProgressFunc) error {
|
func (c *Client) Push(ctx context.Context, req *PushRequest, fn PushProgressFunc) error {
|
||||||
return c.stream(ctx, http.MethodPost, "/api/push", req, func(bts []byte) error {
|
return c.stream(ctx, http.MethodPost, "/api/push", req, func(bts []byte) error {
|
||||||
var resp ProgressResponse
|
var resp ProgressResponse
|
||||||
|
@ -278,8 +325,15 @@ func (c *Client) Push(ctx context.Context, req *PushRequest, fn PushProgressFunc
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// CreateProgressFunc is a function that [Client.Create] invokes when progress
|
||||||
|
// is made.
|
||||||
|
// It's similar to other progress function types like [PullProgressFunc].
|
||||||
type CreateProgressFunc func(ProgressResponse) error
|
type CreateProgressFunc func(ProgressResponse) error
|
||||||
|
|
||||||
|
// Create creates a model from a [Modelfile]. fn is a progress function that
|
||||||
|
// behaves similarly to other methods (see [Client.Pull]).
|
||||||
|
//
|
||||||
|
// [Modelfile]: https://github.com/ollama/ollama/blob/main/docs/modelfile.md
|
||||||
func (c *Client) Create(ctx context.Context, req *CreateRequest, fn CreateProgressFunc) error {
|
func (c *Client) Create(ctx context.Context, req *CreateRequest, fn CreateProgressFunc) error {
|
||||||
return c.stream(ctx, http.MethodPost, "/api/create", req, func(bts []byte) error {
|
return c.stream(ctx, http.MethodPost, "/api/create", req, func(bts []byte) error {
|
||||||
var resp ProgressResponse
|
var resp ProgressResponse
|
||||||
|
@ -291,6 +345,7 @@ func (c *Client) Create(ctx context.Context, req *CreateRequest, fn CreateProgre
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// List lists models that are available locally.
|
||||||
func (c *Client) List(ctx context.Context) (*ListResponse, error) {
|
func (c *Client) List(ctx context.Context) (*ListResponse, error) {
|
||||||
var lr ListResponse
|
var lr ListResponse
|
||||||
if err := c.do(ctx, http.MethodGet, "/api/tags", nil, &lr); err != nil {
|
if err := c.do(ctx, http.MethodGet, "/api/tags", nil, &lr); err != nil {
|
||||||
|
@ -299,6 +354,8 @@ func (c *Client) List(ctx context.Context) (*ListResponse, error) {
|
||||||
return &lr, nil
|
return &lr, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Copy copies a model - creating a model with another name from an existing
|
||||||
|
// model.
|
||||||
func (c *Client) Copy(ctx context.Context, req *CopyRequest) error {
|
func (c *Client) Copy(ctx context.Context, req *CopyRequest) error {
|
||||||
if err := c.do(ctx, http.MethodPost, "/api/copy", req, nil); err != nil {
|
if err := c.do(ctx, http.MethodPost, "/api/copy", req, nil); err != nil {
|
||||||
return err
|
return err
|
||||||
|
@ -306,6 +363,7 @@ func (c *Client) Copy(ctx context.Context, req *CopyRequest) error {
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Delete deletes a model and its data.
|
||||||
func (c *Client) Delete(ctx context.Context, req *DeleteRequest) error {
|
func (c *Client) Delete(ctx context.Context, req *DeleteRequest) error {
|
||||||
if err := c.do(ctx, http.MethodDelete, "/api/delete", req, nil); err != nil {
|
if err := c.do(ctx, http.MethodDelete, "/api/delete", req, nil); err != nil {
|
||||||
return err
|
return err
|
||||||
|
@ -313,6 +371,7 @@ func (c *Client) Delete(ctx context.Context, req *DeleteRequest) error {
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Show obtains model information, including details, modelfile, license etc.
|
||||||
func (c *Client) Show(ctx context.Context, req *ShowRequest) (*ShowResponse, error) {
|
func (c *Client) Show(ctx context.Context, req *ShowRequest) (*ShowResponse, error) {
|
||||||
var resp ShowResponse
|
var resp ShowResponse
|
||||||
if err := c.do(ctx, http.MethodPost, "/api/show", req, &resp); err != nil {
|
if err := c.do(ctx, http.MethodPost, "/api/show", req, &resp); err != nil {
|
||||||
|
@ -321,12 +380,16 @@ func (c *Client) Show(ctx context.Context, req *ShowRequest) (*ShowResponse, err
|
||||||
return &resp, nil
|
return &resp, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Hearbeat checks if the server has started and is responsive; if yes, it
|
||||||
|
// returns nil, otherwise an error.
|
||||||
func (c *Client) Heartbeat(ctx context.Context) error {
|
func (c *Client) Heartbeat(ctx context.Context) error {
|
||||||
if err := c.do(ctx, http.MethodHead, "/", nil, nil); err != nil {
|
if err := c.do(ctx, http.MethodHead, "/", nil, nil); err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Embeddings generates embeddings from a model.
|
||||||
func (c *Client) Embeddings(ctx context.Context, req *EmbeddingRequest) (*EmbeddingResponse, error) {
|
func (c *Client) Embeddings(ctx context.Context, req *EmbeddingRequest) (*EmbeddingResponse, error) {
|
||||||
var resp EmbeddingResponse
|
var resp EmbeddingResponse
|
||||||
if err := c.do(ctx, http.MethodPost, "/api/embeddings", req, &resp); err != nil {
|
if err := c.do(ctx, http.MethodPost, "/api/embeddings", req, &resp); err != nil {
|
||||||
|
@ -335,10 +398,13 @@ func (c *Client) Embeddings(ctx context.Context, req *EmbeddingRequest) (*Embedd
|
||||||
return &resp, nil
|
return &resp, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// CreateBlob creates a blob from a file on the server. digest is the
|
||||||
|
// expected SHA256 digest of the file, and r represents the file.
|
||||||
func (c *Client) CreateBlob(ctx context.Context, digest string, r io.Reader) error {
|
func (c *Client) CreateBlob(ctx context.Context, digest string, r io.Reader) error {
|
||||||
return c.do(ctx, http.MethodPost, fmt.Sprintf("/api/blobs/%s", digest), r, nil)
|
return c.do(ctx, http.MethodPost, fmt.Sprintf("/api/blobs/%s", digest), r, nil)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Version returns the Ollama server version as a string.
|
||||||
func (c *Client) Version(ctx context.Context) (string, error) {
|
func (c *Client) Version(ctx context.Context) (string, error) {
|
||||||
var version struct {
|
var version struct {
|
||||||
Version string `json:"version"`
|
Version string `json:"version"`
|
||||||
|
|
|
@ -1,6 +1,12 @@
|
||||||
package api
|
package api
|
||||||
|
|
||||||
import "testing"
|
import (
|
||||||
|
"fmt"
|
||||||
|
"net"
|
||||||
|
"testing"
|
||||||
|
|
||||||
|
"github.com/stretchr/testify/assert"
|
||||||
|
)
|
||||||
|
|
||||||
func TestClientFromEnvironment(t *testing.T) {
|
func TestClientFromEnvironment(t *testing.T) {
|
||||||
type testCase struct {
|
type testCase struct {
|
||||||
|
@ -40,4 +46,40 @@ func TestClientFromEnvironment(t *testing.T) {
|
||||||
}
|
}
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
|
hostTestCases := map[string]*testCase{
|
||||||
|
"empty": {value: "", expect: "127.0.0.1:11434"},
|
||||||
|
"only address": {value: "1.2.3.4", expect: "1.2.3.4:11434"},
|
||||||
|
"only port": {value: ":1234", expect: ":1234"},
|
||||||
|
"address and port": {value: "1.2.3.4:1234", expect: "1.2.3.4:1234"},
|
||||||
|
"hostname": {value: "example.com", expect: "example.com:11434"},
|
||||||
|
"hostname and port": {value: "example.com:1234", expect: "example.com:1234"},
|
||||||
|
"zero port": {value: ":0", expect: ":0"},
|
||||||
|
"too large port": {value: ":66000", err: ErrInvalidHostPort},
|
||||||
|
"too small port": {value: ":-1", err: ErrInvalidHostPort},
|
||||||
|
"ipv6 localhost": {value: "[::1]", expect: "[::1]:11434"},
|
||||||
|
"ipv6 world open": {value: "[::]", expect: "[::]:11434"},
|
||||||
|
"ipv6 no brackets": {value: "::1", expect: "[::1]:11434"},
|
||||||
|
"ipv6 + port": {value: "[::1]:1337", expect: "[::1]:1337"},
|
||||||
|
"extra space": {value: " 1.2.3.4 ", expect: "1.2.3.4:11434"},
|
||||||
|
"extra quotes": {value: "\"1.2.3.4\"", expect: "1.2.3.4:11434"},
|
||||||
|
"extra space+quotes": {value: " \" 1.2.3.4 \" ", expect: "1.2.3.4:11434"},
|
||||||
|
"extra single quotes": {value: "'1.2.3.4'", expect: "1.2.3.4:11434"},
|
||||||
|
}
|
||||||
|
|
||||||
|
for k, v := range hostTestCases {
|
||||||
|
t.Run(k, func(t *testing.T) {
|
||||||
|
t.Setenv("OLLAMA_HOST", v.value)
|
||||||
|
|
||||||
|
oh, err := GetOllamaHost()
|
||||||
|
if err != v.err {
|
||||||
|
t.Fatalf("expected %s, got %s", v.err, err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if err == nil {
|
||||||
|
host := net.JoinHostPort(oh.Host, oh.Port)
|
||||||
|
assert.Equal(t, v.expect, host, fmt.Sprintf("%s: expected %s, got %s", k, v.expect, host))
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
77
api/types.go
77
api/types.go
|
@ -2,6 +2,7 @@ package api
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"encoding/json"
|
"encoding/json"
|
||||||
|
"errors"
|
||||||
"fmt"
|
"fmt"
|
||||||
"math"
|
"math"
|
||||||
"os"
|
"os"
|
||||||
|
@ -11,6 +12,7 @@ import (
|
||||||
"time"
|
"time"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
// StatusError is an error with and HTTP status code.
|
||||||
type StatusError struct {
|
type StatusError struct {
|
||||||
StatusCode int
|
StatusCode int
|
||||||
Status string
|
Status string
|
||||||
|
@ -31,6 +33,7 @@ func (e StatusError) Error() string {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// ImageData represents the raw binary data of an image file.
|
||||||
type ImageData []byte
|
type ImageData []byte
|
||||||
|
|
||||||
// GenerateRequest describes a request sent by [Client.Generate]. While you
|
// GenerateRequest describes a request sent by [Client.Generate]. While you
|
||||||
|
@ -76,22 +79,39 @@ type GenerateRequest struct {
|
||||||
Options map[string]interface{} `json:"options"`
|
Options map[string]interface{} `json:"options"`
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// ChatRequest describes a request sent by [Client.Chat].
|
||||||
type ChatRequest struct {
|
type ChatRequest struct {
|
||||||
|
// Model is the model name, as in [GenerateRequest].
|
||||||
Model string `json:"model"`
|
Model string `json:"model"`
|
||||||
|
|
||||||
|
// Messages is the messages of the chat - can be used to keep a chat memory.
|
||||||
Messages []Message `json:"messages"`
|
Messages []Message `json:"messages"`
|
||||||
|
|
||||||
|
// Stream enable streaming of returned response; true by default.
|
||||||
Stream *bool `json:"stream,omitempty"`
|
Stream *bool `json:"stream,omitempty"`
|
||||||
|
|
||||||
|
// Format is the format to return the response in (e.g. "json").
|
||||||
Format string `json:"format"`
|
Format string `json:"format"`
|
||||||
|
|
||||||
|
// KeepAlive controls how long the model will stay loaded into memory
|
||||||
|
// followin the request.
|
||||||
KeepAlive *Duration `json:"keep_alive,omitempty"`
|
KeepAlive *Duration `json:"keep_alive,omitempty"`
|
||||||
|
|
||||||
|
// Options lists model-specific options.
|
||||||
Options map[string]interface{} `json:"options"`
|
Options map[string]interface{} `json:"options"`
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Message is a single message in a chat sequence. The message contains the
|
||||||
|
// role ("system", "user", or "assistant"), the content and an optional list
|
||||||
|
// of images.
|
||||||
type Message struct {
|
type Message struct {
|
||||||
Role string `json:"role"` // one of ["system", "user", "assistant"]
|
Role string `json:"role"`
|
||||||
Content string `json:"content"`
|
Content string `json:"content"`
|
||||||
Images []ImageData `json:"images,omitempty"`
|
Images []ImageData `json:"images,omitempty"`
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// ChatResponse is the response returned by [Client.Chat]. Its fields are
|
||||||
|
// similar to [GenerateResponse].
|
||||||
type ChatResponse struct {
|
type ChatResponse struct {
|
||||||
Model string `json:"model"`
|
Model string `json:"model"`
|
||||||
CreatedAt time.Time `json:"created_at"`
|
CreatedAt time.Time `json:"created_at"`
|
||||||
|
@ -111,7 +131,8 @@ type Metrics struct {
|
||||||
EvalDuration time.Duration `json:"eval_duration,omitempty"`
|
EvalDuration time.Duration `json:"eval_duration,omitempty"`
|
||||||
}
|
}
|
||||||
|
|
||||||
// Options specified in GenerateRequest, if you add a new option here add it to the API docs also
|
// Options specified in [GenerateRequest], if you add a new option here add it
|
||||||
|
// to the API docs also.
|
||||||
type Options struct {
|
type Options struct {
|
||||||
Runner
|
Runner
|
||||||
|
|
||||||
|
@ -157,18 +178,28 @@ type Runner struct {
|
||||||
RopeFrequencyScale float32 `json:"rope_frequency_scale,omitempty"`
|
RopeFrequencyScale float32 `json:"rope_frequency_scale,omitempty"`
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// EmbeddingRequest is the request passed to [Client.Embeddings].
|
||||||
type EmbeddingRequest struct {
|
type EmbeddingRequest struct {
|
||||||
|
// Model is the model name.
|
||||||
Model string `json:"model"`
|
Model string `json:"model"`
|
||||||
|
|
||||||
|
// Prompt is the textual prompt to embed.
|
||||||
Prompt string `json:"prompt"`
|
Prompt string `json:"prompt"`
|
||||||
|
|
||||||
|
// KeepAlive controls how long the model will stay loaded in memory following
|
||||||
|
// this request.
|
||||||
KeepAlive *Duration `json:"keep_alive,omitempty"`
|
KeepAlive *Duration `json:"keep_alive,omitempty"`
|
||||||
|
|
||||||
|
// Options lists model-specific options.
|
||||||
Options map[string]interface{} `json:"options"`
|
Options map[string]interface{} `json:"options"`
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// EmbeddingResponse is the response from [Client.Embeddings].
|
||||||
type EmbeddingResponse struct {
|
type EmbeddingResponse struct {
|
||||||
Embedding []float64 `json:"embedding"`
|
Embedding []float64 `json:"embedding"`
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// CreateRequest is the request passed to [Client.Create].
|
||||||
type CreateRequest struct {
|
type CreateRequest struct {
|
||||||
Model string `json:"model"`
|
Model string `json:"model"`
|
||||||
Path string `json:"path"`
|
Path string `json:"path"`
|
||||||
|
@ -180,6 +211,7 @@ type CreateRequest struct {
|
||||||
Name string `json:"name"`
|
Name string `json:"name"`
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// DeleteRequest is the request passed to [Client.Delete].
|
||||||
type DeleteRequest struct {
|
type DeleteRequest struct {
|
||||||
Model string `json:"model"`
|
Model string `json:"model"`
|
||||||
|
|
||||||
|
@ -187,6 +219,7 @@ type DeleteRequest struct {
|
||||||
Name string `json:"name"`
|
Name string `json:"name"`
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// ShowRequest is the request passed to [Client.Show].
|
||||||
type ShowRequest struct {
|
type ShowRequest struct {
|
||||||
Model string `json:"model"`
|
Model string `json:"model"`
|
||||||
System string `json:"system"`
|
System string `json:"system"`
|
||||||
|
@ -198,6 +231,7 @@ type ShowRequest struct {
|
||||||
Name string `json:"name"`
|
Name string `json:"name"`
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// ShowResponse is the response returned from [Client.Show].
|
||||||
type ShowResponse struct {
|
type ShowResponse struct {
|
||||||
License string `json:"license,omitempty"`
|
License string `json:"license,omitempty"`
|
||||||
Modelfile string `json:"modelfile,omitempty"`
|
Modelfile string `json:"modelfile,omitempty"`
|
||||||
|
@ -208,11 +242,13 @@ type ShowResponse struct {
|
||||||
Messages []Message `json:"messages,omitempty"`
|
Messages []Message `json:"messages,omitempty"`
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// CopyRequest is the request passed to [Client.Copy].
|
||||||
type CopyRequest struct {
|
type CopyRequest struct {
|
||||||
Source string `json:"source"`
|
Source string `json:"source"`
|
||||||
Destination string `json:"destination"`
|
Destination string `json:"destination"`
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// PullRequest is the request passed to [Client.Pull].
|
||||||
type PullRequest struct {
|
type PullRequest struct {
|
||||||
Model string `json:"model"`
|
Model string `json:"model"`
|
||||||
Insecure bool `json:"insecure,omitempty"`
|
Insecure bool `json:"insecure,omitempty"`
|
||||||
|
@ -224,6 +260,8 @@ type PullRequest struct {
|
||||||
Name string `json:"name"`
|
Name string `json:"name"`
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// ProgressResponse is the response passed to progress functions like
|
||||||
|
// [PullProgressFunc] and [PushProgressFunc].
|
||||||
type ProgressResponse struct {
|
type ProgressResponse struct {
|
||||||
Status string `json:"status"`
|
Status string `json:"status"`
|
||||||
Digest string `json:"digest,omitempty"`
|
Digest string `json:"digest,omitempty"`
|
||||||
|
@ -231,6 +269,7 @@ type ProgressResponse struct {
|
||||||
Completed int64 `json:"completed,omitempty"`
|
Completed int64 `json:"completed,omitempty"`
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// PushRequest is the request passed to [Client.Push].
|
||||||
type PushRequest struct {
|
type PushRequest struct {
|
||||||
Model string `json:"model"`
|
Model string `json:"model"`
|
||||||
Insecure bool `json:"insecure,omitempty"`
|
Insecure bool `json:"insecure,omitempty"`
|
||||||
|
@ -242,10 +281,12 @@ type PushRequest struct {
|
||||||
Name string `json:"name"`
|
Name string `json:"name"`
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// ListResponse is the response from [Client.List].
|
||||||
type ListResponse struct {
|
type ListResponse struct {
|
||||||
Models []ModelResponse `json:"models"`
|
Models []ModelResponse `json:"models"`
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// ModelResponse is a single model description in [ListResponse].
|
||||||
type ModelResponse struct {
|
type ModelResponse struct {
|
||||||
Name string `json:"name"`
|
Name string `json:"name"`
|
||||||
Model string `json:"model"`
|
Model string `json:"model"`
|
||||||
|
@ -259,17 +300,28 @@ type TokenResponse struct {
|
||||||
Token string `json:"token"`
|
Token string `json:"token"`
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// GenerateResponse is the response passed into [GenerateResponseFunc].
|
||||||
type GenerateResponse struct {
|
type GenerateResponse struct {
|
||||||
|
// Model is the model name that generated the response.
|
||||||
Model string `json:"model"`
|
Model string `json:"model"`
|
||||||
|
|
||||||
|
//CreatedAt is the timestamp of the response.
|
||||||
CreatedAt time.Time `json:"created_at"`
|
CreatedAt time.Time `json:"created_at"`
|
||||||
|
|
||||||
|
// Response is the textual response itself.
|
||||||
Response string `json:"response"`
|
Response string `json:"response"`
|
||||||
|
|
||||||
|
// Done specifies if the response is complete.
|
||||||
Done bool `json:"done"`
|
Done bool `json:"done"`
|
||||||
|
|
||||||
|
// Context is an encoding of the conversation used in this response; this
|
||||||
|
// can be sent in the next request to keep a conversational memory.
|
||||||
Context []int `json:"context,omitempty"`
|
Context []int `json:"context,omitempty"`
|
||||||
|
|
||||||
Metrics
|
Metrics
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// ModelDetails provides details about a model.
|
||||||
type ModelDetails struct {
|
type ModelDetails struct {
|
||||||
ParentModel string `json:"parent_model"`
|
ParentModel string `json:"parent_model"`
|
||||||
Format string `json:"format"`
|
Format string `json:"format"`
|
||||||
|
@ -307,7 +359,9 @@ func (m *Metrics) Summary() {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
var ErrInvalidOpts = fmt.Errorf("invalid options")
|
// ErrInvalidOpts is returned when invalid options are passed to the client.
|
||||||
|
var ErrInvalidOpts = errors.New("invalid options")
|
||||||
|
var ErrInvalidHostPort = errors.New("invalid port specified in OLLAMA_HOST")
|
||||||
|
|
||||||
func (opts *Options) FromMap(m map[string]interface{}) error {
|
func (opts *Options) FromMap(m map[string]interface{}) error {
|
||||||
valueOpts := reflect.ValueOf(opts).Elem() // names of the fields in the options struct
|
valueOpts := reflect.ValueOf(opts).Elem() // names of the fields in the options struct
|
||||||
|
@ -392,11 +446,15 @@ func (opts *Options) FromMap(m map[string]interface{}) error {
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// DefaultOptions is the default set of options for [GenerateRequest]; these
|
||||||
|
// values are used unless the user specifies other values explicitly.
|
||||||
func DefaultOptions() Options {
|
func DefaultOptions() Options {
|
||||||
return Options{
|
return Options{
|
||||||
// options set on request to runner
|
// options set on request to runner
|
||||||
NumPredict: -1,
|
NumPredict: -1,
|
||||||
NumKeep: 0,
|
|
||||||
|
// set a minimal num_keep to avoid issues on context shifts
|
||||||
|
NumKeep: 4,
|
||||||
Temperature: 0.8,
|
Temperature: 0.8,
|
||||||
TopK: 40,
|
TopK: 40,
|
||||||
TopP: 0.9,
|
TopP: 0.9,
|
||||||
|
@ -432,6 +490,13 @@ type Duration struct {
|
||||||
time.Duration
|
time.Duration
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (d Duration) MarshalJSON() ([]byte, error) {
|
||||||
|
if d.Duration < 0 {
|
||||||
|
return []byte("-1"), nil
|
||||||
|
}
|
||||||
|
return []byte("\"" + d.Duration.String() + "\""), nil
|
||||||
|
}
|
||||||
|
|
||||||
func (d *Duration) UnmarshalJSON(b []byte) (err error) {
|
func (d *Duration) UnmarshalJSON(b []byte) (err error) {
|
||||||
var v any
|
var v any
|
||||||
if err := json.Unmarshal(b, &v); err != nil {
|
if err := json.Unmarshal(b, &v); err != nil {
|
||||||
|
@ -445,7 +510,7 @@ func (d *Duration) UnmarshalJSON(b []byte) (err error) {
|
||||||
if t < 0 {
|
if t < 0 {
|
||||||
d.Duration = time.Duration(math.MaxInt64)
|
d.Duration = time.Duration(math.MaxInt64)
|
||||||
} else {
|
} else {
|
||||||
d.Duration = time.Duration(t * float64(time.Second))
|
d.Duration = time.Duration(int(t) * int(time.Second))
|
||||||
}
|
}
|
||||||
case string:
|
case string:
|
||||||
d.Duration, err = time.ParseDuration(t)
|
d.Duration, err = time.ParseDuration(t)
|
||||||
|
@ -455,6 +520,8 @@ func (d *Duration) UnmarshalJSON(b []byte) (err error) {
|
||||||
if d.Duration < 0 {
|
if d.Duration < 0 {
|
||||||
d.Duration = time.Duration(math.MaxInt64)
|
d.Duration = time.Duration(math.MaxInt64)
|
||||||
}
|
}
|
||||||
|
default:
|
||||||
|
return fmt.Errorf("Unsupported type: '%s'", reflect.TypeOf(v))
|
||||||
}
|
}
|
||||||
|
|
||||||
return nil
|
return nil
|
||||||
|
|
|
@ -21,6 +21,11 @@ func TestKeepAliveParsingFromJSON(t *testing.T) {
|
||||||
req: `{ "keep_alive": 42 }`,
|
req: `{ "keep_alive": 42 }`,
|
||||||
exp: &Duration{42 * time.Second},
|
exp: &Duration{42 * time.Second},
|
||||||
},
|
},
|
||||||
|
{
|
||||||
|
name: "Positive Float",
|
||||||
|
req: `{ "keep_alive": 42.5 }`,
|
||||||
|
exp: &Duration{42 * time.Second},
|
||||||
|
},
|
||||||
{
|
{
|
||||||
name: "Positive Integer String",
|
name: "Positive Integer String",
|
||||||
req: `{ "keep_alive": "42m" }`,
|
req: `{ "keep_alive": "42m" }`,
|
||||||
|
@ -31,6 +36,11 @@ func TestKeepAliveParsingFromJSON(t *testing.T) {
|
||||||
req: `{ "keep_alive": -1 }`,
|
req: `{ "keep_alive": -1 }`,
|
||||||
exp: &Duration{math.MaxInt64},
|
exp: &Duration{math.MaxInt64},
|
||||||
},
|
},
|
||||||
|
{
|
||||||
|
name: "Negative Float",
|
||||||
|
req: `{ "keep_alive": -3.14 }`,
|
||||||
|
exp: &Duration{math.MaxInt64},
|
||||||
|
},
|
||||||
{
|
{
|
||||||
name: "Negative Integer String",
|
name: "Negative Integer String",
|
||||||
req: `{ "keep_alive": "-1m" }`,
|
req: `{ "keep_alive": "-1m" }`,
|
||||||
|
@ -48,3 +58,50 @@ func TestKeepAliveParsingFromJSON(t *testing.T) {
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestDurationMarshalUnmarshal(t *testing.T) {
|
||||||
|
tests := []struct {
|
||||||
|
name string
|
||||||
|
input time.Duration
|
||||||
|
expected time.Duration
|
||||||
|
}{
|
||||||
|
{
|
||||||
|
"negative duration",
|
||||||
|
time.Duration(-1),
|
||||||
|
time.Duration(math.MaxInt64),
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"positive duration",
|
||||||
|
time.Duration(42 * time.Second),
|
||||||
|
time.Duration(42 * time.Second),
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"another positive duration",
|
||||||
|
time.Duration(42 * time.Minute),
|
||||||
|
time.Duration(42 * time.Minute),
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"zero duration",
|
||||||
|
time.Duration(0),
|
||||||
|
time.Duration(0),
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"max duration",
|
||||||
|
time.Duration(math.MaxInt64),
|
||||||
|
time.Duration(math.MaxInt64),
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, test := range tests {
|
||||||
|
t.Run(test.name, func(t *testing.T) {
|
||||||
|
b, err := json.Marshal(Duration{test.input})
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
var d Duration
|
||||||
|
err = json.Unmarshal(b, &d)
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
assert.Equal(t, test.expected, d.Duration, "input %v, marshalled %v, got %v", test.input, string(b), d.Duration)
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
|
@ -5,12 +5,14 @@ import (
|
||||||
"log/slog"
|
"log/slog"
|
||||||
"os"
|
"os"
|
||||||
"path/filepath"
|
"path/filepath"
|
||||||
|
|
||||||
|
"github.com/ollama/ollama/server/envconfig"
|
||||||
)
|
)
|
||||||
|
|
||||||
func InitLogging() {
|
func InitLogging() {
|
||||||
level := slog.LevelInfo
|
level := slog.LevelInfo
|
||||||
|
|
||||||
if debug := os.Getenv("OLLAMA_DEBUG"); debug != "" {
|
if envconfig.Debug {
|
||||||
level = slog.LevelDebug
|
level = slog.LevelDebug
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -43,37 +43,36 @@ func getCLIFullPath(command string) string {
|
||||||
return command
|
return command
|
||||||
}
|
}
|
||||||
|
|
||||||
func SpawnServer(ctx context.Context, command string) (chan int, error) {
|
func start(ctx context.Context, command string) (*exec.Cmd, error) {
|
||||||
done := make(chan int)
|
|
||||||
|
|
||||||
logDir := filepath.Dir(ServerLogFile)
|
|
||||||
_, err := os.Stat(logDir)
|
|
||||||
if errors.Is(err, os.ErrNotExist) {
|
|
||||||
if err := os.MkdirAll(logDir, 0o755); err != nil {
|
|
||||||
return done, fmt.Errorf("create ollama server log dir %s: %v", logDir, err)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
cmd := getCmd(ctx, getCLIFullPath(command))
|
cmd := getCmd(ctx, getCLIFullPath(command))
|
||||||
// send stdout and stderr to a file
|
|
||||||
stdout, err := cmd.StdoutPipe()
|
stdout, err := cmd.StdoutPipe()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return done, fmt.Errorf("failed to spawn server stdout pipe %s", err)
|
return nil, fmt.Errorf("failed to spawn server stdout pipe: %w", err)
|
||||||
}
|
}
|
||||||
stderr, err := cmd.StderrPipe()
|
stderr, err := cmd.StderrPipe()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return done, fmt.Errorf("failed to spawn server stderr pipe %s", err)
|
return nil, fmt.Errorf("failed to spawn server stderr pipe: %w", err)
|
||||||
}
|
|
||||||
stdin, err := cmd.StdinPipe()
|
|
||||||
if err != nil {
|
|
||||||
return done, fmt.Errorf("failed to spawn server stdin pipe %s", err)
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// TODO - rotation
|
// TODO - rotation
|
||||||
logFile, err := os.OpenFile(ServerLogFile, os.O_APPEND|os.O_WRONLY|os.O_CREATE, 0755)
|
logFile, err := os.OpenFile(ServerLogFile, os.O_APPEND|os.O_WRONLY|os.O_CREATE, 0755)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return done, fmt.Errorf("failed to create server log %w", err)
|
return nil, fmt.Errorf("failed to create server log: %w", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
logDir := filepath.Dir(ServerLogFile)
|
||||||
|
_, err = os.Stat(logDir)
|
||||||
|
if err != nil {
|
||||||
|
if !errors.Is(err, os.ErrNotExist) {
|
||||||
|
return nil, fmt.Errorf("stat ollama server log dir %s: %v", logDir, err)
|
||||||
|
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := os.MkdirAll(logDir, 0o755); err != nil {
|
||||||
|
return nil, fmt.Errorf("create ollama server log dir %s: %v", logDir, err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
go func() {
|
go func() {
|
||||||
defer logFile.Close()
|
defer logFile.Close()
|
||||||
io.Copy(logFile, stdout) //nolint:errcheck
|
io.Copy(logFile, stdout) //nolint:errcheck
|
||||||
|
@ -117,19 +116,33 @@ func SpawnServer(ctx context.Context, command string) (chan int, error) {
|
||||||
|
|
||||||
// run the command and wait for it to finish
|
// run the command and wait for it to finish
|
||||||
if err := cmd.Start(); err != nil {
|
if err := cmd.Start(); err != nil {
|
||||||
return done, fmt.Errorf("failed to start server %w", err)
|
return nil, fmt.Errorf("failed to start server %w", err)
|
||||||
}
|
}
|
||||||
if cmd.Process != nil {
|
if cmd.Process != nil {
|
||||||
slog.Info(fmt.Sprintf("started ollama server with pid %d", cmd.Process.Pid))
|
slog.Info(fmt.Sprintf("started ollama server with pid %d", cmd.Process.Pid))
|
||||||
}
|
}
|
||||||
slog.Info(fmt.Sprintf("ollama server logs %s", ServerLogFile))
|
slog.Info(fmt.Sprintf("ollama server logs %s", ServerLogFile))
|
||||||
|
|
||||||
|
return cmd, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func SpawnServer(ctx context.Context, command string) (chan int, error) {
|
||||||
|
done := make(chan int)
|
||||||
|
|
||||||
go func() {
|
go func() {
|
||||||
// Keep the server running unless we're shuttind down the app
|
// Keep the server running unless we're shuttind down the app
|
||||||
crashCount := 0
|
crashCount := 0
|
||||||
for {
|
for {
|
||||||
|
slog.Info("starting server...")
|
||||||
|
cmd, err := start(ctx, command)
|
||||||
|
if err != nil {
|
||||||
|
crashCount++
|
||||||
|
slog.Error(fmt.Sprintf("failed to start server %s", err))
|
||||||
|
time.Sleep(500 * time.Millisecond * time.Duration(crashCount))
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
cmd.Wait() //nolint:errcheck
|
cmd.Wait() //nolint:errcheck
|
||||||
stdin.Close()
|
|
||||||
var code int
|
var code int
|
||||||
if cmd.ProcessState != nil {
|
if cmd.ProcessState != nil {
|
||||||
code = cmd.ProcessState.ExitCode()
|
code = cmd.ProcessState.ExitCode()
|
||||||
|
@ -143,15 +156,12 @@ func SpawnServer(ctx context.Context, command string) (chan int, error) {
|
||||||
default:
|
default:
|
||||||
crashCount++
|
crashCount++
|
||||||
slog.Warn(fmt.Sprintf("server crash %d - exit code %d - respawning", crashCount, code))
|
slog.Warn(fmt.Sprintf("server crash %d - exit code %d - respawning", crashCount, code))
|
||||||
time.Sleep(500 * time.Millisecond)
|
time.Sleep(500 * time.Millisecond * time.Duration(crashCount))
|
||||||
if err := cmd.Start(); err != nil {
|
break
|
||||||
slog.Error(fmt.Sprintf("failed to restart server %s", err))
|
|
||||||
// Keep trying, but back off if we keep failing
|
|
||||||
time.Sleep(time.Duration(crashCount) * time.Second)
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}()
|
}()
|
||||||
|
|
||||||
return done, nil
|
return done, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -31,16 +31,13 @@ func DoUpgrade(cancel context.CancelFunc, done chan int) error {
|
||||||
"/LOG=" + filepath.Base(UpgradeLogFile), // Only relative seems reliable, so set pwd
|
"/LOG=" + filepath.Base(UpgradeLogFile), // Only relative seems reliable, so set pwd
|
||||||
"/FORCECLOSEAPPLICATIONS", // Force close the tray app - might be needed
|
"/FORCECLOSEAPPLICATIONS", // Force close the tray app - might be needed
|
||||||
}
|
}
|
||||||
// When we're not in debug mode, make the upgrade as quiet as possible (no GUI, no prompts)
|
// make the upgrade as quiet as possible (no GUI, no prompts)
|
||||||
// TODO - temporarily disable since we're pinning in debug mode for the preview
|
|
||||||
// if debug := os.Getenv("OLLAMA_DEBUG"); debug == "" {
|
|
||||||
installArgs = append(installArgs,
|
installArgs = append(installArgs,
|
||||||
"/SP", // Skip the "This will install... Do you wish to continue" prompt
|
"/SP", // Skip the "This will install... Do you wish to continue" prompt
|
||||||
"/SUPPRESSMSGBOXES",
|
"/SUPPRESSMSGBOXES",
|
||||||
"/SILENT",
|
"/SILENT",
|
||||||
"/VERYSILENT",
|
"/VERYSILENT",
|
||||||
)
|
)
|
||||||
// }
|
|
||||||
|
|
||||||
// Safeguard in case we have requests in flight that need to drain...
|
// Safeguard in case we have requests in flight that need to drain...
|
||||||
slog.Info("Waiting for server to shutdown")
|
slog.Info("Waiting for server to shutdown")
|
||||||
|
|
|
@ -88,15 +88,12 @@ DialogFontSize=12
|
||||||
[Files]
|
[Files]
|
||||||
Source: ".\app.exe"; DestDir: "{app}"; DestName: "{#MyAppExeName}" ; Flags: ignoreversion 64bit
|
Source: ".\app.exe"; DestDir: "{app}"; DestName: "{#MyAppExeName}" ; Flags: ignoreversion 64bit
|
||||||
Source: "..\ollama.exe"; DestDir: "{app}"; Flags: ignoreversion 64bit
|
Source: "..\ollama.exe"; DestDir: "{app}"; Flags: ignoreversion 64bit
|
||||||
Source: "..\dist\windeps\*.dll"; DestDir: "{app}"; Flags: ignoreversion 64bit
|
Source: "..\dist\windows-{#ARCH}\*.dll"; DestDir: "{app}"; Flags: ignoreversion 64bit
|
||||||
|
Source: "..\dist\windows-{#ARCH}\ollama_runners\*"; DestDir: "{app}\ollama_runners"; Flags: ignoreversion 64bit recursesubdirs
|
||||||
Source: "..\dist\ollama_welcome.ps1"; DestDir: "{app}"; Flags: ignoreversion
|
Source: "..\dist\ollama_welcome.ps1"; DestDir: "{app}"; Flags: ignoreversion
|
||||||
Source: ".\assets\app.ico"; DestDir: "{app}"; Flags: ignoreversion
|
Source: ".\assets\app.ico"; DestDir: "{app}"; Flags: ignoreversion
|
||||||
; Assumes v5.7, may need adjustments for v6
|
#if DirExists("..\dist\windows-amd64\rocm")
|
||||||
#if GetEnv("HIP_PATH") != ""
|
Source: "..\dist\windows-amd64\rocm\*"; DestDir: "{app}\rocm\"; Flags: ignoreversion recursesubdirs
|
||||||
Source: "{#GetEnv('HIP_PATH')}\bin\hipblas.dll"; DestDir: "{app}\rocm\"; Flags: ignoreversion
|
|
||||||
Source: "{#GetEnv('HIP_PATH')}\bin\rocblas.dll"; DestDir: "{app}\rocm\"; Flags: ignoreversion
|
|
||||||
; amdhip64.dll dependency comes from the driver and must be installed already
|
|
||||||
Source: "{#GetEnv('HIP_PATH')}\bin\rocblas\library\*"; DestDir: "{app}\rocm\rocblas\library\"; Flags: ignoreversion
|
|
||||||
#endif
|
#endif
|
||||||
|
|
||||||
|
|
||||||
|
@ -132,7 +129,7 @@ SetupAppRunningError=Another Ollama installer is running.%n%nPlease cancel or fi
|
||||||
|
|
||||||
|
|
||||||
;FinishedHeadingLabel=Run your first model
|
;FinishedHeadingLabel=Run your first model
|
||||||
;FinishedLabel=%nRun this command in a PowerShell or cmd terminal.%n%n%n ollama run llama2
|
;FinishedLabel=%nRun this command in a PowerShell or cmd terminal.%n%n%n ollama run llama3
|
||||||
;ClickFinish=%n
|
;ClickFinish=%n
|
||||||
|
|
||||||
[Registry]
|
[Registry]
|
||||||
|
|
36
auth/auth.go
36
auth/auth.go
|
@ -10,12 +10,44 @@ import (
|
||||||
"log/slog"
|
"log/slog"
|
||||||
"os"
|
"os"
|
||||||
"path/filepath"
|
"path/filepath"
|
||||||
|
"strings"
|
||||||
|
|
||||||
"golang.org/x/crypto/ssh"
|
"golang.org/x/crypto/ssh"
|
||||||
)
|
)
|
||||||
|
|
||||||
const defaultPrivateKey = "id_ed25519"
|
const defaultPrivateKey = "id_ed25519"
|
||||||
|
|
||||||
|
func keyPath() (string, error) {
|
||||||
|
home, err := os.UserHomeDir()
|
||||||
|
if err != nil {
|
||||||
|
return "", err
|
||||||
|
}
|
||||||
|
|
||||||
|
return filepath.Join(home, ".ollama", defaultPrivateKey), nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func GetPublicKey() (string, error) {
|
||||||
|
keyPath, err := keyPath()
|
||||||
|
if err != nil {
|
||||||
|
return "", err
|
||||||
|
}
|
||||||
|
|
||||||
|
privateKeyFile, err := os.ReadFile(keyPath)
|
||||||
|
if err != nil {
|
||||||
|
slog.Info(fmt.Sprintf("Failed to load private key: %v", err))
|
||||||
|
return "", err
|
||||||
|
}
|
||||||
|
|
||||||
|
privateKey, err := ssh.ParsePrivateKey(privateKeyFile)
|
||||||
|
if err != nil {
|
||||||
|
return "", err
|
||||||
|
}
|
||||||
|
|
||||||
|
publicKey := ssh.MarshalAuthorizedKey(privateKey.PublicKey())
|
||||||
|
|
||||||
|
return strings.TrimSpace(string(publicKey)), nil
|
||||||
|
}
|
||||||
|
|
||||||
func NewNonce(r io.Reader, length int) (string, error) {
|
func NewNonce(r io.Reader, length int) (string, error) {
|
||||||
nonce := make([]byte, length)
|
nonce := make([]byte, length)
|
||||||
if _, err := io.ReadFull(r, nonce); err != nil {
|
if _, err := io.ReadFull(r, nonce); err != nil {
|
||||||
|
@ -26,13 +58,11 @@ func NewNonce(r io.Reader, length int) (string, error) {
|
||||||
}
|
}
|
||||||
|
|
||||||
func Sign(ctx context.Context, bts []byte) (string, error) {
|
func Sign(ctx context.Context, bts []byte) (string, error) {
|
||||||
home, err := os.UserHomeDir()
|
keyPath, err := keyPath()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return "", err
|
return "", err
|
||||||
}
|
}
|
||||||
|
|
||||||
keyPath := filepath.Join(home, ".ollama", defaultPrivateKey)
|
|
||||||
|
|
||||||
privateKeyFile, err := os.ReadFile(keyPath)
|
privateKeyFile, err := os.ReadFile(keyPath)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
slog.Info(fmt.Sprintf("Failed to load private key: %v", err))
|
slog.Info(fmt.Sprintf("Failed to load private key: %v", err))
|
||||||
|
|
296
cmd/cmd.go
296
cmd/cmd.go
|
@ -17,6 +17,7 @@ import (
|
||||||
"os"
|
"os"
|
||||||
"os/signal"
|
"os/signal"
|
||||||
"path/filepath"
|
"path/filepath"
|
||||||
|
"regexp"
|
||||||
"runtime"
|
"runtime"
|
||||||
"strings"
|
"strings"
|
||||||
"syscall"
|
"syscall"
|
||||||
|
@ -31,10 +32,12 @@ import (
|
||||||
"golang.org/x/term"
|
"golang.org/x/term"
|
||||||
|
|
||||||
"github.com/ollama/ollama/api"
|
"github.com/ollama/ollama/api"
|
||||||
|
"github.com/ollama/ollama/auth"
|
||||||
"github.com/ollama/ollama/format"
|
"github.com/ollama/ollama/format"
|
||||||
"github.com/ollama/ollama/parser"
|
|
||||||
"github.com/ollama/ollama/progress"
|
"github.com/ollama/ollama/progress"
|
||||||
"github.com/ollama/ollama/server"
|
"github.com/ollama/ollama/server"
|
||||||
|
"github.com/ollama/ollama/types/errtypes"
|
||||||
|
"github.com/ollama/ollama/types/model"
|
||||||
"github.com/ollama/ollama/version"
|
"github.com/ollama/ollama/version"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@ -53,14 +56,13 @@ func CreateHandler(cmd *cobra.Command, args []string) error {
|
||||||
p := progress.NewProgress(os.Stderr)
|
p := progress.NewProgress(os.Stderr)
|
||||||
defer p.Stop()
|
defer p.Stop()
|
||||||
|
|
||||||
bars := make(map[string]*progress.Bar)
|
f, err := os.Open(filename)
|
||||||
|
|
||||||
modelfile, err := os.ReadFile(filename)
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
defer f.Close()
|
||||||
|
|
||||||
commands, err := parser.Parse(bytes.NewReader(modelfile))
|
modelfile, err := model.ParseFile(f)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
@ -74,10 +76,10 @@ func CreateHandler(cmd *cobra.Command, args []string) error {
|
||||||
spinner := progress.NewSpinner(status)
|
spinner := progress.NewSpinner(status)
|
||||||
p.Add(status, spinner)
|
p.Add(status, spinner)
|
||||||
|
|
||||||
for _, c := range commands {
|
for i := range modelfile.Commands {
|
||||||
switch c.Name {
|
switch modelfile.Commands[i].Name {
|
||||||
case "model", "adapter":
|
case "model", "adapter":
|
||||||
path := c.Args
|
path := modelfile.Commands[i].Args
|
||||||
if path == "~" {
|
if path == "~" {
|
||||||
path = home
|
path = home
|
||||||
} else if strings.HasPrefix(path, "~/") {
|
} else if strings.HasPrefix(path, "~/") {
|
||||||
|
@ -89,101 +91,22 @@ func CreateHandler(cmd *cobra.Command, args []string) error {
|
||||||
}
|
}
|
||||||
|
|
||||||
fi, err := os.Stat(path)
|
fi, err := os.Stat(path)
|
||||||
if errors.Is(err, os.ErrNotExist) && c.Name == "model" {
|
if errors.Is(err, os.ErrNotExist) && modelfile.Commands[i].Name == "model" {
|
||||||
continue
|
continue
|
||||||
} else if err != nil {
|
} else if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
// TODO make this work w/ adapters
|
|
||||||
if fi.IsDir() {
|
if fi.IsDir() {
|
||||||
tf, err := os.CreateTemp("", "ollama-tf")
|
// this is likely a safetensors or pytorch directory
|
||||||
|
// TODO make this work w/ adapters
|
||||||
|
tempfile, err := tempZipFiles(path)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
defer os.RemoveAll(tf.Name())
|
defer os.RemoveAll(tempfile)
|
||||||
|
|
||||||
zf := zip.NewWriter(tf)
|
path = tempfile
|
||||||
|
|
||||||
files := []string{}
|
|
||||||
|
|
||||||
tfiles, err := filepath.Glob(filepath.Join(path, "pytorch_model-*.bin"))
|
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
} else if len(tfiles) == 0 {
|
|
||||||
tfiles, err = filepath.Glob(filepath.Join(path, "model-*.safetensors"))
|
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
files = append(files, tfiles...)
|
|
||||||
|
|
||||||
if len(files) == 0 {
|
|
||||||
return fmt.Errorf("no models were found in '%s'", path)
|
|
||||||
}
|
|
||||||
|
|
||||||
// add the safetensor/torch config file + tokenizer
|
|
||||||
files = append(files, filepath.Join(path, "config.json"))
|
|
||||||
files = append(files, filepath.Join(path, "params.json"))
|
|
||||||
files = append(files, filepath.Join(path, "added_tokens.json"))
|
|
||||||
files = append(files, filepath.Join(path, "tokenizer.model"))
|
|
||||||
|
|
||||||
for _, fn := range files {
|
|
||||||
f, err := os.Open(fn)
|
|
||||||
|
|
||||||
// just skip whatever files aren't there
|
|
||||||
if os.IsNotExist(err) {
|
|
||||||
if strings.HasSuffix(fn, "tokenizer.model") {
|
|
||||||
// try the parent dir before giving up
|
|
||||||
parentDir := filepath.Dir(path)
|
|
||||||
newFn := filepath.Join(parentDir, "tokenizer.model")
|
|
||||||
f, err = os.Open(newFn)
|
|
||||||
if os.IsNotExist(err) {
|
|
||||||
continue
|
|
||||||
} else if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
} else {
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
} else if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
|
|
||||||
fi, err := f.Stat()
|
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
|
|
||||||
h, err := zip.FileInfoHeader(fi)
|
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
|
|
||||||
h.Name = filepath.Base(fn)
|
|
||||||
h.Method = zip.Store
|
|
||||||
|
|
||||||
w, err := zf.CreateHeader(h)
|
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
|
|
||||||
_, err = io.Copy(w, f)
|
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
|
|
||||||
}
|
|
||||||
|
|
||||||
if err := zf.Close(); err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
|
|
||||||
if err := tf.Close(); err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
path = tf.Name()
|
|
||||||
}
|
}
|
||||||
|
|
||||||
digest, err := createBlob(cmd, client, path)
|
digest, err := createBlob(cmd, client, path)
|
||||||
|
@ -191,10 +114,11 @@ func CreateHandler(cmd *cobra.Command, args []string) error {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
modelfile = bytes.ReplaceAll(modelfile, []byte(c.Args), []byte("@"+digest))
|
modelfile.Commands[i].Args = "@" + digest
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
bars := make(map[string]*progress.Bar)
|
||||||
fn := func(resp api.ProgressResponse) error {
|
fn := func(resp api.ProgressResponse) error {
|
||||||
if resp.Digest != "" {
|
if resp.Digest != "" {
|
||||||
spinner.Stop()
|
spinner.Stop()
|
||||||
|
@ -220,7 +144,7 @@ func CreateHandler(cmd *cobra.Command, args []string) error {
|
||||||
|
|
||||||
quantization, _ := cmd.Flags().GetString("quantization")
|
quantization, _ := cmd.Flags().GetString("quantization")
|
||||||
|
|
||||||
request := api.CreateRequest{Name: args[0], Modelfile: string(modelfile), Quantization: quantization}
|
request := api.CreateRequest{Name: args[0], Modelfile: modelfile.String(), Quantization: quantization}
|
||||||
if err := client.Create(cmd.Context(), &request, fn); err != nil {
|
if err := client.Create(cmd.Context(), &request, fn); err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
@ -228,6 +152,114 @@ func CreateHandler(cmd *cobra.Command, args []string) error {
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func tempZipFiles(path string) (string, error) {
|
||||||
|
tempfile, err := os.CreateTemp("", "ollama-tf")
|
||||||
|
if err != nil {
|
||||||
|
return "", err
|
||||||
|
}
|
||||||
|
defer tempfile.Close()
|
||||||
|
|
||||||
|
zipfile := zip.NewWriter(tempfile)
|
||||||
|
defer zipfile.Close()
|
||||||
|
|
||||||
|
detectContentType := func(path string) (string, error) {
|
||||||
|
f, err := os.Open(path)
|
||||||
|
if err != nil {
|
||||||
|
return "", err
|
||||||
|
}
|
||||||
|
defer f.Close()
|
||||||
|
|
||||||
|
var b bytes.Buffer
|
||||||
|
b.Grow(512)
|
||||||
|
|
||||||
|
if _, err := io.CopyN(&b, f, 512); err != nil && !errors.Is(err, io.EOF) {
|
||||||
|
return "", err
|
||||||
|
}
|
||||||
|
|
||||||
|
contentType, _, _ := strings.Cut(http.DetectContentType(b.Bytes()), ";")
|
||||||
|
return contentType, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
glob := func(pattern, contentType string) ([]string, error) {
|
||||||
|
matches, err := filepath.Glob(pattern)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, safetensor := range matches {
|
||||||
|
if ct, err := detectContentType(safetensor); err != nil {
|
||||||
|
return nil, err
|
||||||
|
} else if ct != contentType {
|
||||||
|
return nil, fmt.Errorf("invalid content type: expected %s for %s", ct, safetensor)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return matches, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
var files []string
|
||||||
|
if st, _ := glob(filepath.Join(path, "model*.safetensors"), "application/octet-stream"); len(st) > 0 {
|
||||||
|
// safetensors files might be unresolved git lfs references; skip if they are
|
||||||
|
// covers model-x-of-y.safetensors, model.fp32-x-of-y.safetensors, model.safetensors
|
||||||
|
files = append(files, st...)
|
||||||
|
} else if pt, _ := glob(filepath.Join(path, "pytorch_model*.bin"), "application/zip"); len(pt) > 0 {
|
||||||
|
// pytorch files might also be unresolved git lfs references; skip if they are
|
||||||
|
// covers pytorch_model-x-of-y.bin, pytorch_model.fp32-x-of-y.bin, pytorch_model.bin
|
||||||
|
files = append(files, pt...)
|
||||||
|
} else if pt, _ := glob(filepath.Join(path, "consolidated*.pth"), "application/octet-stream"); len(pt) > 0 {
|
||||||
|
// pytorch files might also be unresolved git lfs references; skip if they are
|
||||||
|
// covers consolidated.x.pth, consolidated.pth
|
||||||
|
files = append(files, pt...)
|
||||||
|
} else {
|
||||||
|
return "", errors.New("no safetensors or torch files found")
|
||||||
|
}
|
||||||
|
|
||||||
|
// add configuration files, json files are detected as text/plain
|
||||||
|
js, err := glob(filepath.Join(path, "*.json"), "text/plain")
|
||||||
|
if err != nil {
|
||||||
|
return "", err
|
||||||
|
}
|
||||||
|
files = append(files, js...)
|
||||||
|
|
||||||
|
if tks, _ := glob(filepath.Join(path, "tokenizer.model"), "application/octet-stream"); len(tks) > 0 {
|
||||||
|
// add tokenizer.model if it exists, tokenizer.json is automatically picked up by the previous glob
|
||||||
|
// tokenizer.model might be a unresolved git lfs reference; error if it is
|
||||||
|
files = append(files, tks...)
|
||||||
|
} else if tks, _ := glob(filepath.Join(path, "**/tokenizer.model"), "text/plain"); len(tks) > 0 {
|
||||||
|
// some times tokenizer.model is in a subdirectory (e.g. meta-llama/Meta-Llama-3-8B)
|
||||||
|
files = append(files, tks...)
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, file := range files {
|
||||||
|
f, err := os.Open(file)
|
||||||
|
if err != nil {
|
||||||
|
return "", err
|
||||||
|
}
|
||||||
|
defer f.Close()
|
||||||
|
|
||||||
|
fi, err := f.Stat()
|
||||||
|
if err != nil {
|
||||||
|
return "", err
|
||||||
|
}
|
||||||
|
|
||||||
|
zfi, err := zip.FileInfoHeader(fi)
|
||||||
|
if err != nil {
|
||||||
|
return "", err
|
||||||
|
}
|
||||||
|
|
||||||
|
zf, err := zipfile.CreateHeader(zfi)
|
||||||
|
if err != nil {
|
||||||
|
return "", err
|
||||||
|
}
|
||||||
|
|
||||||
|
if _, err := io.Copy(zf, f); err != nil {
|
||||||
|
return "", err
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return tempfile.Name(), nil
|
||||||
|
}
|
||||||
|
|
||||||
func createBlob(cmd *cobra.Command, client *api.Client, path string) (string, error) {
|
func createBlob(cmd *cobra.Command, client *api.Client, path string) (string, error) {
|
||||||
bin, err := os.Open(path)
|
bin, err := os.Open(path)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
|
@ -322,6 +354,47 @@ func RunHandler(cmd *cobra.Command, args []string) error {
|
||||||
return generateInteractive(cmd, opts)
|
return generateInteractive(cmd, opts)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func errFromUnknownKey(unknownKeyErr error) error {
|
||||||
|
// find SSH public key in the error message
|
||||||
|
sshKeyPattern := `ssh-\w+ [^\s"]+`
|
||||||
|
re := regexp.MustCompile(sshKeyPattern)
|
||||||
|
matches := re.FindStringSubmatch(unknownKeyErr.Error())
|
||||||
|
|
||||||
|
if len(matches) > 0 {
|
||||||
|
serverPubKey := matches[0]
|
||||||
|
|
||||||
|
localPubKey, err := auth.GetPublicKey()
|
||||||
|
if err != nil {
|
||||||
|
return unknownKeyErr
|
||||||
|
}
|
||||||
|
|
||||||
|
if runtime.GOOS == "linux" && serverPubKey != localPubKey {
|
||||||
|
// try the ollama service public key
|
||||||
|
svcPubKey, err := os.ReadFile("/usr/share/ollama/.ollama/id_ed25519.pub")
|
||||||
|
if err != nil {
|
||||||
|
return unknownKeyErr
|
||||||
|
}
|
||||||
|
localPubKey = strings.TrimSpace(string(svcPubKey))
|
||||||
|
}
|
||||||
|
|
||||||
|
// check if the returned public key matches the local public key, this prevents adding a remote key to the user's account
|
||||||
|
if serverPubKey != localPubKey {
|
||||||
|
return unknownKeyErr
|
||||||
|
}
|
||||||
|
|
||||||
|
var msg strings.Builder
|
||||||
|
msg.WriteString(unknownKeyErr.Error())
|
||||||
|
msg.WriteString("\n\nYour ollama key is:\n")
|
||||||
|
msg.WriteString(localPubKey)
|
||||||
|
msg.WriteString("\nAdd your key at:\n")
|
||||||
|
msg.WriteString("https://ollama.com/settings/keys")
|
||||||
|
|
||||||
|
return errors.New(msg.String())
|
||||||
|
}
|
||||||
|
|
||||||
|
return unknownKeyErr
|
||||||
|
}
|
||||||
|
|
||||||
func PushHandler(cmd *cobra.Command, args []string) error {
|
func PushHandler(cmd *cobra.Command, args []string) error {
|
||||||
client, err := api.ClientFromEnvironment()
|
client, err := api.ClientFromEnvironment()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
|
@ -369,6 +442,20 @@ func PushHandler(cmd *cobra.Command, args []string) error {
|
||||||
|
|
||||||
request := api.PushRequest{Name: args[0], Insecure: insecure}
|
request := api.PushRequest{Name: args[0], Insecure: insecure}
|
||||||
if err := client.Push(cmd.Context(), &request, fn); err != nil {
|
if err := client.Push(cmd.Context(), &request, fn); err != nil {
|
||||||
|
if spinner != nil {
|
||||||
|
spinner.Stop()
|
||||||
|
}
|
||||||
|
if strings.Contains(err.Error(), "access denied") {
|
||||||
|
return errors.New("you are not authorized to push to this namespace, create the model under a namespace you own")
|
||||||
|
}
|
||||||
|
host := model.ParseName(args[0]).Host
|
||||||
|
isOllamaHost := strings.HasSuffix(host, ".ollama.ai") || strings.HasSuffix(host, ".ollama.com")
|
||||||
|
if strings.Contains(err.Error(), errtypes.UnknownOllamaKeyErrMsg) && isOllamaHost {
|
||||||
|
// the user has not added their ollama key to ollama.com
|
||||||
|
// re-throw an error with a more user-friendly message
|
||||||
|
return errFromUnknownKey(err)
|
||||||
|
}
|
||||||
|
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -796,24 +883,27 @@ func generate(cmd *cobra.Command, opts runOptions) error {
|
||||||
}
|
}
|
||||||
|
|
||||||
func RunServer(cmd *cobra.Command, _ []string) error {
|
func RunServer(cmd *cobra.Command, _ []string) error {
|
||||||
host, port, err := net.SplitHostPort(strings.Trim(os.Getenv("OLLAMA_HOST"), "\"'"))
|
// retrieve the OLLAMA_HOST environment variable
|
||||||
|
ollamaHost, err := api.GetOllamaHost()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
host, port = "127.0.0.1", "11434"
|
return err
|
||||||
if ip := net.ParseIP(strings.Trim(os.Getenv("OLLAMA_HOST"), "[]")); ip != nil {
|
|
||||||
host = ip.String()
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
if err := initializeKeypair(); err != nil {
|
if err := initializeKeypair(); err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
ln, err := net.Listen("tcp", net.JoinHostPort(host, port))
|
ln, err := net.Listen("tcp", net.JoinHostPort(ollamaHost.Host, ollamaHost.Port))
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
return server.Serve(ln)
|
err = server.Serve(ln)
|
||||||
|
if errors.Is(err, http.ErrServerClosed) {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
func initializeKeypair() error {
|
func initializeKeypair() error {
|
||||||
|
@ -1034,7 +1124,7 @@ Environment Variables:
|
||||||
RunE: ListHandler,
|
RunE: ListHandler,
|
||||||
}
|
}
|
||||||
copyCmd := &cobra.Command{
|
copyCmd := &cobra.Command{
|
||||||
Use: "cp SOURCE TARGET",
|
Use: "cp SOURCE DESTINATION",
|
||||||
Short: "Copy a model",
|
Short: "Copy a model",
|
||||||
Args: cobra.ExactArgs(2),
|
Args: cobra.ExactArgs(2),
|
||||||
PreRunE: checkServerHeartbeat,
|
PreRunE: checkServerHeartbeat,
|
||||||
|
|
|
@ -94,6 +94,7 @@ func generateInteractive(cmd *cobra.Command, opts runOptions) error {
|
||||||
fmt.Fprintln(os.Stderr, " /show Show model information")
|
fmt.Fprintln(os.Stderr, " /show Show model information")
|
||||||
fmt.Fprintln(os.Stderr, " /load <model> Load a session or model")
|
fmt.Fprintln(os.Stderr, " /load <model> Load a session or model")
|
||||||
fmt.Fprintln(os.Stderr, " /save <model> Save your current session")
|
fmt.Fprintln(os.Stderr, " /save <model> Save your current session")
|
||||||
|
fmt.Fprintln(os.Stderr, " /clear Clear session context")
|
||||||
fmt.Fprintln(os.Stderr, " /bye Exit")
|
fmt.Fprintln(os.Stderr, " /bye Exit")
|
||||||
fmt.Fprintln(os.Stderr, " /?, /help Help for a command")
|
fmt.Fprintln(os.Stderr, " /?, /help Help for a command")
|
||||||
fmt.Fprintln(os.Stderr, " /? shortcuts Help for keyboard shortcuts")
|
fmt.Fprintln(os.Stderr, " /? shortcuts Help for keyboard shortcuts")
|
||||||
|
@ -161,7 +162,7 @@ func generateInteractive(cmd *cobra.Command, opts runOptions) error {
|
||||||
fmt.Fprintln(os.Stderr, " /set parameter repeat_penalty <float> How strongly to penalize repetitions")
|
fmt.Fprintln(os.Stderr, " /set parameter repeat_penalty <float> How strongly to penalize repetitions")
|
||||||
fmt.Fprintln(os.Stderr, " /set parameter repeat_last_n <int> Set how far back to look for repetitions")
|
fmt.Fprintln(os.Stderr, " /set parameter repeat_last_n <int> Set how far back to look for repetitions")
|
||||||
fmt.Fprintln(os.Stderr, " /set parameter num_gpu <int> The number of layers to send to the GPU")
|
fmt.Fprintln(os.Stderr, " /set parameter num_gpu <int> The number of layers to send to the GPU")
|
||||||
fmt.Fprintln(os.Stderr, " /set parameter stop \"<string>\", ... Set the stop parameters")
|
fmt.Fprintln(os.Stderr, " /set parameter stop <string> <string> ... Set the stop parameters")
|
||||||
fmt.Fprintln(os.Stderr, "")
|
fmt.Fprintln(os.Stderr, "")
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -280,6 +281,10 @@ func generateInteractive(cmd *cobra.Command, opts runOptions) error {
|
||||||
}
|
}
|
||||||
fmt.Printf("Created new model '%s'\n", args[1])
|
fmt.Printf("Created new model '%s'\n", args[1])
|
||||||
continue
|
continue
|
||||||
|
case strings.HasPrefix(line, "/clear"):
|
||||||
|
opts.Messages = []api.Message{}
|
||||||
|
fmt.Println("Cleared session context")
|
||||||
|
continue
|
||||||
case strings.HasPrefix(line, "/set"):
|
case strings.HasPrefix(line, "/set"):
|
||||||
args := strings.Fields(line)
|
args := strings.Fields(line)
|
||||||
if len(args) > 1 {
|
if len(args) > 1 {
|
||||||
|
|
|
@ -5,6 +5,7 @@ import (
|
||||||
"encoding/binary"
|
"encoding/binary"
|
||||||
"encoding/json"
|
"encoding/json"
|
||||||
"fmt"
|
"fmt"
|
||||||
|
"io"
|
||||||
"log/slog"
|
"log/slog"
|
||||||
"os"
|
"os"
|
||||||
"path/filepath"
|
"path/filepath"
|
||||||
|
@ -31,6 +32,10 @@ type Params struct {
|
||||||
EoSTokenID int `json:"eos_token_id"`
|
EoSTokenID int `json:"eos_token_id"`
|
||||||
HeadDimension int `json:"head_dim"`
|
HeadDimension int `json:"head_dim"`
|
||||||
PaddingTokenID int `json:"pad_token_id"`
|
PaddingTokenID int `json:"pad_token_id"`
|
||||||
|
RopeFrequencyBase float64 `json:"rope_theta"`
|
||||||
|
|
||||||
|
Experts int `json:"num_local_experts"`
|
||||||
|
ExpertsUsed int `json:"num_experts_per_tok"`
|
||||||
|
|
||||||
ByteOrder
|
ByteOrder
|
||||||
}
|
}
|
||||||
|
@ -43,7 +48,7 @@ type ByteOrder interface {
|
||||||
type ModelArch interface {
|
type ModelArch interface {
|
||||||
GetTensors() error
|
GetTensors() error
|
||||||
LoadVocab() error
|
LoadVocab() error
|
||||||
WriteGGUF() (string, error)
|
WriteGGUF(io.WriteSeeker) error
|
||||||
}
|
}
|
||||||
|
|
||||||
type ModelFormat interface {
|
type ModelFormat interface {
|
||||||
|
|
|
@ -94,7 +94,7 @@ func (m *GemmaModel) LoadVocab() error {
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (m *GemmaModel) WriteGGUF() (string, error) {
|
func (m *GemmaModel) WriteGGUF(ws io.WriteSeeker) error {
|
||||||
kv := llm.KV{
|
kv := llm.KV{
|
||||||
"general.architecture": "gemma",
|
"general.architecture": "gemma",
|
||||||
"general.name": m.Name,
|
"general.name": m.Name,
|
||||||
|
@ -122,16 +122,5 @@ func (m *GemmaModel) WriteGGUF() (string, error) {
|
||||||
"tokenizer.ggml.add_eos_token": false,
|
"tokenizer.ggml.add_eos_token": false,
|
||||||
}
|
}
|
||||||
|
|
||||||
f, err := os.CreateTemp("", "ollama-gguf")
|
return llm.NewGGUFV3(m.Params.ByteOrder).Encode(ws, kv, m.Tensors)
|
||||||
if err != nil {
|
|
||||||
return "", err
|
|
||||||
}
|
|
||||||
defer f.Close()
|
|
||||||
|
|
||||||
mod := llm.NewGGUFV3(m.Params.ByteOrder)
|
|
||||||
if err := mod.Encode(f, kv, m.Tensors); err != nil {
|
|
||||||
return "", err
|
|
||||||
}
|
|
||||||
|
|
||||||
return f.Name(), nil
|
|
||||||
}
|
}
|
||||||
|
|
|
@ -5,7 +5,6 @@ import (
|
||||||
"fmt"
|
"fmt"
|
||||||
"io"
|
"io"
|
||||||
"log/slog"
|
"log/slog"
|
||||||
"os"
|
|
||||||
"regexp"
|
"regexp"
|
||||||
"strings"
|
"strings"
|
||||||
|
|
||||||
|
@ -132,7 +131,7 @@ func (m *LlamaModel) LoadVocab() error {
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (m *LlamaModel) WriteGGUF() (string, error) {
|
func (m *LlamaModel) WriteGGUF(ws io.WriteSeeker) error {
|
||||||
kv := llm.KV{
|
kv := llm.KV{
|
||||||
"general.architecture": "llama",
|
"general.architecture": "llama",
|
||||||
"general.name": m.Name,
|
"general.name": m.Name,
|
||||||
|
@ -159,18 +158,5 @@ func (m *LlamaModel) WriteGGUF() (string, error) {
|
||||||
"tokenizer.ggml.add_eos_token": false,
|
"tokenizer.ggml.add_eos_token": false,
|
||||||
}
|
}
|
||||||
|
|
||||||
f, err := os.CreateTemp("", "ollama-gguf")
|
return llm.NewGGUFV3(m.Params.ByteOrder).Encode(ws, kv, m.Tensors)
|
||||||
if err != nil {
|
|
||||||
return "", err
|
|
||||||
}
|
|
||||||
defer f.Close()
|
|
||||||
|
|
||||||
mod := llm.NewGGUFV3(m.Params.ByteOrder)
|
|
||||||
if err := mod.Encode(f, kv, m.Tensors); err != nil {
|
|
||||||
return "", err
|
|
||||||
}
|
|
||||||
|
|
||||||
slog.Debug(fmt.Sprintf("gguf file = %s", f.Name()))
|
|
||||||
|
|
||||||
return f.Name(), nil
|
|
||||||
}
|
}
|
||||||
|
|
|
@ -132,7 +132,7 @@ func (m *MistralModel) LoadVocab() error {
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (m *MistralModel) WriteGGUF() (string, error) {
|
func (m *MistralModel) WriteGGUF(ws io.WriteSeeker) error {
|
||||||
kv := llm.KV{
|
kv := llm.KV{
|
||||||
"general.architecture": "llama",
|
"general.architecture": "llama",
|
||||||
"general.name": m.Name,
|
"general.name": m.Name,
|
||||||
|
@ -158,16 +158,5 @@ func (m *MistralModel) WriteGGUF() (string, error) {
|
||||||
"tokenizer.ggml.unknown_token_id": uint32(0),
|
"tokenizer.ggml.unknown_token_id": uint32(0),
|
||||||
}
|
}
|
||||||
|
|
||||||
f, err := os.CreateTemp("", "ollama-gguf")
|
return llm.NewGGUFV3(m.Params.ByteOrder).Encode(ws, kv, m.Tensors)
|
||||||
if err != nil {
|
|
||||||
return "", err
|
|
||||||
}
|
|
||||||
defer f.Close()
|
|
||||||
|
|
||||||
mod := llm.NewGGUFV3(m.Params.ByteOrder)
|
|
||||||
if err := mod.Encode(f, kv, m.Tensors); err != nil {
|
|
||||||
return "", err
|
|
||||||
}
|
|
||||||
|
|
||||||
return f.Name(), nil
|
|
||||||
}
|
}
|
||||||
|
|
85
convert/mixtral.go
Normal file
85
convert/mixtral.go
Normal file
|
@ -0,0 +1,85 @@
|
||||||
|
package convert
|
||||||
|
|
||||||
|
import (
|
||||||
|
"io"
|
||||||
|
"regexp"
|
||||||
|
|
||||||
|
"github.com/ollama/ollama/llm"
|
||||||
|
)
|
||||||
|
|
||||||
|
type MixtralModel struct {
|
||||||
|
ModelData
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m *MixtralModel) GetTensors() error {
|
||||||
|
t, err := m.Format.GetTensors(m.Path, m.Params)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
m.Tensors = []llm.Tensor{}
|
||||||
|
|
||||||
|
pattern := `^blk\.[0-9]+\.attn_(?P<layer>q|k)\.weight$`
|
||||||
|
re, err := regexp.Compile(pattern)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, l := range t {
|
||||||
|
matches := re.FindAllStringSubmatch(l.Name, -1)
|
||||||
|
if len(matches) > 0 {
|
||||||
|
wt := l.WriterTo.(safetensorWriterTo)
|
||||||
|
wt.handler = mistralLayerHandler
|
||||||
|
l.WriterTo = wt
|
||||||
|
}
|
||||||
|
m.Tensors = append(m.Tensors, l)
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m *MixtralModel) LoadVocab() error {
|
||||||
|
v, err := LoadSentencePieceTokens(m.Path, m.Params)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
m.Vocab = v
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m *MixtralModel) WriteGGUF(ws io.WriteSeeker) error {
|
||||||
|
kv := llm.KV{
|
||||||
|
"general.architecture": "llama",
|
||||||
|
"general.name": m.Name,
|
||||||
|
"llama.block_count": uint32(m.Params.HiddenLayers),
|
||||||
|
"llama.context_length": uint32(m.Params.ContextSize),
|
||||||
|
"llama.embedding_length": uint32(m.Params.HiddenSize),
|
||||||
|
"llama.feed_forward_length": uint32(m.Params.IntermediateSize),
|
||||||
|
"llama.attention.head_count": uint32(m.Params.AttentionHeads),
|
||||||
|
"llama.attention.head_count_kv": uint32(m.Params.KeyValHeads),
|
||||||
|
|
||||||
|
"llama.rope.freq_base": float32(m.Params.RopeFrequencyBase),
|
||||||
|
"llama.attention.layer_norm_rms_epsilon": float32(m.Params.NormEPS),
|
||||||
|
|
||||||
|
"llama.expert_count": uint32(m.Params.Experts),
|
||||||
|
"llama.expert_used_count": uint32(m.Params.ExpertsUsed),
|
||||||
|
|
||||||
|
"llama.vocab_size": uint32(len(m.Vocab.Tokens)),
|
||||||
|
"llama.rope.dimension_count": uint32(m.Params.HiddenSize / m.Params.AttentionHeads),
|
||||||
|
|
||||||
|
"general.file_type": uint32(1),
|
||||||
|
"tokenizer.ggml.model": "llama",
|
||||||
|
|
||||||
|
"tokenizer.ggml.tokens": m.Vocab.Tokens,
|
||||||
|
"tokenizer.ggml.scores": m.Vocab.Scores,
|
||||||
|
"tokenizer.ggml.token_type": m.Vocab.Types,
|
||||||
|
|
||||||
|
"tokenizer.ggml.bos_token_id": uint32(m.Params.BoSTokenID),
|
||||||
|
"tokenizer.ggml.eos_token_id": uint32(m.Params.EoSTokenID),
|
||||||
|
"tokenizer.ggml.unknown_token_id": uint32(0),
|
||||||
|
"tokenizer.ggml.add_bos_token": true,
|
||||||
|
"tokenizer.ggml.add_eos_token": false,
|
||||||
|
}
|
||||||
|
|
||||||
|
return llm.NewGGUFV3(m.Params.ByteOrder).Encode(ws, kv, m.Tensors)
|
||||||
|
}
|
|
@ -53,7 +53,7 @@ func (m *SafetensorFormat) GetTensors(dirpath string, params *Params) ([]llm.Ten
|
||||||
var err error
|
var err error
|
||||||
t, offset, err = m.readTensors(f, offset, params)
|
t, offset, err = m.readTensors(f, offset, params)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
slog.Error("%v", err)
|
slog.Error(err.Error())
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
tensors = append(tensors, t...)
|
tensors = append(tensors, t...)
|
||||||
|
@ -93,7 +93,6 @@ func (m *SafetensorFormat) readTensors(fn string, offset uint64, params *Params)
|
||||||
}
|
}
|
||||||
|
|
||||||
slices.Sort(keys)
|
slices.Sort(keys)
|
||||||
|
|
||||||
slog.Info("converting layers")
|
slog.Info("converting layers")
|
||||||
|
|
||||||
var tensors []llm.Tensor
|
var tensors []llm.Tensor
|
||||||
|
@ -105,7 +104,6 @@ func (m *SafetensorFormat) readTensors(fn string, offset uint64, params *Params)
|
||||||
return nil, 0, err
|
return nil, 0, err
|
||||||
}
|
}
|
||||||
|
|
||||||
slog.Debug(fmt.Sprintf("metadata = %#v", data))
|
|
||||||
var size uint64
|
var size uint64
|
||||||
var kind uint32
|
var kind uint32
|
||||||
switch len(data.Shape) {
|
switch len(data.Shape) {
|
||||||
|
@ -124,7 +122,7 @@ func (m *SafetensorFormat) readTensors(fn string, offset uint64, params *Params)
|
||||||
|
|
||||||
ggufName, err := m.GetLayerName(k)
|
ggufName, err := m.GetLayerName(k)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
slog.Error("%v", err)
|
slog.Error(err.Error())
|
||||||
return nil, 0, err
|
return nil, 0, err
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -150,11 +148,13 @@ func (m *SafetensorFormat) readTensors(fn string, offset uint64, params *Params)
|
||||||
padding: 8 + jsonSize,
|
padding: 8 + jsonSize,
|
||||||
}
|
}
|
||||||
|
|
||||||
tensors = append(tensors, t)
|
|
||||||
offset += size
|
offset += size
|
||||||
|
tensors = append(tensors, t)
|
||||||
}
|
}
|
||||||
|
|
||||||
slog.Debug(fmt.Sprintf("total tensors for file = %d", len(tensors)))
|
slog.Debug(fmt.Sprintf("total tensors for file = %d", len(tensors)))
|
||||||
slog.Debug(fmt.Sprintf("offset = %d", offset))
|
slog.Debug(fmt.Sprintf("offset = %d", offset))
|
||||||
|
|
||||||
return tensors, offset, nil
|
return tensors, offset, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -194,6 +194,10 @@ func (m *SafetensorFormat) GetLayerName(n string) (string, error) {
|
||||||
"model.layers.(\\d+).self_attn.o_proj.weight": "blk.$1.attn_output.weight",
|
"model.layers.(\\d+).self_attn.o_proj.weight": "blk.$1.attn_output.weight",
|
||||||
"model.layers.(\\d+).self_attn.q_proj.weight": "blk.$1.attn_q.weight",
|
"model.layers.(\\d+).self_attn.q_proj.weight": "blk.$1.attn_q.weight",
|
||||||
"model.layers.(\\d+).self_attn.v_proj.weight": "blk.$1.attn_v.weight",
|
"model.layers.(\\d+).self_attn.v_proj.weight": "blk.$1.attn_v.weight",
|
||||||
|
"model.layers.(\\d+).block_sparse_moe.gate.weight": "blk.$1.ffn_gate_inp.weight",
|
||||||
|
"model.layers.(\\d+).block_sparse_moe.experts.(\\d+).w1.weight": "blk.$1.ffn_gate.$2.weight",
|
||||||
|
"model.layers.(\\d+).block_sparse_moe.experts.(\\d+).w2.weight": "blk.$1.ffn_down.$2.weight",
|
||||||
|
"model.layers.(\\d+).block_sparse_moe.experts.(\\d+).w3.weight": "blk.$1.ffn_up.$2.weight",
|
||||||
}
|
}
|
||||||
|
|
||||||
v, ok := directMap[n]
|
v, ok := directMap[n]
|
||||||
|
@ -286,6 +290,15 @@ func (m *SafetensorFormat) GetModelArch(name, dirPath string, params *Params) (M
|
||||||
Format: m,
|
Format: m,
|
||||||
},
|
},
|
||||||
}, nil
|
}, nil
|
||||||
|
case "MixtralForCausalLM":
|
||||||
|
return &MixtralModel{
|
||||||
|
ModelData{
|
||||||
|
Name: name,
|
||||||
|
Path: dirPath,
|
||||||
|
Params: params,
|
||||||
|
Format: m,
|
||||||
|
},
|
||||||
|
}, nil
|
||||||
case "GemmaForCausalLM":
|
case "GemmaForCausalLM":
|
||||||
return &GemmaModel{
|
return &GemmaModel{
|
||||||
ModelData{
|
ModelData{
|
||||||
|
|
|
@ -74,7 +74,7 @@ func (tf *TorchFormat) GetTensors(dirpath string, params *Params) ([]llm.Tensor,
|
||||||
|
|
||||||
ggufName, err := tf.GetLayerName(k.(string))
|
ggufName, err := tf.GetLayerName(k.(string))
|
||||||
if err != nil {
|
if err != nil {
|
||||||
slog.Error("%v", err)
|
slog.Error(err.Error())
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
slog.Debug(fmt.Sprintf("finding name for '%s' -> '%s'", k.(string), ggufName))
|
slog.Debug(fmt.Sprintf("finding name for '%s' -> '%s'", k.(string), ggufName))
|
||||||
|
|
62
docs/api.md
62
docs/api.md
|
@ -17,7 +17,7 @@
|
||||||
|
|
||||||
### Model names
|
### Model names
|
||||||
|
|
||||||
Model names follow a `model:tag` format, where `model` can have an optional namespace such as `example/model`. Some examples are `orca-mini:3b-q4_1` and `llama2:70b`. The tag is optional and, if not provided, will default to `latest`. The tag is used to identify a specific version.
|
Model names follow a `model:tag` format, where `model` can have an optional namespace such as `example/model`. Some examples are `orca-mini:3b-q4_1` and `llama3:70b`. The tag is optional and, if not provided, will default to `latest`. The tag is used to identify a specific version.
|
||||||
|
|
||||||
### Durations
|
### Durations
|
||||||
|
|
||||||
|
@ -66,7 +66,7 @@ Enable JSON mode by setting the `format` parameter to `json`. This will structur
|
||||||
|
|
||||||
```shell
|
```shell
|
||||||
curl http://localhost:11434/api/generate -d '{
|
curl http://localhost:11434/api/generate -d '{
|
||||||
"model": "llama2",
|
"model": "llama3",
|
||||||
"prompt": "Why is the sky blue?"
|
"prompt": "Why is the sky blue?"
|
||||||
}'
|
}'
|
||||||
```
|
```
|
||||||
|
@ -77,7 +77,7 @@ A stream of JSON objects is returned:
|
||||||
|
|
||||||
```json
|
```json
|
||||||
{
|
{
|
||||||
"model": "llama2",
|
"model": "llama3",
|
||||||
"created_at": "2023-08-04T08:52:19.385406455-07:00",
|
"created_at": "2023-08-04T08:52:19.385406455-07:00",
|
||||||
"response": "The",
|
"response": "The",
|
||||||
"done": false
|
"done": false
|
||||||
|
@ -90,16 +90,16 @@ The final response in the stream also includes additional data about the generat
|
||||||
- `load_duration`: time spent in nanoseconds loading the model
|
- `load_duration`: time spent in nanoseconds loading the model
|
||||||
- `prompt_eval_count`: number of tokens in the prompt
|
- `prompt_eval_count`: number of tokens in the prompt
|
||||||
- `prompt_eval_duration`: time spent in nanoseconds evaluating the prompt
|
- `prompt_eval_duration`: time spent in nanoseconds evaluating the prompt
|
||||||
- `eval_count`: number of tokens the response
|
- `eval_count`: number of tokens in the response
|
||||||
- `eval_duration`: time in nanoseconds spent generating the response
|
- `eval_duration`: time in nanoseconds spent generating the response
|
||||||
- `context`: an encoding of the conversation used in this response, this can be sent in the next request to keep a conversational memory
|
- `context`: an encoding of the conversation used in this response, this can be sent in the next request to keep a conversational memory
|
||||||
- `response`: empty if the response was streamed, if not streamed, this will contain the full response
|
- `response`: empty if the response was streamed, if not streamed, this will contain the full response
|
||||||
|
|
||||||
To calculate how fast the response is generated in tokens per second (token/s), divide `eval_count` / `eval_duration`.
|
To calculate how fast the response is generated in tokens per second (token/s), divide `eval_count` / `eval_duration` * `10^9`.
|
||||||
|
|
||||||
```json
|
```json
|
||||||
{
|
{
|
||||||
"model": "llama2",
|
"model": "llama3",
|
||||||
"created_at": "2023-08-04T19:22:45.499127Z",
|
"created_at": "2023-08-04T19:22:45.499127Z",
|
||||||
"response": "",
|
"response": "",
|
||||||
"done": true,
|
"done": true,
|
||||||
|
@ -121,7 +121,7 @@ A response can be received in one reply when streaming is off.
|
||||||
|
|
||||||
```shell
|
```shell
|
||||||
curl http://localhost:11434/api/generate -d '{
|
curl http://localhost:11434/api/generate -d '{
|
||||||
"model": "llama2",
|
"model": "llama3",
|
||||||
"prompt": "Why is the sky blue?",
|
"prompt": "Why is the sky blue?",
|
||||||
"stream": false
|
"stream": false
|
||||||
}'
|
}'
|
||||||
|
@ -133,7 +133,7 @@ If `stream` is set to `false`, the response will be a single JSON object:
|
||||||
|
|
||||||
```json
|
```json
|
||||||
{
|
{
|
||||||
"model": "llama2",
|
"model": "llama3",
|
||||||
"created_at": "2023-08-04T19:22:45.499127Z",
|
"created_at": "2023-08-04T19:22:45.499127Z",
|
||||||
"response": "The sky is blue because it is the color of the sky.",
|
"response": "The sky is blue because it is the color of the sky.",
|
||||||
"done": true,
|
"done": true,
|
||||||
|
@ -155,7 +155,7 @@ If `stream` is set to `false`, the response will be a single JSON object:
|
||||||
|
|
||||||
```shell
|
```shell
|
||||||
curl http://localhost:11434/api/generate -d '{
|
curl http://localhost:11434/api/generate -d '{
|
||||||
"model": "llama2",
|
"model": "llama3",
|
||||||
"prompt": "What color is the sky at different times of the day? Respond using JSON",
|
"prompt": "What color is the sky at different times of the day? Respond using JSON",
|
||||||
"format": "json",
|
"format": "json",
|
||||||
"stream": false
|
"stream": false
|
||||||
|
@ -166,7 +166,7 @@ curl http://localhost:11434/api/generate -d '{
|
||||||
|
|
||||||
```json
|
```json
|
||||||
{
|
{
|
||||||
"model": "llama2",
|
"model": "llama3",
|
||||||
"created_at": "2023-11-09T21:07:55.186497Z",
|
"created_at": "2023-11-09T21:07:55.186497Z",
|
||||||
"response": "{\n\"morning\": {\n\"color\": \"blue\"\n},\n\"noon\": {\n\"color\": \"blue-gray\"\n},\n\"afternoon\": {\n\"color\": \"warm gray\"\n},\n\"evening\": {\n\"color\": \"orange\"\n}\n}\n",
|
"response": "{\n\"morning\": {\n\"color\": \"blue\"\n},\n\"noon\": {\n\"color\": \"blue-gray\"\n},\n\"afternoon\": {\n\"color\": \"warm gray\"\n},\n\"evening\": {\n\"color\": \"orange\"\n}\n}\n",
|
||||||
"done": true,
|
"done": true,
|
||||||
|
@ -289,7 +289,7 @@ If you want to set custom options for the model at runtime rather than in the Mo
|
||||||
|
|
||||||
```shell
|
```shell
|
||||||
curl http://localhost:11434/api/generate -d '{
|
curl http://localhost:11434/api/generate -d '{
|
||||||
"model": "llama2",
|
"model": "llama3",
|
||||||
"prompt": "Why is the sky blue?",
|
"prompt": "Why is the sky blue?",
|
||||||
"stream": false,
|
"stream": false,
|
||||||
"options": {
|
"options": {
|
||||||
|
@ -332,7 +332,7 @@ curl http://localhost:11434/api/generate -d '{
|
||||||
|
|
||||||
```json
|
```json
|
||||||
{
|
{
|
||||||
"model": "llama2",
|
"model": "llama3",
|
||||||
"created_at": "2023-08-04T19:22:45.499127Z",
|
"created_at": "2023-08-04T19:22:45.499127Z",
|
||||||
"response": "The sky is blue because it is the color of the sky.",
|
"response": "The sky is blue because it is the color of the sky.",
|
||||||
"done": true,
|
"done": true,
|
||||||
|
@ -354,7 +354,7 @@ If an empty prompt is provided, the model will be loaded into memory.
|
||||||
|
|
||||||
```shell
|
```shell
|
||||||
curl http://localhost:11434/api/generate -d '{
|
curl http://localhost:11434/api/generate -d '{
|
||||||
"model": "llama2"
|
"model": "llama3"
|
||||||
}'
|
}'
|
||||||
```
|
```
|
||||||
|
|
||||||
|
@ -364,7 +364,7 @@ A single JSON object is returned:
|
||||||
|
|
||||||
```json
|
```json
|
||||||
{
|
{
|
||||||
"model": "llama2",
|
"model": "llama3",
|
||||||
"created_at": "2023-12-18T19:52:07.071755Z",
|
"created_at": "2023-12-18T19:52:07.071755Z",
|
||||||
"response": "",
|
"response": "",
|
||||||
"done": true
|
"done": true
|
||||||
|
@ -407,7 +407,7 @@ Send a chat message with a streaming response.
|
||||||
|
|
||||||
```shell
|
```shell
|
||||||
curl http://localhost:11434/api/chat -d '{
|
curl http://localhost:11434/api/chat -d '{
|
||||||
"model": "llama2",
|
"model": "llama3",
|
||||||
"messages": [
|
"messages": [
|
||||||
{
|
{
|
||||||
"role": "user",
|
"role": "user",
|
||||||
|
@ -423,7 +423,7 @@ A stream of JSON objects is returned:
|
||||||
|
|
||||||
```json
|
```json
|
||||||
{
|
{
|
||||||
"model": "llama2",
|
"model": "llama3",
|
||||||
"created_at": "2023-08-04T08:52:19.385406455-07:00",
|
"created_at": "2023-08-04T08:52:19.385406455-07:00",
|
||||||
"message": {
|
"message": {
|
||||||
"role": "assistant",
|
"role": "assistant",
|
||||||
|
@ -438,7 +438,7 @@ Final response:
|
||||||
|
|
||||||
```json
|
```json
|
||||||
{
|
{
|
||||||
"model": "llama2",
|
"model": "llama3",
|
||||||
"created_at": "2023-08-04T19:22:45.499127Z",
|
"created_at": "2023-08-04T19:22:45.499127Z",
|
||||||
"done": true,
|
"done": true,
|
||||||
"total_duration": 4883583458,
|
"total_duration": 4883583458,
|
||||||
|
@ -456,7 +456,7 @@ Final response:
|
||||||
|
|
||||||
```shell
|
```shell
|
||||||
curl http://localhost:11434/api/chat -d '{
|
curl http://localhost:11434/api/chat -d '{
|
||||||
"model": "llama2",
|
"model": "llama3",
|
||||||
"messages": [
|
"messages": [
|
||||||
{
|
{
|
||||||
"role": "user",
|
"role": "user",
|
||||||
|
@ -471,7 +471,7 @@ curl http://localhost:11434/api/chat -d '{
|
||||||
|
|
||||||
```json
|
```json
|
||||||
{
|
{
|
||||||
"model": "registry.ollama.ai/library/llama2:latest",
|
"model": "registry.ollama.ai/library/llama3:latest",
|
||||||
"created_at": "2023-12-12T14:13:43.416799Z",
|
"created_at": "2023-12-12T14:13:43.416799Z",
|
||||||
"message": {
|
"message": {
|
||||||
"role": "assistant",
|
"role": "assistant",
|
||||||
|
@ -495,7 +495,7 @@ Send a chat message with a conversation history. You can use this same approach
|
||||||
|
|
||||||
```shell
|
```shell
|
||||||
curl http://localhost:11434/api/chat -d '{
|
curl http://localhost:11434/api/chat -d '{
|
||||||
"model": "llama2",
|
"model": "llama3",
|
||||||
"messages": [
|
"messages": [
|
||||||
{
|
{
|
||||||
"role": "user",
|
"role": "user",
|
||||||
|
@ -519,7 +519,7 @@ A stream of JSON objects is returned:
|
||||||
|
|
||||||
```json
|
```json
|
||||||
{
|
{
|
||||||
"model": "llama2",
|
"model": "llama3",
|
||||||
"created_at": "2023-08-04T08:52:19.385406455-07:00",
|
"created_at": "2023-08-04T08:52:19.385406455-07:00",
|
||||||
"message": {
|
"message": {
|
||||||
"role": "assistant",
|
"role": "assistant",
|
||||||
|
@ -533,7 +533,7 @@ Final response:
|
||||||
|
|
||||||
```json
|
```json
|
||||||
{
|
{
|
||||||
"model": "llama2",
|
"model": "llama3",
|
||||||
"created_at": "2023-08-04T19:22:45.499127Z",
|
"created_at": "2023-08-04T19:22:45.499127Z",
|
||||||
"done": true,
|
"done": true,
|
||||||
"total_duration": 8113331500,
|
"total_duration": 8113331500,
|
||||||
|
@ -591,7 +591,7 @@ curl http://localhost:11434/api/chat -d '{
|
||||||
|
|
||||||
```shell
|
```shell
|
||||||
curl http://localhost:11434/api/chat -d '{
|
curl http://localhost:11434/api/chat -d '{
|
||||||
"model": "llama2",
|
"model": "llama3",
|
||||||
"messages": [
|
"messages": [
|
||||||
{
|
{
|
||||||
"role": "user",
|
"role": "user",
|
||||||
|
@ -609,7 +609,7 @@ curl http://localhost:11434/api/chat -d '{
|
||||||
|
|
||||||
```json
|
```json
|
||||||
{
|
{
|
||||||
"model": "registry.ollama.ai/library/llama2:latest",
|
"model": "registry.ollama.ai/library/llama3:latest",
|
||||||
"created_at": "2023-12-12T14:13:43.416799Z",
|
"created_at": "2023-12-12T14:13:43.416799Z",
|
||||||
"message": {
|
"message": {
|
||||||
"role": "assistant",
|
"role": "assistant",
|
||||||
|
@ -651,7 +651,7 @@ Create a new model from a `Modelfile`.
|
||||||
```shell
|
```shell
|
||||||
curl http://localhost:11434/api/create -d '{
|
curl http://localhost:11434/api/create -d '{
|
||||||
"name": "mario",
|
"name": "mario",
|
||||||
"modelfile": "FROM llama2\nSYSTEM You are mario from Super Mario Bros."
|
"modelfile": "FROM llama3\nSYSTEM You are mario from Super Mario Bros."
|
||||||
}'
|
}'
|
||||||
```
|
```
|
||||||
|
|
||||||
|
@ -758,7 +758,7 @@ A single JSON object will be returned.
|
||||||
}
|
}
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"name": "llama2:latest",
|
"name": "llama3:latest",
|
||||||
"modified_at": "2023-12-07T09:32:18.757212583-08:00",
|
"modified_at": "2023-12-07T09:32:18.757212583-08:00",
|
||||||
"size": 3825819519,
|
"size": 3825819519,
|
||||||
"digest": "fe938a131f40e6f6d40083c9f0f430a515233eb2edaa6d72eb85c50d64f2300e",
|
"digest": "fe938a131f40e6f6d40083c9f0f430a515233eb2edaa6d72eb85c50d64f2300e",
|
||||||
|
@ -792,7 +792,7 @@ Show information about a model including details, modelfile, template, parameter
|
||||||
|
|
||||||
```shell
|
```shell
|
||||||
curl http://localhost:11434/api/show -d '{
|
curl http://localhost:11434/api/show -d '{
|
||||||
"name": "llama2"
|
"name": "llama3"
|
||||||
}'
|
}'
|
||||||
```
|
```
|
||||||
|
|
||||||
|
@ -827,8 +827,8 @@ Copy a model. Creates a model with another name from an existing model.
|
||||||
|
|
||||||
```shell
|
```shell
|
||||||
curl http://localhost:11434/api/copy -d '{
|
curl http://localhost:11434/api/copy -d '{
|
||||||
"source": "llama2",
|
"source": "llama3",
|
||||||
"destination": "llama2-backup"
|
"destination": "llama3-backup"
|
||||||
}'
|
}'
|
||||||
```
|
```
|
||||||
|
|
||||||
|
@ -854,7 +854,7 @@ Delete a model and its data.
|
||||||
|
|
||||||
```shell
|
```shell
|
||||||
curl -X DELETE http://localhost:11434/api/delete -d '{
|
curl -X DELETE http://localhost:11434/api/delete -d '{
|
||||||
"name": "llama2:13b"
|
"name": "llama3:13b"
|
||||||
}'
|
}'
|
||||||
```
|
```
|
||||||
|
|
||||||
|
@ -882,7 +882,7 @@ Download a model from the ollama library. Cancelled pulls are resumed from where
|
||||||
|
|
||||||
```shell
|
```shell
|
||||||
curl http://localhost:11434/api/pull -d '{
|
curl http://localhost:11434/api/pull -d '{
|
||||||
"name": "llama2"
|
"name": "llama3"
|
||||||
}'
|
}'
|
||||||
```
|
```
|
||||||
|
|
||||||
|
|
|
@ -51,7 +51,7 @@ Typically the build scripts will auto-detect CUDA, however, if your Linux distro
|
||||||
or installation approach uses unusual paths, you can specify the location by
|
or installation approach uses unusual paths, you can specify the location by
|
||||||
specifying an environment variable `CUDA_LIB_DIR` to the location of the shared
|
specifying an environment variable `CUDA_LIB_DIR` to the location of the shared
|
||||||
libraries, and `CUDACXX` to the location of the nvcc compiler. You can customize
|
libraries, and `CUDACXX` to the location of the nvcc compiler. You can customize
|
||||||
set set of target CUDA architectues by setting `CMAKE_CUDA_ARCHITECTURES` (e.g. "50;60;70")
|
a set of target CUDA architectures by setting `CMAKE_CUDA_ARCHITECTURES` (e.g. "50;60;70")
|
||||||
|
|
||||||
Then generate dependencies:
|
Then generate dependencies:
|
||||||
|
|
||||||
|
|
18
docs/faq.md
18
docs/faq.md
|
@ -32,7 +32,7 @@ When using the API, specify the `num_ctx` parameter:
|
||||||
|
|
||||||
```
|
```
|
||||||
curl http://localhost:11434/api/generate -d '{
|
curl http://localhost:11434/api/generate -d '{
|
||||||
"model": "llama2",
|
"model": "llama3",
|
||||||
"prompt": "Why is the sky blue?",
|
"prompt": "Why is the sky blue?",
|
||||||
"options": {
|
"options": {
|
||||||
"num_ctx": 4096
|
"num_ctx": 4096
|
||||||
|
@ -140,7 +140,7 @@ Refer to the section [above](#how-do-i-configure-ollama-server) for how to set e
|
||||||
|
|
||||||
- macOS: `~/.ollama/models`
|
- macOS: `~/.ollama/models`
|
||||||
- Linux: `/usr/share/ollama/.ollama/models`
|
- Linux: `/usr/share/ollama/.ollama/models`
|
||||||
- Windows: `C:\Users\<username>\.ollama\models`
|
- Windows: `C:\Users\%username%\.ollama\models`
|
||||||
|
|
||||||
### How do I set them to a different location?
|
### How do I set them to a different location?
|
||||||
|
|
||||||
|
@ -221,10 +221,20 @@ The `keep_alive` parameter can be set to:
|
||||||
|
|
||||||
For example, to preload a model and leave it in memory use:
|
For example, to preload a model and leave it in memory use:
|
||||||
```shell
|
```shell
|
||||||
curl http://localhost:11434/api/generate -d '{"model": "llama2", "keep_alive": -1}'
|
curl http://localhost:11434/api/generate -d '{"model": "llama3", "keep_alive": -1}'
|
||||||
```
|
```
|
||||||
|
|
||||||
To unload the model and free up memory use:
|
To unload the model and free up memory use:
|
||||||
```shell
|
```shell
|
||||||
curl http://localhost:11434/api/generate -d '{"model": "llama2", "keep_alive": 0}'
|
curl http://localhost:11434/api/generate -d '{"model": "llama3", "keep_alive": 0}'
|
||||||
```
|
```
|
||||||
|
|
||||||
|
Alternatively, you can change the amount of time all models are loaded into memory by setting the `OLLAMA_KEEP_ALIVE` environment variable when starting the Ollama server. The `OLLAMA_KEEP_ALIVE` variable uses the same parameter types as the `keep_alive` parameter types mentioned above. Refer to section explaining [how to configure the Ollama server](#how-do-i-configure-ollama-server) to correctly set the environment variable.
|
||||||
|
|
||||||
|
If you wish to override the `OLLAMA_KEEP_ALIVE` setting, use the `keep_alive` API parameter with the `/api/generate` or `/api/chat` API endpoints.
|
||||||
|
|
||||||
|
## How do I manage the maximum number of requests the server can queue
|
||||||
|
|
||||||
|
If too many requests are sent to the server, it will respond with a 503 error
|
||||||
|
indicating the server is overloaded. You can adjust how many requests may be
|
||||||
|
queue by setting `OLLAMA_MAX_QUEUE`
|
|
@ -125,7 +125,7 @@ Publishing models is in early alpha. If you'd like to publish your model to shar
|
||||||
|
|
||||||
1. Create [an account](https://ollama.com/signup)
|
1. Create [an account](https://ollama.com/signup)
|
||||||
2. Copy your Ollama public key:
|
2. Copy your Ollama public key:
|
||||||
- macOS: `cat ~/.ollama/id_ed25519.pub`
|
- macOS: `cat ~/.ollama/id_ed25519.pub | pbcopy`
|
||||||
- Windows: `type %USERPROFILE%\.ollama\id_ed25519.pub`
|
- Windows: `type %USERPROFILE%\.ollama\id_ed25519.pub`
|
||||||
- Linux: `cat /usr/share/ollama/.ollama/id_ed25519.pub`
|
- Linux: `cat /usr/share/ollama/.ollama/id_ed25519.pub`
|
||||||
3. Add your public key to your [Ollama account](https://ollama.com/settings/keys)
|
3. Add your public key to your [Ollama account](https://ollama.com/settings/keys)
|
||||||
|
@ -136,6 +136,8 @@ Next, copy your model to your username's namespace:
|
||||||
ollama cp example <your username>/example
|
ollama cp example <your username>/example
|
||||||
```
|
```
|
||||||
|
|
||||||
|
> Note: model names may only contain lowercase letters, digits, and the characters `.`, `-`, and `_`.
|
||||||
|
|
||||||
Then push the model:
|
Then push the model:
|
||||||
|
|
||||||
```
|
```
|
||||||
|
|
|
@ -105,7 +105,7 @@ sudo chmod +x /usr/bin/ollama
|
||||||
To view logs of Ollama running as a startup service, run:
|
To view logs of Ollama running as a startup service, run:
|
||||||
|
|
||||||
```bash
|
```bash
|
||||||
journalctl -u ollama
|
journalctl -e -u ollama
|
||||||
```
|
```
|
||||||
|
|
||||||
## Uninstall
|
## Uninstall
|
||||||
|
|
|
@ -10,7 +10,7 @@ A model file is the blueprint to create and share models with Ollama.
|
||||||
- [Examples](#examples)
|
- [Examples](#examples)
|
||||||
- [Instructions](#instructions)
|
- [Instructions](#instructions)
|
||||||
- [FROM (Required)](#from-required)
|
- [FROM (Required)](#from-required)
|
||||||
- [Build from llama2](#build-from-llama2)
|
- [Build from llama3](#build-from-llama3)
|
||||||
- [Build from a bin file](#build-from-a-bin-file)
|
- [Build from a bin file](#build-from-a-bin-file)
|
||||||
- [PARAMETER](#parameter)
|
- [PARAMETER](#parameter)
|
||||||
- [Valid Parameters and Values](#valid-parameters-and-values)
|
- [Valid Parameters and Values](#valid-parameters-and-values)
|
||||||
|
@ -48,7 +48,7 @@ INSTRUCTION arguments
|
||||||
An example of a `Modelfile` creating a mario blueprint:
|
An example of a `Modelfile` creating a mario blueprint:
|
||||||
|
|
||||||
```modelfile
|
```modelfile
|
||||||
FROM llama2
|
FROM llama3
|
||||||
# sets the temperature to 1 [higher is more creative, lower is more coherent]
|
# sets the temperature to 1 [higher is more creative, lower is more coherent]
|
||||||
PARAMETER temperature 1
|
PARAMETER temperature 1
|
||||||
# sets the context window size to 4096, this controls how many tokens the LLM can use as context to generate the next token
|
# sets the context window size to 4096, this controls how many tokens the LLM can use as context to generate the next token
|
||||||
|
@ -67,33 +67,25 @@ To use this:
|
||||||
|
|
||||||
More examples are available in the [examples directory](../examples).
|
More examples are available in the [examples directory](../examples).
|
||||||
|
|
||||||
### `Modelfile`s in [ollama.com/library][1]
|
To view the Modelfile of a given model, use the `ollama show --modelfile` command.
|
||||||
|
|
||||||
There are two ways to view `Modelfile`s underlying the models in [ollama.com/library][1]:
|
|
||||||
|
|
||||||
- Option 1: view a details page from a model's tags page:
|
|
||||||
1. Go to a particular model's tags (e.g. https://ollama.com/library/llama2/tags)
|
|
||||||
2. Click on a tag (e.g. https://ollama.com/library/llama2:13b)
|
|
||||||
3. Scroll down to "Layers"
|
|
||||||
- Note: if the [`FROM` instruction](#from-required) is not present,
|
|
||||||
it means the model was created from a local file
|
|
||||||
- Option 2: use `ollama show` to print the `Modelfile` for any local models like so:
|
|
||||||
|
|
||||||
```bash
|
```bash
|
||||||
> ollama show --modelfile llama2:13b
|
> ollama show --modelfile llama3
|
||||||
# Modelfile generated by "ollama show"
|
# Modelfile generated by "ollama show"
|
||||||
# To build a new Modelfile based on this one, replace the FROM line with:
|
# To build a new Modelfile based on this one, replace the FROM line with:
|
||||||
# FROM llama2:13b
|
# FROM llama3:latest
|
||||||
|
FROM /Users/pdevine/.ollama/models/blobs/sha256-00e1317cbf74d901080d7100f57580ba8dd8de57203072dc6f668324ba545f29
|
||||||
|
TEMPLATE """{{ if .System }}<|start_header_id|>system<|end_header_id|>
|
||||||
|
|
||||||
FROM /root/.ollama/models/blobs/sha256:123abc
|
{{ .System }}<|eot_id|>{{ end }}{{ if .Prompt }}<|start_header_id|>user<|end_header_id|>
|
||||||
TEMPLATE """[INST] {{ if .System }}<<SYS>>{{ .System }}<</SYS>>
|
|
||||||
|
|
||||||
{{ end }}{{ .Prompt }} [/INST] """
|
{{ .Prompt }}<|eot_id|>{{ end }}<|start_header_id|>assistant<|end_header_id|>
|
||||||
SYSTEM """"""
|
|
||||||
PARAMETER stop [INST]
|
{{ .Response }}<|eot_id|>"""
|
||||||
PARAMETER stop [/INST]
|
PARAMETER stop "<|start_header_id|>"
|
||||||
PARAMETER stop <<SYS>>
|
PARAMETER stop "<|end_header_id|>"
|
||||||
PARAMETER stop <</SYS>>
|
PARAMETER stop "<|eot_id|>"
|
||||||
|
PARAMETER stop "<|reserved_special_token"
|
||||||
```
|
```
|
||||||
|
|
||||||
## Instructions
|
## Instructions
|
||||||
|
@ -106,10 +98,10 @@ The `FROM` instruction defines the base model to use when creating a model.
|
||||||
FROM <model name>:<tag>
|
FROM <model name>:<tag>
|
||||||
```
|
```
|
||||||
|
|
||||||
#### Build from llama2
|
#### Build from llama3
|
||||||
|
|
||||||
```modelfile
|
```modelfile
|
||||||
FROM llama2
|
FROM llama3
|
||||||
```
|
```
|
||||||
|
|
||||||
A list of available base models:
|
A list of available base models:
|
||||||
|
|
|
@ -25,7 +25,7 @@ chat_completion = client.chat.completions.create(
|
||||||
'content': 'Say this is a test',
|
'content': 'Say this is a test',
|
||||||
}
|
}
|
||||||
],
|
],
|
||||||
model='llama2',
|
model='llama3',
|
||||||
)
|
)
|
||||||
```
|
```
|
||||||
|
|
||||||
|
@ -43,7 +43,7 @@ const openai = new OpenAI({
|
||||||
|
|
||||||
const chatCompletion = await openai.chat.completions.create({
|
const chatCompletion = await openai.chat.completions.create({
|
||||||
messages: [{ role: 'user', content: 'Say this is a test' }],
|
messages: [{ role: 'user', content: 'Say this is a test' }],
|
||||||
model: 'llama2',
|
model: 'llama3',
|
||||||
})
|
})
|
||||||
```
|
```
|
||||||
|
|
||||||
|
@ -53,7 +53,7 @@ const chatCompletion = await openai.chat.completions.create({
|
||||||
curl http://localhost:11434/v1/chat/completions \
|
curl http://localhost:11434/v1/chat/completions \
|
||||||
-H "Content-Type: application/json" \
|
-H "Content-Type: application/json" \
|
||||||
-d '{
|
-d '{
|
||||||
"model": "llama2",
|
"model": "llama3",
|
||||||
"messages": [
|
"messages": [
|
||||||
{
|
{
|
||||||
"role": "system",
|
"role": "system",
|
||||||
|
@ -113,7 +113,7 @@ curl http://localhost:11434/v1/chat/completions \
|
||||||
Before using a model, pull it locally `ollama pull`:
|
Before using a model, pull it locally `ollama pull`:
|
||||||
|
|
||||||
```shell
|
```shell
|
||||||
ollama pull llama2
|
ollama pull llama3
|
||||||
```
|
```
|
||||||
|
|
||||||
### Default model names
|
### Default model names
|
||||||
|
@ -121,7 +121,7 @@ ollama pull llama2
|
||||||
For tooling that relies on default OpenAI model names such as `gpt-3.5-turbo`, use `ollama cp` to copy an existing model name to a temporary name:
|
For tooling that relies on default OpenAI model names such as `gpt-3.5-turbo`, use `ollama cp` to copy an existing model name to a temporary name:
|
||||||
|
|
||||||
```
|
```
|
||||||
ollama cp llama2 gpt-3.5-turbo
|
ollama cp llama3 gpt-3.5-turbo
|
||||||
```
|
```
|
||||||
|
|
||||||
Afterwards, this new model name can be specified the `model` field:
|
Afterwards, this new model name can be specified the `model` field:
|
||||||
|
|
|
@ -15,7 +15,7 @@ import { Ollama } from "langchain/llms/ollama";
|
||||||
|
|
||||||
const ollama = new Ollama({
|
const ollama = new Ollama({
|
||||||
baseUrl: "http://localhost:11434",
|
baseUrl: "http://localhost:11434",
|
||||||
model: "llama2",
|
model: "llama3",
|
||||||
});
|
});
|
||||||
|
|
||||||
const answer = await ollama.invoke(`why is the sky blue?`);
|
const answer = await ollama.invoke(`why is the sky blue?`);
|
||||||
|
@ -23,7 +23,7 @@ const answer = await ollama.invoke(`why is the sky blue?`);
|
||||||
console.log(answer);
|
console.log(answer);
|
||||||
```
|
```
|
||||||
|
|
||||||
That will get us the same thing as if we ran `ollama run llama2 "why is the sky blue"` in the terminal. But we want to load a document from the web to ask a question against. **Cheerio** is a great library for ingesting a webpage, and **LangChain** uses it in their **CheerioWebBaseLoader**. So let's install **Cheerio** and build that part of the app.
|
That will get us the same thing as if we ran `ollama run llama3 "why is the sky blue"` in the terminal. But we want to load a document from the web to ask a question against. **Cheerio** is a great library for ingesting a webpage, and **LangChain** uses it in their **CheerioWebBaseLoader**. So let's install **Cheerio** and build that part of the app.
|
||||||
|
|
||||||
```bash
|
```bash
|
||||||
npm install cheerio
|
npm install cheerio
|
||||||
|
|
|
@ -12,15 +12,17 @@ So let's figure out how we can use **LangChain** with Ollama to ask our question
|
||||||
|
|
||||||
Let's start by asking a simple question that we can get an answer to from the **Llama2** model using **Ollama**. First, we need to install the **LangChain** package:
|
Let's start by asking a simple question that we can get an answer to from the **Llama2** model using **Ollama**. First, we need to install the **LangChain** package:
|
||||||
|
|
||||||
`pip install langchain`
|
`pip install langchain_community`
|
||||||
|
|
||||||
Then we can create a model and ask the question:
|
Then we can create a model and ask the question:
|
||||||
|
|
||||||
```python
|
```python
|
||||||
from langchain.llms import Ollama
|
from langchain_community.llms import Ollama
|
||||||
ollama = Ollama(base_url='http://localhost:11434',
|
ollama = Ollama(
|
||||||
model="llama2")
|
base_url='http://localhost:11434',
|
||||||
print(ollama("why is the sky blue"))
|
model="llama3"
|
||||||
|
)
|
||||||
|
print(ollama.invoke("why is the sky blue"))
|
||||||
```
|
```
|
||||||
|
|
||||||
Notice that we are defining the model and the base URL for Ollama.
|
Notice that we are defining the model and the base URL for Ollama.
|
||||||
|
|
|
@ -1,38 +1,15 @@
|
||||||
# Running Ollama on NVIDIA Jetson Devices
|
# Running Ollama on NVIDIA Jetson Devices
|
||||||
|
|
||||||
With some minor configuration, Ollama runs well on [NVIDIA Jetson Devices](https://www.nvidia.com/en-us/autonomous-machines/embedded-systems/). The following has been tested on [JetPack 5.1.2](https://developer.nvidia.com/embedded/jetpack).
|
Ollama runs well on [NVIDIA Jetson Devices](https://www.nvidia.com/en-us/autonomous-machines/embedded-systems/) and should run out of the box with the standard installation instructions.
|
||||||
|
|
||||||
NVIDIA Jetson devices are Linux-based embedded AI computers that are purpose-built for AI applications.
|
The following has been tested on [JetPack 5.1.2](https://developer.nvidia.com/embedded/jetpack), but should also work on JetPack 6.0.
|
||||||
|
|
||||||
Jetsons have an integrated GPU that is wired directly to the memory controller of the machine. For this reason, the `nvidia-smi` command is unrecognized, and Ollama proceeds to operate in "CPU only"
|
|
||||||
mode. This can be verified by using a monitoring tool like jtop.
|
|
||||||
|
|
||||||
In order to address this, we simply pass the path to the Jetson's pre-installed CUDA libraries into `ollama serve` (while in a tmux session). We then hardcode the num_gpu parameters into a cloned
|
|
||||||
version of our target model.
|
|
||||||
|
|
||||||
Prerequisites:
|
|
||||||
|
|
||||||
- curl
|
|
||||||
- tmux
|
|
||||||
|
|
||||||
Here are the steps:
|
|
||||||
|
|
||||||
- Install Ollama via standard Linux command (ignore the 404 error): `curl https://ollama.com/install.sh | sh`
|
- Install Ollama via standard Linux command (ignore the 404 error): `curl https://ollama.com/install.sh | sh`
|
||||||
- Stop the Ollama service: `sudo systemctl stop ollama`
|
|
||||||
- Start Ollama serve in a tmux session called ollama_jetson and reference the CUDA libraries path: `tmux has-session -t ollama_jetson 2>/dev/null || tmux new-session -d -s ollama_jetson
|
|
||||||
'LD_LIBRARY_PATH=/usr/local/cuda/lib64 ollama serve'`
|
|
||||||
- Pull the model you want to use (e.g. mistral): `ollama pull mistral`
|
- Pull the model you want to use (e.g. mistral): `ollama pull mistral`
|
||||||
- Create a new Modelfile specifically for enabling GPU support on the Jetson: `touch ModelfileMistralJetson`
|
- Start an interactive session: `ollama run mistral`
|
||||||
- In the ModelfileMistralJetson file, specify the FROM model and the num_gpu PARAMETER as shown below:
|
|
||||||
|
|
||||||
```
|
|
||||||
FROM mistral
|
|
||||||
PARAMETER num_gpu 999
|
|
||||||
```
|
|
||||||
|
|
||||||
- Create a new model from your Modelfile: `ollama create mistral-jetson -f ./ModelfileMistralJetson`
|
|
||||||
- Run the new model: `ollama run mistral-jetson`
|
|
||||||
|
|
||||||
If you run a monitoring tool like jtop you should now see that Ollama is using the Jetson's integrated GPU.
|
|
||||||
|
|
||||||
And that's it!
|
And that's it!
|
||||||
|
|
||||||
|
# Running Ollama in Docker
|
||||||
|
|
||||||
|
When running GPU accelerated applications in Docker, it is highly recommended to use [dusty-nv jetson-containers repo](https://github.com/dusty-nv/jetson-containers).
|
|
@ -14,7 +14,7 @@ As this is a preview release, you should expect a few bugs here and there. If
|
||||||
you run into a problem you can reach out on
|
you run into a problem you can reach out on
|
||||||
[Discord](https://discord.gg/ollama), or file an
|
[Discord](https://discord.gg/ollama), or file an
|
||||||
[issue](https://github.com/ollama/ollama/issues).
|
[issue](https://github.com/ollama/ollama/issues).
|
||||||
Logs will often be helpful in dianosing the problem (see
|
Logs will often be helpful in diagnosing the problem (see
|
||||||
[Troubleshooting](#troubleshooting) below)
|
[Troubleshooting](#troubleshooting) below)
|
||||||
|
|
||||||
## System Requirements
|
## System Requirements
|
||||||
|
@ -27,7 +27,7 @@ Logs will often be helpful in dianosing the problem (see
|
||||||
|
|
||||||
Here's a quick example showing API access from `powershell`
|
Here's a quick example showing API access from `powershell`
|
||||||
```powershell
|
```powershell
|
||||||
(Invoke-WebRequest -method POST -Body '{"model":"llama2", "prompt":"Why is the sky blue?", "stream": false}' -uri http://localhost:11434/api/generate ).Content | ConvertFrom-json
|
(Invoke-WebRequest -method POST -Body '{"model":"llama3", "prompt":"Why is the sky blue?", "stream": false}' -uri http://localhost:11434/api/generate ).Content | ConvertFrom-json
|
||||||
```
|
```
|
||||||
|
|
||||||
## Troubleshooting
|
## Troubleshooting
|
||||||
|
@ -45,3 +45,17 @@ the explorer window by hitting `<cmd>+R` and type in:
|
||||||
- `explorer %LOCALAPPDATA%\Programs\Ollama` contains the binaries (The installer adds this to your user PATH)
|
- `explorer %LOCALAPPDATA%\Programs\Ollama` contains the binaries (The installer adds this to your user PATH)
|
||||||
- `explorer %HOMEPATH%\.ollama` contains models and configuration
|
- `explorer %HOMEPATH%\.ollama` contains models and configuration
|
||||||
- `explorer %TEMP%` contains temporary executable files in one or more `ollama*` directories
|
- `explorer %TEMP%` contains temporary executable files in one or more `ollama*` directories
|
||||||
|
|
||||||
|
|
||||||
|
## Standalone CLI
|
||||||
|
|
||||||
|
The easiest way to install Ollama on Windows is to use the `OllamaSetup.exe`
|
||||||
|
installer. It installs in your account without requiring Administrator rights.
|
||||||
|
We update Ollama regularly to support the latest models, and this installer will
|
||||||
|
help you keep up to date.
|
||||||
|
|
||||||
|
If you'd like to install or integrate Ollama as a service, a standalone
|
||||||
|
`ollama-windows-amd64.zip` zip file is available containing only the Ollama CLI
|
||||||
|
and GPU library dependencies for Nvidia and AMD. This allows for embedding
|
||||||
|
Ollama in existing applications, or running it as a system service via `ollama
|
||||||
|
serve` with tools such as [NSSM](https://nssm.cc/).
|
||||||
|
|
|
@ -2,7 +2,7 @@
|
||||||
|
|
||||||
When calling `ollama`, you can pass it a file to run all the prompts in the file, one after the other:
|
When calling `ollama`, you can pass it a file to run all the prompts in the file, one after the other:
|
||||||
|
|
||||||
`ollama run llama2 < sourcequestions.txt`
|
`ollama run llama3 < sourcequestions.txt`
|
||||||
|
|
||||||
This concept is used in the following example.
|
This concept is used in the following example.
|
||||||
|
|
||||||
|
|
1
examples/flyio/.gitignore
vendored
Normal file
1
examples/flyio/.gitignore
vendored
Normal file
|
@ -0,0 +1 @@
|
||||||
|
fly.toml
|
67
examples/flyio/README.md
Normal file
67
examples/flyio/README.md
Normal file
|
@ -0,0 +1,67 @@
|
||||||
|
# Deploy Ollama to Fly.io
|
||||||
|
|
||||||
|
> Note: this example exposes a public endpoint and does not configure authentication. Use with care.
|
||||||
|
|
||||||
|
## Prerequisites
|
||||||
|
|
||||||
|
- Ollama: https://ollama.com/download
|
||||||
|
- Fly.io account. Sign up for a free account: https://fly.io/app/sign-up
|
||||||
|
|
||||||
|
## Steps
|
||||||
|
|
||||||
|
1. Login to Fly.io
|
||||||
|
|
||||||
|
```bash
|
||||||
|
fly auth login
|
||||||
|
```
|
||||||
|
|
||||||
|
1. Create a new Fly app
|
||||||
|
|
||||||
|
```bash
|
||||||
|
fly launch --name <name> --image ollama/ollama --internal-port 11434 --vm-size shared-cpu-8x --now
|
||||||
|
```
|
||||||
|
|
||||||
|
1. Pull and run `orca-mini:3b`
|
||||||
|
|
||||||
|
```bash
|
||||||
|
OLLAMA_HOST=https://<name>.fly.dev ollama run orca-mini:3b
|
||||||
|
```
|
||||||
|
|
||||||
|
`shared-cpu-8x` is a free-tier eligible machine type. For better performance, switch to a `performance` or `dedicated` machine type or attach a GPU for hardware acceleration (see below).
|
||||||
|
|
||||||
|
## (Optional) Persistent Volume
|
||||||
|
|
||||||
|
By default Fly Machines use ephemeral storage which is problematic if you want to use the same model across restarts without pulling it again. Create and attach a persistent volume to store the downloaded models:
|
||||||
|
|
||||||
|
1. Create the Fly Volume
|
||||||
|
|
||||||
|
```bash
|
||||||
|
fly volume create ollama
|
||||||
|
```
|
||||||
|
|
||||||
|
1. Update `fly.toml` and add `[mounts]`
|
||||||
|
|
||||||
|
```toml
|
||||||
|
[mounts]
|
||||||
|
source = "ollama"
|
||||||
|
destination = "/mnt/ollama/models"
|
||||||
|
```
|
||||||
|
|
||||||
|
1. Update `fly.toml` and add `[env]`
|
||||||
|
|
||||||
|
```toml
|
||||||
|
[env]
|
||||||
|
OLLAMA_MODELS = "/mnt/ollama/models"
|
||||||
|
```
|
||||||
|
|
||||||
|
1. Deploy your app
|
||||||
|
|
||||||
|
```bash
|
||||||
|
fly deploy
|
||||||
|
```
|
||||||
|
|
||||||
|
## (Optional) Hardware Acceleration
|
||||||
|
|
||||||
|
Fly.io GPU is currently in waitlist. Sign up for the waitlist: https://fly.io/gpu
|
||||||
|
|
||||||
|
Once you've been accepted, create the app with the additional flags `--vm-gpu-kind a100-pcie-40gb` or `--vm-gpu-kind a100-pcie-80gb`.
|
|
@ -35,7 +35,7 @@ func main() {
|
||||||
|
|
||||||
ctx := context.Background()
|
ctx := context.Background()
|
||||||
req := &api.ChatRequest{
|
req := &api.ChatRequest{
|
||||||
Model: "llama2",
|
Model: "llama3",
|
||||||
Messages: messages,
|
Messages: messages,
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -7,12 +7,24 @@
|
||||||
|
|
||||||
## Steps
|
## Steps
|
||||||
|
|
||||||
1. Create the Ollama namespace, daemon set, and service
|
1. Create the Ollama namespace, deployment, and service
|
||||||
|
|
||||||
```bash
|
```bash
|
||||||
kubectl apply -f cpu.yaml
|
kubectl apply -f cpu.yaml
|
||||||
```
|
```
|
||||||
|
|
||||||
|
## (Optional) Hardware Acceleration
|
||||||
|
|
||||||
|
Hardware acceleration in Kubernetes requires NVIDIA's [`k8s-device-plugin`](https://github.com/NVIDIA/k8s-device-plugin) which is deployed in Kubernetes in form of daemonset. Follow the link for more details.
|
||||||
|
|
||||||
|
Once configured, create a GPU enabled Ollama deployment.
|
||||||
|
|
||||||
|
```bash
|
||||||
|
kubectl apply -f gpu.yaml
|
||||||
|
```
|
||||||
|
|
||||||
|
## Test
|
||||||
|
|
||||||
1. Port forward the Ollama service to connect and use it locally
|
1. Port forward the Ollama service to connect and use it locally
|
||||||
|
|
||||||
```bash
|
```bash
|
||||||
|
@ -24,13 +36,3 @@
|
||||||
```bash
|
```bash
|
||||||
ollama run orca-mini:3b
|
ollama run orca-mini:3b
|
||||||
```
|
```
|
||||||
|
|
||||||
## (Optional) Hardware Acceleration
|
|
||||||
|
|
||||||
Hardware acceleration in Kubernetes requires NVIDIA's [`k8s-device-plugin`](https://github.com/NVIDIA/k8s-device-plugin). Follow the link for more details.
|
|
||||||
|
|
||||||
Once configured, create a GPU enabled Ollama deployment.
|
|
||||||
|
|
||||||
```bash
|
|
||||||
kubectl apply -f gpu.yaml
|
|
||||||
```
|
|
||||||
|
|
|
@ -51,7 +51,7 @@ while True:
|
||||||
template=template,
|
template=template,
|
||||||
)
|
)
|
||||||
|
|
||||||
llm = Ollama(model="llama2:13b", callback_manager=CallbackManager([StreamingStdOutCallbackHandler()]))
|
llm = Ollama(model="llama3:8b", callback_manager=CallbackManager([StreamingStdOutCallbackHandler()]))
|
||||||
qa_chain = RetrievalQA.from_chain_type(
|
qa_chain = RetrievalQA.from_chain_type(
|
||||||
llm,
|
llm,
|
||||||
retriever=vectorstore.as_retriever(),
|
retriever=vectorstore.as_retriever(),
|
||||||
|
|
|
@ -1,12 +1,12 @@
|
||||||
from langchain.llms import Ollama
|
from langchain_community.llms import Ollama
|
||||||
from langchain.document_loaders import WebBaseLoader
|
from langchain_community.document_loaders import WebBaseLoader
|
||||||
from langchain.chains.summarize import load_summarize_chain
|
from langchain.chains.summarize import load_summarize_chain
|
||||||
|
|
||||||
loader = WebBaseLoader("https://ollama.com/blog/run-llama2-uncensored-locally")
|
loader = WebBaseLoader("https://ollama.com/blog/run-llama2-uncensored-locally")
|
||||||
docs = loader.load()
|
docs = loader.load()
|
||||||
|
|
||||||
llm = Ollama(model="llama2")
|
llm = Ollama(model="llama3")
|
||||||
chain = load_summarize_chain(llm, chain_type="stuff")
|
chain = load_summarize_chain(llm, chain_type="stuff")
|
||||||
|
|
||||||
result = chain.run(docs)
|
result = chain.invoke(docs)
|
||||||
print(result)
|
print(result)
|
||||||
|
|
|
@ -4,10 +4,10 @@ This example is a basic "hello world" of using LangChain with Ollama.
|
||||||
|
|
||||||
## Running the Example
|
## Running the Example
|
||||||
|
|
||||||
1. Ensure you have the `llama2` model installed:
|
1. Ensure you have the `llama3` model installed:
|
||||||
|
|
||||||
```bash
|
```bash
|
||||||
ollama pull llama2
|
ollama pull llama3
|
||||||
```
|
```
|
||||||
|
|
||||||
2. Install the Python Requirements.
|
2. Install the Python Requirements.
|
||||||
|
@ -21,4 +21,3 @@ This example is a basic "hello world" of using LangChain with Ollama.
|
||||||
```bash
|
```bash
|
||||||
python main.py
|
python main.py
|
||||||
```
|
```
|
||||||
|
|
|
@ -1,6 +1,6 @@
|
||||||
from langchain.llms import Ollama
|
from langchain.llms import Ollama
|
||||||
|
|
||||||
input = input("What is your question?")
|
input = input("What is your question?")
|
||||||
llm = Ollama(model="llama2")
|
llm = Ollama(model="llama3")
|
||||||
res = llm.predict(input)
|
res = llm.predict(input)
|
||||||
print (res)
|
print (res)
|
||||||
|
|
|
@ -1,4 +1,4 @@
|
||||||
FROM llama2
|
FROM llama3
|
||||||
PARAMETER temperature 1
|
PARAMETER temperature 1
|
||||||
SYSTEM """
|
SYSTEM """
|
||||||
You are Mario from super mario bros, acting as an assistant.
|
You are Mario from super mario bros, acting as an assistant.
|
||||||
|
|
|
@ -2,12 +2,12 @@
|
||||||
|
|
||||||
# Example character: Mario
|
# Example character: Mario
|
||||||
|
|
||||||
This example shows how to create a basic character using Llama2 as the base model.
|
This example shows how to create a basic character using Llama3 as the base model.
|
||||||
|
|
||||||
To run this example:
|
To run this example:
|
||||||
|
|
||||||
1. Download the Modelfile
|
1. Download the Modelfile
|
||||||
2. `ollama pull llama2` to get the base model used in the model file.
|
2. `ollama pull llama3` to get the base model used in the model file.
|
||||||
3. `ollama create NAME -f ./Modelfile`
|
3. `ollama create NAME -f ./Modelfile`
|
||||||
4. `ollama run NAME`
|
4. `ollama run NAME`
|
||||||
|
|
||||||
|
@ -18,7 +18,7 @@ Ask it some questions like "Who are you?" or "Is Peach in trouble again?"
|
||||||
What the model file looks like:
|
What the model file looks like:
|
||||||
|
|
||||||
```
|
```
|
||||||
FROM llama2
|
FROM llama3
|
||||||
PARAMETER temperature 1
|
PARAMETER temperature 1
|
||||||
SYSTEM """
|
SYSTEM """
|
||||||
You are Mario from Super Mario Bros, acting as an assistant.
|
You are Mario from Super Mario Bros, acting as an assistant.
|
||||||
|
|
|
@ -2,7 +2,7 @@ import requests
|
||||||
import json
|
import json
|
||||||
import random
|
import random
|
||||||
|
|
||||||
model = "llama2"
|
model = "llama3"
|
||||||
template = {
|
template = {
|
||||||
"firstName": "",
|
"firstName": "",
|
||||||
"lastName": "",
|
"lastName": "",
|
||||||
|
|
|
@ -12,7 +12,7 @@ countries = [
|
||||||
"France",
|
"France",
|
||||||
]
|
]
|
||||||
country = random.choice(countries)
|
country = random.choice(countries)
|
||||||
model = "llama2"
|
model = "llama3"
|
||||||
|
|
||||||
prompt = f"generate one realistically believable sample data set of a persons first name, last name, address in {country}, and phone number. Do not use common names. Respond using JSON. Key names should have no backslashes, values should use plain ascii with no special characters."
|
prompt = f"generate one realistically believable sample data set of a persons first name, last name, address in {country}, and phone number. Do not use common names. Respond using JSON. Key names should have no backslashes, values should use plain ascii with no special characters."
|
||||||
|
|
||||||
|
|
|
@ -6,10 +6,10 @@ There are two python scripts in this example. `randomaddresses.py` generates ran
|
||||||
|
|
||||||
## Running the Example
|
## Running the Example
|
||||||
|
|
||||||
1. Ensure you have the `llama2` model installed:
|
1. Ensure you have the `llama3` model installed:
|
||||||
|
|
||||||
```bash
|
```bash
|
||||||
ollama pull llama2
|
ollama pull llama3
|
||||||
```
|
```
|
||||||
|
|
||||||
2. Install the Python Requirements.
|
2. Install the Python Requirements.
|
||||||
|
|
|
@ -2,7 +2,7 @@ import json
|
||||||
import requests
|
import requests
|
||||||
|
|
||||||
# NOTE: ollama must be running for this to work, start the ollama app or run `ollama serve`
|
# NOTE: ollama must be running for this to work, start the ollama app or run `ollama serve`
|
||||||
model = "llama2" # TODO: update this for whatever model you wish to use
|
model = "llama3" # TODO: update this for whatever model you wish to use
|
||||||
|
|
||||||
|
|
||||||
def chat(messages):
|
def chat(messages):
|
||||||
|
|
|
@ -4,10 +4,10 @@ The **chat** endpoint is one of two ways to generate text from an LLM with Ollam
|
||||||
|
|
||||||
## Running the Example
|
## Running the Example
|
||||||
|
|
||||||
1. Ensure you have the `llama2` model installed:
|
1. Ensure you have the `llama3` model installed:
|
||||||
|
|
||||||
```bash
|
```bash
|
||||||
ollama pull llama2
|
ollama pull llama3
|
||||||
```
|
```
|
||||||
|
|
||||||
2. Install the Python Requirements.
|
2. Install the Python Requirements.
|
||||||
|
|
|
@ -4,10 +4,10 @@ This example demonstrates how one would create a set of 'mentors' you can have a
|
||||||
|
|
||||||
## Usage
|
## Usage
|
||||||
|
|
||||||
1. Add llama2 to have the mentors ask your questions:
|
1. Add llama3 to have the mentors ask your questions:
|
||||||
|
|
||||||
```bash
|
```bash
|
||||||
ollama pull llama2
|
ollama pull llama3
|
||||||
```
|
```
|
||||||
|
|
||||||
2. Install prerequisites:
|
2. Install prerequisites:
|
||||||
|
|
|
@ -15,7 +15,7 @@ async function characterGenerator() {
|
||||||
ollama.setModel("stablebeluga2:70b-q4_K_M");
|
ollama.setModel("stablebeluga2:70b-q4_K_M");
|
||||||
const bio = await ollama.generate(`create a bio of ${character} in a single long paragraph. Instead of saying '${character} is...' or '${character} was...' use language like 'You are...' or 'You were...'. Then create a paragraph describing the speaking mannerisms and style of ${character}. Don't include anything about how ${character} looked or what they sounded like, just focus on the words they said. Instead of saying '${character} would say...' use language like 'You should say...'. If you use quotes, always use single quotes instead of double quotes. If there are any specific words or phrases you used a lot, show how you used them. `);
|
const bio = await ollama.generate(`create a bio of ${character} in a single long paragraph. Instead of saying '${character} is...' or '${character} was...' use language like 'You are...' or 'You were...'. Then create a paragraph describing the speaking mannerisms and style of ${character}. Don't include anything about how ${character} looked or what they sounded like, just focus on the words they said. Instead of saying '${character} would say...' use language like 'You should say...'. If you use quotes, always use single quotes instead of double quotes. If there are any specific words or phrases you used a lot, show how you used them. `);
|
||||||
|
|
||||||
const thecontents = `FROM llama2\nSYSTEM """\n${bio.response.replace(/(\r\n|\n|\r)/gm, " ").replace('would', 'should')} All answers to questions should be related back to what you are most known for.\n"""`;
|
const thecontents = `FROM llama3\nSYSTEM """\n${bio.response.replace(/(\r\n|\n|\r)/gm, " ").replace('would', 'should')} All answers to questions should be related back to what you are most known for.\n"""`;
|
||||||
|
|
||||||
fs.writeFile(path.join(directory, 'Modelfile'), thecontents, (err: any) => {
|
fs.writeFile(path.join(directory, 'Modelfile'), thecontents, (err: any) => {
|
||||||
if (err) throw err;
|
if (err) throw err;
|
||||||
|
|
|
@ -1,6 +1,6 @@
|
||||||
import * as readline from "readline";
|
import * as readline from "readline";
|
||||||
|
|
||||||
const model = "llama2";
|
const model = "llama3";
|
||||||
type Message = {
|
type Message = {
|
||||||
role: "assistant" | "user" | "system";
|
role: "assistant" | "user" | "system";
|
||||||
content: string;
|
content: string;
|
||||||
|
|
|
@ -15,6 +15,7 @@ const (
|
||||||
|
|
||||||
KibiByte = Byte * 1024
|
KibiByte = Byte * 1024
|
||||||
MebiByte = KibiByte * 1024
|
MebiByte = KibiByte * 1024
|
||||||
|
GibiByte = MebiByte * 1024
|
||||||
)
|
)
|
||||||
|
|
||||||
func HumanBytes(b int64) string {
|
func HumanBytes(b int64) string {
|
||||||
|
@ -52,6 +53,8 @@ func HumanBytes(b int64) string {
|
||||||
|
|
||||||
func HumanBytes2(b uint64) string {
|
func HumanBytes2(b uint64) string {
|
||||||
switch {
|
switch {
|
||||||
|
case b >= GibiByte:
|
||||||
|
return fmt.Sprintf("%.1f GiB", float64(b)/GibiByte)
|
||||||
case b >= MebiByte:
|
case b >= MebiByte:
|
||||||
return fmt.Sprintf("%.1f MiB", float64(b)/MebiByte)
|
return fmt.Sprintf("%.1f MiB", float64(b)/MebiByte)
|
||||||
case b >= KibiByte:
|
case b >= KibiByte:
|
||||||
|
|
|
@ -13,12 +13,20 @@ const (
|
||||||
|
|
||||||
func HumanNumber(b uint64) string {
|
func HumanNumber(b uint64) string {
|
||||||
switch {
|
switch {
|
||||||
case b > Billion:
|
case b >= Billion:
|
||||||
return fmt.Sprintf("%.0fB", math.Round(float64(b)/Billion))
|
number := float64(b) / Billion
|
||||||
case b > Million:
|
if number == math.Floor(number) {
|
||||||
return fmt.Sprintf("%.0fM", math.Round(float64(b)/Million))
|
return fmt.Sprintf("%.0fB", number) // no decimals if whole number
|
||||||
case b > Thousand:
|
}
|
||||||
return fmt.Sprintf("%.0fK", math.Round(float64(b)/Thousand))
|
return fmt.Sprintf("%.1fB", number) // one decimal if not a whole number
|
||||||
|
case b >= Million:
|
||||||
|
number := float64(b) / Million
|
||||||
|
if number == math.Floor(number) {
|
||||||
|
return fmt.Sprintf("%.0fM", number) // no decimals if whole number
|
||||||
|
}
|
||||||
|
return fmt.Sprintf("%.2fM", number) // two decimals if not a whole number
|
||||||
|
case b >= Thousand:
|
||||||
|
return fmt.Sprintf("%.0fK", float64(b)/Thousand)
|
||||||
default:
|
default:
|
||||||
return fmt.Sprintf("%d", b)
|
return fmt.Sprintf("%d", b)
|
||||||
}
|
}
|
||||||
|
|
34
format/format_test.go
Normal file
34
format/format_test.go
Normal file
|
@ -0,0 +1,34 @@
|
||||||
|
package format
|
||||||
|
|
||||||
|
import (
|
||||||
|
"testing"
|
||||||
|
)
|
||||||
|
|
||||||
|
func TestHumanNumber(t *testing.T) {
|
||||||
|
|
||||||
|
type testCase struct {
|
||||||
|
input uint64
|
||||||
|
expected string
|
||||||
|
}
|
||||||
|
|
||||||
|
testCases := []testCase{
|
||||||
|
{0, "0"},
|
||||||
|
{1000000, "1M"},
|
||||||
|
{125000000, "125M"},
|
||||||
|
{500500000, "500.50M"},
|
||||||
|
{500550000, "500.55M"},
|
||||||
|
{1000000000, "1B"},
|
||||||
|
{2800000000, "2.8B"},
|
||||||
|
{2850000000, "2.9B"},
|
||||||
|
{1000000000000, "1000B"},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tc := range testCases {
|
||||||
|
t.Run(tc.expected, func(t *testing.T) {
|
||||||
|
result := HumanNumber(tc.input)
|
||||||
|
if result != tc.expected {
|
||||||
|
t.Errorf("Expected %s, got %s", tc.expected, result)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
|
@ -7,7 +7,7 @@ import (
|
||||||
"log/slog"
|
"log/slog"
|
||||||
"os"
|
"os"
|
||||||
"path/filepath"
|
"path/filepath"
|
||||||
"strconv"
|
"runtime"
|
||||||
"strings"
|
"strings"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@ -35,22 +35,66 @@ func GetSupportedGFX(libDir string) ([]string, error) {
|
||||||
return ret, nil
|
return ret, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func amdSetVisibleDevices(ids []int, skip map[int]interface{}) {
|
func rocmGetVisibleDevicesEnv(gpuInfo []GpuInfo) (string, string) {
|
||||||
// Set the visible devices if not already set
|
ids := []string{}
|
||||||
// TODO - does sort order matter?
|
for _, info := range gpuInfo {
|
||||||
devices := []string{}
|
if info.Library != "rocm" {
|
||||||
for i := range ids {
|
// TODO shouldn't happen if things are wired correctly...
|
||||||
if _, skipped := skip[i]; skipped {
|
slog.Debug("rocmGetVisibleDevicesEnv skipping over non-rocm device", "library", info.Library)
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
devices = append(devices, strconv.Itoa(i))
|
ids = append(ids, info.ID)
|
||||||
|
}
|
||||||
|
return "HIP_VISIBLE_DEVICES", strings.Join(ids, ",")
|
||||||
|
}
|
||||||
|
|
||||||
|
func commonAMDValidateLibDir() (string, error) {
|
||||||
|
// We try to favor system paths first, so that we can wire up the subprocess to use
|
||||||
|
// the system version. Only use our bundled version if the system version doesn't work
|
||||||
|
// This gives users a more recovery options if versions have subtle problems at runtime
|
||||||
|
|
||||||
|
// Prefer explicit HIP env var
|
||||||
|
hipPath := os.Getenv("HIP_PATH")
|
||||||
|
if hipPath != "" {
|
||||||
|
hipLibDir := filepath.Join(hipPath, "bin")
|
||||||
|
if rocmLibUsable(hipLibDir) {
|
||||||
|
slog.Debug("detected ROCM via HIP_PATH=" + hipPath)
|
||||||
|
return hipLibDir, nil
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
val := strings.Join(devices, ",")
|
// Scan the LD_LIBRARY_PATH or PATH
|
||||||
err := os.Setenv("HIP_VISIBLE_DEVICES", val)
|
pathEnv := "LD_LIBRARY_PATH"
|
||||||
if err != nil {
|
if runtime.GOOS == "windows" {
|
||||||
slog.Warn(fmt.Sprintf("failed to set env: %s", err))
|
pathEnv = "PATH"
|
||||||
} else {
|
|
||||||
slog.Info("Setting HIP_VISIBLE_DEVICES=" + val)
|
|
||||||
}
|
}
|
||||||
|
|
||||||
|
paths := os.Getenv(pathEnv)
|
||||||
|
for _, path := range filepath.SplitList(paths) {
|
||||||
|
d, err := filepath.Abs(path)
|
||||||
|
if err != nil {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
if rocmLibUsable(d) {
|
||||||
|
return d, nil
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Well known location(s)
|
||||||
|
for _, path := range RocmStandardLocations {
|
||||||
|
if rocmLibUsable(path) {
|
||||||
|
return path, nil
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Installer payload location if we're running the installed binary
|
||||||
|
exe, err := os.Executable()
|
||||||
|
if err == nil {
|
||||||
|
rocmTargetDir := filepath.Join(filepath.Dir(exe), "rocm")
|
||||||
|
if rocmLibUsable(rocmTargetDir) {
|
||||||
|
slog.Debug("detected ROCM next to ollama executable " + rocmTargetDir)
|
||||||
|
return rocmTargetDir, nil
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return "", fmt.Errorf("no suitable rocm found, falling back to CPU")
|
||||||
}
|
}
|
||||||
|
|
|
@ -69,7 +69,7 @@ func NewHipLib() (*HipLib, error) {
|
||||||
func (hl *HipLib) Release() {
|
func (hl *HipLib) Release() {
|
||||||
err := windows.FreeLibrary(hl.dll)
|
err := windows.FreeLibrary(hl.dll)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
slog.Warn(fmt.Sprintf("failed to unload amdhip64.dll: %s", err))
|
slog.Warn("failed to unload amdhip64.dll", "error", err)
|
||||||
}
|
}
|
||||||
hl.dll = 0
|
hl.dll = 0
|
||||||
}
|
}
|
||||||
|
@ -98,7 +98,7 @@ func (hl *HipLib) HipGetDeviceCount() int {
|
||||||
return 0
|
return 0
|
||||||
}
|
}
|
||||||
if status != hipSuccess {
|
if status != hipSuccess {
|
||||||
slog.Warn(fmt.Sprintf("failed call to hipGetDeviceCount: %d %s", status, err))
|
slog.Warn("failed call to hipGetDeviceCount", "status", status, "error", err)
|
||||||
}
|
}
|
||||||
return count
|
return count
|
||||||
}
|
}
|
||||||
|
|
473
gpu/amd_linux.go
473
gpu/amd_linux.go
|
@ -11,6 +11,8 @@ import (
|
||||||
"slices"
|
"slices"
|
||||||
"strconv"
|
"strconv"
|
||||||
"strings"
|
"strings"
|
||||||
|
|
||||||
|
"github.com/ollama/ollama/format"
|
||||||
)
|
)
|
||||||
|
|
||||||
// Discovery logic for AMD/ROCm GPUs
|
// Discovery logic for AMD/ROCm GPUs
|
||||||
|
@ -23,26 +25,20 @@ const (
|
||||||
// Prefix with the node dir
|
// Prefix with the node dir
|
||||||
GPUTotalMemoryFileGlob = "mem_banks/*/properties" // size_in_bytes line
|
GPUTotalMemoryFileGlob = "mem_banks/*/properties" // size_in_bytes line
|
||||||
GPUUsedMemoryFileGlob = "mem_banks/*/used_memory"
|
GPUUsedMemoryFileGlob = "mem_banks/*/used_memory"
|
||||||
RocmStandardLocation = "/opt/rocm/lib"
|
|
||||||
|
|
||||||
// TODO find a better way to detect iGPU instead of minimum memory
|
|
||||||
IGPUMemLimit = 1024 * 1024 * 1024 // 512G is what they typically report, so anything less than 1G must be iGPU
|
|
||||||
)
|
)
|
||||||
|
|
||||||
var (
|
var (
|
||||||
// Used to validate if the given ROCm lib is usable
|
// Used to validate if the given ROCm lib is usable
|
||||||
ROCmLibGlobs = []string{"libhipblas.so.2*", "rocblas"} // TODO - probably include more coverage of files here...
|
ROCmLibGlobs = []string{"libhipblas.so.2*", "rocblas"} // TODO - probably include more coverage of files here...
|
||||||
|
RocmStandardLocations = []string{"/opt/rocm/lib", "/usr/lib64"}
|
||||||
)
|
)
|
||||||
|
|
||||||
// Gather GPU information from the amdgpu driver if any supported GPUs are detected
|
// Gather GPU information from the amdgpu driver if any supported GPUs are detected
|
||||||
// HIP_VISIBLE_DEVICES will be set if we detect a mix of unsupported and supported devices
|
func AMDGetGPUInfo() []GpuInfo {
|
||||||
// and the user hasn't already set this variable
|
resp := []GpuInfo{}
|
||||||
func AMDGetGPUInfo(resp *GpuInfo) {
|
|
||||||
// TODO - DRY this out with windows
|
|
||||||
if !AMDDetected() {
|
if !AMDDetected() {
|
||||||
return
|
return resp
|
||||||
}
|
}
|
||||||
skip := map[int]interface{}{}
|
|
||||||
|
|
||||||
// Opportunistic logging of driver version to aid in troubleshooting
|
// Opportunistic logging of driver version to aid in troubleshooting
|
||||||
ver, err := AMDDriverVersion()
|
ver, err := AMDDriverVersion()
|
||||||
|
@ -50,160 +46,117 @@ func AMDGetGPUInfo(resp *GpuInfo) {
|
||||||
slog.Info("AMD Driver: " + ver)
|
slog.Info("AMD Driver: " + ver)
|
||||||
} else {
|
} else {
|
||||||
// TODO - if we see users crash and burn with the upstreamed kernel this can be adjusted to hard-fail rocm support and fallback to CPU
|
// TODO - if we see users crash and burn with the upstreamed kernel this can be adjusted to hard-fail rocm support and fallback to CPU
|
||||||
slog.Warn(fmt.Sprintf("ollama recommends running the https://www.amd.com/en/support/linux-drivers: %s", err))
|
slog.Warn("ollama recommends running the https://www.amd.com/en/support/linux-drivers", "error", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
// If the user has specified exactly which GPUs to use, look up their memory
|
// Determine if the user has already pre-selected which GPUs to look at, then ignore the others
|
||||||
visibleDevices := os.Getenv("HIP_VISIBLE_DEVICES")
|
var visibleDevices []string
|
||||||
if visibleDevices != "" {
|
hipVD := os.Getenv("HIP_VISIBLE_DEVICES") // zero based index only
|
||||||
ids := []int{}
|
rocrVD := os.Getenv("ROCR_VISIBLE_DEVICES") // zero based index or UUID, but consumer cards seem to not support UUID
|
||||||
for _, idStr := range strings.Split(visibleDevices, ",") {
|
gpuDO := os.Getenv("GPU_DEVICE_ORDINAL") // zero based index
|
||||||
id, err := strconv.Atoi(idStr)
|
switch {
|
||||||
if err != nil {
|
// TODO is this priorty order right?
|
||||||
slog.Warn(fmt.Sprintf("malformed HIP_VISIBLE_DEVICES=%s %s", visibleDevices, err))
|
case hipVD != "":
|
||||||
} else {
|
visibleDevices = strings.Split(hipVD, ",")
|
||||||
ids = append(ids, id)
|
case rocrVD != "":
|
||||||
|
visibleDevices = strings.Split(rocrVD, ",")
|
||||||
|
// TODO - since we don't yet support UUIDs, consider detecting and reporting here
|
||||||
|
// all our test systems show GPU-XX indicating UUID is not supported
|
||||||
|
case gpuDO != "":
|
||||||
|
visibleDevices = strings.Split(gpuDO, ",")
|
||||||
}
|
}
|
||||||
}
|
|
||||||
amdProcMemLookup(resp, nil, ids)
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
// Gather GFX version information from all detected cards
|
|
||||||
gfx := AMDGFXVersions()
|
|
||||||
verStrings := []string{}
|
|
||||||
for i, v := range gfx {
|
|
||||||
verStrings = append(verStrings, v.ToGFXString())
|
|
||||||
if v.Major == 0 {
|
|
||||||
// Silently skip CPUs
|
|
||||||
skip[i] = struct{}{}
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
if v.Major < 9 {
|
|
||||||
// TODO consider this a build-time setting if we can support 8xx family GPUs
|
|
||||||
slog.Warn(fmt.Sprintf("amdgpu [%d] too old %s", i, v.ToGFXString()))
|
|
||||||
skip[i] = struct{}{}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
slog.Info(fmt.Sprintf("detected amdgpu versions %v", verStrings))
|
|
||||||
|
|
||||||
// Abort if all GPUs are skipped
|
|
||||||
if len(skip) >= len(gfx) {
|
|
||||||
slog.Info("all detected amdgpus are skipped, falling back to CPU")
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
// If we got this far, then we have at least 1 GPU that's a ROCm candidate, so make sure we have a lib
|
|
||||||
libDir, err := AMDValidateLibDir()
|
|
||||||
if err != nil {
|
|
||||||
slog.Warn(fmt.Sprintf("unable to verify rocm library, will use cpu: %s", err))
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
updateLibPath(libDir)
|
|
||||||
|
|
||||||
gfxOverride := os.Getenv("HSA_OVERRIDE_GFX_VERSION")
|
gfxOverride := os.Getenv("HSA_OVERRIDE_GFX_VERSION")
|
||||||
if gfxOverride == "" {
|
var supported []string
|
||||||
supported, err := GetSupportedGFX(libDir)
|
libDir := ""
|
||||||
|
|
||||||
|
// The amdgpu driver always exposes the host CPU(s) first, but we have to skip them and subtract
|
||||||
|
// from the other IDs to get alignment with the HIP libraries expectations (zero is the first GPU, not the CPU)
|
||||||
|
matches, _ := filepath.Glob(GPUPropertiesFileGlob)
|
||||||
|
cpuCount := 0
|
||||||
|
for _, match := range matches {
|
||||||
|
slog.Debug("evaluating amdgpu node " + match)
|
||||||
|
fp, err := os.Open(match)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
slog.Warn(fmt.Sprintf("failed to lookup supported GFX types, falling back to CPU mode: %s", err))
|
slog.Debug("failed to open sysfs node", "file", match, "error", err)
|
||||||
return
|
|
||||||
}
|
|
||||||
slog.Debug(fmt.Sprintf("rocm supported GPU types %v", supported))
|
|
||||||
|
|
||||||
for i, v := range gfx {
|
|
||||||
if !slices.Contains[[]string, string](supported, v.ToGFXString()) {
|
|
||||||
slog.Warn(fmt.Sprintf("amdgpu [%d] %s is not supported by %s %v", i, v.ToGFXString(), libDir, supported))
|
|
||||||
// TODO - consider discrete markdown just for ROCM troubleshooting?
|
|
||||||
slog.Warn("See https://github.com/ollama/ollama/blob/main/docs/gpu.md#overrides for HSA_OVERRIDE_GFX_VERSION usage")
|
|
||||||
skip[i] = struct{}{}
|
|
||||||
} else {
|
|
||||||
slog.Info(fmt.Sprintf("amdgpu [%d] %s is supported", i, v.ToGFXString()))
|
|
||||||
}
|
|
||||||
}
|
|
||||||
} else {
|
|
||||||
slog.Debug("skipping rocm gfx compatibility check with HSA_OVERRIDE_GFX_VERSION=" + gfxOverride)
|
|
||||||
}
|
|
||||||
|
|
||||||
if len(skip) >= len(gfx) {
|
|
||||||
slog.Info("all detected amdgpus are skipped, falling back to CPU")
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
ids := make([]int, len(gfx))
|
|
||||||
i := 0
|
|
||||||
for k := range gfx {
|
|
||||||
ids[i] = k
|
|
||||||
i++
|
|
||||||
}
|
|
||||||
amdProcMemLookup(resp, skip, ids)
|
|
||||||
if resp.memInfo.DeviceCount == 0 {
|
|
||||||
return
|
|
||||||
}
|
|
||||||
if len(skip) > 0 {
|
|
||||||
amdSetVisibleDevices(ids, skip)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func updateLibPath(libDir string) {
|
|
||||||
ldPaths := []string{}
|
|
||||||
if val, ok := os.LookupEnv("LD_LIBRARY_PATH"); ok {
|
|
||||||
ldPaths = strings.Split(val, ":")
|
|
||||||
}
|
|
||||||
for _, d := range ldPaths {
|
|
||||||
if d == libDir {
|
|
||||||
return
|
|
||||||
}
|
|
||||||
}
|
|
||||||
val := strings.Join(append(ldPaths, libDir), ":")
|
|
||||||
slog.Debug("updated lib path", "LD_LIBRARY_PATH", val)
|
|
||||||
os.Setenv("LD_LIBRARY_PATH", val)
|
|
||||||
}
|
|
||||||
|
|
||||||
// Walk the sysfs nodes for the available GPUs and gather information from them
|
|
||||||
// skipping over any devices in the skip map
|
|
||||||
func amdProcMemLookup(resp *GpuInfo, skip map[int]interface{}, ids []int) {
|
|
||||||
resp.memInfo.DeviceCount = 0
|
|
||||||
resp.memInfo.TotalMemory = 0
|
|
||||||
resp.memInfo.FreeMemory = 0
|
|
||||||
slog.Debug("discovering VRAM for amdgpu devices")
|
|
||||||
if len(ids) == 0 {
|
|
||||||
entries, err := os.ReadDir(AMDNodesSysfsDir)
|
|
||||||
if err != nil {
|
|
||||||
slog.Warn(fmt.Sprintf("failed to read amdgpu sysfs %s - %s", AMDNodesSysfsDir, err))
|
|
||||||
return
|
|
||||||
}
|
|
||||||
for _, node := range entries {
|
|
||||||
if !node.IsDir() {
|
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
id, err := strconv.Atoi(node.Name())
|
defer fp.Close()
|
||||||
|
nodeID, err := strconv.Atoi(filepath.Base(filepath.Dir(match)))
|
||||||
if err != nil {
|
if err != nil {
|
||||||
slog.Warn("malformed amdgpu sysfs node id " + node.Name())
|
slog.Debug("failed to parse node ID", "error", err)
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
ids = append(ids, id)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
slog.Debug(fmt.Sprintf("amdgpu devices %v", ids))
|
|
||||||
|
|
||||||
for _, id := range ids {
|
scanner := bufio.NewScanner(fp)
|
||||||
if _, skipped := skip[id]; skipped {
|
isCPU := false
|
||||||
|
var major, minor, patch uint64
|
||||||
|
for scanner.Scan() {
|
||||||
|
line := strings.TrimSpace(scanner.Text())
|
||||||
|
// Note: we could also use "cpu_cores_count X" where X is greater than zero to detect CPUs
|
||||||
|
if strings.HasPrefix(line, "gfx_target_version") {
|
||||||
|
ver := strings.Fields(line)
|
||||||
|
|
||||||
|
// Detect CPUs
|
||||||
|
if len(ver) == 2 && ver[1] == "0" {
|
||||||
|
slog.Debug("detected CPU " + match)
|
||||||
|
isCPU = true
|
||||||
|
break
|
||||||
|
}
|
||||||
|
|
||||||
|
if len(ver) != 2 || len(ver[1]) < 5 {
|
||||||
|
slog.Warn("malformed "+match, "gfx_target_version", line)
|
||||||
|
// If this winds up being a CPU, our offsets may be wrong
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
|
l := len(ver[1])
|
||||||
|
var err1, err2, err3 error
|
||||||
|
patch, err1 = strconv.ParseUint(ver[1][l-2:l], 10, 32)
|
||||||
|
minor, err2 = strconv.ParseUint(ver[1][l-4:l-2], 10, 32)
|
||||||
|
major, err3 = strconv.ParseUint(ver[1][:l-4], 10, 32)
|
||||||
|
if err1 != nil || err2 != nil || err3 != nil {
|
||||||
|
slog.Debug("malformed int " + line)
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// TODO - any other properties we want to extract and record?
|
||||||
|
// vendor_id + device_id -> pci lookup for "Name"
|
||||||
|
// Other metrics that may help us understand relative performance between multiple GPUs
|
||||||
|
}
|
||||||
|
|
||||||
|
if isCPU {
|
||||||
|
cpuCount++
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
// CPUs are always first in the list
|
||||||
|
gpuID := nodeID - cpuCount
|
||||||
|
|
||||||
|
// Shouldn't happen, but just in case...
|
||||||
|
if gpuID < 0 {
|
||||||
|
slog.Error("unexpected amdgpu sysfs data resulted in negative GPU ID, please set OLLAMA_DEBUG=1 and report an issue")
|
||||||
|
return []GpuInfo{}
|
||||||
|
}
|
||||||
|
|
||||||
|
if int(major) < RocmComputeMin {
|
||||||
|
slog.Warn(fmt.Sprintf("amdgpu too old gfx%d%d%x", major, minor, patch), "gpu", gpuID)
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
// Look up the memory for the current node
|
||||||
totalMemory := uint64(0)
|
totalMemory := uint64(0)
|
||||||
usedMemory := uint64(0)
|
usedMemory := uint64(0)
|
||||||
// Adjust for sysfs vs HIP ids
|
propGlob := filepath.Join(AMDNodesSysfsDir, strconv.Itoa(nodeID), GPUTotalMemoryFileGlob)
|
||||||
propGlob := filepath.Join(AMDNodesSysfsDir, strconv.Itoa(id+1), GPUTotalMemoryFileGlob)
|
|
||||||
propFiles, err := filepath.Glob(propGlob)
|
propFiles, err := filepath.Glob(propGlob)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
slog.Warn(fmt.Sprintf("error looking up total GPU memory: %s %s", propGlob, err))
|
slog.Warn("error looking up total GPU memory", "glob", propGlob, "error", err)
|
||||||
}
|
}
|
||||||
// 1 or more memory banks - sum the values of all of them
|
// 1 or more memory banks - sum the values of all of them
|
||||||
for _, propFile := range propFiles {
|
for _, propFile := range propFiles {
|
||||||
fp, err := os.Open(propFile)
|
fp, err := os.Open(propFile)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
slog.Warn(fmt.Sprintf("failed to open sysfs node file %s: %s", propFile, err))
|
slog.Warn("failed to open sysfs node", "file", propFile, "erroir", err)
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
defer fp.Close()
|
defer fp.Close()
|
||||||
|
@ -226,49 +179,113 @@ func amdProcMemLookup(resp *GpuInfo, skip map[int]interface{}, ids []int) {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
if totalMemory == 0 {
|
if totalMemory == 0 {
|
||||||
slog.Warn(fmt.Sprintf("amdgpu [%d] reports zero total memory, skipping", id))
|
slog.Warn("amdgpu reports zero total memory", "gpu", gpuID)
|
||||||
skip[id] = struct{}{}
|
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
if totalMemory < IGPUMemLimit {
|
usedGlob := filepath.Join(AMDNodesSysfsDir, strconv.Itoa(nodeID), GPUUsedMemoryFileGlob)
|
||||||
slog.Info(fmt.Sprintf("amdgpu [%d] appears to be an iGPU with %dM reported total memory, skipping", id, totalMemory/1024/1024))
|
|
||||||
skip[id] = struct{}{}
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
usedGlob := filepath.Join(AMDNodesSysfsDir, strconv.Itoa(id), GPUUsedMemoryFileGlob)
|
|
||||||
usedFiles, err := filepath.Glob(usedGlob)
|
usedFiles, err := filepath.Glob(usedGlob)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
slog.Warn(fmt.Sprintf("error looking up used GPU memory: %s %s", usedGlob, err))
|
slog.Warn("error looking up used GPU memory", "glob", usedGlob, "error", err)
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
for _, usedFile := range usedFiles {
|
for _, usedFile := range usedFiles {
|
||||||
fp, err := os.Open(usedFile)
|
fp, err := os.Open(usedFile)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
slog.Warn(fmt.Sprintf("failed to open sysfs node file %s: %s", usedFile, err))
|
slog.Warn("failed to open sysfs node", "file", usedFile, "error", err)
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
defer fp.Close()
|
defer fp.Close()
|
||||||
data, err := io.ReadAll(fp)
|
data, err := io.ReadAll(fp)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
slog.Warn(fmt.Sprintf("failed to read sysfs node file %s: %s", usedFile, err))
|
slog.Warn("failed to read sysfs node", "file", usedFile, "error", err)
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
used, err := strconv.ParseUint(strings.TrimSpace(string(data)), 10, 64)
|
used, err := strconv.ParseUint(strings.TrimSpace(string(data)), 10, 64)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
slog.Warn(fmt.Sprintf("malformed used memory %s: %s", string(data), err))
|
slog.Warn("malformed used memory", "data", string(data), "error", err)
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
usedMemory += used
|
usedMemory += used
|
||||||
}
|
}
|
||||||
slog.Info(fmt.Sprintf("[%d] amdgpu totalMemory %dM", id, totalMemory/1024/1024))
|
|
||||||
slog.Info(fmt.Sprintf("[%d] amdgpu freeMemory %dM", id, (totalMemory-usedMemory)/1024/1024))
|
// iGPU detection, remove this check once we can support an iGPU variant of the rocm library
|
||||||
resp.memInfo.DeviceCount++
|
if totalMemory < IGPUMemLimit {
|
||||||
resp.memInfo.TotalMemory += totalMemory
|
slog.Info("amdgpu appears to be an iGPU, skipping", "gpu", gpuID, "total", format.HumanBytes2(totalMemory))
|
||||||
resp.memInfo.FreeMemory += (totalMemory - usedMemory)
|
continue
|
||||||
}
|
}
|
||||||
if resp.memInfo.DeviceCount > 0 {
|
|
||||||
resp.Library = "rocm"
|
slog.Info("amdgpu memory", "gpu", gpuID, "total", format.HumanBytes2(totalMemory))
|
||||||
|
slog.Info("amdgpu memory", "gpu", gpuID, "available", format.HumanBytes2(totalMemory-usedMemory))
|
||||||
|
gpuInfo := GpuInfo{
|
||||||
|
Library: "rocm",
|
||||||
|
memInfo: memInfo{
|
||||||
|
TotalMemory: totalMemory,
|
||||||
|
FreeMemory: (totalMemory - usedMemory),
|
||||||
|
},
|
||||||
|
ID: fmt.Sprintf("%d", gpuID),
|
||||||
|
// Name: not exposed in sysfs directly, would require pci device id lookup
|
||||||
|
Major: int(major),
|
||||||
|
Minor: int(minor),
|
||||||
|
Patch: int(patch),
|
||||||
|
MinimumMemory: rocmMinimumMemory,
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// If the user wants to filter to a subset of devices, filter out if we aren't a match
|
||||||
|
if len(visibleDevices) > 0 {
|
||||||
|
include := false
|
||||||
|
for _, visible := range visibleDevices {
|
||||||
|
if visible == gpuInfo.ID {
|
||||||
|
include = true
|
||||||
|
break
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if !include {
|
||||||
|
slog.Info("filtering out device per user request", "id", gpuInfo.ID, "visible_devices", visibleDevices)
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Final validation is gfx compatibility - load the library if we haven't already loaded it
|
||||||
|
// even if the user overrides, we still need to validate the library
|
||||||
|
if libDir == "" {
|
||||||
|
libDir, err = AMDValidateLibDir()
|
||||||
|
if err != nil {
|
||||||
|
slog.Warn("unable to verify rocm library, will use cpu", "error", err)
|
||||||
|
return []GpuInfo{}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
gpuInfo.DependencyPath = libDir
|
||||||
|
|
||||||
|
if gfxOverride == "" {
|
||||||
|
// Only load supported list once
|
||||||
|
if len(supported) == 0 {
|
||||||
|
supported, err = GetSupportedGFX(libDir)
|
||||||
|
if err != nil {
|
||||||
|
slog.Warn("failed to lookup supported GFX types, falling back to CPU mode", "error", err)
|
||||||
|
return []GpuInfo{}
|
||||||
|
}
|
||||||
|
slog.Debug("rocm supported GPUs", "types", supported)
|
||||||
|
}
|
||||||
|
gfx := fmt.Sprintf("gfx%d%d%x", gpuInfo.Major, gpuInfo.Minor, gpuInfo.Patch)
|
||||||
|
if !slices.Contains[[]string, string](supported, gfx) {
|
||||||
|
slog.Warn("amdgpu is not supported", "gpu", gpuInfo.ID, "gpu_type", gfx, "library", libDir, "supported_types", supported)
|
||||||
|
// TODO - consider discrete markdown just for ROCM troubleshooting?
|
||||||
|
slog.Warn("See https://github.com/ollama/ollama/blob/main/docs/gpu.md#overrides for HSA_OVERRIDE_GFX_VERSION usage")
|
||||||
|
continue
|
||||||
|
} else {
|
||||||
|
slog.Info("amdgpu is supported", "gpu", gpuInfo.ID, "gpu_type", gfx)
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
slog.Debug("skipping rocm gfx compatibility check with HSA_OVERRIDE_GFX_VERSION=" + gfxOverride)
|
||||||
|
}
|
||||||
|
|
||||||
|
// The GPU has passed all the verification steps and is supported
|
||||||
|
resp = append(resp, gpuInfo)
|
||||||
|
}
|
||||||
|
if len(resp) == 0 {
|
||||||
|
slog.Info("no compatible amdgpu devices detected")
|
||||||
|
}
|
||||||
|
return resp
|
||||||
}
|
}
|
||||||
|
|
||||||
// Quick check for AMD driver so we can skip amdgpu discovery if not present
|
// Quick check for AMD driver so we can skip amdgpu discovery if not present
|
||||||
|
@ -280,87 +297,24 @@ func AMDDetected() bool {
|
||||||
slog.Debug("amdgpu driver not detected " + sysfsDir)
|
slog.Debug("amdgpu driver not detected " + sysfsDir)
|
||||||
return false
|
return false
|
||||||
} else if err != nil {
|
} else if err != nil {
|
||||||
slog.Debug(fmt.Sprintf("error looking up amd driver %s %s", sysfsDir, err))
|
slog.Debug("error looking up amd driver", "path", sysfsDir, "error", err)
|
||||||
return false
|
return false
|
||||||
}
|
}
|
||||||
return true
|
return true
|
||||||
}
|
}
|
||||||
|
|
||||||
func setupLink(source, target string) error {
|
|
||||||
if err := os.RemoveAll(target); err != nil {
|
|
||||||
return fmt.Errorf("failed to remove old rocm directory %s %w", target, err)
|
|
||||||
}
|
|
||||||
if err := os.Symlink(source, target); err != nil {
|
|
||||||
return fmt.Errorf("failed to create link %s => %s %w", source, target, err)
|
|
||||||
}
|
|
||||||
slog.Debug(fmt.Sprintf("host rocm linked %s => %s", source, target))
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// Ensure the AMD rocm lib dir is wired up
|
|
||||||
// Prefer to use host installed ROCm, as long as it meets our minimum requirements
|
// Prefer to use host installed ROCm, as long as it meets our minimum requirements
|
||||||
// failing that, tell the user how to download it on their own
|
// failing that, tell the user how to download it on their own
|
||||||
func AMDValidateLibDir() (string, error) {
|
func AMDValidateLibDir() (string, error) {
|
||||||
// We rely on the rpath compiled into our library to find rocm
|
libDir, err := commonAMDValidateLibDir()
|
||||||
// so we establish a symlink to wherever we find it on the system
|
|
||||||
// to <payloads>/rocm
|
|
||||||
payloadsDir, err := PayloadsDir()
|
|
||||||
if err != nil {
|
|
||||||
return "", err
|
|
||||||
}
|
|
||||||
|
|
||||||
// If we already have a rocm dependency wired, nothing more to do
|
|
||||||
rocmTargetDir := filepath.Clean(filepath.Join(payloadsDir, "..", "rocm"))
|
|
||||||
if rocmLibUsable(rocmTargetDir) {
|
|
||||||
return rocmTargetDir, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// next to the running binary
|
|
||||||
exe, err := os.Executable()
|
|
||||||
if err == nil {
|
if err == nil {
|
||||||
peerDir := filepath.Dir(exe)
|
return libDir, nil
|
||||||
if rocmLibUsable(peerDir) {
|
|
||||||
slog.Debug("detected ROCM next to ollama executable " + peerDir)
|
|
||||||
return rocmTargetDir, setupLink(peerDir, rocmTargetDir)
|
|
||||||
}
|
|
||||||
peerDir = filepath.Join(filepath.Dir(exe), "rocm")
|
|
||||||
if rocmLibUsable(peerDir) {
|
|
||||||
slog.Debug("detected ROCM next to ollama executable " + peerDir)
|
|
||||||
return rocmTargetDir, setupLink(peerDir, rocmTargetDir)
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// Well known ollama installer path
|
// Well known ollama installer path
|
||||||
installedRocmDir := "/usr/share/ollama/lib/rocm"
|
installedRocmDir := "/usr/share/ollama/lib/rocm"
|
||||||
if rocmLibUsable(installedRocmDir) {
|
if rocmLibUsable(installedRocmDir) {
|
||||||
return rocmTargetDir, setupLink(installedRocmDir, rocmTargetDir)
|
return installedRocmDir, nil
|
||||||
}
|
|
||||||
|
|
||||||
// Prefer explicit HIP env var
|
|
||||||
hipPath := os.Getenv("HIP_PATH")
|
|
||||||
if hipPath != "" {
|
|
||||||
hipLibDir := filepath.Join(hipPath, "lib")
|
|
||||||
if rocmLibUsable(hipLibDir) {
|
|
||||||
slog.Debug("detected ROCM via HIP_PATH=" + hipPath)
|
|
||||||
return rocmTargetDir, setupLink(hipLibDir, rocmTargetDir)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// Scan the library path for potential matches
|
|
||||||
ldPaths := strings.Split(os.Getenv("LD_LIBRARY_PATH"), ":")
|
|
||||||
for _, ldPath := range ldPaths {
|
|
||||||
d, err := filepath.Abs(ldPath)
|
|
||||||
if err != nil {
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
if rocmLibUsable(d) {
|
|
||||||
return rocmTargetDir, setupLink(d, rocmTargetDir)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// Well known location(s)
|
|
||||||
if rocmLibUsable("/opt/rocm/lib") {
|
|
||||||
return rocmTargetDir, setupLink("/opt/rocm/lib", rocmTargetDir)
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// If we still haven't found a usable rocm, the user will have to install it on their own
|
// If we still haven't found a usable rocm, the user will have to install it on their own
|
||||||
|
@ -384,68 +338,3 @@ func AMDDriverVersion() (string, error) {
|
||||||
}
|
}
|
||||||
return strings.TrimSpace(string(verString)), nil
|
return strings.TrimSpace(string(verString)), nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func AMDGFXVersions() map[int]Version {
|
|
||||||
// The amdgpu driver always exposes the host CPU as node 0, but we have to skip that and subtract one
|
|
||||||
// from the other IDs to get alignment with the HIP libraries expectations (zero is the first GPU, not the CPU)
|
|
||||||
res := map[int]Version{}
|
|
||||||
matches, _ := filepath.Glob(GPUPropertiesFileGlob)
|
|
||||||
for _, match := range matches {
|
|
||||||
fp, err := os.Open(match)
|
|
||||||
if err != nil {
|
|
||||||
slog.Debug(fmt.Sprintf("failed to open sysfs node file %s: %s", match, err))
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
defer fp.Close()
|
|
||||||
i, err := strconv.Atoi(filepath.Base(filepath.Dir(match)))
|
|
||||||
if err != nil {
|
|
||||||
slog.Debug(fmt.Sprintf("failed to parse node ID %s", err))
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
|
|
||||||
if i == 0 {
|
|
||||||
// Skipping the CPU
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
// Align with HIP IDs (zero is first GPU, not CPU)
|
|
||||||
i -= 1
|
|
||||||
|
|
||||||
scanner := bufio.NewScanner(fp)
|
|
||||||
for scanner.Scan() {
|
|
||||||
line := strings.TrimSpace(scanner.Text())
|
|
||||||
if strings.HasPrefix(line, "gfx_target_version") {
|
|
||||||
ver := strings.Fields(line)
|
|
||||||
if len(ver) != 2 || len(ver[1]) < 5 {
|
|
||||||
if ver[1] != "0" {
|
|
||||||
slog.Debug("malformed " + line)
|
|
||||||
}
|
|
||||||
res[i] = Version{
|
|
||||||
Major: 0,
|
|
||||||
Minor: 0,
|
|
||||||
Patch: 0,
|
|
||||||
}
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
l := len(ver[1])
|
|
||||||
patch, err1 := strconv.ParseUint(ver[1][l-2:l], 10, 32)
|
|
||||||
minor, err2 := strconv.ParseUint(ver[1][l-4:l-2], 10, 32)
|
|
||||||
major, err3 := strconv.ParseUint(ver[1][:l-4], 10, 32)
|
|
||||||
if err1 != nil || err2 != nil || err3 != nil {
|
|
||||||
slog.Debug("malformed int " + line)
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
|
|
||||||
res[i] = Version{
|
|
||||||
Major: uint(major),
|
|
||||||
Minor: uint(minor),
|
|
||||||
Patch: uint(patch),
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
return res
|
|
||||||
}
|
|
||||||
|
|
||||||
func (v Version) ToGFXString() string {
|
|
||||||
return fmt.Sprintf("gfx%d%d%d", v.Major, v.Minor, v.Patch)
|
|
||||||
}
|
|
||||||
|
|
|
@ -7,11 +7,13 @@ import (
|
||||||
"os"
|
"os"
|
||||||
"path/filepath"
|
"path/filepath"
|
||||||
"slices"
|
"slices"
|
||||||
|
"strconv"
|
||||||
"strings"
|
"strings"
|
||||||
|
|
||||||
|
"github.com/ollama/ollama/format"
|
||||||
)
|
)
|
||||||
|
|
||||||
const (
|
const (
|
||||||
RocmStandardLocation = "C:\\Program Files\\AMD\\ROCm\\5.7\\bin" // TODO glob?
|
|
||||||
|
|
||||||
// TODO We're lookinng for this exact name to detect iGPUs since hipGetDeviceProperties never reports integrated==true
|
// TODO We're lookinng for this exact name to detect iGPUs since hipGetDeviceProperties never reports integrated==true
|
||||||
iGPUName = "AMD Radeon(TM) Graphics"
|
iGPUName = "AMD Radeon(TM) Graphics"
|
||||||
|
@ -20,38 +22,35 @@ const (
|
||||||
var (
|
var (
|
||||||
// Used to validate if the given ROCm lib is usable
|
// Used to validate if the given ROCm lib is usable
|
||||||
ROCmLibGlobs = []string{"hipblas.dll", "rocblas"} // TODO - probably include more coverage of files here...
|
ROCmLibGlobs = []string{"hipblas.dll", "rocblas"} // TODO - probably include more coverage of files here...
|
||||||
|
RocmStandardLocations = []string{"C:\\Program Files\\AMD\\ROCm\\5.7\\bin"} // TODO glob?
|
||||||
)
|
)
|
||||||
|
|
||||||
func AMDGetGPUInfo(resp *GpuInfo) {
|
func AMDGetGPUInfo() []GpuInfo {
|
||||||
|
resp := []GpuInfo{}
|
||||||
hl, err := NewHipLib()
|
hl, err := NewHipLib()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
slog.Debug(err.Error())
|
slog.Debug(err.Error())
|
||||||
return
|
return nil
|
||||||
}
|
}
|
||||||
defer hl.Release()
|
defer hl.Release()
|
||||||
skip := map[int]interface{}{}
|
|
||||||
ids := []int{}
|
|
||||||
resp.memInfo.DeviceCount = 0
|
|
||||||
resp.memInfo.TotalMemory = 0
|
|
||||||
resp.memInfo.FreeMemory = 0
|
|
||||||
|
|
||||||
ver, err := hl.AMDDriverVersion()
|
ver, err := hl.AMDDriverVersion()
|
||||||
if err == nil {
|
if err == nil {
|
||||||
slog.Info("AMD Driver: " + ver)
|
slog.Info("AMD Driver: " + ver)
|
||||||
} else {
|
} else {
|
||||||
// For now this is benign, but we may eventually need to fail compatibility checks
|
// For now this is benign, but we may eventually need to fail compatibility checks
|
||||||
slog.Debug(fmt.Sprintf("error looking up amd driver version: %s", err))
|
slog.Debug("error looking up amd driver version", "error", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
// Note: the HIP library automatically handles HIP_VISIBLE_DEVICES
|
// Note: the HIP library automatically handles subsetting to any HIP_VISIBLE_DEVICES the user specified
|
||||||
count := hl.HipGetDeviceCount()
|
count := hl.HipGetDeviceCount()
|
||||||
if count == 0 {
|
if count == 0 {
|
||||||
return
|
return nil
|
||||||
}
|
}
|
||||||
libDir, err := AMDValidateLibDir()
|
libDir, err := AMDValidateLibDir()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
slog.Warn(fmt.Sprintf("unable to verify rocm library, will use cpu: %s", err))
|
slog.Warn("unable to verify rocm library, will use cpu", "error", err)
|
||||||
return
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
var supported []string
|
var supported []string
|
||||||
|
@ -59,95 +58,120 @@ func AMDGetGPUInfo(resp *GpuInfo) {
|
||||||
if gfxOverride == "" {
|
if gfxOverride == "" {
|
||||||
supported, err = GetSupportedGFX(libDir)
|
supported, err = GetSupportedGFX(libDir)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
slog.Warn(fmt.Sprintf("failed to lookup supported GFX types, falling back to CPU mode: %s", err))
|
slog.Warn("failed to lookup supported GFX types, falling back to CPU mode", "error", err)
|
||||||
return
|
return nil
|
||||||
}
|
}
|
||||||
} else {
|
} else {
|
||||||
slog.Debug("skipping rocm gfx compatibility check with HSA_OVERRIDE_GFX_VERSION=" + gfxOverride)
|
slog.Debug("skipping rocm gfx compatibility check with HSA_OVERRIDE_GFX_VERSION=" + gfxOverride)
|
||||||
}
|
}
|
||||||
|
|
||||||
slog.Info(fmt.Sprintf("detected %d hip devices", count))
|
slog.Info("detected hip devices", "count", count)
|
||||||
|
// TODO how to determine the underlying device ID when visible devices is causing this to subset?
|
||||||
for i := 0; i < count; i++ {
|
for i := 0; i < count; i++ {
|
||||||
ids = append(ids, i)
|
|
||||||
err = hl.HipSetDevice(i)
|
err = hl.HipSetDevice(i)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
slog.Warn(fmt.Sprintf("[%d] %s", i, err))
|
slog.Warn("set device", "id", i, "error", err)
|
||||||
skip[i] = struct{}{}
|
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
|
|
||||||
props, err := hl.HipGetDeviceProperties(i)
|
props, err := hl.HipGetDeviceProperties(i)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
slog.Warn(fmt.Sprintf("[%d] %s", i, err))
|
slog.Warn("get properties", "id", i, "error", err)
|
||||||
skip[i] = struct{}{}
|
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
n := bytes.IndexByte(props.Name[:], 0)
|
n := bytes.IndexByte(props.Name[:], 0)
|
||||||
name := string(props.Name[:n])
|
name := string(props.Name[:n])
|
||||||
slog.Info(fmt.Sprintf("[%d] Name: %s", i, name))
|
// TODO is UUID actually populated on windows?
|
||||||
|
// Can luid be used on windows for setting visible devices (and is it actually set?)
|
||||||
n = bytes.IndexByte(props.GcnArchName[:], 0)
|
n = bytes.IndexByte(props.GcnArchName[:], 0)
|
||||||
gfx := string(props.GcnArchName[:n])
|
gfx := string(props.GcnArchName[:n])
|
||||||
slog.Info(fmt.Sprintf("[%d] GcnArchName: %s", i, gfx))
|
slog.Info("hip device", "id", i, "name", name, "gfx", gfx)
|
||||||
|
var major, minor, patch string
|
||||||
|
switch len(gfx) {
|
||||||
|
case 6:
|
||||||
|
major, minor, patch = gfx[3:4], gfx[4:5], gfx[5:]
|
||||||
|
case 7:
|
||||||
|
major, minor, patch = gfx[3:5], gfx[5:6], gfx[6:]
|
||||||
|
}
|
||||||
//slog.Info(fmt.Sprintf("[%d] Integrated: %d", i, props.iGPU)) // DOESN'T REPORT CORRECTLY! Always 0
|
//slog.Info(fmt.Sprintf("[%d] Integrated: %d", i, props.iGPU)) // DOESN'T REPORT CORRECTLY! Always 0
|
||||||
// TODO Why isn't props.iGPU accurate!?
|
// TODO Why isn't props.iGPU accurate!?
|
||||||
if strings.EqualFold(name, iGPUName) {
|
if strings.EqualFold(name, iGPUName) {
|
||||||
slog.Info(fmt.Sprintf("iGPU detected [%d] skipping", i))
|
slog.Info("iGPU detected skipping", "id", i)
|
||||||
skip[i] = struct{}{}
|
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
if gfxOverride == "" {
|
if gfxOverride == "" {
|
||||||
if !slices.Contains[[]string, string](supported, gfx) {
|
if !slices.Contains[[]string, string](supported, gfx) {
|
||||||
slog.Warn(fmt.Sprintf("amdgpu [%d] %s is not supported by %s %v", i, gfx, libDir, supported))
|
slog.Warn("amdgpu is not supported", "gpu", i, "gpu_type", gfx, "library", libDir, "supported_types", supported)
|
||||||
// TODO - consider discrete markdown just for ROCM troubleshooting?
|
// TODO - consider discrete markdown just for ROCM troubleshooting?
|
||||||
slog.Warn("See https://github.com/ollama/ollama/blob/main/docs/troubleshooting.md for HSA_OVERRIDE_GFX_VERSION usage")
|
slog.Warn("See https://github.com/ollama/ollama/blob/main/docs/troubleshooting.md for HSA_OVERRIDE_GFX_VERSION usage")
|
||||||
skip[i] = struct{}{}
|
|
||||||
continue
|
continue
|
||||||
} else {
|
} else {
|
||||||
slog.Info(fmt.Sprintf("amdgpu [%d] %s is supported", i, gfx))
|
slog.Info("amdgpu is supported", "gpu", i, "gpu_type", gfx)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
totalMemory, freeMemory, err := hl.HipMemGetInfo()
|
freeMemory, totalMemory, err := hl.HipMemGetInfo()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
slog.Warn(fmt.Sprintf("[%d] %s", i, err))
|
slog.Warn("get mem info", "id", i, "error", err)
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
|
|
||||||
// TODO according to docs, freeMem may lie on windows!
|
// iGPU detection, remove this check once we can support an iGPU variant of the rocm library
|
||||||
slog.Info(fmt.Sprintf("[%d] Total Mem: %d", i, totalMemory))
|
if totalMemory < IGPUMemLimit {
|
||||||
slog.Info(fmt.Sprintf("[%d] Free Mem: %d", i, freeMemory))
|
slog.Info("amdgpu appears to be an iGPU, skipping", "gpu", i, "total", format.HumanBytes2(totalMemory))
|
||||||
resp.memInfo.DeviceCount++
|
continue
|
||||||
resp.memInfo.TotalMemory += totalMemory
|
|
||||||
resp.memInfo.FreeMemory += freeMemory
|
|
||||||
}
|
}
|
||||||
if resp.memInfo.DeviceCount > 0 {
|
|
||||||
resp.Library = "rocm"
|
// TODO revisit this once ROCm v6 is available on windows.
|
||||||
|
// v5.7 only reports VRAM used by this process, so it's completely wrong and unusable
|
||||||
|
slog.Info("amdgpu memory", "gpu", i, "total", format.HumanBytes2(totalMemory))
|
||||||
|
slog.Info("amdgpu memory", "gpu", i, "available", format.HumanBytes2(freeMemory))
|
||||||
|
gpuInfo := GpuInfo{
|
||||||
|
Library: "rocm",
|
||||||
|
memInfo: memInfo{
|
||||||
|
TotalMemory: totalMemory,
|
||||||
|
FreeMemory: freeMemory,
|
||||||
|
},
|
||||||
|
ID: fmt.Sprintf("%d", i), // TODO this is probably wrong if we specify visible devices
|
||||||
|
DependencyPath: libDir,
|
||||||
|
MinimumMemory: rocmMinimumMemory,
|
||||||
}
|
}
|
||||||
// Abort if all GPUs are skipped
|
if major != "" {
|
||||||
if len(skip) >= count {
|
gpuInfo.Major, err = strconv.Atoi(major)
|
||||||
slog.Info("all detected amdgpus are skipped, falling back to CPU")
|
if err != nil {
|
||||||
return
|
slog.Info("failed to parse version", "version", gfx, "error", err)
|
||||||
}
|
}
|
||||||
if len(skip) > 0 {
|
|
||||||
amdSetVisibleDevices(ids, skip)
|
|
||||||
}
|
}
|
||||||
UpdatePath(libDir)
|
if minor != "" {
|
||||||
|
gpuInfo.Minor, err = strconv.Atoi(minor)
|
||||||
|
if err != nil {
|
||||||
|
slog.Info("failed to parse version", "version", gfx, "error", err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if patch != "" {
|
||||||
|
// Patch rev is hex; e.g. gfx90a
|
||||||
|
p, err := strconv.ParseInt(patch, 16, 0)
|
||||||
|
if err != nil {
|
||||||
|
slog.Info("failed to parse version", "version", gfx, "error", err)
|
||||||
|
} else {
|
||||||
|
gpuInfo.Patch = int(p)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if gpuInfo.Major < RocmComputeMin {
|
||||||
|
slog.Warn(fmt.Sprintf("amdgpu [%s] too old gfx%d%d%x", gpuInfo.ID, gpuInfo.Major, gpuInfo.Minor, gpuInfo.Patch))
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
resp = append(resp, gpuInfo)
|
||||||
|
}
|
||||||
|
|
||||||
|
return resp
|
||||||
}
|
}
|
||||||
|
|
||||||
func AMDValidateLibDir() (string, error) {
|
func AMDValidateLibDir() (string, error) {
|
||||||
// On windows non-admins typically can't create links
|
libDir, err := commonAMDValidateLibDir()
|
||||||
// so instead of trying to rely on rpath and a link in
|
|
||||||
// $LibDir/rocm, we instead rely on setting PATH to point
|
|
||||||
// to the location of the ROCm library
|
|
||||||
|
|
||||||
// Installer payload location if we're running the installed binary
|
|
||||||
exe, err := os.Executable()
|
|
||||||
if err == nil {
|
if err == nil {
|
||||||
rocmTargetDir := filepath.Join(filepath.Dir(exe), "rocm")
|
return libDir, nil
|
||||||
if rocmLibUsable(rocmTargetDir) {
|
|
||||||
slog.Debug("detected ROCM next to ollama executable " + rocmTargetDir)
|
|
||||||
return rocmTargetDir, nil
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// Installer payload (if we're running from some other location)
|
// Installer payload (if we're running from some other location)
|
||||||
|
@ -159,21 +183,6 @@ func AMDValidateLibDir() (string, error) {
|
||||||
return rocmTargetDir, nil
|
return rocmTargetDir, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// Prefer explicit HIP env var
|
|
||||||
hipPath := os.Getenv("HIP_PATH")
|
|
||||||
if hipPath != "" {
|
|
||||||
hipLibDir := filepath.Join(hipPath, "bin")
|
|
||||||
if rocmLibUsable(hipLibDir) {
|
|
||||||
slog.Debug("detected ROCM via HIP_PATH=" + hipPath)
|
|
||||||
return hipLibDir, nil
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// Well known location(s)
|
|
||||||
if rocmLibUsable(RocmStandardLocation) {
|
|
||||||
return RocmStandardLocation, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// Should not happen on windows since we include it in the installer, but stand-alone binary might hit this
|
// Should not happen on windows since we include it in the installer, but stand-alone binary might hit this
|
||||||
slog.Warn("amdgpu detected, but no compatible rocm library found. Please install ROCm")
|
slog.Warn("amdgpu detected, but no compatible rocm library found. Please install ROCm")
|
||||||
return "", fmt.Errorf("no suitable rocm found, falling back to CPU")
|
return "", fmt.Errorf("no suitable rocm found, falling back to CPU")
|
||||||
|
|
|
@ -12,6 +12,8 @@ import (
|
||||||
"sync"
|
"sync"
|
||||||
"syscall"
|
"syscall"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
|
"github.com/ollama/ollama/server/envconfig"
|
||||||
)
|
)
|
||||||
|
|
||||||
var (
|
var (
|
||||||
|
@ -24,8 +26,16 @@ func PayloadsDir() (string, error) {
|
||||||
defer lock.Unlock()
|
defer lock.Unlock()
|
||||||
var err error
|
var err error
|
||||||
if payloadsDir == "" {
|
if payloadsDir == "" {
|
||||||
|
runnersDir := envconfig.RunnersDir
|
||||||
|
|
||||||
|
if runnersDir != "" {
|
||||||
|
payloadsDir = runnersDir
|
||||||
|
return payloadsDir, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// The remainder only applies on non-windows where we still carry payloads in the main executable
|
||||||
cleanupTmpDirs()
|
cleanupTmpDirs()
|
||||||
tmpDir := os.Getenv("OLLAMA_TMPDIR")
|
tmpDir := envconfig.TmpDir
|
||||||
if tmpDir == "" {
|
if tmpDir == "" {
|
||||||
tmpDir, err = os.MkdirTemp("", "ollama")
|
tmpDir, err = os.MkdirTemp("", "ollama")
|
||||||
if err != nil {
|
if err != nil {
|
||||||
|
@ -80,7 +90,7 @@ func cleanupTmpDirs() {
|
||||||
}
|
}
|
||||||
err = os.RemoveAll(d)
|
err = os.RemoveAll(d)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
slog.Debug(fmt.Sprintf("unable to cleanup stale tmpdir %s: %s", d, err))
|
slog.Debug("unable to cleanup stale tmpdir", "path", d, "error", err)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -88,7 +98,8 @@ func cleanupTmpDirs() {
|
||||||
func Cleanup() {
|
func Cleanup() {
|
||||||
lock.Lock()
|
lock.Lock()
|
||||||
defer lock.Unlock()
|
defer lock.Unlock()
|
||||||
if payloadsDir != "" {
|
runnersDir := envconfig.RunnersDir
|
||||||
|
if payloadsDir != "" && runnersDir == "" && runtime.GOOS != "windows" {
|
||||||
// We want to fully clean up the tmpdir parent of the payloads dir
|
// We want to fully clean up the tmpdir parent of the payloads dir
|
||||||
tmpDir := filepath.Clean(filepath.Join(payloadsDir, ".."))
|
tmpDir := filepath.Clean(filepath.Join(payloadsDir, ".."))
|
||||||
slog.Debug("cleaning up", "dir", tmpDir)
|
slog.Debug("cleaning up", "dir", tmpDir)
|
||||||
|
@ -120,7 +131,7 @@ func UpdatePath(dir string) {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
newPath := strings.Join(append([]string{dir}, pathComponents...), ";")
|
newPath := strings.Join(append([]string{dir}, pathComponents...), ";")
|
||||||
slog.Info(fmt.Sprintf("Updating PATH to %s", newPath))
|
slog.Info("updating", "PATH", newPath)
|
||||||
os.Setenv("PATH", newPath)
|
os.Setenv("PATH", newPath)
|
||||||
}
|
}
|
||||||
// linux and darwin rely on rpath
|
// linux and darwin rely on rpath
|
||||||
|
|
22
gpu/cuda_common.go
Normal file
22
gpu/cuda_common.go
Normal file
|
@ -0,0 +1,22 @@
|
||||||
|
//go:build linux || windows
|
||||||
|
|
||||||
|
package gpu
|
||||||
|
|
||||||
|
import (
|
||||||
|
"log/slog"
|
||||||
|
"strings"
|
||||||
|
)
|
||||||
|
|
||||||
|
func cudaGetVisibleDevicesEnv(gpuInfo []GpuInfo) (string, string) {
|
||||||
|
ids := []string{}
|
||||||
|
for _, info := range gpuInfo {
|
||||||
|
if info.Library != "cuda" {
|
||||||
|
// TODO shouldn't happen if things are wired correctly...
|
||||||
|
slog.Debug("cudaGetVisibleDevicesEnv skipping over non-cuda device", "library", info.Library)
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
ids = append(ids, info.ID)
|
||||||
|
}
|
||||||
|
return "CUDA_VISIBLE_DEVICES", strings.Join(ids, ",")
|
||||||
|
|
||||||
|
}
|
324
gpu/gpu.go
324
gpu/gpu.go
|
@ -16,22 +16,23 @@ import (
|
||||||
"os"
|
"os"
|
||||||
"path/filepath"
|
"path/filepath"
|
||||||
"runtime"
|
"runtime"
|
||||||
"strconv"
|
|
||||||
"strings"
|
"strings"
|
||||||
"sync"
|
"sync"
|
||||||
"unsafe"
|
"unsafe"
|
||||||
|
|
||||||
"github.com/ollama/ollama/format"
|
"github.com/ollama/ollama/format"
|
||||||
|
"github.com/ollama/ollama/server/envconfig"
|
||||||
)
|
)
|
||||||
|
|
||||||
type handles struct {
|
type handles struct {
|
||||||
nvml *C.nvml_handle_t
|
deviceCount int
|
||||||
cudart *C.cudart_handle_t
|
cudart *C.cudart_handle_t
|
||||||
|
nvcuda *C.nvcuda_handle_t
|
||||||
}
|
}
|
||||||
|
|
||||||
const (
|
const (
|
||||||
cudaMinimumMemory = 457 * format.MebiByte
|
cudaMinimumMemory = 256 * format.MebiByte
|
||||||
rocmMinimumMemory = 457 * format.MebiByte
|
rocmMinimumMemory = 256 * format.MebiByte
|
||||||
)
|
)
|
||||||
|
|
||||||
var gpuMutex sync.Mutex
|
var gpuMutex sync.Mutex
|
||||||
|
@ -39,26 +40,10 @@ var gpuMutex sync.Mutex
|
||||||
// With our current CUDA compile flags, older than 5.0 will not work properly
|
// With our current CUDA compile flags, older than 5.0 will not work properly
|
||||||
var CudaComputeMin = [2]C.int{5, 0}
|
var CudaComputeMin = [2]C.int{5, 0}
|
||||||
|
|
||||||
// Possible locations for the nvidia-ml library
|
var RocmComputeMin = 9
|
||||||
var NvmlLinuxGlobs = []string{
|
|
||||||
"/usr/local/cuda/lib64/libnvidia-ml.so*",
|
|
||||||
"/usr/lib/x86_64-linux-gnu/nvidia/current/libnvidia-ml.so*",
|
|
||||||
"/usr/lib/x86_64-linux-gnu/libnvidia-ml.so*",
|
|
||||||
"/usr/lib/wsl/lib/libnvidia-ml.so*",
|
|
||||||
"/usr/lib/wsl/drivers/*/libnvidia-ml.so*",
|
|
||||||
"/opt/cuda/lib64/libnvidia-ml.so*",
|
|
||||||
"/usr/lib*/libnvidia-ml.so*",
|
|
||||||
"/usr/lib/aarch64-linux-gnu/nvidia/current/libnvidia-ml.so*",
|
|
||||||
"/usr/lib/aarch64-linux-gnu/libnvidia-ml.so*",
|
|
||||||
"/usr/local/lib*/libnvidia-ml.so*",
|
|
||||||
|
|
||||||
// TODO: are these stubs ever valid?
|
// TODO find a better way to detect iGPU instead of minimum memory
|
||||||
"/opt/cuda/targets/x86_64-linux/lib/stubs/libnvidia-ml.so*",
|
const IGPUMemLimit = 1 * format.GibiByte // 512G is what they typically report, so anything less than 1G must be iGPU
|
||||||
}
|
|
||||||
|
|
||||||
var NvmlWindowsGlobs = []string{
|
|
||||||
"c:\\Windows\\System32\\nvml.dll",
|
|
||||||
}
|
|
||||||
|
|
||||||
var CudartLinuxGlobs = []string{
|
var CudartLinuxGlobs = []string{
|
||||||
"/usr/local/cuda/lib64/libcudart.so*",
|
"/usr/local/cuda/lib64/libcudart.so*",
|
||||||
|
@ -79,6 +64,22 @@ var CudartWindowsGlobs = []string{
|
||||||
"c:\\Program Files\\NVIDIA GPU Computing Toolkit\\CUDA\\v*\\bin\\cudart64_*.dll",
|
"c:\\Program Files\\NVIDIA GPU Computing Toolkit\\CUDA\\v*\\bin\\cudart64_*.dll",
|
||||||
}
|
}
|
||||||
|
|
||||||
|
var NvcudaLinuxGlobs = []string{
|
||||||
|
"/usr/local/cuda*/targets/*/lib/libcuda.so*",
|
||||||
|
"/usr/lib/*-linux-gnu/nvidia/current/libcuda.so*",
|
||||||
|
"/usr/lib/*-linux-gnu/libcuda.so*",
|
||||||
|
"/usr/lib/wsl/lib/libcuda.so*",
|
||||||
|
"/usr/lib/wsl/drivers/*/libcuda.so*",
|
||||||
|
"/opt/cuda/lib*/libcuda.so*",
|
||||||
|
"/usr/local/cuda/lib*/libcuda.so*",
|
||||||
|
"/usr/lib*/libcuda.so*",
|
||||||
|
"/usr/local/lib*/libcuda.so*",
|
||||||
|
}
|
||||||
|
|
||||||
|
var NvcudaWindowsGlobs = []string{
|
||||||
|
"c:\\windows\\system*\\nvcuda.dll",
|
||||||
|
}
|
||||||
|
|
||||||
// Jetson devices have JETSON_JETPACK="x.y.z" factory set to the Jetpack version installed.
|
// Jetson devices have JETSON_JETPACK="x.y.z" factory set to the Jetpack version installed.
|
||||||
// Included to drive logic for reducing Ollama-allocated overhead on L4T/Jetson devices.
|
// Included to drive logic for reducing Ollama-allocated overhead on L4T/Jetson devices.
|
||||||
var CudaTegra string = os.Getenv("JETSON_JETPACK")
|
var CudaTegra string = os.Getenv("JETSON_JETPACK")
|
||||||
|
@ -88,61 +89,62 @@ func initGPUHandles() *handles {
|
||||||
|
|
||||||
// TODO - if the ollama build is CPU only, don't do these checks as they're irrelevant and confusing
|
// TODO - if the ollama build is CPU only, don't do these checks as they're irrelevant and confusing
|
||||||
|
|
||||||
gpuHandles := &handles{nil, nil}
|
gpuHandles := &handles{}
|
||||||
var nvmlMgmtName string
|
|
||||||
var nvmlMgmtPatterns []string
|
|
||||||
var cudartMgmtName string
|
var cudartMgmtName string
|
||||||
var cudartMgmtPatterns []string
|
var cudartMgmtPatterns []string
|
||||||
|
var nvcudaMgmtName string
|
||||||
|
var nvcudaMgmtPatterns []string
|
||||||
|
|
||||||
tmpDir, _ := PayloadsDir()
|
tmpDir, _ := PayloadsDir()
|
||||||
switch runtime.GOOS {
|
switch runtime.GOOS {
|
||||||
case "windows":
|
case "windows":
|
||||||
nvmlMgmtName = "nvml.dll"
|
|
||||||
nvmlMgmtPatterns = make([]string, len(NvmlWindowsGlobs))
|
|
||||||
copy(nvmlMgmtPatterns, NvmlWindowsGlobs)
|
|
||||||
cudartMgmtName = "cudart64_*.dll"
|
cudartMgmtName = "cudart64_*.dll"
|
||||||
localAppData := os.Getenv("LOCALAPPDATA")
|
localAppData := os.Getenv("LOCALAPPDATA")
|
||||||
cudartMgmtPatterns = []string{filepath.Join(localAppData, "Programs", "Ollama", cudartMgmtName)}
|
cudartMgmtPatterns = []string{filepath.Join(localAppData, "Programs", "Ollama", cudartMgmtName)}
|
||||||
cudartMgmtPatterns = append(cudartMgmtPatterns, CudartWindowsGlobs...)
|
cudartMgmtPatterns = append(cudartMgmtPatterns, CudartWindowsGlobs...)
|
||||||
|
// Aligned with driver, we can't carry as payloads
|
||||||
|
nvcudaMgmtName = "nvcuda.dll"
|
||||||
|
nvcudaMgmtPatterns = NvcudaWindowsGlobs
|
||||||
case "linux":
|
case "linux":
|
||||||
nvmlMgmtName = "libnvidia-ml.so"
|
|
||||||
nvmlMgmtPatterns = make([]string, len(NvmlLinuxGlobs))
|
|
||||||
copy(nvmlMgmtPatterns, NvmlLinuxGlobs)
|
|
||||||
cudartMgmtName = "libcudart.so*"
|
cudartMgmtName = "libcudart.so*"
|
||||||
if tmpDir != "" {
|
if tmpDir != "" {
|
||||||
// TODO - add "payloads" for subprocess
|
// TODO - add "payloads" for subprocess
|
||||||
cudartMgmtPatterns = []string{filepath.Join(tmpDir, "cuda*", cudartMgmtName)}
|
cudartMgmtPatterns = []string{filepath.Join(tmpDir, "cuda*", cudartMgmtName)}
|
||||||
}
|
}
|
||||||
cudartMgmtPatterns = append(cudartMgmtPatterns, CudartLinuxGlobs...)
|
cudartMgmtPatterns = append(cudartMgmtPatterns, CudartLinuxGlobs...)
|
||||||
|
// Aligned with driver, we can't carry as payloads
|
||||||
|
nvcudaMgmtName = "libcuda.so*"
|
||||||
|
nvcudaMgmtPatterns = NvcudaLinuxGlobs
|
||||||
default:
|
default:
|
||||||
return gpuHandles
|
return gpuHandles
|
||||||
}
|
}
|
||||||
|
|
||||||
slog.Info("Detecting GPU type")
|
slog.Info("Detecting GPUs")
|
||||||
cudartLibPaths := FindGPULibs(cudartMgmtName, cudartMgmtPatterns)
|
nvcudaLibPaths := FindGPULibs(nvcudaMgmtName, nvcudaMgmtPatterns)
|
||||||
if len(cudartLibPaths) > 0 {
|
if len(nvcudaLibPaths) > 0 {
|
||||||
cudart := LoadCUDARTMgmt(cudartLibPaths)
|
deviceCount, nvcuda, libPath := LoadNVCUDAMgmt(nvcudaLibPaths)
|
||||||
if cudart != nil {
|
if nvcuda != nil {
|
||||||
slog.Info("Nvidia GPU detected via cudart")
|
slog.Info("detected GPUs", "count", deviceCount, "library", libPath)
|
||||||
gpuHandles.cudart = cudart
|
gpuHandles.nvcuda = nvcuda
|
||||||
|
gpuHandles.deviceCount = deviceCount
|
||||||
return gpuHandles
|
return gpuHandles
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// TODO once we build confidence, remove this and the gpu_info_nvml.[ch] files
|
cudartLibPaths := FindGPULibs(cudartMgmtName, cudartMgmtPatterns)
|
||||||
nvmlLibPaths := FindGPULibs(nvmlMgmtName, nvmlMgmtPatterns)
|
if len(cudartLibPaths) > 0 {
|
||||||
if len(nvmlLibPaths) > 0 {
|
deviceCount, cudart, libPath := LoadCUDARTMgmt(cudartLibPaths)
|
||||||
nvml := LoadNVMLMgmt(nvmlLibPaths)
|
if cudart != nil {
|
||||||
if nvml != nil {
|
slog.Info("detected GPUs", "library", libPath, "count", deviceCount)
|
||||||
slog.Info("Nvidia GPU detected via nvidia-ml")
|
gpuHandles.cudart = cudart
|
||||||
gpuHandles.nvml = nvml
|
gpuHandles.deviceCount = deviceCount
|
||||||
return gpuHandles
|
return gpuHandles
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
return gpuHandles
|
return gpuHandles
|
||||||
}
|
}
|
||||||
|
|
||||||
func GetGPUInfo() GpuInfo {
|
func GetGPUInfo() GpuInfoList {
|
||||||
// TODO - consider exploring lspci (and equivalent on windows) to check for
|
// TODO - consider exploring lspci (and equivalent on windows) to check for
|
||||||
// GPUs so we can report warnings if we see Nvidia/AMD but fail to load the libraries
|
// GPUs so we can report warnings if we see Nvidia/AMD but fail to load the libraries
|
||||||
gpuMutex.Lock()
|
gpuMutex.Lock()
|
||||||
|
@ -150,12 +152,12 @@ func GetGPUInfo() GpuInfo {
|
||||||
|
|
||||||
gpuHandles := initGPUHandles()
|
gpuHandles := initGPUHandles()
|
||||||
defer func() {
|
defer func() {
|
||||||
if gpuHandles.nvml != nil {
|
|
||||||
C.nvml_release(*gpuHandles.nvml)
|
|
||||||
}
|
|
||||||
if gpuHandles.cudart != nil {
|
if gpuHandles.cudart != nil {
|
||||||
C.cudart_release(*gpuHandles.cudart)
|
C.cudart_release(*gpuHandles.cudart)
|
||||||
}
|
}
|
||||||
|
if gpuHandles.nvcuda != nil {
|
||||||
|
C.nvcuda_release(*gpuHandles.nvcuda)
|
||||||
|
}
|
||||||
}()
|
}()
|
||||||
|
|
||||||
// All our GPU builds on x86 have AVX enabled, so fallback to CPU if we don't detect at least AVX
|
// All our GPU builds on x86 have AVX enabled, so fallback to CPU if we don't detect at least AVX
|
||||||
|
@ -164,73 +166,75 @@ func GetGPUInfo() GpuInfo {
|
||||||
slog.Warn("CPU does not have AVX or AVX2, disabling GPU support.")
|
slog.Warn("CPU does not have AVX or AVX2, disabling GPU support.")
|
||||||
}
|
}
|
||||||
|
|
||||||
var memInfo C.mem_info_t
|
// On windows we bundle the nvidia library one level above the runner dir
|
||||||
resp := GpuInfo{}
|
depPath := ""
|
||||||
if gpuHandles.nvml != nil && (cpuVariant != "" || runtime.GOARCH != "amd64") {
|
if runtime.GOOS == "windows" && envconfig.RunnersDir != "" {
|
||||||
C.nvml_check_vram(*gpuHandles.nvml, &memInfo)
|
depPath = filepath.Dir(envconfig.RunnersDir)
|
||||||
if memInfo.err != nil {
|
}
|
||||||
slog.Info(fmt.Sprintf("[nvidia-ml] error looking up NVML GPU memory: %s", C.GoString(memInfo.err)))
|
|
||||||
C.free(unsafe.Pointer(memInfo.err))
|
var memInfo C.mem_info_t
|
||||||
} else if memInfo.count > 0 {
|
resp := []GpuInfo{}
|
||||||
// Verify minimum compute capability
|
|
||||||
var cc C.nvml_compute_capability_t
|
// NVIDIA first
|
||||||
C.nvml_compute_capability(*gpuHandles.nvml, &cc)
|
for i := 0; i < gpuHandles.deviceCount; i++ {
|
||||||
if cc.err != nil {
|
// TODO once we support CPU compilation variants of GPU libraries refine this...
|
||||||
slog.Info(fmt.Sprintf("[nvidia-ml] error looking up NVML GPU compute capability: %s", C.GoString(cc.err)))
|
if cpuVariant == "" && runtime.GOARCH == "amd64" {
|
||||||
C.free(unsafe.Pointer(cc.err))
|
continue
|
||||||
} else if cc.major > CudaComputeMin[0] || (cc.major == CudaComputeMin[0] && cc.minor >= CudaComputeMin[1]) {
|
}
|
||||||
slog.Info(fmt.Sprintf("[nvidia-ml] NVML CUDA Compute Capability detected: %d.%d", cc.major, cc.minor))
|
gpuInfo := GpuInfo{
|
||||||
resp.Library = "cuda"
|
Library: "cuda",
|
||||||
resp.MinimumMemory = cudaMinimumMemory
|
}
|
||||||
} else {
|
if gpuHandles.cudart != nil {
|
||||||
slog.Info(fmt.Sprintf("[nvidia-ml] CUDA GPU is too old. Falling back to CPU mode. Compute Capability detected: %d.%d", cc.major, cc.minor))
|
C.cudart_check_vram(*gpuHandles.cudart, C.int(i), &memInfo)
|
||||||
}
|
} else {
|
||||||
}
|
C.nvcuda_check_vram(*gpuHandles.nvcuda, C.int(i), &memInfo)
|
||||||
} else if gpuHandles.cudart != nil && (cpuVariant != "" || runtime.GOARCH != "amd64") {
|
}
|
||||||
C.cudart_check_vram(*gpuHandles.cudart, &memInfo)
|
if memInfo.err != nil {
|
||||||
if memInfo.err != nil {
|
slog.Info("error looking up nvidia GPU memory", "error", C.GoString(memInfo.err))
|
||||||
slog.Info(fmt.Sprintf("[cudart] error looking up CUDART GPU memory: %s", C.GoString(memInfo.err)))
|
C.free(unsafe.Pointer(memInfo.err))
|
||||||
C.free(unsafe.Pointer(memInfo.err))
|
continue
|
||||||
} else if memInfo.count > 0 {
|
}
|
||||||
// Verify minimum compute capability
|
if memInfo.major < CudaComputeMin[0] || (memInfo.major == CudaComputeMin[0] && memInfo.minor < CudaComputeMin[1]) {
|
||||||
var cc C.cudart_compute_capability_t
|
slog.Info(fmt.Sprintf("[%d] CUDA GPU is too old. Compute Capability detected: %d.%d", i, memInfo.major, memInfo.minor))
|
||||||
C.cudart_compute_capability(*gpuHandles.cudart, &cc)
|
continue
|
||||||
if cc.err != nil {
|
}
|
||||||
slog.Info(fmt.Sprintf("[cudart] error looking up CUDA compute capability: %s", C.GoString(cc.err)))
|
gpuInfo.TotalMemory = uint64(memInfo.total)
|
||||||
C.free(unsafe.Pointer(cc.err))
|
gpuInfo.FreeMemory = uint64(memInfo.free)
|
||||||
} else if cc.major > CudaComputeMin[0] || (cc.major == CudaComputeMin[0] && cc.minor >= CudaComputeMin[1]) {
|
gpuInfo.ID = C.GoString(&memInfo.gpu_id[0])
|
||||||
slog.Info(fmt.Sprintf("[cudart] CUDART CUDA Compute Capability detected: %d.%d", cc.major, cc.minor))
|
gpuInfo.Major = int(memInfo.major)
|
||||||
resp.Library = "cuda"
|
gpuInfo.Minor = int(memInfo.minor)
|
||||||
resp.MinimumMemory = cudaMinimumMemory
|
gpuInfo.MinimumMemory = cudaMinimumMemory
|
||||||
} else {
|
gpuInfo.DependencyPath = depPath
|
||||||
slog.Info(fmt.Sprintf("[cudart] CUDA GPU is too old. Falling back to CPU mode. Compute Capability detected: %d.%d", cc.major, cc.minor))
|
|
||||||
}
|
// TODO potentially sort on our own algorithm instead of what the underlying GPU library does...
|
||||||
}
|
resp = append(resp, gpuInfo)
|
||||||
} else {
|
}
|
||||||
AMDGetGPUInfo(&resp)
|
|
||||||
if resp.Library != "" {
|
// Then AMD
|
||||||
resp.MinimumMemory = rocmMinimumMemory
|
resp = append(resp, AMDGetGPUInfo()...)
|
||||||
return resp
|
|
||||||
}
|
if len(resp) == 0 {
|
||||||
}
|
C.cpu_check_ram(&memInfo)
|
||||||
if resp.Library == "" {
|
if memInfo.err != nil {
|
||||||
C.cpu_check_ram(&memInfo)
|
slog.Info("error looking up CPU memory", "error", C.GoString(memInfo.err))
|
||||||
resp.Library = "cpu"
|
C.free(unsafe.Pointer(memInfo.err))
|
||||||
resp.Variant = cpuVariant
|
return resp
|
||||||
}
|
}
|
||||||
if memInfo.err != nil {
|
gpuInfo := GpuInfo{
|
||||||
slog.Info(fmt.Sprintf("error looking up CPU memory: %s", C.GoString(memInfo.err)))
|
Library: "cpu",
|
||||||
C.free(unsafe.Pointer(memInfo.err))
|
Variant: cpuVariant,
|
||||||
return resp
|
}
|
||||||
|
gpuInfo.TotalMemory = uint64(memInfo.total)
|
||||||
|
gpuInfo.FreeMemory = uint64(memInfo.free)
|
||||||
|
gpuInfo.ID = C.GoString(&memInfo.gpu_id[0])
|
||||||
|
|
||||||
|
resp = append(resp, gpuInfo)
|
||||||
}
|
}
|
||||||
|
|
||||||
resp.DeviceCount = uint32(memInfo.count)
|
|
||||||
resp.FreeMemory = uint64(memInfo.free)
|
|
||||||
resp.TotalMemory = uint64(memInfo.total)
|
|
||||||
return resp
|
return resp
|
||||||
}
|
}
|
||||||
|
|
||||||
func getCPUMem() (memInfo, error) {
|
func GetCPUMem() (memInfo, error) {
|
||||||
var ret memInfo
|
var ret memInfo
|
||||||
var info C.mem_info_t
|
var info C.mem_info_t
|
||||||
C.cpu_check_ram(&info)
|
C.cpu_check_ram(&info)
|
||||||
|
@ -243,29 +247,12 @@ func getCPUMem() (memInfo, error) {
|
||||||
return ret, nil
|
return ret, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func CheckVRAM() (uint64, error) {
|
func FindGPULibs(baseLibName string, defaultPatterns []string) []string {
|
||||||
userLimit := os.Getenv("OLLAMA_MAX_VRAM")
|
|
||||||
if userLimit != "" {
|
|
||||||
avail, err := strconv.ParseInt(userLimit, 10, 64)
|
|
||||||
if err != nil {
|
|
||||||
return 0, fmt.Errorf("Invalid OLLAMA_MAX_VRAM setting %s: %s", userLimit, err)
|
|
||||||
}
|
|
||||||
slog.Info(fmt.Sprintf("user override OLLAMA_MAX_VRAM=%d", avail))
|
|
||||||
return uint64(avail), nil
|
|
||||||
}
|
|
||||||
gpuInfo := GetGPUInfo()
|
|
||||||
if gpuInfo.FreeMemory > 0 && (gpuInfo.Library == "cuda" || gpuInfo.Library == "rocm") {
|
|
||||||
return gpuInfo.FreeMemory, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
return 0, fmt.Errorf("no GPU detected") // TODO - better handling of CPU based memory determiniation
|
|
||||||
}
|
|
||||||
|
|
||||||
func FindGPULibs(baseLibName string, patterns []string) []string {
|
|
||||||
// Multiple GPU libraries may exist, and some may not work, so keep trying until we exhaust them
|
// Multiple GPU libraries may exist, and some may not work, so keep trying until we exhaust them
|
||||||
var ldPaths []string
|
var ldPaths []string
|
||||||
|
var patterns []string
|
||||||
gpuLibPaths := []string{}
|
gpuLibPaths := []string{}
|
||||||
slog.Info(fmt.Sprintf("Searching for GPU management library %s", baseLibName))
|
slog.Debug("Searching for GPU library", "name", baseLibName)
|
||||||
|
|
||||||
switch runtime.GOOS {
|
switch runtime.GOOS {
|
||||||
case "windows":
|
case "windows":
|
||||||
|
@ -283,8 +270,14 @@ func FindGPULibs(baseLibName string, patterns []string) []string {
|
||||||
}
|
}
|
||||||
patterns = append(patterns, filepath.Join(d, baseLibName+"*"))
|
patterns = append(patterns, filepath.Join(d, baseLibName+"*"))
|
||||||
}
|
}
|
||||||
slog.Debug(fmt.Sprintf("gpu management search paths: %v", patterns))
|
patterns = append(patterns, defaultPatterns...)
|
||||||
|
slog.Debug("gpu library search", "globs", patterns)
|
||||||
for _, pattern := range patterns {
|
for _, pattern := range patterns {
|
||||||
|
|
||||||
|
// Nvidia PhysX known to return bogus results
|
||||||
|
if strings.Contains(pattern, "PhysX") {
|
||||||
|
slog.Debug("skipping PhysX cuda library path", "path", pattern)
|
||||||
|
}
|
||||||
// Ignore glob discovery errors
|
// Ignore glob discovery errors
|
||||||
matches, _ := filepath.Glob(pattern)
|
matches, _ := filepath.Glob(pattern)
|
||||||
for _, match := range matches {
|
for _, match := range matches {
|
||||||
|
@ -311,28 +304,11 @@ func FindGPULibs(baseLibName string, patterns []string) []string {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
slog.Info(fmt.Sprintf("Discovered GPU libraries: %v", gpuLibPaths))
|
slog.Debug("discovered GPU libraries", "paths", gpuLibPaths)
|
||||||
return gpuLibPaths
|
return gpuLibPaths
|
||||||
}
|
}
|
||||||
|
|
||||||
func LoadNVMLMgmt(nvmlLibPaths []string) *C.nvml_handle_t {
|
func LoadCUDARTMgmt(cudartLibPaths []string) (int, *C.cudart_handle_t, string) {
|
||||||
var resp C.nvml_init_resp_t
|
|
||||||
resp.ch.verbose = getVerboseState()
|
|
||||||
for _, libPath := range nvmlLibPaths {
|
|
||||||
lib := C.CString(libPath)
|
|
||||||
defer C.free(unsafe.Pointer(lib))
|
|
||||||
C.nvml_init(lib, &resp)
|
|
||||||
if resp.err != nil {
|
|
||||||
slog.Info(fmt.Sprintf("Unable to load NVML management library %s: %s", libPath, C.GoString(resp.err)))
|
|
||||||
C.free(unsafe.Pointer(resp.err))
|
|
||||||
} else {
|
|
||||||
return &resp.ch
|
|
||||||
}
|
|
||||||
}
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func LoadCUDARTMgmt(cudartLibPaths []string) *C.cudart_handle_t {
|
|
||||||
var resp C.cudart_init_resp_t
|
var resp C.cudart_init_resp_t
|
||||||
resp.ch.verbose = getVerboseState()
|
resp.ch.verbose = getVerboseState()
|
||||||
for _, libPath := range cudartLibPaths {
|
for _, libPath := range cudartLibPaths {
|
||||||
|
@ -340,18 +316,54 @@ func LoadCUDARTMgmt(cudartLibPaths []string) *C.cudart_handle_t {
|
||||||
defer C.free(unsafe.Pointer(lib))
|
defer C.free(unsafe.Pointer(lib))
|
||||||
C.cudart_init(lib, &resp)
|
C.cudart_init(lib, &resp)
|
||||||
if resp.err != nil {
|
if resp.err != nil {
|
||||||
slog.Info(fmt.Sprintf("Unable to load cudart CUDA management library %s: %s", libPath, C.GoString(resp.err)))
|
slog.Debug("Unable to load cudart", "library", libPath, "error", C.GoString(resp.err))
|
||||||
C.free(unsafe.Pointer(resp.err))
|
C.free(unsafe.Pointer(resp.err))
|
||||||
} else {
|
} else {
|
||||||
return &resp.ch
|
return int(resp.num_devices), &resp.ch, libPath
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
return nil
|
return 0, nil, ""
|
||||||
|
}
|
||||||
|
|
||||||
|
func LoadNVCUDAMgmt(nvcudaLibPaths []string) (int, *C.nvcuda_handle_t, string) {
|
||||||
|
var resp C.nvcuda_init_resp_t
|
||||||
|
resp.ch.verbose = getVerboseState()
|
||||||
|
for _, libPath := range nvcudaLibPaths {
|
||||||
|
lib := C.CString(libPath)
|
||||||
|
defer C.free(unsafe.Pointer(lib))
|
||||||
|
C.nvcuda_init(lib, &resp)
|
||||||
|
if resp.err != nil {
|
||||||
|
slog.Debug("Unable to load nvcuda", "library", libPath, "error", C.GoString(resp.err))
|
||||||
|
C.free(unsafe.Pointer(resp.err))
|
||||||
|
} else {
|
||||||
|
return int(resp.num_devices), &resp.ch, libPath
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return 0, nil, ""
|
||||||
}
|
}
|
||||||
|
|
||||||
func getVerboseState() C.uint16_t {
|
func getVerboseState() C.uint16_t {
|
||||||
if debug := os.Getenv("OLLAMA_DEBUG"); debug != "" {
|
if envconfig.Debug {
|
||||||
return C.uint16_t(1)
|
return C.uint16_t(1)
|
||||||
}
|
}
|
||||||
return C.uint16_t(0)
|
return C.uint16_t(0)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Given the list of GPUs this instantiation is targeted for,
|
||||||
|
// figure out the visible devices environment variable
|
||||||
|
//
|
||||||
|
// If different libraries are detected, the first one is what we use
|
||||||
|
func (l GpuInfoList) GetVisibleDevicesEnv() (string, string) {
|
||||||
|
if len(l) == 0 {
|
||||||
|
return "", ""
|
||||||
|
}
|
||||||
|
switch l[0].Library {
|
||||||
|
case "cuda":
|
||||||
|
return cudaGetVisibleDevicesEnv(l)
|
||||||
|
case "rocm":
|
||||||
|
return rocmGetVisibleDevicesEnv(l)
|
||||||
|
default:
|
||||||
|
slog.Debug("no filter required for library " + l[0].Library)
|
||||||
|
return "", ""
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
|
@ -9,52 +9,47 @@ package gpu
|
||||||
*/
|
*/
|
||||||
import "C"
|
import "C"
|
||||||
import (
|
import (
|
||||||
"fmt"
|
|
||||||
"log/slog"
|
|
||||||
"os"
|
|
||||||
"runtime"
|
"runtime"
|
||||||
"strconv"
|
|
||||||
|
"github.com/ollama/ollama/format"
|
||||||
)
|
)
|
||||||
|
|
||||||
// CheckVRAM returns the free VRAM in bytes on Linux machines with NVIDIA GPUs
|
const (
|
||||||
func CheckVRAM() (uint64, error) {
|
metalMinimumMemory = 384 * format.MebiByte
|
||||||
userLimit := os.Getenv("OLLAMA_MAX_VRAM")
|
)
|
||||||
if userLimit != "" {
|
|
||||||
avail, err := strconv.ParseInt(userLimit, 10, 64)
|
|
||||||
if err != nil {
|
|
||||||
return 0, fmt.Errorf("Invalid OLLAMA_MAX_VRAM setting %s: %s", userLimit, err)
|
|
||||||
}
|
|
||||||
slog.Info(fmt.Sprintf("user override OLLAMA_MAX_VRAM=%d", avail))
|
|
||||||
return uint64(avail), nil
|
|
||||||
}
|
|
||||||
|
|
||||||
|
func GetGPUInfo() GpuInfoList {
|
||||||
|
mem, _ := GetCPUMem()
|
||||||
if runtime.GOARCH == "amd64" {
|
if runtime.GOARCH == "amd64" {
|
||||||
// gpu not supported, this may not be metal
|
return []GpuInfo{
|
||||||
return 0, nil
|
{
|
||||||
}
|
|
||||||
|
|
||||||
return uint64(C.getRecommendedMaxVRAM()), nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func GetGPUInfo() GpuInfo {
|
|
||||||
mem, _ := getCPUMem()
|
|
||||||
if runtime.GOARCH == "amd64" {
|
|
||||||
return GpuInfo{
|
|
||||||
Library: "cpu",
|
Library: "cpu",
|
||||||
Variant: GetCPUVariant(),
|
Variant: GetCPUVariant(),
|
||||||
memInfo: mem,
|
memInfo: mem,
|
||||||
|
},
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
return GpuInfo{
|
info := GpuInfo{
|
||||||
Library: "metal",
|
Library: "metal",
|
||||||
memInfo: mem,
|
ID: "0",
|
||||||
}
|
}
|
||||||
|
info.TotalMemory = uint64(C.getRecommendedMaxVRAM())
|
||||||
|
|
||||||
|
// TODO is there a way to gather actual allocated video memory? (currentAllocatedSize doesn't work)
|
||||||
|
info.FreeMemory = info.TotalMemory
|
||||||
|
|
||||||
|
info.MinimumMemory = metalMinimumMemory
|
||||||
|
return []GpuInfo{info}
|
||||||
}
|
}
|
||||||
|
|
||||||
func getCPUMem() (memInfo, error) {
|
func GetCPUMem() (memInfo, error) {
|
||||||
return memInfo{
|
return memInfo{
|
||||||
TotalMemory: uint64(C.getPhysicalMemory()),
|
TotalMemory: uint64(C.getPhysicalMemory()),
|
||||||
FreeMemory: 0,
|
FreeMemory: 0,
|
||||||
DeviceCount: 1,
|
|
||||||
}, nil
|
}, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (l GpuInfoList) GetVisibleDevicesEnv() (string, string) {
|
||||||
|
// No-op on darwin
|
||||||
|
return "", ""
|
||||||
|
}
|
||||||
|
|
|
@ -38,12 +38,17 @@
|
||||||
extern "C" {
|
extern "C" {
|
||||||
#endif
|
#endif
|
||||||
|
|
||||||
|
#define GPU_ID_LEN 64
|
||||||
|
|
||||||
typedef struct mem_info {
|
typedef struct mem_info {
|
||||||
|
char *err; // If non-nill, caller responsible for freeing
|
||||||
|
char gpu_id[GPU_ID_LEN];
|
||||||
uint64_t total;
|
uint64_t total;
|
||||||
uint64_t free;
|
uint64_t free;
|
||||||
unsigned int count;
|
|
||||||
int igpu_index; // If >= 0, we detected an integrated GPU to ignore
|
// Compute Capability
|
||||||
char *err; // If non-nill, caller responsible for freeing
|
int major;
|
||||||
|
int minor;
|
||||||
} mem_info_t;
|
} mem_info_t;
|
||||||
|
|
||||||
void cpu_check_ram(mem_info_t *resp);
|
void cpu_check_ram(mem_info_t *resp);
|
||||||
|
@ -52,8 +57,8 @@ void cpu_check_ram(mem_info_t *resp);
|
||||||
}
|
}
|
||||||
#endif
|
#endif
|
||||||
|
|
||||||
#include "gpu_info_nvml.h"
|
|
||||||
#include "gpu_info_cudart.h"
|
#include "gpu_info_cudart.h"
|
||||||
|
#include "gpu_info_nvcuda.h"
|
||||||
|
|
||||||
#endif // __GPU_INFO_H__
|
#endif // __GPU_INFO_H__
|
||||||
#endif // __APPLE__
|
#endif // __APPLE__
|
|
@ -8,9 +8,11 @@ void cpu_check_ram(mem_info_t *resp) {
|
||||||
MEMORYSTATUSEX info;
|
MEMORYSTATUSEX info;
|
||||||
info.dwLength = sizeof(info);
|
info.dwLength = sizeof(info);
|
||||||
if (GlobalMemoryStatusEx(&info) != 0) {
|
if (GlobalMemoryStatusEx(&info) != 0) {
|
||||||
resp->count = 1;
|
|
||||||
resp->total = info.ullTotalPhys;
|
resp->total = info.ullTotalPhys;
|
||||||
resp->free = info.ullAvailPhys;
|
resp->free = info.ullAvailPhys;
|
||||||
|
resp->major = 0;
|
||||||
|
resp->minor = 0;
|
||||||
|
snprintf(&resp->gpu_id[0], GPU_ID_LEN, "0");
|
||||||
} else {
|
} else {
|
||||||
resp->err = LOAD_ERR();
|
resp->err = LOAD_ERR();
|
||||||
}
|
}
|
||||||
|
@ -27,9 +29,11 @@ void cpu_check_ram(mem_info_t *resp) {
|
||||||
if (sysinfo(&info) != 0) {
|
if (sysinfo(&info) != 0) {
|
||||||
resp->err = strdup(strerror(errno));
|
resp->err = strdup(strerror(errno));
|
||||||
} else {
|
} else {
|
||||||
resp->count = 1;
|
|
||||||
resp->total = info.totalram * info.mem_unit;
|
resp->total = info.totalram * info.mem_unit;
|
||||||
resp->free = info.freeram * info.mem_unit;
|
resp->free = info.freeram * info.mem_unit;
|
||||||
|
resp->major = 0;
|
||||||
|
resp->minor = 0;
|
||||||
|
snprintf(&resp->gpu_id[0], GPU_ID_LEN, "0");
|
||||||
}
|
}
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
|
|
|
@ -6,6 +6,7 @@
|
||||||
void cudart_init(char *cudart_lib_path, cudart_init_resp_t *resp) {
|
void cudart_init(char *cudart_lib_path, cudart_init_resp_t *resp) {
|
||||||
cudartReturn_t ret;
|
cudartReturn_t ret;
|
||||||
resp->err = NULL;
|
resp->err = NULL;
|
||||||
|
resp->num_devices = 0;
|
||||||
const int buflen = 256;
|
const int buflen = 256;
|
||||||
char buf[buflen + 1];
|
char buf[buflen + 1];
|
||||||
int i;
|
int i;
|
||||||
|
@ -21,6 +22,7 @@ void cudart_init(char *cudart_lib_path, cudart_init_resp_t *resp) {
|
||||||
{"cudaGetDeviceCount", (void *)&resp->ch.cudaGetDeviceCount},
|
{"cudaGetDeviceCount", (void *)&resp->ch.cudaGetDeviceCount},
|
||||||
{"cudaDeviceGetAttribute", (void *)&resp->ch.cudaDeviceGetAttribute},
|
{"cudaDeviceGetAttribute", (void *)&resp->ch.cudaDeviceGetAttribute},
|
||||||
{"cudaDriverGetVersion", (void *)&resp->ch.cudaDriverGetVersion},
|
{"cudaDriverGetVersion", (void *)&resp->ch.cudaDriverGetVersion},
|
||||||
|
{"cudaGetDeviceProperties", (void *)&resp->ch.cudaGetDeviceProperties},
|
||||||
{NULL, NULL},
|
{NULL, NULL},
|
||||||
};
|
};
|
||||||
|
|
||||||
|
@ -36,13 +38,7 @@ void cudart_init(char *cudart_lib_path, cudart_init_resp_t *resp) {
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
|
|
||||||
// TODO once we've squashed the remaining corner cases remove this log
|
|
||||||
LOG(resp->ch.verbose, "wiring cudart library functions in %s\n", cudart_lib_path);
|
|
||||||
|
|
||||||
for (i = 0; l[i].s != NULL; i++) {
|
for (i = 0; l[i].s != NULL; i++) {
|
||||||
// TODO once we've squashed the remaining corner cases remove this log
|
|
||||||
LOG(resp->ch.verbose, "dlsym: %s\n", l[i].s);
|
|
||||||
|
|
||||||
*l[i].p = LOAD_SYMBOL(resp->ch.handle, l[i].s);
|
*l[i].p = LOAD_SYMBOL(resp->ch.handle, l[i].s);
|
||||||
if (!l[i].p) {
|
if (!l[i].p) {
|
||||||
char *msg = LOAD_ERR();
|
char *msg = LOAD_ERR();
|
||||||
|
@ -63,7 +59,7 @@ void cudart_init(char *cudart_lib_path, cudart_init_resp_t *resp) {
|
||||||
UNLOAD_LIBRARY(resp->ch.handle);
|
UNLOAD_LIBRARY(resp->ch.handle);
|
||||||
resp->ch.handle = NULL;
|
resp->ch.handle = NULL;
|
||||||
if (ret == CUDA_ERROR_INSUFFICIENT_DRIVER) {
|
if (ret == CUDA_ERROR_INSUFFICIENT_DRIVER) {
|
||||||
resp->err = strdup("your nvidia driver is too old or missing, please upgrade to run ollama");
|
resp->err = strdup("your nvidia driver is too old or missing. If you have a CUDA GPU please upgrade to run ollama");
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
snprintf(buf, buflen, "cudart init failure: %d", ret);
|
snprintf(buf, buflen, "cudart init failure: %d", ret);
|
||||||
|
@ -85,42 +81,82 @@ void cudart_init(char *cudart_lib_path, cudart_init_resp_t *resp) {
|
||||||
driverVersion.minor = (version - (driverVersion.major * 1000)) / 10;
|
driverVersion.minor = (version - (driverVersion.major * 1000)) / 10;
|
||||||
LOG(resp->ch.verbose, "CUDA driver version: %d-%d\n", driverVersion.major, driverVersion.minor);
|
LOG(resp->ch.verbose, "CUDA driver version: %d-%d\n", driverVersion.major, driverVersion.minor);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
ret = (*resp->ch.cudaGetDeviceCount)(&resp->num_devices);
|
||||||
|
if (ret != CUDART_SUCCESS) {
|
||||||
|
LOG(resp->ch.verbose, "cudaGetDeviceCount err: %d\n", ret);
|
||||||
|
UNLOAD_LIBRARY(resp->ch.handle);
|
||||||
|
resp->ch.handle = NULL;
|
||||||
|
snprintf(buf, buflen, "unable to get device count: %d", ret);
|
||||||
|
resp->err = strdup(buf);
|
||||||
|
return;
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
void cudart_check_vram(cudart_handle_t h, mem_info_t *resp) {
|
void cudart_check_vram(cudart_handle_t h, int i, mem_info_t *resp) {
|
||||||
resp->err = NULL;
|
resp->err = NULL;
|
||||||
cudartMemory_t memInfo = {0,0,0};
|
cudartMemory_t memInfo = {0,0,0};
|
||||||
cudartReturn_t ret;
|
cudartReturn_t ret;
|
||||||
const int buflen = 256;
|
const int buflen = 256;
|
||||||
char buf[buflen + 1];
|
char buf[buflen + 1];
|
||||||
int i;
|
|
||||||
|
|
||||||
if (h.handle == NULL) {
|
if (h.handle == NULL) {
|
||||||
resp->err = strdup("cudart handle isn't initialized");
|
resp->err = strdup("cudart handle isn't initialized");
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
|
|
||||||
// cudaGetDeviceCount takes int type, resp-> count is uint
|
|
||||||
int deviceCount;
|
|
||||||
ret = (*h.cudaGetDeviceCount)(&deviceCount);
|
|
||||||
if (ret != CUDART_SUCCESS) {
|
|
||||||
snprintf(buf, buflen, "unable to get device count: %d", ret);
|
|
||||||
resp->err = strdup(buf);
|
|
||||||
return;
|
|
||||||
} else {
|
|
||||||
resp->count = (unsigned int)deviceCount;
|
|
||||||
}
|
|
||||||
|
|
||||||
resp->total = 0;
|
|
||||||
resp->free = 0;
|
|
||||||
for (i = 0; i < resp-> count; i++) {
|
|
||||||
ret = (*h.cudaSetDevice)(i);
|
ret = (*h.cudaSetDevice)(i);
|
||||||
if (ret != CUDART_SUCCESS) {
|
if (ret != CUDART_SUCCESS) {
|
||||||
snprintf(buf, buflen, "cudart device failed to initialize");
|
snprintf(buf, buflen, "cudart device failed to initialize");
|
||||||
resp->err = strdup(buf);
|
resp->err = strdup(buf);
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
cudaDeviceProp_t props;
|
||||||
|
ret = (*h.cudaGetDeviceProperties)(&props, i);
|
||||||
|
if (ret != CUDART_SUCCESS) {
|
||||||
|
LOG(h.verbose, "[%d] device properties lookup failure: %d\n", i, ret);
|
||||||
|
snprintf(&resp->gpu_id[0], GPU_ID_LEN, "%d", i);
|
||||||
|
resp->major = 0;
|
||||||
|
resp->minor = 0;
|
||||||
|
} else {
|
||||||
|
int allNull = 1;
|
||||||
|
for (int j = 0; j < 16; j++) {
|
||||||
|
if (props.uuid.bytes[j] != 0) {
|
||||||
|
allNull = 0;
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if (allNull != 0) {
|
||||||
|
snprintf(&resp->gpu_id[0], GPU_ID_LEN, "%d", i);
|
||||||
|
} else {
|
||||||
|
// GPU-d110a105-ac29-1d54-7b49-9c90440f215b
|
||||||
|
snprintf(&resp->gpu_id[0], GPU_ID_LEN,
|
||||||
|
"GPU-%02x%02x%02x%02x-%02x%02x-%02x%02x-%02x%02x-%02x%02x%02x%02x%02x%02x",
|
||||||
|
props.uuid.bytes[0],
|
||||||
|
props.uuid.bytes[1],
|
||||||
|
props.uuid.bytes[2],
|
||||||
|
props.uuid.bytes[3],
|
||||||
|
props.uuid.bytes[4],
|
||||||
|
props.uuid.bytes[5],
|
||||||
|
props.uuid.bytes[6],
|
||||||
|
props.uuid.bytes[7],
|
||||||
|
props.uuid.bytes[8],
|
||||||
|
props.uuid.bytes[9],
|
||||||
|
props.uuid.bytes[10],
|
||||||
|
props.uuid.bytes[11],
|
||||||
|
props.uuid.bytes[12],
|
||||||
|
props.uuid.bytes[13],
|
||||||
|
props.uuid.bytes[14],
|
||||||
|
props.uuid.bytes[15]
|
||||||
|
);
|
||||||
|
}
|
||||||
|
resp->major = props.major;
|
||||||
|
resp->minor = props.minor;
|
||||||
|
|
||||||
|
// TODO add other useful properties from props
|
||||||
|
}
|
||||||
ret = (*h.cudaMemGetInfo)(&memInfo.free, &memInfo.total);
|
ret = (*h.cudaMemGetInfo)(&memInfo.free, &memInfo.total);
|
||||||
if (ret != CUDART_SUCCESS) {
|
if (ret != CUDART_SUCCESS) {
|
||||||
snprintf(buf, buflen, "cudart device memory info lookup failure %d", ret);
|
snprintf(buf, buflen, "cudart device memory info lookup failure %d", ret);
|
||||||
|
@ -128,67 +164,12 @@ void cudart_check_vram(cudart_handle_t h, mem_info_t *resp) {
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
|
|
||||||
LOG(h.verbose, "[%d] CUDA totalMem %lu\n", i, memInfo.total);
|
resp->total = memInfo.total;
|
||||||
LOG(h.verbose, "[%d] CUDA freeMem %lu\n", i, memInfo.free);
|
resp->free = memInfo.free;
|
||||||
|
|
||||||
resp->total += memInfo.total;
|
LOG(h.verbose, "[%s] CUDA totalMem %lu\n", resp->gpu_id, resp->total);
|
||||||
resp->free += memInfo.free;
|
LOG(h.verbose, "[%s] CUDA freeMem %lu\n", resp->gpu_id, resp->free);
|
||||||
}
|
LOG(h.verbose, "[%s] Compute Capability %d.%d\n", resp->gpu_id, resp->major, resp->minor);
|
||||||
}
|
|
||||||
|
|
||||||
void cudart_compute_capability(cudart_handle_t h, cudart_compute_capability_t *resp) {
|
|
||||||
resp->err = NULL;
|
|
||||||
resp->major = 0;
|
|
||||||
resp->minor = 0;
|
|
||||||
int major = 0;
|
|
||||||
int minor = 0;
|
|
||||||
cudartReturn_t ret;
|
|
||||||
const int buflen = 256;
|
|
||||||
char buf[buflen + 1];
|
|
||||||
int i;
|
|
||||||
|
|
||||||
if (h.handle == NULL) {
|
|
||||||
resp->err = strdup("cudart handle not initialized");
|
|
||||||
return;
|
|
||||||
}
|
|
||||||
|
|
||||||
int devices;
|
|
||||||
ret = (*h.cudaGetDeviceCount)(&devices);
|
|
||||||
if (ret != CUDART_SUCCESS) {
|
|
||||||
snprintf(buf, buflen, "unable to get cudart device count: %d", ret);
|
|
||||||
resp->err = strdup(buf);
|
|
||||||
return;
|
|
||||||
}
|
|
||||||
|
|
||||||
for (i = 0; i < devices; i++) {
|
|
||||||
ret = (*h.cudaSetDevice)(i);
|
|
||||||
if (ret != CUDART_SUCCESS) {
|
|
||||||
snprintf(buf, buflen, "cudart device failed to initialize");
|
|
||||||
resp->err = strdup(buf);
|
|
||||||
return;
|
|
||||||
}
|
|
||||||
|
|
||||||
ret = (*h.cudaDeviceGetAttribute)(&major, cudartDevAttrComputeCapabilityMajor, i);
|
|
||||||
if (ret != CUDART_SUCCESS) {
|
|
||||||
snprintf(buf, buflen, "device compute capability lookup failure %d: %d", i, ret);
|
|
||||||
resp->err = strdup(buf);
|
|
||||||
return;
|
|
||||||
}
|
|
||||||
ret = (*h.cudaDeviceGetAttribute)(&minor, cudartDevAttrComputeCapabilityMinor, i);
|
|
||||||
if (ret != CUDART_SUCCESS) {
|
|
||||||
snprintf(buf, buflen, "device compute capability lookup failure %d: %d", i, ret);
|
|
||||||
resp->err = strdup(buf);
|
|
||||||
return;
|
|
||||||
}
|
|
||||||
|
|
||||||
// Report the lowest major.minor we detect as that limits our compatibility
|
|
||||||
if (resp->major == 0 || resp->major > major ) {
|
|
||||||
resp->major = major;
|
|
||||||
resp->minor = minor;
|
|
||||||
} else if ( resp->major == major && resp->minor > minor ) {
|
|
||||||
resp->minor = minor;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
void cudart_release(cudart_handle_t h) {
|
void cudart_release(cudart_handle_t h) {
|
||||||
|
|
|
@ -6,14 +6,20 @@
|
||||||
// Just enough typedef's to dlopen/dlsym for memory information
|
// Just enough typedef's to dlopen/dlsym for memory information
|
||||||
typedef enum cudartReturn_enum {
|
typedef enum cudartReturn_enum {
|
||||||
CUDART_SUCCESS = 0,
|
CUDART_SUCCESS = 0,
|
||||||
CUDART_UNSUPPORTED = 1,
|
CUDART_ERROR_INVALID_VALUE = 1,
|
||||||
CUDA_ERROR_INSUFFICIENT_DRIVER = 35,
|
CUDART_ERROR_MEMORY_ALLOCATION = 2,
|
||||||
|
CUDART_ERROR_INSUFFICIENT_DRIVER = 35,
|
||||||
// Other values omitted for now...
|
// Other values omitted for now...
|
||||||
} cudartReturn_t;
|
} cudartReturn_t;
|
||||||
|
|
||||||
typedef enum cudartDeviceAttr_enum {
|
typedef enum cudartDeviceAttr_enum {
|
||||||
cudartDevAttrComputeCapabilityMajor = 75,
|
cudartDevAttrComputeCapabilityMajor = 75,
|
||||||
cudartDevAttrComputeCapabilityMinor = 76,
|
cudartDevAttrComputeCapabilityMinor = 76,
|
||||||
|
|
||||||
|
// TODO - not yet wired up but may be useful for Jetson or other
|
||||||
|
// integrated GPU scenarios with shared memory
|
||||||
|
cudaDevAttrIntegrated = 18
|
||||||
|
|
||||||
} cudartDeviceAttr_t;
|
} cudartDeviceAttr_t;
|
||||||
|
|
||||||
typedef void *cudartDevice_t; // Opaque is sufficient
|
typedef void *cudartDevice_t; // Opaque is sufficient
|
||||||
|
@ -28,6 +34,92 @@ typedef struct cudartDriverVersion {
|
||||||
int minor;
|
int minor;
|
||||||
} cudartDriverVersion_t;
|
} cudartDriverVersion_t;
|
||||||
|
|
||||||
|
typedef struct cudaUUID {
|
||||||
|
unsigned char bytes[16];
|
||||||
|
} cudaUUID_t;
|
||||||
|
typedef struct cudaDeviceProp {
|
||||||
|
char name[256]; /**< ASCII string identifying device */
|
||||||
|
cudaUUID_t uuid; /**< 16-byte unique identifier */
|
||||||
|
char luid[8]; /**< 8-byte locally unique identifier. Value is undefined on TCC and non-Windows platforms */
|
||||||
|
unsigned int luidDeviceNodeMask; /**< LUID device node mask. Value is undefined on TCC and non-Windows platforms */
|
||||||
|
size_t totalGlobalMem; /**< Global memory available on device in bytes */
|
||||||
|
size_t sharedMemPerBlock; /**< Shared memory available per block in bytes */
|
||||||
|
int regsPerBlock; /**< 32-bit registers available per block */
|
||||||
|
int warpSize; /**< Warp size in threads */
|
||||||
|
size_t memPitch; /**< Maximum pitch in bytes allowed by memory copies */
|
||||||
|
int maxThreadsPerBlock; /**< Maximum number of threads per block */
|
||||||
|
int maxThreadsDim[3]; /**< Maximum size of each dimension of a block */
|
||||||
|
int maxGridSize[3]; /**< Maximum size of each dimension of a grid */
|
||||||
|
int clockRate; /**< Clock frequency in kilohertz */
|
||||||
|
size_t totalConstMem; /**< Constant memory available on device in bytes */
|
||||||
|
int major; /**< Major compute capability */
|
||||||
|
int minor; /**< Minor compute capability */
|
||||||
|
size_t textureAlignment; /**< Alignment requirement for textures */
|
||||||
|
size_t texturePitchAlignment; /**< Pitch alignment requirement for texture references bound to pitched memory */
|
||||||
|
int deviceOverlap; /**< Device can concurrently copy memory and execute a kernel. Deprecated. Use instead asyncEngineCount. */
|
||||||
|
int multiProcessorCount; /**< Number of multiprocessors on device */
|
||||||
|
int kernelExecTimeoutEnabled; /**< Specified whether there is a run time limit on kernels */
|
||||||
|
int integrated; /**< Device is integrated as opposed to discrete */
|
||||||
|
int canMapHostMemory; /**< Device can map host memory with cudaHostAlloc/cudaHostGetDevicePointer */
|
||||||
|
int computeMode; /**< Compute mode (See ::cudaComputeMode) */
|
||||||
|
int maxTexture1D; /**< Maximum 1D texture size */
|
||||||
|
int maxTexture1DMipmap; /**< Maximum 1D mipmapped texture size */
|
||||||
|
int maxTexture1DLinear; /**< Deprecated, do not use. Use cudaDeviceGetTexture1DLinearMaxWidth() or cuDeviceGetTexture1DLinearMaxWidth() instead. */
|
||||||
|
int maxTexture2D[2]; /**< Maximum 2D texture dimensions */
|
||||||
|
int maxTexture2DMipmap[2]; /**< Maximum 2D mipmapped texture dimensions */
|
||||||
|
int maxTexture2DLinear[3]; /**< Maximum dimensions (width, height, pitch) for 2D textures bound to pitched memory */
|
||||||
|
int maxTexture2DGather[2]; /**< Maximum 2D texture dimensions if texture gather operations have to be performed */
|
||||||
|
int maxTexture3D[3]; /**< Maximum 3D texture dimensions */
|
||||||
|
int maxTexture3DAlt[3]; /**< Maximum alternate 3D texture dimensions */
|
||||||
|
int maxTextureCubemap; /**< Maximum Cubemap texture dimensions */
|
||||||
|
int maxTexture1DLayered[2]; /**< Maximum 1D layered texture dimensions */
|
||||||
|
int maxTexture2DLayered[3]; /**< Maximum 2D layered texture dimensions */
|
||||||
|
int maxTextureCubemapLayered[2];/**< Maximum Cubemap layered texture dimensions */
|
||||||
|
int maxSurface1D; /**< Maximum 1D surface size */
|
||||||
|
int maxSurface2D[2]; /**< Maximum 2D surface dimensions */
|
||||||
|
int maxSurface3D[3]; /**< Maximum 3D surface dimensions */
|
||||||
|
int maxSurface1DLayered[2]; /**< Maximum 1D layered surface dimensions */
|
||||||
|
int maxSurface2DLayered[3]; /**< Maximum 2D layered surface dimensions */
|
||||||
|
int maxSurfaceCubemap; /**< Maximum Cubemap surface dimensions */
|
||||||
|
int maxSurfaceCubemapLayered[2];/**< Maximum Cubemap layered surface dimensions */
|
||||||
|
size_t surfaceAlignment; /**< Alignment requirements for surfaces */
|
||||||
|
int concurrentKernels; /**< Device can possibly execute multiple kernels concurrently */
|
||||||
|
int ECCEnabled; /**< Device has ECC support enabled */
|
||||||
|
int pciBusID; /**< PCI bus ID of the device */
|
||||||
|
int pciDeviceID; /**< PCI device ID of the device */
|
||||||
|
int pciDomainID; /**< PCI domain ID of the device */
|
||||||
|
int tccDriver; /**< 1 if device is a Tesla device using TCC driver, 0 otherwise */
|
||||||
|
int asyncEngineCount; /**< Number of asynchronous engines */
|
||||||
|
int unifiedAddressing; /**< Device shares a unified address space with the host */
|
||||||
|
int memoryClockRate; /**< Peak memory clock frequency in kilohertz */
|
||||||
|
int memoryBusWidth; /**< Global memory bus width in bits */
|
||||||
|
int l2CacheSize; /**< Size of L2 cache in bytes */
|
||||||
|
int persistingL2CacheMaxSize; /**< Device's maximum l2 persisting lines capacity setting in bytes */
|
||||||
|
int maxThreadsPerMultiProcessor;/**< Maximum resident threads per multiprocessor */
|
||||||
|
int streamPrioritiesSupported; /**< Device supports stream priorities */
|
||||||
|
int globalL1CacheSupported; /**< Device supports caching globals in L1 */
|
||||||
|
int localL1CacheSupported; /**< Device supports caching locals in L1 */
|
||||||
|
size_t sharedMemPerMultiprocessor; /**< Shared memory available per multiprocessor in bytes */
|
||||||
|
int regsPerMultiprocessor; /**< 32-bit registers available per multiprocessor */
|
||||||
|
int managedMemory; /**< Device supports allocating managed memory on this system */
|
||||||
|
int isMultiGpuBoard; /**< Device is on a multi-GPU board */
|
||||||
|
int multiGpuBoardGroupID; /**< Unique identifier for a group of devices on the same multi-GPU board */
|
||||||
|
int hostNativeAtomicSupported; /**< Link between the device and the host supports native atomic operations */
|
||||||
|
int singleToDoublePrecisionPerfRatio; /**< Ratio of single precision performance (in floating-point operations per second) to double precision performance */
|
||||||
|
int pageableMemoryAccess; /**< Device supports coherently accessing pageable memory without calling cudaHostRegister on it */
|
||||||
|
int concurrentManagedAccess; /**< Device can coherently access managed memory concurrently with the CPU */
|
||||||
|
int computePreemptionSupported; /**< Device supports Compute Preemption */
|
||||||
|
int canUseHostPointerForRegisteredMem; /**< Device can access host registered memory at the same virtual address as the CPU */
|
||||||
|
int cooperativeLaunch; /**< Device supports launching cooperative kernels via ::cudaLaunchCooperativeKernel */
|
||||||
|
int cooperativeMultiDeviceLaunch; /**< Deprecated, cudaLaunchCooperativeKernelMultiDevice is deprecated. */
|
||||||
|
size_t sharedMemPerBlockOptin; /**< Per device maximum shared memory per block usable by special opt in */
|
||||||
|
int pageableMemoryAccessUsesHostPageTables; /**< Device accesses pageable memory via the host's page tables */
|
||||||
|
int directManagedMemAccessFromHost; /**< Host can directly access managed memory on the device without migration. */
|
||||||
|
int maxBlocksPerMultiProcessor; /**< Maximum number of resident blocks per multiprocessor */
|
||||||
|
int accessPolicyMaxWindowSize; /**< The maximum value of ::cudaAccessPolicyWindow::num_bytes. */
|
||||||
|
size_t reservedSharedMemPerBlock; /**< Shared memory reserved by CUDA driver per block in bytes */
|
||||||
|
} cudaDeviceProp_t;
|
||||||
|
|
||||||
typedef struct cudart_handle {
|
typedef struct cudart_handle {
|
||||||
void *handle;
|
void *handle;
|
||||||
uint16_t verbose;
|
uint16_t verbose;
|
||||||
|
@ -38,23 +130,17 @@ typedef struct cudart_handle {
|
||||||
cudartReturn_t (*cudaGetDeviceCount)(int *);
|
cudartReturn_t (*cudaGetDeviceCount)(int *);
|
||||||
cudartReturn_t (*cudaDeviceGetAttribute)(int* value, cudartDeviceAttr_t attr, int device);
|
cudartReturn_t (*cudaDeviceGetAttribute)(int* value, cudartDeviceAttr_t attr, int device);
|
||||||
cudartReturn_t (*cudaDriverGetVersion) (int *driverVersion);
|
cudartReturn_t (*cudaDriverGetVersion) (int *driverVersion);
|
||||||
|
cudartReturn_t (*cudaGetDeviceProperties) (cudaDeviceProp_t* prop, int device);
|
||||||
} cudart_handle_t;
|
} cudart_handle_t;
|
||||||
|
|
||||||
typedef struct cudart_init_resp {
|
typedef struct cudart_init_resp {
|
||||||
char *err; // If err is non-null handle is invalid
|
char *err; // If err is non-null handle is invalid
|
||||||
cudart_handle_t ch;
|
cudart_handle_t ch;
|
||||||
|
int num_devices;
|
||||||
} cudart_init_resp_t;
|
} cudart_init_resp_t;
|
||||||
|
|
||||||
typedef struct cudart_compute_capability {
|
|
||||||
char *err;
|
|
||||||
int major;
|
|
||||||
int minor;
|
|
||||||
} cudart_compute_capability_t;
|
|
||||||
|
|
||||||
|
|
||||||
void cudart_init(char *cudart_lib_path, cudart_init_resp_t *resp);
|
void cudart_init(char *cudart_lib_path, cudart_init_resp_t *resp);
|
||||||
void cudart_check_vram(cudart_handle_t ch, mem_info_t *resp);
|
void cudart_check_vram(cudart_handle_t ch, int device_id, mem_info_t *resp);
|
||||||
void cudart_compute_capability(cudart_handle_t th, cudart_compute_capability_t *cc);
|
|
||||||
void cudart_release(cudart_handle_t ch);
|
void cudart_release(cudart_handle_t ch);
|
||||||
|
|
||||||
#endif // __GPU_INFO_CUDART_H__
|
#endif // __GPU_INFO_CUDART_H__
|
||||||
|
|
203
gpu/gpu_info_nvcuda.c
Normal file
203
gpu/gpu_info_nvcuda.c
Normal file
|
@ -0,0 +1,203 @@
|
||||||
|
#ifndef __APPLE__ // TODO - maybe consider nvidia support on intel macs?
|
||||||
|
|
||||||
|
#include <string.h>
|
||||||
|
#include "gpu_info_nvcuda.h"
|
||||||
|
|
||||||
|
void nvcuda_init(char *nvcuda_lib_path, nvcuda_init_resp_t *resp) {
|
||||||
|
CUresult ret;
|
||||||
|
resp->err = NULL;
|
||||||
|
resp->num_devices = 0;
|
||||||
|
const int buflen = 256;
|
||||||
|
char buf[buflen + 1];
|
||||||
|
int i;
|
||||||
|
|
||||||
|
struct lookup {
|
||||||
|
char *s;
|
||||||
|
void **p;
|
||||||
|
} l[] = {
|
||||||
|
|
||||||
|
{"cuInit", (void *)&resp->ch.cuInit},
|
||||||
|
{"cuDriverGetVersion", (void *)&resp->ch.cuDriverGetVersion},
|
||||||
|
{"cuDeviceGetCount", (void *)&resp->ch.cuDeviceGetCount},
|
||||||
|
{"cuDeviceGet", (void *)&resp->ch.cuDeviceGet},
|
||||||
|
{"cuDeviceGetAttribute", (void *)&resp->ch.cuDeviceGetAttribute},
|
||||||
|
{"cuDeviceGetUuid", (void *)&resp->ch.cuDeviceGetUuid},
|
||||||
|
{"cuCtxCreate_v3", (void *)&resp->ch.cuCtxCreate_v3},
|
||||||
|
{"cuMemGetInfo_v2", (void *)&resp->ch.cuMemGetInfo_v2},
|
||||||
|
{"cuCtxDestroy", (void *)&resp->ch.cuCtxDestroy},
|
||||||
|
{NULL, NULL},
|
||||||
|
};
|
||||||
|
|
||||||
|
resp->ch.handle = LOAD_LIBRARY(nvcuda_lib_path, RTLD_LAZY);
|
||||||
|
if (!resp->ch.handle) {
|
||||||
|
char *msg = LOAD_ERR();
|
||||||
|
LOG(resp->ch.verbose, "library %s load err: %s\n", nvcuda_lib_path, msg);
|
||||||
|
snprintf(buf, buflen,
|
||||||
|
"Unable to load %s library to query for Nvidia GPUs: %s",
|
||||||
|
nvcuda_lib_path, msg);
|
||||||
|
free(msg);
|
||||||
|
resp->err = strdup(buf);
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
|
for (i = 0; l[i].s != NULL; i++) {
|
||||||
|
*l[i].p = LOAD_SYMBOL(resp->ch.handle, l[i].s);
|
||||||
|
if (!*l[i].p) {
|
||||||
|
char *msg = LOAD_ERR();
|
||||||
|
LOG(resp->ch.verbose, "dlerr: %s\n", msg);
|
||||||
|
UNLOAD_LIBRARY(resp->ch.handle);
|
||||||
|
resp->ch.handle = NULL;
|
||||||
|
snprintf(buf, buflen, "symbol lookup for %s failed: %s", l[i].s,
|
||||||
|
msg);
|
||||||
|
free(msg);
|
||||||
|
resp->err = strdup(buf);
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
ret = (*resp->ch.cuInit)(0);
|
||||||
|
if (ret != CUDA_SUCCESS) {
|
||||||
|
LOG(resp->ch.verbose, "cuInit err: %d\n", ret);
|
||||||
|
UNLOAD_LIBRARY(resp->ch.handle);
|
||||||
|
resp->ch.handle = NULL;
|
||||||
|
if (ret == CUDA_ERROR_INSUFFICIENT_DRIVER) {
|
||||||
|
resp->err = strdup("your nvidia driver is too old or missing. If you have a CUDA GPU please upgrade to run ollama");
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
snprintf(buf, buflen, "nvcuda init failure: %d", ret);
|
||||||
|
resp->err = strdup(buf);
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
|
int version = 0;
|
||||||
|
nvcudaDriverVersion_t driverVersion;
|
||||||
|
driverVersion.major = 0;
|
||||||
|
driverVersion.minor = 0;
|
||||||
|
|
||||||
|
// Report driver version if we're in verbose mode, ignore errors
|
||||||
|
ret = (*resp->ch.cuDriverGetVersion)(&version);
|
||||||
|
if (ret != CUDA_SUCCESS) {
|
||||||
|
LOG(resp->ch.verbose, "cuDriverGetVersion failed: %d\n", ret);
|
||||||
|
} else {
|
||||||
|
driverVersion.major = version / 1000;
|
||||||
|
driverVersion.minor = (version - (driverVersion.major * 1000)) / 10;
|
||||||
|
LOG(resp->ch.verbose, "CUDA driver version: %d-%d\n", driverVersion.major, driverVersion.minor);
|
||||||
|
}
|
||||||
|
|
||||||
|
ret = (*resp->ch.cuDeviceGetCount)(&resp->num_devices);
|
||||||
|
if (ret != CUDA_SUCCESS) {
|
||||||
|
LOG(resp->ch.verbose, "cuDeviceGetCount err: %d\n", ret);
|
||||||
|
UNLOAD_LIBRARY(resp->ch.handle);
|
||||||
|
resp->ch.handle = NULL;
|
||||||
|
snprintf(buf, buflen, "unable to get device count: %d", ret);
|
||||||
|
resp->err = strdup(buf);
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
const int buflen = 256;
|
||||||
|
void nvcuda_check_vram(nvcuda_handle_t h, int i, mem_info_t *resp) {
|
||||||
|
resp->err = NULL;
|
||||||
|
nvcudaMemory_t memInfo = {0,0};
|
||||||
|
CUresult ret;
|
||||||
|
CUdevice device = -1;
|
||||||
|
CUcontext ctx = NULL;
|
||||||
|
char buf[buflen + 1];
|
||||||
|
CUuuid uuid = {0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0};
|
||||||
|
|
||||||
|
if (h.handle == NULL) {
|
||||||
|
resp->err = strdup("nvcuda handle isn't initialized");
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
|
ret = (*h.cuDeviceGet)(&device, i);
|
||||||
|
if (ret != CUDA_SUCCESS) {
|
||||||
|
snprintf(buf, buflen, "nvcuda device failed to initialize");
|
||||||
|
resp->err = strdup(buf);
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
|
resp->major = 0;
|
||||||
|
resp->minor = 0;
|
||||||
|
int major = 0;
|
||||||
|
int minor = 0;
|
||||||
|
ret = (*h.cuDeviceGetAttribute)(&major, CU_DEVICE_ATTRIBUTE_COMPUTE_CAPABILITY_MAJOR, device);
|
||||||
|
if (ret != CUDA_SUCCESS) {
|
||||||
|
LOG(h.verbose, "[%d] device major lookup failure: %d\n", i, ret);
|
||||||
|
} else {
|
||||||
|
ret = (*h.cuDeviceGetAttribute)(&minor, CU_DEVICE_ATTRIBUTE_COMPUTE_CAPABILITY_MINOR, device);
|
||||||
|
if (ret != CUDA_SUCCESS) {
|
||||||
|
LOG(h.verbose, "[%d] device minor lookup failure: %d\n", i, ret);
|
||||||
|
} else {
|
||||||
|
resp->minor = minor;
|
||||||
|
resp->major = major;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
ret = (*h.cuDeviceGetUuid)(&uuid, device);
|
||||||
|
if (ret != CUDA_SUCCESS) {
|
||||||
|
LOG(h.verbose, "[%d] device uuid lookup failure: %d\n", i, ret);
|
||||||
|
snprintf(&resp->gpu_id[0], GPU_ID_LEN, "%d", i);
|
||||||
|
} else {
|
||||||
|
// GPU-d110a105-ac29-1d54-7b49-9c90440f215b
|
||||||
|
snprintf(&resp->gpu_id[0], GPU_ID_LEN,
|
||||||
|
"GPU-%02x%02x%02x%02x-%02x%02x-%02x%02x-%02x%02x-%02x%02x%02x%02x%02x%02x",
|
||||||
|
uuid.bytes[0],
|
||||||
|
uuid.bytes[1],
|
||||||
|
uuid.bytes[2],
|
||||||
|
uuid.bytes[3],
|
||||||
|
uuid.bytes[4],
|
||||||
|
uuid.bytes[5],
|
||||||
|
uuid.bytes[6],
|
||||||
|
uuid.bytes[7],
|
||||||
|
uuid.bytes[8],
|
||||||
|
uuid.bytes[9],
|
||||||
|
uuid.bytes[10],
|
||||||
|
uuid.bytes[11],
|
||||||
|
uuid.bytes[12],
|
||||||
|
uuid.bytes[13],
|
||||||
|
uuid.bytes[14],
|
||||||
|
uuid.bytes[15]
|
||||||
|
);
|
||||||
|
}
|
||||||
|
|
||||||
|
// To get memory we have to set (and release) a context
|
||||||
|
ret = (*h.cuCtxCreate_v3)(&ctx, NULL, 0, 0, device);
|
||||||
|
if (ret != CUDA_SUCCESS) {
|
||||||
|
snprintf(buf, buflen, "nvcuda failed to get primary device context %d", ret);
|
||||||
|
resp->err = strdup(buf);
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
|
ret = (*h.cuMemGetInfo_v2)(&memInfo.free, &memInfo.total);
|
||||||
|
if (ret != CUDA_SUCCESS) {
|
||||||
|
snprintf(buf, buflen, "nvcuda device memory info lookup failure %d", ret);
|
||||||
|
resp->err = strdup(buf);
|
||||||
|
// Best effort on failure...
|
||||||
|
(*h.cuCtxDestroy)(ctx);
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
|
resp->total = memInfo.total;
|
||||||
|
resp->free = memInfo.free;
|
||||||
|
|
||||||
|
LOG(h.verbose, "[%s] CUDA totalMem %lu mb\n", resp->gpu_id, resp->total / 1024 / 1024);
|
||||||
|
LOG(h.verbose, "[%s] CUDA freeMem %lu mb\n", resp->gpu_id, resp->free / 1024 / 1024);
|
||||||
|
LOG(h.verbose, "[%s] Compute Capability %d.%d\n", resp->gpu_id, resp->major, resp->minor);
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
ret = (*h.cuCtxDestroy)(ctx);
|
||||||
|
if (ret != CUDA_SUCCESS) {
|
||||||
|
LOG(1, "nvcuda failed to release primary device context %d", ret);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
void nvcuda_release(nvcuda_handle_t h) {
|
||||||
|
LOG(h.verbose, "releasing nvcuda library\n");
|
||||||
|
UNLOAD_LIBRARY(h.handle);
|
||||||
|
// TODO and other context release logic?
|
||||||
|
h.handle = NULL;
|
||||||
|
}
|
||||||
|
|
||||||
|
#endif // __APPLE__
|
71
gpu/gpu_info_nvcuda.h
Normal file
71
gpu/gpu_info_nvcuda.h
Normal file
|
@ -0,0 +1,71 @@
|
||||||
|
#ifndef __APPLE__
|
||||||
|
#ifndef __GPU_INFO_NVCUDA_H__
|
||||||
|
#define __GPU_INFO_NVCUDA_H__
|
||||||
|
#include "gpu_info.h"
|
||||||
|
|
||||||
|
// Just enough typedef's to dlopen/dlsym for memory information
|
||||||
|
typedef enum cudaError_enum {
|
||||||
|
CUDA_SUCCESS = 0,
|
||||||
|
CUDA_ERROR_INVALID_VALUE = 1,
|
||||||
|
CUDA_ERROR_MEMORY_ALLOCATION = 2,
|
||||||
|
CUDA_ERROR_NOT_INITIALIZED = 3,
|
||||||
|
CUDA_ERROR_INSUFFICIENT_DRIVER = 35,
|
||||||
|
// Other values omitted for now...
|
||||||
|
} CUresult;
|
||||||
|
|
||||||
|
typedef enum CUdevice_attribute_enum {
|
||||||
|
CU_DEVICE_ATTRIBUTE_COMPUTE_CAPABILITY_MAJOR = 75,
|
||||||
|
CU_DEVICE_ATTRIBUTE_COMPUTE_CAPABILITY_MINOR = 76,
|
||||||
|
|
||||||
|
// TODO - not yet wired up but may be useful for Jetson or other
|
||||||
|
// integrated GPU scenarios with shared memory
|
||||||
|
CU_DEVICE_ATTRIBUTE_INTEGRATED = 18
|
||||||
|
|
||||||
|
} CUdevice_attribute;
|
||||||
|
|
||||||
|
typedef void *nvcudaDevice_t; // Opaque is sufficient
|
||||||
|
typedef struct nvcudaMemory_st {
|
||||||
|
uint64_t total;
|
||||||
|
uint64_t free;
|
||||||
|
} nvcudaMemory_t;
|
||||||
|
|
||||||
|
typedef struct nvcudaDriverVersion {
|
||||||
|
int major;
|
||||||
|
int minor;
|
||||||
|
} nvcudaDriverVersion_t;
|
||||||
|
|
||||||
|
typedef struct CUuuid_st {
|
||||||
|
unsigned char bytes[16];
|
||||||
|
} CUuuid;
|
||||||
|
|
||||||
|
typedef int CUdevice;
|
||||||
|
typedef void* CUcontext;
|
||||||
|
|
||||||
|
typedef struct nvcuda_handle {
|
||||||
|
void *handle;
|
||||||
|
uint16_t verbose;
|
||||||
|
CUresult (*cuInit)(unsigned int Flags);
|
||||||
|
CUresult (*cuDriverGetVersion)(int *driverVersion);
|
||||||
|
CUresult (*cuDeviceGetCount)(int *);
|
||||||
|
CUresult (*cuDeviceGet)(CUdevice* device, int ordinal);
|
||||||
|
CUresult (*cuDeviceGetAttribute)(int* pi, CUdevice_attribute attrib, CUdevice dev);
|
||||||
|
CUresult (*cuDeviceGetUuid)(CUuuid* uuid, CUdevice dev); // signature compatible with cuDeviceGetUuid_v2
|
||||||
|
|
||||||
|
// Context specific aspects
|
||||||
|
CUresult (*cuCtxCreate_v3)(CUcontext* pctx, void *params, int len, unsigned int flags, CUdevice dev);
|
||||||
|
CUresult (*cuMemGetInfo_v2)(uint64_t* free, uint64_t* total);
|
||||||
|
CUresult (*cuCtxDestroy)(CUcontext ctx);
|
||||||
|
} nvcuda_handle_t;
|
||||||
|
|
||||||
|
typedef struct nvcuda_init_resp {
|
||||||
|
char *err; // If err is non-null handle is invalid
|
||||||
|
nvcuda_handle_t ch;
|
||||||
|
int num_devices;
|
||||||
|
} nvcuda_init_resp_t;
|
||||||
|
|
||||||
|
void nvcuda_init(char *nvcuda_lib_path, nvcuda_init_resp_t *resp);
|
||||||
|
void nvcuda_check_vram(nvcuda_handle_t ch, int device_id, mem_info_t *resp);
|
||||||
|
void nvcuda_release(nvcuda_handle_t ch);
|
||||||
|
|
||||||
|
#endif // __GPU_INFO_NVCUDA_H__
|
||||||
|
#endif // __APPLE__
|
|
@ -1,221 +0,0 @@
|
||||||
#ifndef __APPLE__ // TODO - maybe consider nvidia support on intel macs?
|
|
||||||
|
|
||||||
#include <string.h>
|
|
||||||
|
|
||||||
#include "gpu_info_nvml.h"
|
|
||||||
|
|
||||||
void nvml_init(char *nvml_lib_path, nvml_init_resp_t *resp) {
|
|
||||||
nvmlReturn_t ret;
|
|
||||||
resp->err = NULL;
|
|
||||||
const int buflen = 256;
|
|
||||||
char buf[buflen + 1];
|
|
||||||
int i;
|
|
||||||
|
|
||||||
struct lookup {
|
|
||||||
char *s;
|
|
||||||
void **p;
|
|
||||||
} l[] = {
|
|
||||||
{"nvmlInit_v2", (void *)&resp->ch.nvmlInit_v2},
|
|
||||||
{"nvmlShutdown", (void *)&resp->ch.nvmlShutdown},
|
|
||||||
{"nvmlDeviceGetHandleByIndex", (void *)&resp->ch.nvmlDeviceGetHandleByIndex},
|
|
||||||
{"nvmlDeviceGetMemoryInfo", (void *)&resp->ch.nvmlDeviceGetMemoryInfo},
|
|
||||||
{"nvmlDeviceGetCount_v2", (void *)&resp->ch.nvmlDeviceGetCount_v2},
|
|
||||||
{"nvmlDeviceGetCudaComputeCapability", (void *)&resp->ch.nvmlDeviceGetCudaComputeCapability},
|
|
||||||
{"nvmlSystemGetDriverVersion", (void *)&resp->ch.nvmlSystemGetDriverVersion},
|
|
||||||
{"nvmlDeviceGetName", (void *)&resp->ch.nvmlDeviceGetName},
|
|
||||||
{"nvmlDeviceGetSerial", (void *)&resp->ch.nvmlDeviceGetSerial},
|
|
||||||
{"nvmlDeviceGetVbiosVersion", (void *)&resp->ch.nvmlDeviceGetVbiosVersion},
|
|
||||||
{"nvmlDeviceGetBoardPartNumber", (void *)&resp->ch.nvmlDeviceGetBoardPartNumber},
|
|
||||||
{"nvmlDeviceGetBrand", (void *)&resp->ch.nvmlDeviceGetBrand},
|
|
||||||
{NULL, NULL},
|
|
||||||
};
|
|
||||||
|
|
||||||
resp->ch.handle = LOAD_LIBRARY(nvml_lib_path, RTLD_LAZY);
|
|
||||||
if (!resp->ch.handle) {
|
|
||||||
char *msg = LOAD_ERR();
|
|
||||||
LOG(resp->ch.verbose, "library %s load err: %s\n", nvml_lib_path, msg);
|
|
||||||
snprintf(buf, buflen,
|
|
||||||
"Unable to load %s library to query for Nvidia GPUs: %s",
|
|
||||||
nvml_lib_path, msg);
|
|
||||||
free(msg);
|
|
||||||
resp->err = strdup(buf);
|
|
||||||
return;
|
|
||||||
}
|
|
||||||
|
|
||||||
// TODO once we've squashed the remaining corner cases remove this log
|
|
||||||
LOG(resp->ch.verbose, "wiring nvidia management library functions in %s\n", nvml_lib_path);
|
|
||||||
|
|
||||||
for (i = 0; l[i].s != NULL; i++) {
|
|
||||||
// TODO once we've squashed the remaining corner cases remove this log
|
|
||||||
LOG(resp->ch.verbose, "dlsym: %s\n", l[i].s);
|
|
||||||
|
|
||||||
*l[i].p = LOAD_SYMBOL(resp->ch.handle, l[i].s);
|
|
||||||
if (!l[i].p) {
|
|
||||||
resp->ch.handle = NULL;
|
|
||||||
char *msg = LOAD_ERR();
|
|
||||||
LOG(resp->ch.verbose, "dlerr: %s\n", msg);
|
|
||||||
UNLOAD_LIBRARY(resp->ch.handle);
|
|
||||||
snprintf(buf, buflen, "symbol lookup for %s failed: %s", l[i].s,
|
|
||||||
msg);
|
|
||||||
free(msg);
|
|
||||||
resp->err = strdup(buf);
|
|
||||||
return;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
ret = (*resp->ch.nvmlInit_v2)();
|
|
||||||
if (ret != NVML_SUCCESS) {
|
|
||||||
LOG(resp->ch.verbose, "nvmlInit_v2 err: %d\n", ret);
|
|
||||||
UNLOAD_LIBRARY(resp->ch.handle);
|
|
||||||
resp->ch.handle = NULL;
|
|
||||||
snprintf(buf, buflen, "nvml vram init failure: %d", ret);
|
|
||||||
resp->err = strdup(buf);
|
|
||||||
return;
|
|
||||||
}
|
|
||||||
|
|
||||||
// Report driver version if we're in verbose mode, ignore errors
|
|
||||||
ret = (*resp->ch.nvmlSystemGetDriverVersion)(buf, buflen);
|
|
||||||
if (ret != NVML_SUCCESS) {
|
|
||||||
LOG(resp->ch.verbose, "nvmlSystemGetDriverVersion failed: %d\n", ret);
|
|
||||||
} else {
|
|
||||||
LOG(resp->ch.verbose, "CUDA driver version: %s\n", buf);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
void nvml_check_vram(nvml_handle_t h, mem_info_t *resp) {
|
|
||||||
resp->err = NULL;
|
|
||||||
nvmlDevice_t device;
|
|
||||||
nvmlMemory_t memInfo = {0};
|
|
||||||
nvmlReturn_t ret;
|
|
||||||
const int buflen = 256;
|
|
||||||
char buf[buflen + 1];
|
|
||||||
int i;
|
|
||||||
|
|
||||||
if (h.handle == NULL) {
|
|
||||||
resp->err = strdup("nvml handle isn't initialized");
|
|
||||||
return;
|
|
||||||
}
|
|
||||||
|
|
||||||
ret = (*h.nvmlDeviceGetCount_v2)(&resp->count);
|
|
||||||
if (ret != NVML_SUCCESS) {
|
|
||||||
snprintf(buf, buflen, "unable to get device count: %d", ret);
|
|
||||||
resp->err = strdup(buf);
|
|
||||||
return;
|
|
||||||
}
|
|
||||||
|
|
||||||
resp->total = 0;
|
|
||||||
resp->free = 0;
|
|
||||||
for (i = 0; i < resp->count; i++) {
|
|
||||||
ret = (*h.nvmlDeviceGetHandleByIndex)(i, &device);
|
|
||||||
if (ret != NVML_SUCCESS) {
|
|
||||||
snprintf(buf, buflen, "unable to get device handle %d: %d", i, ret);
|
|
||||||
resp->err = strdup(buf);
|
|
||||||
return;
|
|
||||||
}
|
|
||||||
|
|
||||||
ret = (*h.nvmlDeviceGetMemoryInfo)(device, &memInfo);
|
|
||||||
if (ret != NVML_SUCCESS) {
|
|
||||||
snprintf(buf, buflen, "device memory info lookup failure %d: %d", i, ret);
|
|
||||||
resp->err = strdup(buf);
|
|
||||||
return;
|
|
||||||
}
|
|
||||||
if (h.verbose) {
|
|
||||||
nvmlBrandType_t brand = 0;
|
|
||||||
// When in verbose mode, report more information about
|
|
||||||
// the card we discover, but don't fail on error
|
|
||||||
ret = (*h.nvmlDeviceGetName)(device, buf, buflen);
|
|
||||||
if (ret != NVML_SUCCESS) {
|
|
||||||
LOG(h.verbose, "nvmlDeviceGetName failed: %d\n", ret);
|
|
||||||
} else {
|
|
||||||
LOG(h.verbose, "[%d] CUDA device name: %s\n", i, buf);
|
|
||||||
}
|
|
||||||
ret = (*h.nvmlDeviceGetBoardPartNumber)(device, buf, buflen);
|
|
||||||
if (ret != NVML_SUCCESS) {
|
|
||||||
LOG(h.verbose, "nvmlDeviceGetBoardPartNumber failed: %d\n", ret);
|
|
||||||
} else {
|
|
||||||
LOG(h.verbose, "[%d] CUDA part number: %s\n", i, buf);
|
|
||||||
}
|
|
||||||
ret = (*h.nvmlDeviceGetSerial)(device, buf, buflen);
|
|
||||||
if (ret != NVML_SUCCESS) {
|
|
||||||
LOG(h.verbose, "nvmlDeviceGetSerial failed: %d\n", ret);
|
|
||||||
} else {
|
|
||||||
LOG(h.verbose, "[%d] CUDA S/N: %s\n", i, buf);
|
|
||||||
}
|
|
||||||
ret = (*h.nvmlDeviceGetVbiosVersion)(device, buf, buflen);
|
|
||||||
if (ret != NVML_SUCCESS) {
|
|
||||||
LOG(h.verbose, "nvmlDeviceGetVbiosVersion failed: %d\n", ret);
|
|
||||||
} else {
|
|
||||||
LOG(h.verbose, "[%d] CUDA vbios version: %s\n", i, buf);
|
|
||||||
}
|
|
||||||
ret = (*h.nvmlDeviceGetBrand)(device, &brand);
|
|
||||||
if (ret != NVML_SUCCESS) {
|
|
||||||
LOG(h.verbose, "nvmlDeviceGetBrand failed: %d\n", ret);
|
|
||||||
} else {
|
|
||||||
LOG(h.verbose, "[%d] CUDA brand: %d\n", i, brand);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
LOG(h.verbose, "[%d] CUDA totalMem %ld\n", i, memInfo.total);
|
|
||||||
LOG(h.verbose, "[%d] CUDA freeMem %ld\n", i, memInfo.free);
|
|
||||||
|
|
||||||
resp->total += memInfo.total;
|
|
||||||
resp->free += memInfo.free;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
void nvml_compute_capability(nvml_handle_t h, nvml_compute_capability_t *resp) {
|
|
||||||
resp->err = NULL;
|
|
||||||
resp->major = 0;
|
|
||||||
resp->minor = 0;
|
|
||||||
nvmlDevice_t device;
|
|
||||||
int major = 0;
|
|
||||||
int minor = 0;
|
|
||||||
nvmlReturn_t ret;
|
|
||||||
const int buflen = 256;
|
|
||||||
char buf[buflen + 1];
|
|
||||||
int i;
|
|
||||||
|
|
||||||
if (h.handle == NULL) {
|
|
||||||
resp->err = strdup("nvml handle not initialized");
|
|
||||||
return;
|
|
||||||
}
|
|
||||||
|
|
||||||
unsigned int devices;
|
|
||||||
ret = (*h.nvmlDeviceGetCount_v2)(&devices);
|
|
||||||
if (ret != NVML_SUCCESS) {
|
|
||||||
snprintf(buf, buflen, "unable to get device count: %d", ret);
|
|
||||||
resp->err = strdup(buf);
|
|
||||||
return;
|
|
||||||
}
|
|
||||||
|
|
||||||
for (i = 0; i < devices; i++) {
|
|
||||||
ret = (*h.nvmlDeviceGetHandleByIndex)(i, &device);
|
|
||||||
if (ret != NVML_SUCCESS) {
|
|
||||||
snprintf(buf, buflen, "unable to get device handle %d: %d", i, ret);
|
|
||||||
resp->err = strdup(buf);
|
|
||||||
return;
|
|
||||||
}
|
|
||||||
|
|
||||||
ret = (*h.nvmlDeviceGetCudaComputeCapability)(device, &major, &minor);
|
|
||||||
if (ret != NVML_SUCCESS) {
|
|
||||||
snprintf(buf, buflen, "device compute capability lookup failure %d: %d", i, ret);
|
|
||||||
resp->err = strdup(buf);
|
|
||||||
return;
|
|
||||||
}
|
|
||||||
// Report the lowest major.minor we detect as that limits our compatibility
|
|
||||||
if (resp->major == 0 || resp->major > major ) {
|
|
||||||
resp->major = major;
|
|
||||||
resp->minor = minor;
|
|
||||||
} else if ( resp->major == major && resp->minor > minor ) {
|
|
||||||
resp->minor = minor;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
void nvml_release(nvml_handle_t h) {
|
|
||||||
LOG(h.verbose, "releasing nvml library\n");
|
|
||||||
UNLOAD_LIBRARY(h.handle);
|
|
||||||
h.handle = NULL;
|
|
||||||
}
|
|
||||||
|
|
||||||
#endif // __APPLE__
|
|
|
@ -1,57 +0,0 @@
|
||||||
#ifndef __APPLE__
|
|
||||||
#ifndef __GPU_INFO_NVML_H__
|
|
||||||
#define __GPU_INFO_NVML_H__
|
|
||||||
#include "gpu_info.h"
|
|
||||||
|
|
||||||
// Just enough typedef's to dlopen/dlsym for memory information
|
|
||||||
typedef enum nvmlReturn_enum {
|
|
||||||
NVML_SUCCESS = 0,
|
|
||||||
// Other values omitted for now...
|
|
||||||
} nvmlReturn_t;
|
|
||||||
typedef void *nvmlDevice_t; // Opaque is sufficient
|
|
||||||
typedef struct nvmlMemory_st {
|
|
||||||
unsigned long long total;
|
|
||||||
unsigned long long free;
|
|
||||||
unsigned long long used;
|
|
||||||
} nvmlMemory_t;
|
|
||||||
|
|
||||||
typedef enum nvmlBrandType_enum
|
|
||||||
{
|
|
||||||
NVML_BRAND_UNKNOWN = 0,
|
|
||||||
} nvmlBrandType_t;
|
|
||||||
|
|
||||||
typedef struct nvml_handle {
|
|
||||||
void *handle;
|
|
||||||
uint16_t verbose;
|
|
||||||
nvmlReturn_t (*nvmlInit_v2)(void);
|
|
||||||
nvmlReturn_t (*nvmlShutdown)(void);
|
|
||||||
nvmlReturn_t (*nvmlDeviceGetHandleByIndex)(unsigned int, nvmlDevice_t *);
|
|
||||||
nvmlReturn_t (*nvmlDeviceGetMemoryInfo)(nvmlDevice_t, nvmlMemory_t *);
|
|
||||||
nvmlReturn_t (*nvmlDeviceGetCount_v2)(unsigned int *);
|
|
||||||
nvmlReturn_t (*nvmlDeviceGetCudaComputeCapability)(nvmlDevice_t, int* major, int* minor);
|
|
||||||
nvmlReturn_t (*nvmlSystemGetDriverVersion) (char* version, unsigned int length);
|
|
||||||
nvmlReturn_t (*nvmlDeviceGetName) (nvmlDevice_t device, char* name, unsigned int length);
|
|
||||||
nvmlReturn_t (*nvmlDeviceGetSerial) (nvmlDevice_t device, char* serial, unsigned int length);
|
|
||||||
nvmlReturn_t (*nvmlDeviceGetVbiosVersion) (nvmlDevice_t device, char* version, unsigned int length);
|
|
||||||
nvmlReturn_t (*nvmlDeviceGetBoardPartNumber) (nvmlDevice_t device, char* partNumber, unsigned int length);
|
|
||||||
nvmlReturn_t (*nvmlDeviceGetBrand) (nvmlDevice_t device, nvmlBrandType_t* type);
|
|
||||||
} nvml_handle_t;
|
|
||||||
|
|
||||||
typedef struct nvml_init_resp {
|
|
||||||
char *err; // If err is non-null handle is invalid
|
|
||||||
nvml_handle_t ch;
|
|
||||||
} nvml_init_resp_t;
|
|
||||||
|
|
||||||
typedef struct nvml_compute_capability {
|
|
||||||
char *err;
|
|
||||||
int major;
|
|
||||||
int minor;
|
|
||||||
} nvml_compute_capability_t;
|
|
||||||
|
|
||||||
void nvml_init(char *nvml_lib_path, nvml_init_resp_t *resp);
|
|
||||||
void nvml_check_vram(nvml_handle_t ch, mem_info_t *resp);
|
|
||||||
void nvml_compute_capability(nvml_handle_t ch, nvml_compute_capability_t *cc);
|
|
||||||
void nvml_release(nvml_handle_t ch);
|
|
||||||
|
|
||||||
#endif // __GPU_INFO_NVML_H__
|
|
||||||
#endif // __APPLE__
|
|
|
@ -9,23 +9,16 @@ import (
|
||||||
|
|
||||||
func TestBasicGetGPUInfo(t *testing.T) {
|
func TestBasicGetGPUInfo(t *testing.T) {
|
||||||
info := GetGPUInfo()
|
info := GetGPUInfo()
|
||||||
assert.Contains(t, "cuda rocm cpu metal", info.Library)
|
assert.Greater(t, len(info), 0)
|
||||||
|
assert.Contains(t, "cuda rocm cpu metal", info[0].Library)
|
||||||
switch runtime.GOOS {
|
if info[0].Library != "cpu" {
|
||||||
case "darwin":
|
assert.Greater(t, info[0].TotalMemory, uint64(0))
|
||||||
// TODO - remove this once MacOS returns some size for CPU
|
assert.Greater(t, info[0].FreeMemory, uint64(0))
|
||||||
return
|
|
||||||
case "linux", "windows":
|
|
||||||
assert.Greater(t, info.TotalMemory, uint64(0))
|
|
||||||
assert.Greater(t, info.FreeMemory, uint64(0))
|
|
||||||
assert.Greater(t, info.DeviceCount, uint32(0))
|
|
||||||
default:
|
|
||||||
return
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestCPUMemInfo(t *testing.T) {
|
func TestCPUMemInfo(t *testing.T) {
|
||||||
info, err := getCPUMem()
|
info, err := GetCPUMem()
|
||||||
assert.NoError(t, err)
|
assert.NoError(t, err)
|
||||||
switch runtime.GOOS {
|
switch runtime.GOOS {
|
||||||
case "darwin":
|
case "darwin":
|
||||||
|
|
49
gpu/types.go
49
gpu/types.go
|
@ -3,7 +3,6 @@ package gpu
|
||||||
type memInfo struct {
|
type memInfo struct {
|
||||||
TotalMemory uint64 `json:"total_memory,omitempty"`
|
TotalMemory uint64 `json:"total_memory,omitempty"`
|
||||||
FreeMemory uint64 `json:"free_memory,omitempty"`
|
FreeMemory uint64 `json:"free_memory,omitempty"`
|
||||||
DeviceCount uint32 `json:"device_count,omitempty"`
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// Beginning of an `ollama info` command
|
// Beginning of an `ollama info` command
|
||||||
|
@ -17,11 +16,49 @@ type GpuInfo struct {
|
||||||
// MinimumMemory represents the minimum memory required to use the GPU
|
// MinimumMemory represents the minimum memory required to use the GPU
|
||||||
MinimumMemory uint64 `json:"-"`
|
MinimumMemory uint64 `json:"-"`
|
||||||
|
|
||||||
// TODO add other useful attributes about the card here for discovery information
|
// Any extra PATH/LD_LIBRARY_PATH dependencies required for the Library to operate properly
|
||||||
|
DependencyPath string `json:"lib_path,omitempty"`
|
||||||
|
|
||||||
|
// GPU information
|
||||||
|
ID string `json:"gpu_id"` // string to use for selection of this specific GPU
|
||||||
|
Name string `json:"name"` // user friendly name if available
|
||||||
|
Major int `json:"major,omitempty"` // Major compatibility version (CC or gfx)
|
||||||
|
Minor int `json:"minor,omitempty"` // Minor compatibility version (CC or gfx)
|
||||||
|
Patch int `json:"patch,omitempty"` // Patch compatibility only matters on AMD
|
||||||
|
|
||||||
|
// TODO other performance capability info to help in scheduling decisions
|
||||||
}
|
}
|
||||||
|
|
||||||
type Version struct {
|
type GpuInfoList []GpuInfo
|
||||||
Major uint
|
|
||||||
Minor uint
|
// Split up the set of gpu info's by Library and variant
|
||||||
Patch uint
|
func (l GpuInfoList) ByLibrary() []GpuInfoList {
|
||||||
|
resp := []GpuInfoList{}
|
||||||
|
libs := []string{}
|
||||||
|
for _, info := range l {
|
||||||
|
found := false
|
||||||
|
requested := info.Library
|
||||||
|
if info.Variant != "" {
|
||||||
|
requested += "_" + info.Variant
|
||||||
|
}
|
||||||
|
for i, lib := range libs {
|
||||||
|
if lib == requested {
|
||||||
|
resp[i] = append(resp[i], info)
|
||||||
|
found = true
|
||||||
|
break
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if !found {
|
||||||
|
libs = append(libs, info.Library)
|
||||||
|
resp = append(resp, []GpuInfo{info})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return resp
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Sort by Free Space
|
||||||
|
type ByFreeMemory []GpuInfo
|
||||||
|
|
||||||
|
func (a ByFreeMemory) Len() int { return len(a) }
|
||||||
|
func (a ByFreeMemory) Swap(i, j int) { a[i], a[j] = a[j], a[i] }
|
||||||
|
func (a ByFreeMemory) Less(i, j int) bool { return a[i].FreeMemory < a[j].FreeMemory }
|
||||||
|
|
|
@ -4,11 +4,14 @@ package integration
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"context"
|
"context"
|
||||||
"net/http"
|
"log/slog"
|
||||||
|
"os"
|
||||||
|
"runtime"
|
||||||
"testing"
|
"testing"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
"github.com/ollama/ollama/api"
|
"github.com/ollama/ollama/api"
|
||||||
|
"github.com/stretchr/testify/require"
|
||||||
)
|
)
|
||||||
|
|
||||||
func TestOrcaMiniBlueSky(t *testing.T) {
|
func TestOrcaMiniBlueSky(t *testing.T) {
|
||||||
|
@ -24,5 +27,44 @@ func TestOrcaMiniBlueSky(t *testing.T) {
|
||||||
"seed": 123,
|
"seed": 123,
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
GenerateTestHelper(ctx, t, &http.Client{}, req, []string{"rayleigh", "scattering"})
|
GenerateTestHelper(ctx, t, req, []string{"rayleigh", "scattering"})
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestUnicodeModelDir(t *testing.T) {
|
||||||
|
// This is only useful for Windows with utf-16 characters, so skip this test for other platforms
|
||||||
|
if runtime.GOOS != "windows" {
|
||||||
|
t.Skip("Unicode test only applicable to windows")
|
||||||
|
}
|
||||||
|
// Only works for local testing
|
||||||
|
if os.Getenv("OLLAMA_TEST_EXISTING") != "" {
|
||||||
|
t.Skip("TestUnicodeModelDir only works for local testing, skipping")
|
||||||
|
}
|
||||||
|
|
||||||
|
modelDir, err := os.MkdirTemp("", "ollama_埃")
|
||||||
|
require.NoError(t, err)
|
||||||
|
defer os.RemoveAll(modelDir)
|
||||||
|
slog.Info("unicode", "OLLAMA_MODELS", modelDir)
|
||||||
|
|
||||||
|
oldModelsDir := os.Getenv("OLLAMA_MODELS")
|
||||||
|
if oldModelsDir == "" {
|
||||||
|
defer os.Unsetenv("OLLAMA_MODELS")
|
||||||
|
} else {
|
||||||
|
defer os.Setenv("OLLAMA_MODELS", oldModelsDir)
|
||||||
|
}
|
||||||
|
err = os.Setenv("OLLAMA_MODELS", modelDir)
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
ctx, cancel := context.WithTimeout(context.Background(), 2*time.Minute)
|
||||||
|
defer cancel()
|
||||||
|
|
||||||
|
req := api.GenerateRequest{
|
||||||
|
Model: "orca-mini",
|
||||||
|
Prompt: "why is the sky blue?",
|
||||||
|
Stream: &stream,
|
||||||
|
Options: map[string]interface{}{
|
||||||
|
"temperature": 0,
|
||||||
|
"seed": 123,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
GenerateTestHelper(ctx, t, req, []string{"rayleigh", "scattering"})
|
||||||
}
|
}
|
||||||
|
|
225
integration/concurrency_test.go
Normal file
225
integration/concurrency_test.go
Normal file
|
@ -0,0 +1,225 @@
|
||||||
|
//go:build integration
|
||||||
|
|
||||||
|
package integration
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"log/slog"
|
||||||
|
"os"
|
||||||
|
"strconv"
|
||||||
|
"sync"
|
||||||
|
"testing"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"github.com/ollama/ollama/api"
|
||||||
|
"github.com/stretchr/testify/require"
|
||||||
|
)
|
||||||
|
|
||||||
|
func TestMultiModelConcurrency(t *testing.T) {
|
||||||
|
var (
|
||||||
|
req = [2]api.GenerateRequest{
|
||||||
|
{
|
||||||
|
Model: "orca-mini",
|
||||||
|
Prompt: "why is the ocean blue?",
|
||||||
|
Stream: &stream,
|
||||||
|
Options: map[string]interface{}{
|
||||||
|
"seed": 42,
|
||||||
|
"temperature": 0.0,
|
||||||
|
},
|
||||||
|
}, {
|
||||||
|
Model: "tinydolphin",
|
||||||
|
Prompt: "what is the origin of the us thanksgiving holiday?",
|
||||||
|
Stream: &stream,
|
||||||
|
Options: map[string]interface{}{
|
||||||
|
"seed": 42,
|
||||||
|
"temperature": 0.0,
|
||||||
|
},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
resp = [2][]string{
|
||||||
|
[]string{"sunlight"},
|
||||||
|
[]string{"england", "english", "massachusetts", "pilgrims"},
|
||||||
|
}
|
||||||
|
)
|
||||||
|
var wg sync.WaitGroup
|
||||||
|
wg.Add(len(req))
|
||||||
|
ctx, cancel := context.WithTimeout(context.Background(), time.Second*120)
|
||||||
|
defer cancel()
|
||||||
|
for i := 0; i < len(req); i++ {
|
||||||
|
go func(i int) {
|
||||||
|
defer wg.Done()
|
||||||
|
GenerateTestHelper(ctx, t, req[i], resp[i])
|
||||||
|
}(i)
|
||||||
|
}
|
||||||
|
wg.Wait()
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestIntegrationConcurrentPredictOrcaMini(t *testing.T) {
|
||||||
|
ctx, cancel := context.WithTimeout(context.Background(), 10*time.Minute) // GTX 750 2G card takes ~9 minutes
|
||||||
|
defer cancel()
|
||||||
|
client, _, cleanup := InitServerConnection(ctx, t)
|
||||||
|
defer cleanup()
|
||||||
|
|
||||||
|
req, resp := GenerateRequests()
|
||||||
|
// Get the server running (if applicable) warm the model up with a single initial request
|
||||||
|
DoGenerate(ctx, t, client, req[0], resp[0], 60*time.Second, 5*time.Second)
|
||||||
|
|
||||||
|
var wg sync.WaitGroup
|
||||||
|
wg.Add(len(req))
|
||||||
|
for i := 0; i < len(req); i++ {
|
||||||
|
go func(i int) {
|
||||||
|
defer wg.Done()
|
||||||
|
for j := 0; j < 5; j++ {
|
||||||
|
slog.Info("Starting", "req", i, "iter", j)
|
||||||
|
// On slower GPUs it can take a while to process the 4 concurrent requests
|
||||||
|
// so we allow a much longer initial timeout
|
||||||
|
DoGenerate(ctx, t, client, req[i], resp[i], 90*time.Second, 5*time.Second)
|
||||||
|
}
|
||||||
|
}(i)
|
||||||
|
}
|
||||||
|
wg.Wait()
|
||||||
|
}
|
||||||
|
|
||||||
|
// Stress the system if we know how much VRAM it has, and attempt to load more models than will fit
|
||||||
|
func TestMultiModelStress(t *testing.T) {
|
||||||
|
vram := os.Getenv("OLLAMA_MAX_VRAM")
|
||||||
|
if vram == "" {
|
||||||
|
t.Skip("OLLAMA_MAX_VRAM not specified, can't pick the right models for the stress test")
|
||||||
|
}
|
||||||
|
max, err := strconv.ParseUint(vram, 10, 64)
|
||||||
|
require.NoError(t, err)
|
||||||
|
const MB = uint64(1024 * 1024)
|
||||||
|
type model struct {
|
||||||
|
name string
|
||||||
|
size uint64 // Approximate amount of VRAM they typically use when fully loaded in VRAM
|
||||||
|
}
|
||||||
|
|
||||||
|
smallModels := []model{
|
||||||
|
{
|
||||||
|
name: "orca-mini",
|
||||||
|
size: 2992 * MB,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "phi",
|
||||||
|
size: 2616 * MB,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "gemma:2b",
|
||||||
|
size: 2364 * MB,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "stable-code:3b",
|
||||||
|
size: 2608 * MB,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "starcoder2:3b",
|
||||||
|
size: 2166 * MB,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
mediumModels := []model{
|
||||||
|
{
|
||||||
|
name: "llama2",
|
||||||
|
size: 5118 * MB,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "mistral",
|
||||||
|
size: 4620 * MB,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "orca-mini:7b",
|
||||||
|
size: 5118 * MB,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "dolphin-mistral",
|
||||||
|
size: 4620 * MB,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "gemma:7b",
|
||||||
|
size: 5000 * MB,
|
||||||
|
},
|
||||||
|
// TODO - uncomment this once #3565 is merged and this is rebased on it
|
||||||
|
// {
|
||||||
|
// name: "codellama:7b",
|
||||||
|
// size: 5118 * MB,
|
||||||
|
// },
|
||||||
|
}
|
||||||
|
|
||||||
|
// These seem to be too slow to be useful...
|
||||||
|
// largeModels := []model{
|
||||||
|
// {
|
||||||
|
// name: "llama2:13b",
|
||||||
|
// size: 7400 * MB,
|
||||||
|
// },
|
||||||
|
// {
|
||||||
|
// name: "codellama:13b",
|
||||||
|
// size: 7400 * MB,
|
||||||
|
// },
|
||||||
|
// {
|
||||||
|
// name: "orca-mini:13b",
|
||||||
|
// size: 7400 * MB,
|
||||||
|
// },
|
||||||
|
// {
|
||||||
|
// name: "gemma:7b",
|
||||||
|
// size: 5000 * MB,
|
||||||
|
// },
|
||||||
|
// {
|
||||||
|
// name: "starcoder2:15b",
|
||||||
|
// size: 9100 * MB,
|
||||||
|
// },
|
||||||
|
// }
|
||||||
|
|
||||||
|
var chosenModels []model
|
||||||
|
switch {
|
||||||
|
case max < 10000*MB:
|
||||||
|
slog.Info("selecting small models")
|
||||||
|
chosenModels = smallModels
|
||||||
|
// case max < 30000*MB:
|
||||||
|
default:
|
||||||
|
slog.Info("selecting medium models")
|
||||||
|
chosenModels = mediumModels
|
||||||
|
// default:
|
||||||
|
// slog.Info("selecting large models")
|
||||||
|
// chosenModels = largModels
|
||||||
|
}
|
||||||
|
|
||||||
|
req, resp := GenerateRequests()
|
||||||
|
|
||||||
|
for i := range req {
|
||||||
|
if i > len(chosenModels) {
|
||||||
|
break
|
||||||
|
}
|
||||||
|
req[i].Model = chosenModels[i].name
|
||||||
|
}
|
||||||
|
|
||||||
|
ctx, cancel := context.WithTimeout(context.Background(), 15*time.Minute) // TODO baseline -- 10m too short
|
||||||
|
defer cancel()
|
||||||
|
client, _, cleanup := InitServerConnection(ctx, t)
|
||||||
|
defer cleanup()
|
||||||
|
|
||||||
|
// Make sure all the models are pulled before we get started
|
||||||
|
for _, r := range req {
|
||||||
|
require.NoError(t, PullIfMissing(ctx, client, r.Model))
|
||||||
|
}
|
||||||
|
|
||||||
|
var wg sync.WaitGroup
|
||||||
|
consumed := uint64(256 * MB) // Assume some baseline usage
|
||||||
|
for i := 0; i < len(req); i++ {
|
||||||
|
// Always get at least 2 models, but dont' overshoot VRAM too much or we'll take too long
|
||||||
|
if i > 1 && consumed > max {
|
||||||
|
slog.Info("achieved target vram exhaustion", "count", i, "vramMB", max/1024/1024, "modelsMB", consumed/1024/1024)
|
||||||
|
break
|
||||||
|
}
|
||||||
|
consumed += chosenModels[i].size
|
||||||
|
slog.Info("target vram", "count", i, "vramMB", max/1024/1024, "modelsMB", consumed/1024/1024)
|
||||||
|
|
||||||
|
wg.Add(1)
|
||||||
|
go func(i int) {
|
||||||
|
defer wg.Done()
|
||||||
|
for j := 0; j < 3; j++ {
|
||||||
|
slog.Info("Starting", "req", i, "iter", j, "model", req[i].Model)
|
||||||
|
DoGenerate(ctx, t, client, req[i], resp[i], 90*time.Second, 5*time.Second)
|
||||||
|
}
|
||||||
|
}(i)
|
||||||
|
}
|
||||||
|
wg.Wait()
|
||||||
|
}
|
|
@ -4,7 +4,6 @@ package integration
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"context"
|
"context"
|
||||||
"net/http"
|
|
||||||
"testing"
|
"testing"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
|
@ -25,5 +24,5 @@ func TestContextExhaustion(t *testing.T) {
|
||||||
"num_ctx": 128,
|
"num_ctx": 128,
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
GenerateTestHelper(ctx, t, &http.Client{}, req, []string{"once", "upon", "lived"})
|
GenerateTestHelper(ctx, t, req, []string{"once", "upon", "lived"})
|
||||||
}
|
}
|
||||||
|
|
|
@ -5,7 +5,6 @@ package integration
|
||||||
import (
|
import (
|
||||||
"context"
|
"context"
|
||||||
"encoding/base64"
|
"encoding/base64"
|
||||||
"net/http"
|
|
||||||
"testing"
|
"testing"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
|
@ -29,10 +28,11 @@ func TestIntegrationMultimodal(t *testing.T) {
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
|
|
||||||
resp := "the ollamas"
|
// Note: sometimes it returns "the ollamas" sometimes "the ollams"
|
||||||
|
resp := "the ollam"
|
||||||
ctx, cancel := context.WithTimeout(context.Background(), 3*time.Minute)
|
ctx, cancel := context.WithTimeout(context.Background(), 3*time.Minute)
|
||||||
defer cancel()
|
defer cancel()
|
||||||
GenerateTestHelper(ctx, t, &http.Client{}, req, []string{resp})
|
GenerateTestHelper(ctx, t, req, []string{resp})
|
||||||
}
|
}
|
||||||
|
|
||||||
const imageEncoding = `iVBORw0KGgoAAAANSUhEUgAAANIAAAB4CAYAAACHHqzKAAAAAXNSR0IArs4c6QAAAIRlWElmTU0AKgAAAAgABQESAAMAAAABAAEAAAEaAAUAAAABAAAASgEb
|
const imageEncoding = `iVBORw0KGgoAAAANSUhEUgAAANIAAAB4CAYAAACHHqzKAAAAAXNSR0IArs4c6QAAAIRlWElmTU0AKgAAAAgABQESAAMAAAABAAEAAAEaAAUAAAABAAAASgEb
|
||||||
|
|
|
@ -4,8 +4,6 @@ package integration
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"context"
|
"context"
|
||||||
"net/http"
|
|
||||||
"sync"
|
|
||||||
"testing"
|
"testing"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
|
@ -45,25 +43,5 @@ var (
|
||||||
func TestIntegrationSimpleOrcaMini(t *testing.T) {
|
func TestIntegrationSimpleOrcaMini(t *testing.T) {
|
||||||
ctx, cancel := context.WithTimeout(context.Background(), time.Second*120)
|
ctx, cancel := context.WithTimeout(context.Background(), time.Second*120)
|
||||||
defer cancel()
|
defer cancel()
|
||||||
GenerateTestHelper(ctx, t, &http.Client{}, req[0], resp[0])
|
GenerateTestHelper(ctx, t, req[0], resp[0])
|
||||||
}
|
}
|
||||||
|
|
||||||
// TODO
|
|
||||||
// The server always loads a new runner and closes the old one, which forces serial execution
|
|
||||||
// At present this test case fails with concurrency problems. Eventually we should try to
|
|
||||||
// get true concurrency working with n_parallel support in the backend
|
|
||||||
func TestIntegrationConcurrentPredictOrcaMini(t *testing.T) {
|
|
||||||
var wg sync.WaitGroup
|
|
||||||
wg.Add(len(req))
|
|
||||||
ctx, cancel := context.WithTimeout(context.Background(), time.Second*120)
|
|
||||||
defer cancel()
|
|
||||||
for i := 0; i < len(req); i++ {
|
|
||||||
go func(i int) {
|
|
||||||
defer wg.Done()
|
|
||||||
GenerateTestHelper(ctx, t, &http.Client{}, req[i], resp[i])
|
|
||||||
}(i)
|
|
||||||
}
|
|
||||||
wg.Wait()
|
|
||||||
}
|
|
||||||
|
|
||||||
// TODO - create a parallel test with 2 different models once we support concurrency
|
|
||||||
|
|
117
integration/max_queue_test.go
Normal file
117
integration/max_queue_test.go
Normal file
|
@ -0,0 +1,117 @@
|
||||||
|
//go:build integration
|
||||||
|
|
||||||
|
package integration
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"errors"
|
||||||
|
"fmt"
|
||||||
|
"log/slog"
|
||||||
|
"os"
|
||||||
|
"strconv"
|
||||||
|
"strings"
|
||||||
|
"sync"
|
||||||
|
"testing"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"github.com/ollama/ollama/api"
|
||||||
|
"github.com/stretchr/testify/require"
|
||||||
|
)
|
||||||
|
|
||||||
|
func TestMaxQueue(t *testing.T) {
|
||||||
|
// Note: This test can be quite slow when running in CPU mode, so keep the threadCount low unless your on GPU
|
||||||
|
// Also note that by default Darwin can't sustain > ~128 connections without adjusting limits
|
||||||
|
threadCount := 32
|
||||||
|
mq := os.Getenv("OLLAMA_MAX_QUEUE")
|
||||||
|
if mq != "" {
|
||||||
|
var err error
|
||||||
|
threadCount, err = strconv.Atoi(mq)
|
||||||
|
require.NoError(t, err)
|
||||||
|
} else {
|
||||||
|
os.Setenv("OLLAMA_MAX_QUEUE", fmt.Sprintf("%d", threadCount))
|
||||||
|
}
|
||||||
|
|
||||||
|
req := api.GenerateRequest{
|
||||||
|
Model: "orca-mini",
|
||||||
|
Prompt: "write a long historical fiction story about christopher columbus. use at least 10 facts from his actual journey",
|
||||||
|
Options: map[string]interface{}{
|
||||||
|
"seed": 42,
|
||||||
|
"temperature": 0.0,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
resp := []string{"explore", "discover", "ocean"}
|
||||||
|
|
||||||
|
// CPU mode takes much longer at the limit with a large queue setting
|
||||||
|
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Minute)
|
||||||
|
defer cancel()
|
||||||
|
client, _, cleanup := InitServerConnection(ctx, t)
|
||||||
|
defer cleanup()
|
||||||
|
|
||||||
|
require.NoError(t, PullIfMissing(ctx, client, req.Model))
|
||||||
|
|
||||||
|
// Context for the worker threads so we can shut them down
|
||||||
|
// embedCtx, embedCancel := context.WithCancel(ctx)
|
||||||
|
embedCtx := ctx
|
||||||
|
|
||||||
|
var genwg sync.WaitGroup
|
||||||
|
go func() {
|
||||||
|
genwg.Add(1)
|
||||||
|
defer genwg.Done()
|
||||||
|
slog.Info("Starting generate request")
|
||||||
|
DoGenerate(ctx, t, client, req, resp, 45*time.Second, 5*time.Second)
|
||||||
|
slog.Info("generate completed")
|
||||||
|
}()
|
||||||
|
|
||||||
|
// Give the generate a chance to get started before we start hammering on embed requests
|
||||||
|
time.Sleep(5 * time.Millisecond)
|
||||||
|
|
||||||
|
threadCount += 10 // Add a few extra to ensure we push the queue past its limit
|
||||||
|
busyCount := 0
|
||||||
|
resetByPeerCount := 0
|
||||||
|
canceledCount := 0
|
||||||
|
succesCount := 0
|
||||||
|
counterMu := sync.Mutex{}
|
||||||
|
var embedwg sync.WaitGroup
|
||||||
|
for i := 0; i < threadCount; i++ {
|
||||||
|
go func(i int) {
|
||||||
|
embedwg.Add(1)
|
||||||
|
defer embedwg.Done()
|
||||||
|
slog.Info("embed started", "id", i)
|
||||||
|
embedReq := api.EmbeddingRequest{
|
||||||
|
Model: req.Model,
|
||||||
|
Prompt: req.Prompt,
|
||||||
|
Options: req.Options,
|
||||||
|
}
|
||||||
|
// Fresh client for every request
|
||||||
|
client, _ = GetTestEndpoint()
|
||||||
|
|
||||||
|
resp, genErr := client.Embeddings(embedCtx, &embedReq)
|
||||||
|
counterMu.Lock()
|
||||||
|
defer counterMu.Unlock()
|
||||||
|
switch {
|
||||||
|
case genErr == nil:
|
||||||
|
succesCount++
|
||||||
|
require.Greater(t, len(resp.Embedding), 5) // somewhat arbitrary, but sufficient to be reasonable
|
||||||
|
case errors.Is(genErr, context.Canceled):
|
||||||
|
canceledCount++
|
||||||
|
case strings.Contains(genErr.Error(), "busy"):
|
||||||
|
busyCount++
|
||||||
|
case strings.Contains(genErr.Error(), "connection reset by peer"):
|
||||||
|
resetByPeerCount++
|
||||||
|
default:
|
||||||
|
require.NoError(t, genErr, "%d request failed", i)
|
||||||
|
}
|
||||||
|
|
||||||
|
slog.Info("embed finished", "id", i)
|
||||||
|
}(i)
|
||||||
|
}
|
||||||
|
genwg.Wait()
|
||||||
|
slog.Info("generate done, waiting for embeds")
|
||||||
|
embedwg.Wait()
|
||||||
|
|
||||||
|
require.Equal(t, resetByPeerCount, 0, "Connections reset by peer, have you updated your fd and socket limits?")
|
||||||
|
require.True(t, busyCount > 0, "no requests hit busy error but some should have")
|
||||||
|
require.True(t, canceledCount == 0, "no requests should have been canceled due to timeout")
|
||||||
|
|
||||||
|
slog.Info("embeds completed", "success", succesCount, "busy", busyCount, "reset", resetByPeerCount, "canceled", canceledCount)
|
||||||
|
}
|
|
@ -5,13 +5,14 @@ package integration
|
||||||
import (
|
import (
|
||||||
"bytes"
|
"bytes"
|
||||||
"context"
|
"context"
|
||||||
"encoding/json"
|
"errors"
|
||||||
"fmt"
|
"fmt"
|
||||||
"io"
|
"io"
|
||||||
"log/slog"
|
"log/slog"
|
||||||
"math/rand"
|
"math/rand"
|
||||||
"net"
|
"net"
|
||||||
"net/http"
|
"net/http"
|
||||||
|
"net/url"
|
||||||
"os"
|
"os"
|
||||||
"path/filepath"
|
"path/filepath"
|
||||||
"runtime"
|
"runtime"
|
||||||
|
@ -23,9 +24,13 @@ import (
|
||||||
|
|
||||||
"github.com/ollama/ollama/api"
|
"github.com/ollama/ollama/api"
|
||||||
"github.com/ollama/ollama/app/lifecycle"
|
"github.com/ollama/ollama/app/lifecycle"
|
||||||
"github.com/stretchr/testify/assert"
|
"github.com/stretchr/testify/require"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
func Init() {
|
||||||
|
lifecycle.InitLogging()
|
||||||
|
}
|
||||||
|
|
||||||
func FindPort() string {
|
func FindPort() string {
|
||||||
port := 0
|
port := 0
|
||||||
if a, err := net.ResolveTCPAddr("tcp", "localhost:0"); err == nil {
|
if a, err := net.ResolveTCPAddr("tcp", "localhost:0"); err == nil {
|
||||||
|
@ -41,7 +46,7 @@ func FindPort() string {
|
||||||
return strconv.Itoa(port)
|
return strconv.Itoa(port)
|
||||||
}
|
}
|
||||||
|
|
||||||
func GetTestEndpoint() (string, string) {
|
func GetTestEndpoint() (*api.Client, string) {
|
||||||
defaultPort := "11434"
|
defaultPort := "11434"
|
||||||
ollamaHost := os.Getenv("OLLAMA_HOST")
|
ollamaHost := os.Getenv("OLLAMA_HOST")
|
||||||
|
|
||||||
|
@ -67,16 +72,20 @@ func GetTestEndpoint() (string, string) {
|
||||||
port = FindPort()
|
port = FindPort()
|
||||||
}
|
}
|
||||||
|
|
||||||
url := fmt.Sprintf("%s:%s", host, port)
|
slog.Info("server connection", "host", host, "port", port)
|
||||||
slog.Info("server connection", "url", url)
|
|
||||||
return scheme, url
|
return api.NewClient(
|
||||||
|
&url.URL{
|
||||||
|
Scheme: scheme,
|
||||||
|
Host: net.JoinHostPort(host, port),
|
||||||
|
},
|
||||||
|
http.DefaultClient), fmt.Sprintf("%s:%s", host, port)
|
||||||
}
|
}
|
||||||
|
|
||||||
// TODO make fanicier, grab logs, etc.
|
|
||||||
var serverMutex sync.Mutex
|
var serverMutex sync.Mutex
|
||||||
var serverReady bool
|
var serverReady bool
|
||||||
|
|
||||||
func StartServer(ctx context.Context, ollamaHost string) error {
|
func startServer(ctx context.Context, ollamaHost string) error {
|
||||||
// Make sure the server has been built
|
// Make sure the server has been built
|
||||||
CLIName, err := filepath.Abs("../ollama")
|
CLIName, err := filepath.Abs("../ollama")
|
||||||
if err != nil {
|
if err != nil {
|
||||||
|
@ -98,7 +107,7 @@ func StartServer(ctx context.Context, ollamaHost string) error {
|
||||||
|
|
||||||
if tmp := os.Getenv("OLLAMA_HOST"); tmp != ollamaHost {
|
if tmp := os.Getenv("OLLAMA_HOST"); tmp != ollamaHost {
|
||||||
slog.Info("setting env", "OLLAMA_HOST", ollamaHost)
|
slog.Info("setting env", "OLLAMA_HOST", ollamaHost)
|
||||||
os.Setenv("OLLAMA_HOST", ollamaHost)
|
t.Setenv("OLLAMA_HOST", ollamaHost)
|
||||||
}
|
}
|
||||||
|
|
||||||
slog.Info("starting server", "url", ollamaHost)
|
slog.Info("starting server", "url", ollamaHost)
|
||||||
|
@ -125,67 +134,76 @@ func StartServer(ctx context.Context, ollamaHost string) error {
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func PullIfMissing(ctx context.Context, client *http.Client, scheme, testEndpoint, modelName string) error {
|
func PullIfMissing(ctx context.Context, client *api.Client, modelName string) error {
|
||||||
slog.Info("checking status of model", "model", modelName)
|
slog.Info("checking status of model", "model", modelName)
|
||||||
showReq := &api.ShowRequest{Name: modelName}
|
showReq := &api.ShowRequest{Name: modelName}
|
||||||
requestJSON, err := json.Marshal(showReq)
|
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
|
|
||||||
req, err := http.NewRequest("POST", scheme+"://"+testEndpoint+"/api/show", bytes.NewReader(requestJSON))
|
showCtx, cancel := context.WithDeadlineCause(
|
||||||
if err != nil {
|
ctx,
|
||||||
|
time.Now().Add(5*time.Second),
|
||||||
|
fmt.Errorf("show for existing model %s took too long", modelName),
|
||||||
|
)
|
||||||
|
defer cancel()
|
||||||
|
_, err := client.Show(showCtx, showReq)
|
||||||
|
var statusError api.StatusError
|
||||||
|
switch {
|
||||||
|
case errors.As(err, &statusError) && statusError.StatusCode == http.StatusNotFound:
|
||||||
|
break
|
||||||
|
case err != nil:
|
||||||
return err
|
return err
|
||||||
}
|
default:
|
||||||
|
|
||||||
// Make the request with the HTTP client
|
|
||||||
response, err := client.Do(req.WithContext(ctx))
|
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
defer response.Body.Close()
|
|
||||||
if response.StatusCode == 200 {
|
|
||||||
slog.Info("model already present", "model", modelName)
|
slog.Info("model already present", "model", modelName)
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
slog.Info("model missing", "status", response.StatusCode)
|
slog.Info("model missing", "model", modelName)
|
||||||
|
|
||||||
pullReq := &api.PullRequest{Name: modelName, Stream: &stream}
|
stallDuration := 30 * time.Second // This includes checksum verification, which can take a while on larger models
|
||||||
requestJSON, err = json.Marshal(pullReq)
|
stallTimer := time.NewTimer(stallDuration)
|
||||||
if err != nil {
|
fn := func(resp api.ProgressResponse) error {
|
||||||
return err
|
// fmt.Print(".")
|
||||||
|
if !stallTimer.Reset(stallDuration) {
|
||||||
|
return fmt.Errorf("stall was detected, aborting status reporting")
|
||||||
}
|
}
|
||||||
|
|
||||||
req, err = http.NewRequest("POST", scheme+"://"+testEndpoint+"/api/pull", bytes.NewReader(requestJSON))
|
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
slog.Info("pulling", "model", modelName)
|
|
||||||
|
|
||||||
response, err = client.Do(req.WithContext(ctx))
|
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
defer response.Body.Close()
|
|
||||||
if response.StatusCode != 200 {
|
|
||||||
return fmt.Errorf("failed to pull model") // TODO more details perhaps
|
|
||||||
}
|
|
||||||
slog.Info("model pulled", "model", modelName)
|
|
||||||
return nil
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
stream := true
|
||||||
|
pullReq := &api.PullRequest{Name: modelName, Stream: &stream}
|
||||||
|
|
||||||
|
var pullError error
|
||||||
|
|
||||||
|
done := make(chan int)
|
||||||
|
go func() {
|
||||||
|
pullError = client.Pull(ctx, pullReq, fn)
|
||||||
|
done <- 0
|
||||||
|
}()
|
||||||
|
|
||||||
|
select {
|
||||||
|
case <-stallTimer.C:
|
||||||
|
return fmt.Errorf("download stalled")
|
||||||
|
case <-done:
|
||||||
|
return pullError
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
var serverProcMutex sync.Mutex
|
var serverProcMutex sync.Mutex
|
||||||
|
|
||||||
func GenerateTestHelper(ctx context.Context, t *testing.T, client *http.Client, genReq api.GenerateRequest, anyResp []string) {
|
// Returns an Client, the testEndpoint, and a cleanup function, fails the test on errors
|
||||||
|
// Starts the server if needed
|
||||||
// TODO maybe stuff in an init routine?
|
func InitServerConnection(ctx context.Context, t *testing.T) (*api.Client, string, func()) {
|
||||||
lifecycle.InitLogging()
|
client, testEndpoint := GetTestEndpoint()
|
||||||
|
if os.Getenv("OLLAMA_TEST_EXISTING") == "" {
|
||||||
requestJSON, err := json.Marshal(genReq)
|
serverProcMutex.Lock()
|
||||||
|
fp, err := os.CreateTemp("", "ollama-server-*.log")
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatalf("Error serializing request: %v", err)
|
t.Fatalf("failed to generate log file: %s", err)
|
||||||
}
|
}
|
||||||
defer func() {
|
lifecycle.ServerLogFile = fp.Name()
|
||||||
|
fp.Close()
|
||||||
|
require.NoError(t, startServer(ctx, testEndpoint))
|
||||||
|
}
|
||||||
|
|
||||||
|
return client, testEndpoint, func() {
|
||||||
if os.Getenv("OLLAMA_TEST_EXISTING") == "" {
|
if os.Getenv("OLLAMA_TEST_EXISTING") == "" {
|
||||||
defer serverProcMutex.Unlock()
|
defer serverProcMutex.Unlock()
|
||||||
if t.Failed() {
|
if t.Failed() {
|
||||||
|
@ -203,63 +221,118 @@ func GenerateTestHelper(ctx context.Context, t *testing.T, client *http.Client,
|
||||||
os.Stderr.Write(data)
|
os.Stderr.Write(data)
|
||||||
slog.Warn("END OF SERVER")
|
slog.Warn("END OF SERVER")
|
||||||
}
|
}
|
||||||
err = os.Remove(lifecycle.ServerLogFile)
|
err := os.Remove(lifecycle.ServerLogFile)
|
||||||
if err != nil && !os.IsNotExist(err) {
|
if err != nil && !os.IsNotExist(err) {
|
||||||
slog.Warn("failed to cleanup", "logfile", lifecycle.ServerLogFile, "error", err)
|
slog.Warn("failed to cleanup", "logfile", lifecycle.ServerLogFile, "error", err)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func GenerateTestHelper(ctx context.Context, t *testing.T, genReq api.GenerateRequest, anyResp []string) {
|
||||||
|
client, _, cleanup := InitServerConnection(ctx, t)
|
||||||
|
defer cleanup()
|
||||||
|
require.NoError(t, PullIfMissing(ctx, client, genReq.Model))
|
||||||
|
DoGenerate(ctx, t, client, genReq, anyResp, 30*time.Second, 10*time.Second)
|
||||||
|
}
|
||||||
|
|
||||||
|
func DoGenerate(ctx context.Context, t *testing.T, client *api.Client, genReq api.GenerateRequest, anyResp []string, initialTimeout, streamTimeout time.Duration) {
|
||||||
|
stallTimer := time.NewTimer(initialTimeout)
|
||||||
|
var buf bytes.Buffer
|
||||||
|
fn := func(response api.GenerateResponse) error {
|
||||||
|
// fmt.Print(".")
|
||||||
|
buf.Write([]byte(response.Response))
|
||||||
|
if !stallTimer.Reset(streamTimeout) {
|
||||||
|
return fmt.Errorf("stall was detected while streaming response, aborting")
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
stream := true
|
||||||
|
genReq.Stream = &stream
|
||||||
|
done := make(chan int)
|
||||||
|
var genErr error
|
||||||
|
go func() {
|
||||||
|
genErr = client.Generate(ctx, &genReq, fn)
|
||||||
|
done <- 0
|
||||||
}()
|
}()
|
||||||
scheme, testEndpoint := GetTestEndpoint()
|
|
||||||
|
|
||||||
if os.Getenv("OLLAMA_TEST_EXISTING") == "" {
|
select {
|
||||||
serverProcMutex.Lock()
|
case <-stallTimer.C:
|
||||||
fp, err := os.CreateTemp("", "ollama-server-*.log")
|
if buf.Len() == 0 {
|
||||||
if err != nil {
|
t.Errorf("generate never started. Timed out after :%s", initialTimeout.String())
|
||||||
t.Fatalf("failed to generate log file: %s", err)
|
} else {
|
||||||
|
t.Errorf("generate stalled. Response so far:%s", buf.String())
|
||||||
}
|
}
|
||||||
lifecycle.ServerLogFile = fp.Name()
|
case <-done:
|
||||||
fp.Close()
|
require.NoError(t, genErr, "failed with %s request prompt %s ", genReq.Model, genReq.Prompt)
|
||||||
assert.NoError(t, StartServer(ctx, testEndpoint))
|
|
||||||
}
|
|
||||||
|
|
||||||
err = PullIfMissing(ctx, client, scheme, testEndpoint, genReq.Model)
|
|
||||||
if err != nil {
|
|
||||||
t.Fatalf("Error pulling model: %v", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
// Make the request and get the response
|
|
||||||
req, err := http.NewRequest("POST", scheme+"://"+testEndpoint+"/api/generate", bytes.NewReader(requestJSON))
|
|
||||||
if err != nil {
|
|
||||||
t.Fatalf("Error creating request: %v", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
// Set the content type for the request
|
|
||||||
req.Header.Set("Content-Type", "application/json")
|
|
||||||
|
|
||||||
// Make the request with the HTTP client
|
|
||||||
response, err := client.Do(req.WithContext(ctx))
|
|
||||||
if err != nil {
|
|
||||||
t.Fatalf("Error making request: %v", err)
|
|
||||||
}
|
|
||||||
defer response.Body.Close()
|
|
||||||
body, err := io.ReadAll(response.Body)
|
|
||||||
assert.NoError(t, err)
|
|
||||||
assert.Equal(t, response.StatusCode, 200, string(body))
|
|
||||||
|
|
||||||
// Verify the response is valid JSON
|
|
||||||
var payload api.GenerateResponse
|
|
||||||
err = json.Unmarshal(body, &payload)
|
|
||||||
if err != nil {
|
|
||||||
assert.NoError(t, err, body)
|
|
||||||
}
|
|
||||||
|
|
||||||
// Verify the response contains the expected data
|
// Verify the response contains the expected data
|
||||||
|
response := buf.String()
|
||||||
atLeastOne := false
|
atLeastOne := false
|
||||||
for _, resp := range anyResp {
|
for _, resp := range anyResp {
|
||||||
if strings.Contains(strings.ToLower(payload.Response), resp) {
|
if strings.Contains(strings.ToLower(response), resp) {
|
||||||
atLeastOne = true
|
atLeastOne = true
|
||||||
break
|
break
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
assert.True(t, atLeastOne, "none of %v found in %s", anyResp, payload.Response)
|
require.True(t, atLeastOne, "none of %v found in %s", anyResp, response)
|
||||||
|
slog.Info("test pass", "model", genReq.Model, "prompt", genReq.Prompt, "contains", anyResp, "response", response)
|
||||||
|
case <-ctx.Done():
|
||||||
|
t.Error("outer test context done while waiting for generate")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Generate a set of requests
|
||||||
|
// By default each request uses orca-mini as the model
|
||||||
|
func GenerateRequests() ([]api.GenerateRequest, [][]string) {
|
||||||
|
return []api.GenerateRequest{
|
||||||
|
{
|
||||||
|
Model: "orca-mini",
|
||||||
|
Prompt: "why is the ocean blue?",
|
||||||
|
Stream: &stream,
|
||||||
|
Options: map[string]interface{}{
|
||||||
|
"seed": 42,
|
||||||
|
"temperature": 0.0,
|
||||||
|
},
|
||||||
|
}, {
|
||||||
|
Model: "orca-mini",
|
||||||
|
Prompt: "why is the color of dirt brown?",
|
||||||
|
Stream: &stream,
|
||||||
|
Options: map[string]interface{}{
|
||||||
|
"seed": 42,
|
||||||
|
"temperature": 0.0,
|
||||||
|
},
|
||||||
|
}, {
|
||||||
|
Model: "orca-mini",
|
||||||
|
Prompt: "what is the origin of the us thanksgiving holiday?",
|
||||||
|
Stream: &stream,
|
||||||
|
Options: map[string]interface{}{
|
||||||
|
"seed": 42,
|
||||||
|
"temperature": 0.0,
|
||||||
|
},
|
||||||
|
}, {
|
||||||
|
Model: "orca-mini",
|
||||||
|
Prompt: "what is the origin of independence day?",
|
||||||
|
Stream: &stream,
|
||||||
|
Options: map[string]interface{}{
|
||||||
|
"seed": 42,
|
||||||
|
"temperature": 0.0,
|
||||||
|
},
|
||||||
|
}, {
|
||||||
|
Model: "orca-mini",
|
||||||
|
Prompt: "what is the composition of air?",
|
||||||
|
Stream: &stream,
|
||||||
|
Options: map[string]interface{}{
|
||||||
|
"seed": 42,
|
||||||
|
"temperature": 0.0,
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
[][]string{
|
||||||
|
[]string{"sunlight"},
|
||||||
|
[]string{"soil", "organic", "earth", "black", "tan"},
|
||||||
|
[]string{"england", "english", "massachusetts", "pilgrims"},
|
||||||
|
[]string{"fourth", "july", "declaration", "independence"},
|
||||||
|
[]string{"nitrogen", "oxygen", "carbon", "dioxide"},
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
17
llm/ext_server/server.cpp
vendored
17
llm/ext_server/server.cpp
vendored
|
@ -1032,7 +1032,7 @@ struct llama_server_context
|
||||||
slot.has_next_token = false;
|
slot.has_next_token = false;
|
||||||
}
|
}
|
||||||
|
|
||||||
if (!slot.cache_tokens.empty() && result.tok == llama_token_eos(model))
|
if (!slot.cache_tokens.empty() && llama_token_is_eog(model, result.tok))
|
||||||
{
|
{
|
||||||
slot.stopped_eos = true;
|
slot.stopped_eos = true;
|
||||||
slot.has_next_token = false;
|
slot.has_next_token = false;
|
||||||
|
@ -1144,12 +1144,15 @@ struct llama_server_context
|
||||||
|
|
||||||
res.result_json = json
|
res.result_json = json
|
||||||
{
|
{
|
||||||
{"content", tkn.text_to_send},
|
|
||||||
{"stop", false},
|
{"stop", false},
|
||||||
{"slot_id", slot.id},
|
{"slot_id", slot.id},
|
||||||
{"multimodal", multimodal}
|
{"multimodal", multimodal}
|
||||||
};
|
};
|
||||||
|
|
||||||
|
if (!llama_token_is_eog(model, tkn.tok)) {
|
||||||
|
res.result_json["content"] = tkn.text_to_send;
|
||||||
|
}
|
||||||
|
|
||||||
if (slot.sparams.n_probs > 0)
|
if (slot.sparams.n_probs > 0)
|
||||||
{
|
{
|
||||||
std::vector<completion_token_output> probs_output = {};
|
std::vector<completion_token_output> probs_output = {};
|
||||||
|
@ -1183,8 +1186,6 @@ struct llama_server_context
|
||||||
{"model", params.model_alias},
|
{"model", params.model_alias},
|
||||||
{"tokens_predicted", slot.n_decoded},
|
{"tokens_predicted", slot.n_decoded},
|
||||||
{"tokens_evaluated", slot.n_prompt_tokens},
|
{"tokens_evaluated", slot.n_prompt_tokens},
|
||||||
{"generation_settings", get_formated_generation(slot)},
|
|
||||||
{"prompt", slot.prompt},
|
|
||||||
{"truncated", slot.truncated},
|
{"truncated", slot.truncated},
|
||||||
{"stopped_eos", slot.stopped_eos},
|
{"stopped_eos", slot.stopped_eos},
|
||||||
{"stopped_word", slot.stopped_word},
|
{"stopped_word", slot.stopped_word},
|
||||||
|
@ -2644,18 +2645,18 @@ static void server_params_parse(int argc, char **argv, server_params &sparams,
|
||||||
if (strncmp(sep, "int:", 4) == 0) {
|
if (strncmp(sep, "int:", 4) == 0) {
|
||||||
sep += 4;
|
sep += 4;
|
||||||
kvo.tag = LLAMA_KV_OVERRIDE_TYPE_INT;
|
kvo.tag = LLAMA_KV_OVERRIDE_TYPE_INT;
|
||||||
kvo.int_value = std::atol(sep);
|
kvo.val_i64 = std::atol(sep);
|
||||||
} else if (strncmp(sep, "float:", 6) == 0) {
|
} else if (strncmp(sep, "float:", 6) == 0) {
|
||||||
sep += 6;
|
sep += 6;
|
||||||
kvo.tag = LLAMA_KV_OVERRIDE_TYPE_FLOAT;
|
kvo.tag = LLAMA_KV_OVERRIDE_TYPE_FLOAT;
|
||||||
kvo.float_value = std::atof(sep);
|
kvo.val_f64 = std::atof(sep);
|
||||||
} else if (strncmp(sep, "bool:", 5) == 0) {
|
} else if (strncmp(sep, "bool:", 5) == 0) {
|
||||||
sep += 5;
|
sep += 5;
|
||||||
kvo.tag = LLAMA_KV_OVERRIDE_TYPE_BOOL;
|
kvo.tag = LLAMA_KV_OVERRIDE_TYPE_BOOL;
|
||||||
if (std::strcmp(sep, "true") == 0) {
|
if (std::strcmp(sep, "true") == 0) {
|
||||||
kvo.bool_value = true;
|
kvo.val_bool = true;
|
||||||
} else if (std::strcmp(sep, "false") == 0) {
|
} else if (std::strcmp(sep, "false") == 0) {
|
||||||
kvo.bool_value = false;
|
kvo.val_bool = false;
|
||||||
} else {
|
} else {
|
||||||
fprintf(stderr, "error: Invalid boolean value for KV override: %s\n", argv[i]);
|
fprintf(stderr, "error: Invalid boolean value for KV override: %s\n", argv[i]);
|
||||||
invalid_param = true;
|
invalid_param = true;
|
||||||
|
|
140
llm/filetype.go
Normal file
140
llm/filetype.go
Normal file
|
@ -0,0 +1,140 @@
|
||||||
|
package llm
|
||||||
|
|
||||||
|
import "fmt"
|
||||||
|
|
||||||
|
type fileType uint32
|
||||||
|
|
||||||
|
const (
|
||||||
|
fileTypeF32 fileType = iota
|
||||||
|
fileTypeF16
|
||||||
|
fileTypeQ4_0
|
||||||
|
fileTypeQ4_1
|
||||||
|
fileTypeQ4_1_F16
|
||||||
|
fileTypeQ4_2 // unused
|
||||||
|
fileTypeQ4_3 // unused
|
||||||
|
fileTypeQ8_0
|
||||||
|
fileTypeQ5_0
|
||||||
|
fileTypeQ5_1
|
||||||
|
fileTypeQ2_K
|
||||||
|
fileTypeQ3_K_S
|
||||||
|
fileTypeQ3_K_M
|
||||||
|
fileTypeQ3_K_L
|
||||||
|
fileTypeQ4_K_S
|
||||||
|
fileTypeQ4_K_M
|
||||||
|
fileTypeQ5_K_S
|
||||||
|
fileTypeQ5_K_M
|
||||||
|
fileTypeQ6_K
|
||||||
|
fileTypeIQ2_XXS
|
||||||
|
fileTypeIQ2_XS
|
||||||
|
fileTypeQ2_K_S
|
||||||
|
fileTypeQ3_K_XS
|
||||||
|
fileTypeIQ3_XXS
|
||||||
|
|
||||||
|
fileTypeUnknown
|
||||||
|
)
|
||||||
|
|
||||||
|
func ParseFileType(s string) (fileType, error) {
|
||||||
|
switch s {
|
||||||
|
case "F32":
|
||||||
|
return fileTypeF32, nil
|
||||||
|
case "F16":
|
||||||
|
return fileTypeF16, nil
|
||||||
|
case "Q4_0":
|
||||||
|
return fileTypeQ4_0, nil
|
||||||
|
case "Q4_1":
|
||||||
|
return fileTypeQ4_1, nil
|
||||||
|
case "Q4_1_F16":
|
||||||
|
return fileTypeQ4_1_F16, nil
|
||||||
|
case "Q8_0":
|
||||||
|
return fileTypeQ8_0, nil
|
||||||
|
case "Q5_0":
|
||||||
|
return fileTypeQ5_0, nil
|
||||||
|
case "Q5_1":
|
||||||
|
return fileTypeQ5_1, nil
|
||||||
|
case "Q2_K":
|
||||||
|
return fileTypeQ2_K, nil
|
||||||
|
case "Q3_K_S":
|
||||||
|
return fileTypeQ3_K_S, nil
|
||||||
|
case "Q3_K_M":
|
||||||
|
return fileTypeQ3_K_M, nil
|
||||||
|
case "Q3_K_L":
|
||||||
|
return fileTypeQ3_K_L, nil
|
||||||
|
case "Q4_K_S":
|
||||||
|
return fileTypeQ4_K_S, nil
|
||||||
|
case "Q4_K_M":
|
||||||
|
return fileTypeQ4_K_M, nil
|
||||||
|
case "Q5_K_S":
|
||||||
|
return fileTypeQ5_K_S, nil
|
||||||
|
case "Q5_K_M":
|
||||||
|
return fileTypeQ5_K_M, nil
|
||||||
|
case "Q6_K":
|
||||||
|
return fileTypeQ6_K, nil
|
||||||
|
case "IQ2_XXS":
|
||||||
|
return fileTypeIQ2_XXS, nil
|
||||||
|
case "IQ2_XS":
|
||||||
|
return fileTypeIQ2_XS, nil
|
||||||
|
case "Q2_K_S":
|
||||||
|
return fileTypeQ2_K_S, nil
|
||||||
|
case "Q3_K_XS":
|
||||||
|
return fileTypeQ3_K_XS, nil
|
||||||
|
case "IQ3_XXS":
|
||||||
|
return fileTypeIQ3_XXS, nil
|
||||||
|
default:
|
||||||
|
return fileTypeUnknown, fmt.Errorf("unknown fileType: %s", s)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (t fileType) String() string {
|
||||||
|
switch t {
|
||||||
|
case fileTypeF32:
|
||||||
|
return "F32"
|
||||||
|
case fileTypeF16:
|
||||||
|
return "F16"
|
||||||
|
case fileTypeQ4_0:
|
||||||
|
return "Q4_0"
|
||||||
|
case fileTypeQ4_1:
|
||||||
|
return "Q4_1"
|
||||||
|
case fileTypeQ4_1_F16:
|
||||||
|
return "Q4_1_F16"
|
||||||
|
case fileTypeQ8_0:
|
||||||
|
return "Q8_0"
|
||||||
|
case fileTypeQ5_0:
|
||||||
|
return "Q5_0"
|
||||||
|
case fileTypeQ5_1:
|
||||||
|
return "Q5_1"
|
||||||
|
case fileTypeQ2_K:
|
||||||
|
return "Q2_K"
|
||||||
|
case fileTypeQ3_K_S:
|
||||||
|
return "Q3_K_S"
|
||||||
|
case fileTypeQ3_K_M:
|
||||||
|
return "Q3_K_M"
|
||||||
|
case fileTypeQ3_K_L:
|
||||||
|
return "Q3_K_L"
|
||||||
|
case fileTypeQ4_K_S:
|
||||||
|
return "Q4_K_S"
|
||||||
|
case fileTypeQ4_K_M:
|
||||||
|
return "Q4_K_M"
|
||||||
|
case fileTypeQ5_K_S:
|
||||||
|
return "Q5_K_S"
|
||||||
|
case fileTypeQ5_K_M:
|
||||||
|
return "Q5_K_M"
|
||||||
|
case fileTypeQ6_K:
|
||||||
|
return "Q6_K"
|
||||||
|
case fileTypeIQ2_XXS:
|
||||||
|
return "IQ2_XXS"
|
||||||
|
case fileTypeIQ2_XS:
|
||||||
|
return "IQ2_XS"
|
||||||
|
case fileTypeQ2_K_S:
|
||||||
|
return "Q2_K_S"
|
||||||
|
case fileTypeQ3_K_XS:
|
||||||
|
return "Q3_K_XS"
|
||||||
|
case fileTypeIQ3_XXS:
|
||||||
|
return "IQ3_XXS"
|
||||||
|
default:
|
||||||
|
return "unknown"
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (t fileType) Value() uint32 {
|
||||||
|
return uint32(t)
|
||||||
|
}
|
|
@ -57,11 +57,10 @@ init_vars
|
||||||
git_module_setup
|
git_module_setup
|
||||||
apply_patches
|
apply_patches
|
||||||
|
|
||||||
|
|
||||||
init_vars
|
init_vars
|
||||||
if [ -z "${OLLAMA_SKIP_CPU_GENERATE}" ]; then
|
if [ -z "${OLLAMA_SKIP_STATIC_GENERATE}" -o "${OLLAMA_CPU_TARGET}" = "static" ]; then
|
||||||
|
# Builds by default, allows skipping, forces build if OLLAMA_CPU_TARGET="static"
|
||||||
if [ -z "${OLLAMA_CPU_TARGET}" -o "${OLLAMA_CPU_TARGET}" = "static" ]; then
|
# Enables optimized Dockerfile builds using a blanket skip and targeted overrides
|
||||||
# Static build for linking into the Go binary
|
# Static build for linking into the Go binary
|
||||||
init_vars
|
init_vars
|
||||||
CMAKE_TARGETS="--target llama --target ggml"
|
CMAKE_TARGETS="--target llama --target ggml"
|
||||||
|
@ -69,9 +68,10 @@ if [ -z "${OLLAMA_SKIP_CPU_GENERATE}" ]; then
|
||||||
BUILD_DIR="../build/linux/${ARCH}_static"
|
BUILD_DIR="../build/linux/${ARCH}_static"
|
||||||
echo "Building static library"
|
echo "Building static library"
|
||||||
build
|
build
|
||||||
fi
|
fi
|
||||||
|
|
||||||
|
|
||||||
|
init_vars
|
||||||
|
if [ -z "${OLLAMA_SKIP_CPU_GENERATE}" ]; then
|
||||||
# Users building from source can tune the exact flags we pass to cmake for configuring
|
# Users building from source can tune the exact flags we pass to cmake for configuring
|
||||||
# llama.cpp, and we'll build only 1 CPU variant in that case as the default.
|
# llama.cpp, and we'll build only 1 CPU variant in that case as the default.
|
||||||
if [ -n "${OLLAMA_CUSTOM_CPU_DEFS}" ]; then
|
if [ -n "${OLLAMA_CUSTOM_CPU_DEFS}" ]; then
|
||||||
|
@ -172,7 +172,15 @@ if [ -d "${CUDA_LIB_DIR}" ]; then
|
||||||
# Disabling has minimal performance effect while maintaining compatibility.
|
# Disabling has minimal performance effect while maintaining compatibility.
|
||||||
ARM64_DEFS="-DLLAMA_AVX=off -DLLAMA_AVX2=off -DLLAMA_AVX512=off -DLLAMA_CUDA_F16=off"
|
ARM64_DEFS="-DLLAMA_AVX=off -DLLAMA_AVX2=off -DLLAMA_AVX512=off -DLLAMA_CUDA_F16=off"
|
||||||
fi
|
fi
|
||||||
CMAKE_DEFS="-DLLAMA_CUDA=on -DLLAMA_CUDA_FORCE_MMQ=on -DCMAKE_CUDA_ARCHITECTURES=${CMAKE_CUDA_ARCHITECTURES} ${COMMON_CMAKE_DEFS} ${CMAKE_DEFS} ${ARM64_DEFS}"
|
# Users building from source can tune the exact flags we pass to cmake for configuring llama.cpp
|
||||||
|
if [ -n "${OLLAMA_CUSTOM_CUDA_DEFS}" ]; then
|
||||||
|
echo "OLLAMA_CUSTOM_CUDA_DEFS=\"${OLLAMA_CUSTOM_CUDA_DEFS}\""
|
||||||
|
CMAKE_CUDA_DEFS="-DLLAMA_CUDA=on -DCMAKE_CUDA_ARCHITECTURES=${CMAKE_CUDA_ARCHITECTURES} ${OLLAMA_CUSTOM_CUDA_DEFS}"
|
||||||
|
echo "Building custom CUDA GPU"
|
||||||
|
else
|
||||||
|
CMAKE_CUDA_DEFS="-DLLAMA_CUDA=on -DLLAMA_CUDA_FORCE_MMQ=on -DCMAKE_CUDA_ARCHITECTURES=${CMAKE_CUDA_ARCHITECTURES}"
|
||||||
|
fi
|
||||||
|
CMAKE_DEFS="${COMMON_CMAKE_DEFS} ${CMAKE_DEFS} ${ARM64_DEFS} ${CMAKE_CUDA_DEFS}"
|
||||||
BUILD_DIR="../build/linux/${ARCH}/cuda${CUDA_VARIANT}"
|
BUILD_DIR="../build/linux/${ARCH}/cuda${CUDA_VARIANT}"
|
||||||
EXTRA_LIBS="-L${CUDA_LIB_DIR} -lcudart -lcublas -lcublasLt -lcuda"
|
EXTRA_LIBS="-L${CUDA_LIB_DIR} -lcudart -lcublas -lcublasLt -lcuda"
|
||||||
build
|
build
|
||||||
|
@ -217,6 +225,12 @@ if [ -d "${ROCM_PATH}" ]; then
|
||||||
fi
|
fi
|
||||||
init_vars
|
init_vars
|
||||||
CMAKE_DEFS="${COMMON_CMAKE_DEFS} ${CMAKE_DEFS} -DLLAMA_HIPBLAS=on -DCMAKE_C_COMPILER=$ROCM_PATH/llvm/bin/clang -DCMAKE_CXX_COMPILER=$ROCM_PATH/llvm/bin/clang++ -DAMDGPU_TARGETS=$(amdGPUs) -DGPU_TARGETS=$(amdGPUs)"
|
CMAKE_DEFS="${COMMON_CMAKE_DEFS} ${CMAKE_DEFS} -DLLAMA_HIPBLAS=on -DCMAKE_C_COMPILER=$ROCM_PATH/llvm/bin/clang -DCMAKE_CXX_COMPILER=$ROCM_PATH/llvm/bin/clang++ -DAMDGPU_TARGETS=$(amdGPUs) -DGPU_TARGETS=$(amdGPUs)"
|
||||||
|
# Users building from source can tune the exact flags we pass to cmake for configuring llama.cpp
|
||||||
|
if [ -n "${OLLAMA_CUSTOM_ROCM_DEFS}" ]; then
|
||||||
|
echo "OLLAMA_CUSTOM_ROCM_DEFS=\"${OLLAMA_CUSTOM_ROCM_DEFS}\""
|
||||||
|
CMAKE_DEFS="${CMAKE_DEFS} ${OLLAMA_CUSTOM_ROCM_DEFS}"
|
||||||
|
echo "Building custom ROCM GPU"
|
||||||
|
fi
|
||||||
BUILD_DIR="../build/linux/${ARCH}/rocm${ROCM_VARIANT}"
|
BUILD_DIR="../build/linux/${ARCH}/rocm${ROCM_VARIANT}"
|
||||||
EXTRA_LIBS="-L${ROCM_PATH}/lib -L/opt/amdgpu/lib/x86_64-linux-gnu/ -Wl,-rpath,\$ORIGIN/../../rocm/ -lhipblas -lrocblas -lamdhip64 -lrocsolver -lamd_comgr -lhsa-runtime64 -lrocsparse -ldrm -ldrm_amdgpu"
|
EXTRA_LIBS="-L${ROCM_PATH}/lib -L/opt/amdgpu/lib/x86_64-linux-gnu/ -Wl,-rpath,\$ORIGIN/../../rocm/ -lhipblas -lrocblas -lamdhip64 -lrocsolver -lamd_comgr -lhsa-runtime64 -lrocsparse -ldrm -ldrm_amdgpu"
|
||||||
build
|
build
|
||||||
|
|
|
@ -26,15 +26,25 @@ function amdGPUs {
|
||||||
$GPU_LIST -join ';'
|
$GPU_LIST -join ';'
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
function init_vars {
|
function init_vars {
|
||||||
|
if (!$script:SRC_DIR) {
|
||||||
$script:SRC_DIR = $(resolve-path "..\..\")
|
$script:SRC_DIR = $(resolve-path "..\..\")
|
||||||
|
}
|
||||||
|
if (!$script:llamacppDir) {
|
||||||
$script:llamacppDir = "../llama.cpp"
|
$script:llamacppDir = "../llama.cpp"
|
||||||
|
}
|
||||||
|
if (!$script:cmakeTargets) {
|
||||||
|
$script:cmakeTargets = @("ollama_llama_server")
|
||||||
|
}
|
||||||
$script:cmakeDefs = @(
|
$script:cmakeDefs = @(
|
||||||
"-DBUILD_SHARED_LIBS=on",
|
"-DBUILD_SHARED_LIBS=on",
|
||||||
"-DLLAMA_NATIVE=off"
|
"-DLLAMA_NATIVE=off"
|
||||||
)
|
)
|
||||||
$script:cmakeTargets = @("ollama_llama_server")
|
$script:commonCpuDefs = @("-DCMAKE_POSITION_INDEPENDENT_CODE=on")
|
||||||
$script:ARCH = "amd64" # arm not yet supported.
|
$script:ARCH = $Env:PROCESSOR_ARCHITECTURE.ToLower()
|
||||||
|
$script:DIST_BASE = "${script:SRC_DIR}\dist\windows-${script:ARCH}\ollama_runners"
|
||||||
|
md "$script:DIST_BASE" -ea 0 > $null
|
||||||
if ($env:CGO_CFLAGS -contains "-g") {
|
if ($env:CGO_CFLAGS -contains "-g") {
|
||||||
$script:cmakeDefs += @("-DCMAKE_VERBOSE_MAKEFILE=on", "-DLLAMA_SERVER_VERBOSE=on", "-DCMAKE_BUILD_TYPE=RelWithDebInfo")
|
$script:cmakeDefs += @("-DCMAKE_VERBOSE_MAKEFILE=on", "-DLLAMA_SERVER_VERBOSE=on", "-DCMAKE_BUILD_TYPE=RelWithDebInfo")
|
||||||
$script:config = "RelWithDebInfo"
|
$script:config = "RelWithDebInfo"
|
||||||
|
@ -55,7 +65,6 @@ function init_vars {
|
||||||
} else {
|
} else {
|
||||||
$script:CUDA_LIB_DIR=$env:CUDA_LIB_DIR
|
$script:CUDA_LIB_DIR=$env:CUDA_LIB_DIR
|
||||||
}
|
}
|
||||||
$script:GZIP=(get-command -ea 'silentlycontinue' gzip).path
|
|
||||||
$script:DUMPBIN=(get-command -ea 'silentlycontinue' dumpbin).path
|
$script:DUMPBIN=(get-command -ea 'silentlycontinue' dumpbin).path
|
||||||
if ($null -eq $env:CMAKE_CUDA_ARCHITECTURES) {
|
if ($null -eq $env:CMAKE_CUDA_ARCHITECTURES) {
|
||||||
$script:CMAKE_CUDA_ARCHITECTURES="50;52;61;70;75;80"
|
$script:CMAKE_CUDA_ARCHITECTURES="50;52;61;70;75;80"
|
||||||
|
@ -134,21 +143,18 @@ function sign {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
function compress {
|
function install {
|
||||||
if ($script:GZIP -eq $null) {
|
write-host "Installing binaries to dist dir ${script:distDir}"
|
||||||
write-host "gzip not installed, not compressing files"
|
mkdir ${script:distDir} -ErrorAction SilentlyContinue
|
||||||
return
|
|
||||||
}
|
|
||||||
write-host "Compressing binaries..."
|
|
||||||
$binaries = dir "${script:buildDir}/bin/*.exe"
|
$binaries = dir "${script:buildDir}/bin/*.exe"
|
||||||
foreach ($file in $binaries) {
|
foreach ($file in $binaries) {
|
||||||
& "$script:GZIP" --best -f $file
|
copy-item -Path $file -Destination ${script:distDir} -Force
|
||||||
}
|
}
|
||||||
|
|
||||||
write-host "Compressing dlls..."
|
write-host "Installing dlls to dist dir ${script:distDir}"
|
||||||
$dlls = dir "${script:buildDir}/bin/*.dll"
|
$dlls = dir "${script:buildDir}/bin/*.dll"
|
||||||
foreach ($file in $dlls) {
|
foreach ($file in $dlls) {
|
||||||
& "$script:GZIP" --best -f $file
|
copy-item -Path $file -Destination ${script:distDir} -Force
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -169,28 +175,25 @@ function cleanup {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
init_vars
|
|
||||||
git_module_setup
|
|
||||||
apply_patches
|
|
||||||
|
|
||||||
# -DLLAMA_AVX -- 2011 Intel Sandy Bridge & AMD Bulldozer
|
# -DLLAMA_AVX -- 2011 Intel Sandy Bridge & AMD Bulldozer
|
||||||
# -DLLAMA_AVX2 -- 2013 Intel Haswell & 2015 AMD Excavator / 2017 AMD Zen
|
# -DLLAMA_AVX2 -- 2013 Intel Haswell & 2015 AMD Excavator / 2017 AMD Zen
|
||||||
# -DLLAMA_FMA (FMA3) -- 2013 Intel Haswell & 2012 AMD Piledriver
|
# -DLLAMA_FMA (FMA3) -- 2013 Intel Haswell & 2012 AMD Piledriver
|
||||||
|
|
||||||
$script:commonCpuDefs = @("-DCMAKE_POSITION_INDEPENDENT_CODE=on")
|
|
||||||
|
|
||||||
if ($null -eq ${env:OLLAMA_SKIP_CPU_GENERATE}) {
|
function build_static() {
|
||||||
|
if ((-not "${env:OLLAMA_SKIP_STATIC_GENERATE}") -and ((-not "${env:OLLAMA_CPU_TARGET}") -or ("${env:OLLAMA_CPU_TARGET}" -eq "static"))) {
|
||||||
# GCC build for direct linking into the Go binary
|
# GCC build for direct linking into the Go binary
|
||||||
init_vars
|
init_vars
|
||||||
# cmake will silently fallback to msvc compilers if mingw isn't in the path, so detect and fail fast
|
# cmake will silently fallback to msvc compilers if mingw isn't in the path, so detect and fail fast
|
||||||
# as we need this to be compiled by gcc for golang to be able to link with itx
|
# as we need this to be compiled by gcc for golang to be able to link with itx
|
||||||
write-host "Checking for MinGW..."
|
write-host "Checking for MinGW..."
|
||||||
# error action ensures we exit on failure
|
# error action ensures we exit on failure
|
||||||
get-command gcc
|
get-command gcc
|
||||||
get-command mingw32-make
|
get-command mingw32-make
|
||||||
$script:cmakeTargets = @("llama", "ggml")
|
$oldTargets = $script:cmakeTargets
|
||||||
$script:cmakeDefs = @(
|
$script:cmakeTargets = @("llama", "ggml")
|
||||||
|
$script:cmakeDefs = @(
|
||||||
"-G", "MinGW Makefiles"
|
"-G", "MinGW Makefiles"
|
||||||
"-DCMAKE_C_COMPILER=gcc.exe",
|
"-DCMAKE_C_COMPILER=gcc.exe",
|
||||||
"-DCMAKE_CXX_COMPILER=g++.exe",
|
"-DCMAKE_CXX_COMPILER=g++.exe",
|
||||||
|
@ -201,39 +204,63 @@ $script:cmakeDefs = @(
|
||||||
"-DLLAMA_AVX512=off",
|
"-DLLAMA_AVX512=off",
|
||||||
"-DLLAMA_F16C=off",
|
"-DLLAMA_F16C=off",
|
||||||
"-DLLAMA_FMA=off")
|
"-DLLAMA_FMA=off")
|
||||||
$script:buildDir="../build/windows/${script:ARCH}_static"
|
$script:buildDir="../build/windows/${script:ARCH}_static"
|
||||||
write-host "Building static library"
|
write-host "Building static library"
|
||||||
build
|
build
|
||||||
|
$script:cmakeTargets = $oldTargets
|
||||||
|
} else {
|
||||||
|
write-host "Skipping CPU generation step as requested"
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
# remaining llama.cpp builds use MSVC
|
function build_cpu($gen_arch) {
|
||||||
|
if ((-not "${env:OLLAMA_SKIP_CPU_GENERATE}" ) -and ((-not "${env:OLLAMA_CPU_TARGET}") -or ("${env:OLLAMA_CPU_TARGET}" -eq "cpu"))) {
|
||||||
|
# remaining llama.cpp builds use MSVC
|
||||||
init_vars
|
init_vars
|
||||||
$script:cmakeDefs = $script:commonCpuDefs + @("-A", "x64", "-DLLAMA_AVX=off", "-DLLAMA_AVX2=off", "-DLLAMA_AVX512=off", "-DLLAMA_FMA=off", "-DLLAMA_F16C=off") + $script:cmakeDefs
|
$script:cmakeDefs = $script:commonCpuDefs + @("-A", $gen_arch, "-DLLAMA_AVX=off", "-DLLAMA_AVX2=off", "-DLLAMA_AVX512=off", "-DLLAMA_FMA=off", "-DLLAMA_F16C=off") + $script:cmakeDefs
|
||||||
$script:buildDir="../build/windows/${script:ARCH}/cpu"
|
$script:buildDir="../build/windows/${script:ARCH}/cpu"
|
||||||
|
$script:distDir="$script:DIST_BASE\cpu"
|
||||||
write-host "Building LCD CPU"
|
write-host "Building LCD CPU"
|
||||||
build
|
build
|
||||||
sign
|
sign
|
||||||
compress
|
install
|
||||||
|
} else {
|
||||||
|
write-host "Skipping CPU generation step as requested"
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
function build_cpu_avx() {
|
||||||
|
if ((-not "${env:OLLAMA_SKIP_CPU_GENERATE}" ) -and ((-not "${env:OLLAMA_CPU_TARGET}") -or ("${env:OLLAMA_CPU_TARGET}" -eq "cpu_avx"))) {
|
||||||
init_vars
|
init_vars
|
||||||
$script:cmakeDefs = $script:commonCpuDefs + @("-A", "x64", "-DLLAMA_AVX=on", "-DLLAMA_AVX2=off", "-DLLAMA_AVX512=off", "-DLLAMA_FMA=off", "-DLLAMA_F16C=off") + $script:cmakeDefs
|
$script:cmakeDefs = $script:commonCpuDefs + @("-A", "x64", "-DLLAMA_AVX=on", "-DLLAMA_AVX2=off", "-DLLAMA_AVX512=off", "-DLLAMA_FMA=off", "-DLLAMA_F16C=off") + $script:cmakeDefs
|
||||||
$script:buildDir="../build/windows/${script:ARCH}/cpu_avx"
|
$script:buildDir="../build/windows/${script:ARCH}/cpu_avx"
|
||||||
|
$script:distDir="$script:DIST_BASE\cpu_avx"
|
||||||
write-host "Building AVX CPU"
|
write-host "Building AVX CPU"
|
||||||
build
|
build
|
||||||
sign
|
sign
|
||||||
compress
|
install
|
||||||
|
} else {
|
||||||
|
write-host "Skipping CPU AVX generation step as requested"
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
function build_cpu_avx2() {
|
||||||
|
if ((-not "${env:OLLAMA_SKIP_CPU_GENERATE}" ) -and ((-not "${env:OLLAMA_CPU_TARGET}") -or ("${env:OLLAMA_CPU_TARGET}" -eq "cpu_avx2"))) {
|
||||||
init_vars
|
init_vars
|
||||||
$script:cmakeDefs = $script:commonCpuDefs + @("-A", "x64", "-DLLAMA_AVX=on", "-DLLAMA_AVX2=on", "-DLLAMA_AVX512=off", "-DLLAMA_FMA=on", "-DLLAMA_F16C=on") + $script:cmakeDefs
|
$script:cmakeDefs = $script:commonCpuDefs + @("-A", "x64", "-DLLAMA_AVX=on", "-DLLAMA_AVX2=on", "-DLLAMA_AVX512=off", "-DLLAMA_FMA=on", "-DLLAMA_F16C=on") + $script:cmakeDefs
|
||||||
$script:buildDir="../build/windows/${script:ARCH}/cpu_avx2"
|
$script:buildDir="../build/windows/${script:ARCH}/cpu_avx2"
|
||||||
|
$script:distDir="$script:DIST_BASE\cpu_avx2"
|
||||||
write-host "Building AVX2 CPU"
|
write-host "Building AVX2 CPU"
|
||||||
build
|
build
|
||||||
sign
|
sign
|
||||||
compress
|
install
|
||||||
} else {
|
} else {
|
||||||
write-host "Skipping CPU generation step as requested"
|
write-host "Skipping CPU AVX2 generation step as requested"
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
if ($null -ne $script:CUDA_LIB_DIR) {
|
function build_cuda() {
|
||||||
|
if ((-not "${env:OLLAMA_SKIP_CUDA_GENERATE}") -and ("${script:CUDA_LIB_DIR}")) {
|
||||||
# Then build cuda as a dynamically loaded library
|
# Then build cuda as a dynamically loaded library
|
||||||
$nvcc = "$script:CUDA_LIB_DIR\nvcc.exe"
|
$nvcc = "$script:CUDA_LIB_DIR\nvcc.exe"
|
||||||
$script:CUDA_VERSION=(get-item ($nvcc | split-path | split-path)).Basename
|
$script:CUDA_VERSION=(get-item ($nvcc | split-path | split-path)).Basename
|
||||||
|
@ -242,13 +269,28 @@ if ($null -ne $script:CUDA_LIB_DIR) {
|
||||||
}
|
}
|
||||||
init_vars
|
init_vars
|
||||||
$script:buildDir="../build/windows/${script:ARCH}/cuda$script:CUDA_VARIANT"
|
$script:buildDir="../build/windows/${script:ARCH}/cuda$script:CUDA_VARIANT"
|
||||||
|
$script:distDir="$script:DIST_BASE\cuda$script:CUDA_VARIANT"
|
||||||
$script:cmakeDefs += @("-A", "x64", "-DLLAMA_CUDA=ON", "-DLLAMA_AVX=on", "-DLLAMA_AVX2=off", "-DCUDAToolkit_INCLUDE_DIR=$script:CUDA_INCLUDE_DIR", "-DCMAKE_CUDA_ARCHITECTURES=${script:CMAKE_CUDA_ARCHITECTURES}")
|
$script:cmakeDefs += @("-A", "x64", "-DLLAMA_CUDA=ON", "-DLLAMA_AVX=on", "-DLLAMA_AVX2=off", "-DCUDAToolkit_INCLUDE_DIR=$script:CUDA_INCLUDE_DIR", "-DCMAKE_CUDA_ARCHITECTURES=${script:CMAKE_CUDA_ARCHITECTURES}")
|
||||||
|
if ($null -ne $env:OLLAMA_CUSTOM_CUDA_DEFS) {
|
||||||
|
write-host "OLLAMA_CUSTOM_CUDA_DEFS=`"${env:OLLAMA_CUSTOM_CUDA_DEFS}`""
|
||||||
|
$script:cmakeDefs +=@("${env:OLLAMA_CUSTOM_CUDA_DEFS}")
|
||||||
|
write-host "building custom CUDA GPU"
|
||||||
|
}
|
||||||
build
|
build
|
||||||
sign
|
sign
|
||||||
compress
|
install
|
||||||
|
|
||||||
|
write-host "copying CUDA dependencies to ${script:SRC_DIR}\dist\windows-${script:ARCH}\"
|
||||||
|
cp "${script:CUDA_LIB_DIR}\cudart64_*.dll" "${script:SRC_DIR}\dist\windows-${script:ARCH}\"
|
||||||
|
cp "${script:CUDA_LIB_DIR}\cublas64_*.dll" "${script:SRC_DIR}\dist\windows-${script:ARCH}\"
|
||||||
|
cp "${script:CUDA_LIB_DIR}\cublasLt64_*.dll" "${script:SRC_DIR}\dist\windows-${script:ARCH}\"
|
||||||
|
} else {
|
||||||
|
write-host "Skipping CUDA generation step"
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
if ($null -ne $env:HIP_PATH) {
|
function build_rocm() {
|
||||||
|
if ((-not "${env:OLLAMA_SKIP_ROCM_GENERATE}") -and ("${env:HIP_PATH}")) {
|
||||||
$script:ROCM_VERSION=(get-item $env:HIP_PATH).Basename
|
$script:ROCM_VERSION=(get-item $env:HIP_PATH).Basename
|
||||||
if ($null -ne $script:ROCM_VERSION) {
|
if ($null -ne $script:ROCM_VERSION) {
|
||||||
$script:ROCM_VARIANT="_v"+$script:ROCM_VERSION
|
$script:ROCM_VARIANT="_v"+$script:ROCM_VERSION
|
||||||
|
@ -256,6 +298,7 @@ if ($null -ne $env:HIP_PATH) {
|
||||||
|
|
||||||
init_vars
|
init_vars
|
||||||
$script:buildDir="../build/windows/${script:ARCH}/rocm$script:ROCM_VARIANT"
|
$script:buildDir="../build/windows/${script:ARCH}/rocm$script:ROCM_VARIANT"
|
||||||
|
$script:distDir="$script:DIST_BASE\rocm$script:ROCM_VARIANT"
|
||||||
$script:cmakeDefs += @(
|
$script:cmakeDefs += @(
|
||||||
"-G", "Ninja",
|
"-G", "Ninja",
|
||||||
"-DCMAKE_C_COMPILER=clang.exe",
|
"-DCMAKE_C_COMPILER=clang.exe",
|
||||||
|
@ -274,7 +317,11 @@ if ($null -ne $env:HIP_PATH) {
|
||||||
|
|
||||||
# We have to clobber the LIB var from the developer shell for clang to work properly
|
# We have to clobber the LIB var from the developer shell for clang to work properly
|
||||||
$env:LIB=""
|
$env:LIB=""
|
||||||
|
if ($null -ne $env:OLLAMA_CUSTOM_ROCM_DEFS) {
|
||||||
|
write-host "OLLAMA_CUSTOM_ROCM_DEFS=`"${env:OLLAMA_CUSTOM_ROCM_DEFS}`""
|
||||||
|
$script:cmakeDefs += @("${env:OLLAMA_CUSTOM_ROCM_DEFS}")
|
||||||
|
write-host "building custom ROCM GPU"
|
||||||
|
}
|
||||||
write-host "Building ROCm"
|
write-host "Building ROCm"
|
||||||
build
|
build
|
||||||
# Ninja doesn't prefix with config name
|
# Ninja doesn't prefix with config name
|
||||||
|
@ -283,9 +330,40 @@ if ($null -ne $env:HIP_PATH) {
|
||||||
& "$script:DUMPBIN" /dependents "${script:buildDir}/bin/ollama_llama_server.exe" | select-string ".dll"
|
& "$script:DUMPBIN" /dependents "${script:buildDir}/bin/ollama_llama_server.exe" | select-string ".dll"
|
||||||
}
|
}
|
||||||
sign
|
sign
|
||||||
compress
|
install
|
||||||
|
|
||||||
|
# Assumes v5.7, may need adjustments for v6
|
||||||
|
rm -ea 0 -recurse -force -path "${script:SRC_DIR}\dist\windows-${script:ARCH}\rocm\"
|
||||||
|
md "${script:SRC_DIR}\dist\windows-${script:ARCH}\rocm\rocblas\library\" -ea 0 > $null
|
||||||
|
cp "${env:HIP_PATH}\bin\hipblas.dll" "${script:SRC_DIR}\dist\windows-${script:ARCH}\rocm\"
|
||||||
|
cp "${env:HIP_PATH}\bin\rocblas.dll" "${script:SRC_DIR}\dist\windows-${script:ARCH}\rocm\"
|
||||||
|
# amdhip64.dll dependency comes from the driver and must be installed on the host to use AMD GPUs
|
||||||
|
cp "${env:HIP_PATH}\bin\rocblas\library\*" "${script:SRC_DIR}\dist\windows-${script:ARCH}\rocm\rocblas\library\"
|
||||||
|
} else {
|
||||||
|
write-host "Skipping ROCm generation step"
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
init_vars
|
||||||
|
if ($($args.count) -eq 0) {
|
||||||
|
git_module_setup
|
||||||
|
apply_patches
|
||||||
|
build_static
|
||||||
|
if ($script:ARCH -eq "arm64") {
|
||||||
|
build_cpu("ARM64")
|
||||||
|
} else { # amd64
|
||||||
|
build_cpu("x64")
|
||||||
|
build_cpu_avx
|
||||||
|
build_cpu_avx2
|
||||||
|
build_cuda
|
||||||
|
build_rocm
|
||||||
|
}
|
||||||
|
|
||||||
cleanup
|
cleanup
|
||||||
write-host "`ngo generate completed. LLM runners: $(get-childitem -path ${script:SRC_DIR}\llm\build\windows\${script:ARCH})"
|
write-host "`ngo generate completed. LLM runners: $(get-childitem -path $script:DIST_BASE)"
|
||||||
|
} else {
|
||||||
|
for ( $i = 0; $i -lt $args.count; $i++ ) {
|
||||||
|
write-host "performing $($args[$i])"
|
||||||
|
& $($args[$i])
|
||||||
|
}
|
||||||
|
}
|
107
llm/ggml.go
107
llm/ggml.go
|
@ -13,82 +13,6 @@ type GGML struct {
|
||||||
model
|
model
|
||||||
}
|
}
|
||||||
|
|
||||||
const (
|
|
||||||
fileTypeF32 uint32 = iota
|
|
||||||
fileTypeF16
|
|
||||||
fileTypeQ4_0
|
|
||||||
fileTypeQ4_1
|
|
||||||
fileTypeQ4_1_F16
|
|
||||||
fileTypeQ8_0 uint32 = iota + 2
|
|
||||||
fileTypeQ5_0
|
|
||||||
fileTypeQ5_1
|
|
||||||
fileTypeQ2_K
|
|
||||||
fileTypeQ3_K_S
|
|
||||||
fileTypeQ3_K_M
|
|
||||||
fileTypeQ3_K_L
|
|
||||||
fileTypeQ4_K_S
|
|
||||||
fileTypeQ4_K_M
|
|
||||||
fileTypeQ5_K_S
|
|
||||||
fileTypeQ5_K_M
|
|
||||||
fileTypeQ6_K
|
|
||||||
fileTypeIQ2_XXS
|
|
||||||
fileTypeIQ2_XS
|
|
||||||
fileTypeQ2_K_S
|
|
||||||
fileTypeQ3_K_XS
|
|
||||||
fileTypeIQ3_XXS
|
|
||||||
)
|
|
||||||
|
|
||||||
func fileType(fileType uint32) string {
|
|
||||||
switch fileType {
|
|
||||||
case fileTypeF32:
|
|
||||||
return "F32"
|
|
||||||
case fileTypeF16:
|
|
||||||
return "F16"
|
|
||||||
case fileTypeQ4_0:
|
|
||||||
return "Q4_0"
|
|
||||||
case fileTypeQ4_1:
|
|
||||||
return "Q4_1"
|
|
||||||
case fileTypeQ4_1_F16:
|
|
||||||
return "Q4_1_F16"
|
|
||||||
case fileTypeQ8_0:
|
|
||||||
return "Q8_0"
|
|
||||||
case fileTypeQ5_0:
|
|
||||||
return "Q5_0"
|
|
||||||
case fileTypeQ5_1:
|
|
||||||
return "Q5_1"
|
|
||||||
case fileTypeQ2_K:
|
|
||||||
return "Q2_K"
|
|
||||||
case fileTypeQ3_K_S:
|
|
||||||
return "Q3_K_S"
|
|
||||||
case fileTypeQ3_K_M:
|
|
||||||
return "Q3_K_M"
|
|
||||||
case fileTypeQ3_K_L:
|
|
||||||
return "Q3_K_L"
|
|
||||||
case fileTypeQ4_K_S:
|
|
||||||
return "Q4_K_S"
|
|
||||||
case fileTypeQ4_K_M:
|
|
||||||
return "Q4_K_M"
|
|
||||||
case fileTypeQ5_K_S:
|
|
||||||
return "Q5_K_S"
|
|
||||||
case fileTypeQ5_K_M:
|
|
||||||
return "Q5_K_M"
|
|
||||||
case fileTypeQ6_K:
|
|
||||||
return "Q6_K"
|
|
||||||
case fileTypeIQ2_XXS:
|
|
||||||
return "IQ2_XXS"
|
|
||||||
case fileTypeIQ2_XS:
|
|
||||||
return "IQ2_XS"
|
|
||||||
case fileTypeQ2_K_S:
|
|
||||||
return "Q2_K_S"
|
|
||||||
case fileTypeQ3_K_XS:
|
|
||||||
return "Q3_K_XS"
|
|
||||||
case fileTypeIQ3_XXS:
|
|
||||||
return "IQ3_XXS"
|
|
||||||
default:
|
|
||||||
return "unknown"
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
type model interface {
|
type model interface {
|
||||||
KV() KV
|
KV() KV
|
||||||
Tensors() Tensors
|
Tensors() Tensors
|
||||||
|
@ -121,12 +45,12 @@ func (kv KV) ParameterCount() uint64 {
|
||||||
return kv.u64("general.parameter_count")
|
return kv.u64("general.parameter_count")
|
||||||
}
|
}
|
||||||
|
|
||||||
func (kv KV) FileType() string {
|
func (kv KV) FileType() fileType {
|
||||||
if u64 := kv.u64("general.file_type"); u64 > 0 {
|
if u64 := kv.u64("general.file_type"); u64 > 0 {
|
||||||
return fileType(uint32(u64))
|
return fileType(uint32(u64))
|
||||||
}
|
}
|
||||||
|
|
||||||
return "unknown"
|
return fileTypeUnknown
|
||||||
}
|
}
|
||||||
|
|
||||||
func (kv KV) BlockCount() uint64 {
|
func (kv KV) BlockCount() uint64 {
|
||||||
|
@ -286,6 +210,23 @@ const (
|
||||||
|
|
||||||
var ErrUnsupportedFormat = errors.New("unsupported model format")
|
var ErrUnsupportedFormat = errors.New("unsupported model format")
|
||||||
|
|
||||||
|
func DetectGGMLType(b []byte) string {
|
||||||
|
switch binary.LittleEndian.Uint32(b[:4]) {
|
||||||
|
case FILE_MAGIC_GGML:
|
||||||
|
return "ggml"
|
||||||
|
case FILE_MAGIC_GGMF:
|
||||||
|
return "ggmf"
|
||||||
|
case FILE_MAGIC_GGJT:
|
||||||
|
return "ggjt"
|
||||||
|
case FILE_MAGIC_GGLA:
|
||||||
|
return "ggla"
|
||||||
|
case FILE_MAGIC_GGUF_LE, FILE_MAGIC_GGUF_BE:
|
||||||
|
return "gguf"
|
||||||
|
default:
|
||||||
|
return ""
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
func DecodeGGML(rs io.ReadSeeker) (*GGML, int64, error) {
|
func DecodeGGML(rs io.ReadSeeker) (*GGML, int64, error) {
|
||||||
var magic uint32
|
var magic uint32
|
||||||
if err := binary.Read(rs, binary.LittleEndian, &magic); err != nil {
|
if err := binary.Read(rs, binary.LittleEndian, &magic); err != nil {
|
||||||
|
@ -343,7 +284,15 @@ func (llm GGML) GraphSize(context, batch uint64) (partialOffload, fullOffload ui
|
||||||
4*batch*(embedding+vocab)+embedding*vocab*105/128,
|
4*batch*(embedding+vocab)+embedding*vocab*105/128,
|
||||||
)
|
)
|
||||||
|
|
||||||
if ffnGateWeight, ok := layers["0"]["ffn_gate.0.weight"]; ok {
|
if ffnGateExpsWeight, ok := layers["blk.0"]["ffn_gate_exps.weight"]; ok {
|
||||||
|
// mixtral 8x22b
|
||||||
|
ff := uint64(llm.KV()["llama.feed_forward_length"].(uint32))
|
||||||
|
partialOffload = max(
|
||||||
|
3*ffnGateExpsWeight.size()+4*batch*(2*ff+headsKV+embedding+context+embedding/heads*headsKV),
|
||||||
|
4*(context*batch*heads+context*embedding/heads*headsKV+batch*1024+embedding/heads*headsKV*batch),
|
||||||
|
)
|
||||||
|
} else if ffnGateWeight, ok := layers["blk.0"]["ffn_gate.0.weight"]; ok {
|
||||||
|
// mixtral 8x7b
|
||||||
ffnGateWeight1 := ffnGateWeight.Shape[1]
|
ffnGateWeight1 := ffnGateWeight.Shape[1]
|
||||||
fullOffload = 4 * batch * (2 + 3*embedding + context*(1+heads) + 2*headsKV + ffnGateWeight1)
|
fullOffload = 4 * batch * (2 + 3*embedding + context*(1+heads) + 2*headsKV + ffnGateWeight1)
|
||||||
partialOffload = max(
|
partialOffload = max(
|
||||||
|
|
16
llm/gguf.go
16
llm/gguf.go
|
@ -190,8 +190,6 @@ func (llm *gguf) Decode(rs io.ReadSeeker) error {
|
||||||
llm.kv[k] = v
|
llm.kv[k] = v
|
||||||
}
|
}
|
||||||
|
|
||||||
slog.Debug(fmt.Sprintf("general.architecture = %s", llm.kv["general.architecture"]))
|
|
||||||
|
|
||||||
// decode tensors
|
// decode tensors
|
||||||
for i := 0; uint64(i) < llm.numTensor(); i++ {
|
for i := 0; uint64(i) < llm.numTensor(); i++ {
|
||||||
name, err := readGGUFString(llm, rs)
|
name, err := readGGUFString(llm, rs)
|
||||||
|
@ -465,11 +463,13 @@ var ggufKVOrder = map[string][]string{
|
||||||
"llama.embedding_length",
|
"llama.embedding_length",
|
||||||
"llama.block_count",
|
"llama.block_count",
|
||||||
"llama.feed_forward_length",
|
"llama.feed_forward_length",
|
||||||
"llama.rope.dimension_count",
|
|
||||||
"llama.attention.head_count",
|
"llama.attention.head_count",
|
||||||
"llama.attention.head_count_kv",
|
"llama.attention.head_count_kv",
|
||||||
"llama.attention.layer_norm_rms_epsilon",
|
"llama.attention.layer_norm_rms_epsilon",
|
||||||
"llama.rope.freq_base",
|
"llama.rope.freq_base",
|
||||||
|
"llama.rope.dimension_count",
|
||||||
|
"llama.expert_count",
|
||||||
|
"llama.expert_used_count",
|
||||||
"gemma.context_length",
|
"gemma.context_length",
|
||||||
"gemma.embedding_length",
|
"gemma.embedding_length",
|
||||||
"gemma.block_count",
|
"gemma.block_count",
|
||||||
|
@ -577,6 +577,8 @@ func (llm *gguf) Encode(ws io.WriteSeeker, kv KV, tensors []Tensor) error {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
default:
|
||||||
|
return fmt.Errorf("improper type for '%s'", k)
|
||||||
}
|
}
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
|
@ -598,9 +600,11 @@ func (llm *gguf) Encode(ws io.WriteSeeker, kv KV, tensors []Tensor) error {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
dims := 1
|
dims := 0
|
||||||
if tensor.Shape[1] > 0 {
|
for cnt := 0; cnt < len(tensor.Shape); cnt++ {
|
||||||
dims = 2
|
if tensor.Shape[cnt] > 0 {
|
||||||
|
dims++
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
if err := binary.Write(ws, llm.ByteOrder, uint32(dims)); err != nil {
|
if err := binary.Write(ws, llm.ByteOrder, uint32(dims)); err != nil {
|
||||||
|
|
|
@ -1 +1 @@
|
||||||
Subproject commit 7593639ce335e8d7f89aa9a54d616951f273af60
|
Subproject commit 952d03dbead16e4dbdd1d3458486340673cc2465
|
57
llm/llm.go
57
llm/llm.go
|
@ -4,6 +4,7 @@ package llm
|
||||||
// #cgo darwin,arm64 LDFLAGS: ${SRCDIR}/build/darwin/arm64_static/libllama.a -lstdc++
|
// #cgo darwin,arm64 LDFLAGS: ${SRCDIR}/build/darwin/arm64_static/libllama.a -lstdc++
|
||||||
// #cgo darwin,amd64 LDFLAGS: ${SRCDIR}/build/darwin/x86_64_static/libllama.a -lstdc++
|
// #cgo darwin,amd64 LDFLAGS: ${SRCDIR}/build/darwin/x86_64_static/libllama.a -lstdc++
|
||||||
// #cgo windows,amd64 LDFLAGS: ${SRCDIR}/build/windows/amd64_static/libllama.a -static -lstdc++
|
// #cgo windows,amd64 LDFLAGS: ${SRCDIR}/build/windows/amd64_static/libllama.a -static -lstdc++
|
||||||
|
// #cgo windows,arm64 LDFLAGS: ${SRCDIR}/build/windows/arm64_static/libllama.a -static -lstdc++
|
||||||
// #cgo linux,amd64 LDFLAGS: ${SRCDIR}/build/linux/x86_64_static/libllama.a -lstdc++
|
// #cgo linux,amd64 LDFLAGS: ${SRCDIR}/build/linux/x86_64_static/libllama.a -lstdc++
|
||||||
// #cgo linux,arm64 LDFLAGS: ${SRCDIR}/build/linux/arm64_static/libllama.a -lstdc++
|
// #cgo linux,arm64 LDFLAGS: ${SRCDIR}/build/linux/arm64_static/libllama.a -lstdc++
|
||||||
// #include <stdlib.h>
|
// #include <stdlib.h>
|
||||||
|
@ -19,7 +20,7 @@ func SystemInfo() string {
|
||||||
return C.GoString(C.llama_print_system_info())
|
return C.GoString(C.llama_print_system_info())
|
||||||
}
|
}
|
||||||
|
|
||||||
func Quantize(infile, outfile, filetype string) error {
|
func Quantize(infile, outfile string, ftype fileType) error {
|
||||||
cinfile := C.CString(infile)
|
cinfile := C.CString(infile)
|
||||||
defer C.free(unsafe.Pointer(cinfile))
|
defer C.free(unsafe.Pointer(cinfile))
|
||||||
|
|
||||||
|
@ -28,58 +29,10 @@ func Quantize(infile, outfile, filetype string) error {
|
||||||
|
|
||||||
params := C.llama_model_quantize_default_params()
|
params := C.llama_model_quantize_default_params()
|
||||||
params.nthread = -1
|
params.nthread = -1
|
||||||
|
params.ftype = ftype.Value()
|
||||||
|
|
||||||
switch filetype {
|
if rc := C.llama_model_quantize(cinfile, coutfile, ¶ms); rc != 0 {
|
||||||
case "F32":
|
return fmt.Errorf("llama_model_quantize: %d", rc)
|
||||||
params.ftype = fileTypeF32
|
|
||||||
case "F16":
|
|
||||||
params.ftype = fileTypeF16
|
|
||||||
case "Q4_0":
|
|
||||||
params.ftype = fileTypeQ4_0
|
|
||||||
case "Q4_1":
|
|
||||||
params.ftype = fileTypeQ4_1
|
|
||||||
case "Q4_1_F16":
|
|
||||||
params.ftype = fileTypeQ4_1_F16
|
|
||||||
case "Q8_0":
|
|
||||||
params.ftype = fileTypeQ8_0
|
|
||||||
case "Q5_0":
|
|
||||||
params.ftype = fileTypeQ5_0
|
|
||||||
case "Q5_1":
|
|
||||||
params.ftype = fileTypeQ5_1
|
|
||||||
case "Q2_K":
|
|
||||||
params.ftype = fileTypeQ2_K
|
|
||||||
case "Q3_K_S":
|
|
||||||
params.ftype = fileTypeQ3_K_S
|
|
||||||
case "Q3_K_M":
|
|
||||||
params.ftype = fileTypeQ3_K_M
|
|
||||||
case "Q3_K_L":
|
|
||||||
params.ftype = fileTypeQ3_K_L
|
|
||||||
case "Q4_K_S":
|
|
||||||
params.ftype = fileTypeQ4_K_S
|
|
||||||
case "Q4_K_M":
|
|
||||||
params.ftype = fileTypeQ4_K_M
|
|
||||||
case "Q5_K_S":
|
|
||||||
params.ftype = fileTypeQ5_K_S
|
|
||||||
case "Q5_K_M":
|
|
||||||
params.ftype = fileTypeQ5_K_M
|
|
||||||
case "Q6_K":
|
|
||||||
params.ftype = fileTypeQ6_K
|
|
||||||
case "IQ2_XXS":
|
|
||||||
params.ftype = fileTypeIQ2_XXS
|
|
||||||
case "IQ2_XS":
|
|
||||||
params.ftype = fileTypeIQ2_XS
|
|
||||||
case "Q2_K_S":
|
|
||||||
params.ftype = fileTypeQ2_K_S
|
|
||||||
case "Q3_K_XS":
|
|
||||||
params.ftype = fileTypeQ3_K_XS
|
|
||||||
case "IQ3_XXS":
|
|
||||||
params.ftype = fileTypeIQ3_XXS
|
|
||||||
default:
|
|
||||||
return fmt.Errorf("unknown filetype: %s", filetype)
|
|
||||||
}
|
|
||||||
|
|
||||||
if retval := C.llama_model_quantize(cinfile, coutfile, ¶ms); retval != 0 {
|
|
||||||
return fmt.Errorf("llama_model_quantize: %d", retval)
|
|
||||||
}
|
}
|
||||||
|
|
||||||
return nil
|
return nil
|
||||||
|
|
|
@ -2,5 +2,5 @@ package llm
|
||||||
|
|
||||||
import "embed"
|
import "embed"
|
||||||
|
|
||||||
//go:embed build/windows/*/*/bin/*
|
// unused on windows
|
||||||
var libEmbed embed.FS
|
var libEmbed embed.FS
|
||||||
|
|
185
llm/memory.go
Normal file
185
llm/memory.go
Normal file
|
@ -0,0 +1,185 @@
|
||||||
|
package llm
|
||||||
|
|
||||||
|
import (
|
||||||
|
"fmt"
|
||||||
|
"log/slog"
|
||||||
|
|
||||||
|
"github.com/ollama/ollama/api"
|
||||||
|
"github.com/ollama/ollama/format"
|
||||||
|
"github.com/ollama/ollama/gpu"
|
||||||
|
"github.com/ollama/ollama/server/envconfig"
|
||||||
|
)
|
||||||
|
|
||||||
|
// This algorithm looks for a complete fit to determine if we need to unload other models
|
||||||
|
func PredictServerFit(allGpus gpu.GpuInfoList, ggml *GGML, adapters, projectors []string, opts api.Options) (bool, uint64) {
|
||||||
|
var estimatedVRAM uint64
|
||||||
|
if opts.NumCtx > int(ggml.KV().ContextLength()) {
|
||||||
|
slog.Warn("requested context length is greater than model max context length", "requested", opts.NumCtx, "model", ggml.KV().ContextLength())
|
||||||
|
opts.NumCtx = int(ggml.KV().ContextLength())
|
||||||
|
}
|
||||||
|
|
||||||
|
if opts.NumCtx < 4 {
|
||||||
|
opts.NumCtx = 4
|
||||||
|
}
|
||||||
|
|
||||||
|
// Split up the GPUs by type and try them
|
||||||
|
for _, gpus := range allGpus.ByLibrary() {
|
||||||
|
var layerCount int
|
||||||
|
layerCount, estimatedVRAM, _ = EstimateGPULayers(gpus, ggml, projectors, opts)
|
||||||
|
if opts.NumGPU < 0 {
|
||||||
|
if layerCount > 0 && layerCount >= int(ggml.KV().BlockCount()+1) {
|
||||||
|
return true, estimatedVRAM
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
if layerCount > 0 && layerCount >= opts.NumGPU {
|
||||||
|
return true, estimatedVRAM
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return false, estimatedVRAM
|
||||||
|
}
|
||||||
|
|
||||||
|
// Given a model and one or more GPU targets, predict how many layers and bytes we can load, and the total size
|
||||||
|
// The GPUs provided must all be the same Library
|
||||||
|
func EstimateGPULayers(gpus []gpu.GpuInfo, ggml *GGML, projectors []string, opts api.Options) (int, uint64, uint64) {
|
||||||
|
var memoryAvailable uint64
|
||||||
|
for _, info := range gpus {
|
||||||
|
memoryAvailable += info.FreeMemory
|
||||||
|
}
|
||||||
|
if envconfig.MaxVRAM > 0 {
|
||||||
|
memoryAvailable = envconfig.MaxVRAM
|
||||||
|
}
|
||||||
|
|
||||||
|
slog.Debug("evaluating", "library", gpus[0].Library, "gpu_count", len(gpus), "available", format.HumanBytes2(memoryAvailable))
|
||||||
|
|
||||||
|
// TODO - this is probably wrong, first GPU vs secondaries will have different overheads
|
||||||
|
memoryMinimum := gpus[0].MinimumMemory
|
||||||
|
|
||||||
|
for _, projector := range projectors {
|
||||||
|
memoryMinimum += projectorMemoryRequirements(projector)
|
||||||
|
|
||||||
|
// multimodal models require at least 2048 context
|
||||||
|
opts.NumCtx = max(opts.NumCtx, 2048)
|
||||||
|
}
|
||||||
|
|
||||||
|
// fp16 k,v = (1 (k) + 1 (v)) * sizeof(float16) * n_ctx * n_layer * n_embd / n_head * n_head_kv
|
||||||
|
var kv uint64 = 2 * 2 * uint64(opts.NumCtx) * ggml.KV().BlockCount() * ggml.KV().EmbeddingLength() / ggml.KV().HeadCount() * ggml.KV().HeadCountKV()
|
||||||
|
|
||||||
|
graphPartialOffload, graphFullOffload := ggml.GraphSize(uint64(opts.NumCtx), uint64(min(opts.NumCtx, opts.NumBatch)))
|
||||||
|
if graphPartialOffload == 0 {
|
||||||
|
graphPartialOffload = ggml.KV().GQA() * kv / 6
|
||||||
|
}
|
||||||
|
|
||||||
|
if graphFullOffload == 0 {
|
||||||
|
graphFullOffload = graphPartialOffload
|
||||||
|
}
|
||||||
|
|
||||||
|
graphFullOffload *= uint64(len(gpus))
|
||||||
|
graphPartialOffload *= uint64(len(gpus))
|
||||||
|
|
||||||
|
// on metal there's no partial offload overhead
|
||||||
|
if gpus[0].Library == "metal" {
|
||||||
|
graphPartialOffload = graphFullOffload
|
||||||
|
}
|
||||||
|
|
||||||
|
layers := ggml.Tensors().Layers()
|
||||||
|
|
||||||
|
// memoryRequiredTotal represents the memory required for full GPU offloading (all layers)
|
||||||
|
memoryRequiredTotal := memoryMinimum + graphFullOffload + layers["blk.0"].size()
|
||||||
|
|
||||||
|
// memoryRequiredPartial represents the memory required for partial GPU offloading (n > 0, n < layers)
|
||||||
|
memoryRequiredPartial := memoryMinimum + graphPartialOffload + layers["blk.0"].size()
|
||||||
|
|
||||||
|
var memoryLayerOutput uint64
|
||||||
|
if layer, ok := layers["output_norm"]; ok {
|
||||||
|
memoryLayerOutput += layer.size()
|
||||||
|
}
|
||||||
|
|
||||||
|
if layer, ok := layers["output"]; ok {
|
||||||
|
memoryLayerOutput += layer.size()
|
||||||
|
} else if layer, ok := layers["token_embd"]; ok {
|
||||||
|
memoryLayerOutput += layer.size()
|
||||||
|
}
|
||||||
|
|
||||||
|
if gpus[0].Library == "metal" && opts.UseMMap {
|
||||||
|
// memory is preallocated for output tensors
|
||||||
|
memoryRequiredTotal += memoryLayerOutput
|
||||||
|
memoryRequiredPartial += memoryLayerOutput
|
||||||
|
}
|
||||||
|
|
||||||
|
var layerCount int
|
||||||
|
for i := 0; i < int(ggml.KV().BlockCount()); i++ {
|
||||||
|
memoryLayer := layers[fmt.Sprintf("blk.%d", i)].size()
|
||||||
|
|
||||||
|
// KV is proportional to the number of layers
|
||||||
|
memoryLayer += kv / ggml.KV().BlockCount()
|
||||||
|
|
||||||
|
memoryRequiredTotal += memoryLayer
|
||||||
|
if memoryAvailable > memoryRequiredPartial+memoryLayer {
|
||||||
|
memoryRequiredPartial += memoryLayer
|
||||||
|
layerCount++
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if gpus[0].Library != "metal" || !opts.UseMMap {
|
||||||
|
// memory was not preallocated for output tensors
|
||||||
|
memoryRequiredTotal += memoryLayerOutput
|
||||||
|
}
|
||||||
|
|
||||||
|
if memoryAvailable > memoryRequiredTotal {
|
||||||
|
layerCount = int(ggml.KV().BlockCount()) + 1
|
||||||
|
memoryRequiredPartial = memoryRequiredTotal
|
||||||
|
}
|
||||||
|
|
||||||
|
memoryWeights := memoryRequiredTotal - memoryMinimum - graphFullOffload - kv
|
||||||
|
|
||||||
|
slog.Info(
|
||||||
|
"offload to gpu",
|
||||||
|
slog.Group(
|
||||||
|
"layers",
|
||||||
|
// actual number of layers offloaded
|
||||||
|
"real", opts.NumGPU,
|
||||||
|
// estimated number of layers that can be offloaded
|
||||||
|
"estimate", layerCount,
|
||||||
|
),
|
||||||
|
slog.Group(
|
||||||
|
"memory",
|
||||||
|
// memory available for offloading
|
||||||
|
"available", format.HumanBytes2(memoryAvailable),
|
||||||
|
slog.Group(
|
||||||
|
"required",
|
||||||
|
// memory required for full offloading
|
||||||
|
"full", format.HumanBytes2(memoryRequiredTotal),
|
||||||
|
// memory required to offload layers.estimate layers
|
||||||
|
"partial", format.HumanBytes2(memoryRequiredPartial),
|
||||||
|
// memory of KV cache
|
||||||
|
"kv", format.HumanBytes2(kv),
|
||||||
|
),
|
||||||
|
slog.Group(
|
||||||
|
"weights",
|
||||||
|
// memory of the weights
|
||||||
|
"total", format.HumanBytes2(memoryWeights),
|
||||||
|
// memory of repeating layers
|
||||||
|
"repeating", format.HumanBytes2(memoryWeights-memoryLayerOutput),
|
||||||
|
// memory of non-repeating layers
|
||||||
|
"nonrepeating", format.HumanBytes2(memoryLayerOutput),
|
||||||
|
),
|
||||||
|
slog.Group(
|
||||||
|
"graph",
|
||||||
|
// memory of graph when fully offloaded
|
||||||
|
"full", format.HumanBytes2(graphFullOffload),
|
||||||
|
// memory of graph when not fully offloaded
|
||||||
|
"partial", format.HumanBytes2(graphPartialOffload),
|
||||||
|
),
|
||||||
|
),
|
||||||
|
)
|
||||||
|
if gpus[0].Library == "cpu" {
|
||||||
|
return 0, 0, memoryRequiredTotal
|
||||||
|
}
|
||||||
|
if memoryRequiredPartial > memoryAvailable {
|
||||||
|
slog.Debug("insufficient VRAM to load any model layers")
|
||||||
|
return 0, 0, memoryRequiredTotal
|
||||||
|
}
|
||||||
|
|
||||||
|
return layerCount, memoryRequiredPartial, memoryRequiredTotal
|
||||||
|
}
|
12
llm/patches/02-clip-log.diff
Normal file
12
llm/patches/02-clip-log.diff
Normal file
|
@ -0,0 +1,12 @@
|
||||||
|
diff --git a/examples/llava/clip.cpp b/examples/llava/clip.cpp
|
||||||
|
index e431c7f7..f077e688 100644
|
||||||
|
--- a/examples/llava/clip.cpp
|
||||||
|
+++ b/examples/llava/clip.cpp
|
||||||
|
@@ -3,6 +3,7 @@
|
||||||
|
// I'll gradually clean and extend it
|
||||||
|
// Note: Even when using identical normalized image inputs (see normalize_image_u8_to_f32()) we have a significant difference in resulting embeddings compared to pytorch
|
||||||
|
#include "clip.h"
|
||||||
|
+#include "common.h"
|
||||||
|
#include "log.h"
|
||||||
|
#include "ggml.h"
|
||||||
|
#include "ggml-alloc.h"
|
45
llm/patches/04-metal.diff
Normal file
45
llm/patches/04-metal.diff
Normal file
|
@ -0,0 +1,45 @@
|
||||||
|
diff --git a/ggml-metal.m b/ggml-metal.m
|
||||||
|
index 0207b787..b5e9884b 100644
|
||||||
|
--- a/ggml-metal.m
|
||||||
|
+++ b/ggml-metal.m
|
||||||
|
@@ -1396,27 +1396,23 @@ static enum ggml_status ggml_metal_graph_compute(
|
||||||
|
// to the matrix-vector kernel
|
||||||
|
int ne11_mm_min = 1;
|
||||||
|
|
||||||
|
-#if 0
|
||||||
|
// the numbers below are measured on M2 Ultra for 7B and 13B models
|
||||||
|
// these numbers do not translate to other devices or model sizes
|
||||||
|
// TODO: need to find a better approach
|
||||||
|
- if ([ctx->device.name isEqualToString:@"Apple M2 Ultra"]) {
|
||||||
|
- switch (src0t) {
|
||||||
|
- case GGML_TYPE_F16: ne11_mm_min = 2; break;
|
||||||
|
- case GGML_TYPE_Q8_0: ne11_mm_min = 7; break;
|
||||||
|
- case GGML_TYPE_Q2_K: ne11_mm_min = 15; break;
|
||||||
|
- case GGML_TYPE_Q3_K: ne11_mm_min = 7; break;
|
||||||
|
- case GGML_TYPE_Q4_0:
|
||||||
|
- case GGML_TYPE_Q4_1: ne11_mm_min = 15; break;
|
||||||
|
- case GGML_TYPE_Q4_K: ne11_mm_min = 11; break;
|
||||||
|
- case GGML_TYPE_Q5_0: // not tested yet
|
||||||
|
- case GGML_TYPE_Q5_1: ne11_mm_min = 13; break; // not tested yet
|
||||||
|
- case GGML_TYPE_Q5_K: ne11_mm_min = 7; break;
|
||||||
|
- case GGML_TYPE_Q6_K: ne11_mm_min = 7; break;
|
||||||
|
- default: ne11_mm_min = 1; break;
|
||||||
|
- }
|
||||||
|
+ switch (src0t) {
|
||||||
|
+ case GGML_TYPE_F16: ne11_mm_min = 2; break;
|
||||||
|
+ case GGML_TYPE_Q8_0: ne11_mm_min = 7; break;
|
||||||
|
+ case GGML_TYPE_Q2_K: ne11_mm_min = 15; break;
|
||||||
|
+ case GGML_TYPE_Q3_K: ne11_mm_min = 7; break;
|
||||||
|
+ case GGML_TYPE_Q4_0:
|
||||||
|
+ case GGML_TYPE_Q4_1: ne11_mm_min = 15; break;
|
||||||
|
+ case GGML_TYPE_Q4_K: ne11_mm_min = 11; break;
|
||||||
|
+ case GGML_TYPE_Q5_0: // not tested yet
|
||||||
|
+ case GGML_TYPE_Q5_1: ne11_mm_min = 13; break; // not tested yet
|
||||||
|
+ case GGML_TYPE_Q5_K: ne11_mm_min = 7; break;
|
||||||
|
+ case GGML_TYPE_Q6_K: ne11_mm_min = 7; break;
|
||||||
|
+ default: ne11_mm_min = 1; break;
|
||||||
|
}
|
||||||
|
-#endif
|
||||||
|
|
||||||
|
// for now the matrix-matrix multiplication kernel only works on A14+/M1+ SoCs
|
||||||
|
// AMD GPU and older A-chips will reuse matrix-vector multiplication kernel
|
24
llm/patches/05-clip-fix.diff
Normal file
24
llm/patches/05-clip-fix.diff
Normal file
|
@ -0,0 +1,24 @@
|
||||||
|
diff --git a/examples/llava/clip.cpp b/examples/llava/clip.cpp
|
||||||
|
index e3c9bcd4..b43f892d 100644
|
||||||
|
--- a/examples/llava/clip.cpp
|
||||||
|
+++ b/examples/llava/clip.cpp
|
||||||
|
@@ -573,14 +573,16 @@ static ggml_cgraph * clip_image_build_graph(clip_ctx * ctx, const clip_image_f32
|
||||||
|
struct ggml_tensor * embeddings = inp;
|
||||||
|
if (ctx->has_class_embedding) {
|
||||||
|
embeddings = ggml_new_tensor_3d(ctx0, GGML_TYPE_F32, hidden_size, num_positions, batch_size);
|
||||||
|
+ }
|
||||||
|
+ ggml_set_name(embeddings, "embeddings");
|
||||||
|
+ ggml_set_input(embeddings);
|
||||||
|
+
|
||||||
|
+ if (ctx->has_class_embedding) {
|
||||||
|
embeddings = ggml_acc(ctx0, embeddings, model.class_embedding,
|
||||||
|
embeddings->nb[1], embeddings->nb[2], embeddings->nb[3], 0);
|
||||||
|
embeddings = ggml_acc(ctx0, embeddings, inp,
|
||||||
|
embeddings->nb[1], embeddings->nb[2], embeddings->nb[3], model.class_embedding->nb[1]);
|
||||||
|
}
|
||||||
|
- ggml_set_name(embeddings, "embeddings");
|
||||||
|
- ggml_set_input(embeddings);
|
||||||
|
-
|
||||||
|
|
||||||
|
struct ggml_tensor * positions = ggml_new_tensor_1d(ctx0, GGML_TYPE_I32, num_positions);
|
||||||
|
ggml_set_name(positions, "positions");
|
|
@ -9,6 +9,7 @@ import (
|
||||||
"log/slog"
|
"log/slog"
|
||||||
"os"
|
"os"
|
||||||
"path/filepath"
|
"path/filepath"
|
||||||
|
"runtime"
|
||||||
"strings"
|
"strings"
|
||||||
|
|
||||||
"golang.org/x/exp/slices"
|
"golang.org/x/exp/slices"
|
||||||
|
@ -17,7 +18,7 @@ import (
|
||||||
"github.com/ollama/ollama/gpu"
|
"github.com/ollama/ollama/gpu"
|
||||||
)
|
)
|
||||||
|
|
||||||
var errPayloadMissing = fmt.Errorf("expected payloads not included in this build of ollama")
|
var errPayloadMissing = errors.New("expected payloads not included in this build of ollama")
|
||||||
|
|
||||||
func Init() error {
|
func Init() error {
|
||||||
payloadsDir, err := gpu.PayloadsDir()
|
payloadsDir, err := gpu.PayloadsDir()
|
||||||
|
@ -25,6 +26,7 @@ func Init() error {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if runtime.GOOS != "windows" {
|
||||||
slog.Info("extracting embedded files", "dir", payloadsDir)
|
slog.Info("extracting embedded files", "dir", payloadsDir)
|
||||||
binGlob := "build/*/*/*/bin/*"
|
binGlob := "build/*/*/*/bin/*"
|
||||||
|
|
||||||
|
@ -33,6 +35,7 @@ func Init() error {
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return fmt.Errorf("extract binaries: %v", err)
|
return fmt.Errorf("extract binaries: %v", err)
|
||||||
}
|
}
|
||||||
|
}
|
||||||
|
|
||||||
var variants []string
|
var variants []string
|
||||||
for v := range availableServers() {
|
for v := range availableServers() {
|
||||||
|
@ -138,6 +141,23 @@ func serversForGpu(info gpu.GpuInfo) []string {
|
||||||
return servers
|
return servers
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Return the optimal server for this CPU architecture
|
||||||
|
func serverForCpu() string {
|
||||||
|
if runtime.GOOS == "darwin" && runtime.GOARCH == "arm64" {
|
||||||
|
return "metal"
|
||||||
|
}
|
||||||
|
variant := gpu.GetCPUVariant()
|
||||||
|
availableServers := availableServers()
|
||||||
|
if variant != "" {
|
||||||
|
for cmp := range availableServers {
|
||||||
|
if cmp == "cpu_"+variant {
|
||||||
|
return cmp
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return "cpu"
|
||||||
|
}
|
||||||
|
|
||||||
// extract extracts the embedded files to the target directory
|
// extract extracts the embedded files to the target directory
|
||||||
func extractFiles(targetDir string, glob string) error {
|
func extractFiles(targetDir string, glob string) error {
|
||||||
files, err := fs.Glob(libEmbed, glob)
|
files, err := fs.Glob(libEmbed, glob)
|
||||||
|
|
454
llm/server.go
454
llm/server.go
|
@ -21,21 +21,47 @@ import (
|
||||||
"strings"
|
"strings"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
|
"golang.org/x/sync/semaphore"
|
||||||
|
|
||||||
"github.com/ollama/ollama/api"
|
"github.com/ollama/ollama/api"
|
||||||
"github.com/ollama/ollama/format"
|
"github.com/ollama/ollama/format"
|
||||||
"github.com/ollama/ollama/gpu"
|
"github.com/ollama/ollama/gpu"
|
||||||
|
"github.com/ollama/ollama/server/envconfig"
|
||||||
)
|
)
|
||||||
|
|
||||||
// LlamaServer is an instance of the llama.cpp server
|
type LlamaServer interface {
|
||||||
type LlamaServer struct {
|
Ping(ctx context.Context) error
|
||||||
|
WaitUntilRunning(ctx context.Context) error
|
||||||
|
Completion(ctx context.Context, req CompletionRequest, fn func(CompletionResponse)) error
|
||||||
|
Embedding(ctx context.Context, prompt string) ([]float64, error)
|
||||||
|
Tokenize(ctx context.Context, content string) ([]int, error)
|
||||||
|
Detokenize(ctx context.Context, tokens []int) (string, error)
|
||||||
|
Close() error
|
||||||
|
EstimatedVRAM() uint64
|
||||||
|
}
|
||||||
|
|
||||||
|
// llmServer is an instance of the llama.cpp server
|
||||||
|
type llmServer struct {
|
||||||
port int
|
port int
|
||||||
cmd *exec.Cmd
|
cmd *exec.Cmd
|
||||||
done chan error // Channel to signal when the process exits
|
done chan error // Channel to signal when the process exits
|
||||||
status *StatusWriter
|
status *StatusWriter
|
||||||
options api.Options
|
options api.Options
|
||||||
|
|
||||||
|
// TODO - this should be broken down by GPU
|
||||||
|
estimatedVRAM uint64 // Estimated usage of VRAM by the loaded model
|
||||||
|
estimatedTotal uint64 // Total size of model
|
||||||
|
totalLayers uint64
|
||||||
|
gpuCount int
|
||||||
|
|
||||||
|
sem *semaphore.Weighted
|
||||||
}
|
}
|
||||||
|
|
||||||
func NewLlamaServer(model string, adapters, projectors []string, opts api.Options) (*LlamaServer, error) {
|
func LoadModel(model string) (*GGML, error) {
|
||||||
|
if _, err := os.Stat(model); err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
f, err := os.Open(model)
|
f, err := os.Open(model)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
|
@ -43,144 +69,69 @@ func NewLlamaServer(model string, adapters, projectors []string, opts api.Option
|
||||||
defer f.Close()
|
defer f.Close()
|
||||||
|
|
||||||
ggml, _, err := DecodeGGML(f)
|
ggml, _, err := DecodeGGML(f)
|
||||||
if err != nil {
|
return ggml, err
|
||||||
return nil, err
|
}
|
||||||
}
|
|
||||||
|
|
||||||
|
// NewLlamaServer will run a server for the given GPUs
|
||||||
|
// The gpu list must be a single family.
|
||||||
|
func NewLlamaServer(gpus gpu.GpuInfoList, model string, ggml *GGML, adapters, projectors []string, opts api.Options) (LlamaServer, error) {
|
||||||
|
var err error
|
||||||
if opts.NumCtx > int(ggml.KV().ContextLength()) {
|
if opts.NumCtx > int(ggml.KV().ContextLength()) {
|
||||||
slog.Warn("requested context length is greater than model max context length", "requested", opts.NumCtx, "model", ggml.KV().ContextLength())
|
slog.Warn("requested context length is greater than the model's training context window size", "requested", opts.NumCtx, "training size", ggml.KV().ContextLength())
|
||||||
opts.NumCtx = int(ggml.KV().ContextLength())
|
|
||||||
}
|
}
|
||||||
|
|
||||||
if opts.NumCtx < 4 {
|
if opts.NumCtx < 4 {
|
||||||
opts.NumCtx = 4
|
opts.NumCtx = 4
|
||||||
}
|
}
|
||||||
|
|
||||||
memoryAvailable, _ := gpu.CheckVRAM()
|
cpuRunner := ""
|
||||||
info := gpu.GetGPUInfo()
|
var estimatedVRAM uint64
|
||||||
|
var estimatedTotal uint64
|
||||||
|
var systemMemory uint64
|
||||||
|
gpuCount := len(gpus)
|
||||||
|
if (len(gpus) == 1 && gpus[0].Library == "cpu") || opts.NumGPU == 0 {
|
||||||
|
|
||||||
memoryMinimum := info.MinimumMemory
|
// TODO evaluate system memory to see if we should block the load, or force an unload of another CPU runner
|
||||||
for _, projector := range projectors {
|
|
||||||
memoryMinimum += projectorMemoryRequirements(projector)
|
|
||||||
|
|
||||||
// multimodal models require at least 2048 context
|
cpuRunner = serverForCpu()
|
||||||
opts.NumCtx = max(opts.NumCtx, 2048)
|
gpuCount = 0
|
||||||
}
|
} else {
|
||||||
|
if gpus[0].Library == "metal" {
|
||||||
// fp16 k,v = (1 (k) + 1 (v)) * sizeof(float16) * n_ctx * n_layer * n_embd / n_head * n_head_kv
|
memInfo, err := gpu.GetCPUMem()
|
||||||
var kv uint64 = 2 * 2 * uint64(opts.NumCtx) * ggml.KV().BlockCount() * ggml.KV().EmbeddingLength() / ggml.KV().HeadCount() * ggml.KV().HeadCountKV()
|
if err != nil {
|
||||||
|
slog.Error("failed to lookup system memory", "error", err)
|
||||||
graphPartialOffload, graphFullOffload := ggml.GraphSize(uint64(opts.NumCtx), uint64(min(opts.NumCtx, opts.NumBatch)))
|
} else {
|
||||||
if graphPartialOffload == 0 {
|
systemMemory = memInfo.TotalMemory
|
||||||
graphPartialOffload = ggml.KV().GQA() * kv / 6
|
slog.Debug("system memory", "total", format.HumanBytes2(systemMemory))
|
||||||
}
|
|
||||||
|
|
||||||
if graphFullOffload == 0 {
|
|
||||||
graphFullOffload = graphPartialOffload
|
|
||||||
}
|
|
||||||
|
|
||||||
graphFullOffload *= uint64(info.DeviceCount)
|
|
||||||
graphPartialOffload *= uint64(info.DeviceCount)
|
|
||||||
|
|
||||||
// memoryRequiredTotal represents the memory required for full GPU offloading (all layers)
|
|
||||||
memoryRequiredTotal := memoryMinimum + graphFullOffload
|
|
||||||
|
|
||||||
// memoryRequiredPartial represents the memory required for partial GPU offloading (n > 0, n < layers)
|
|
||||||
memoryRequiredPartial := memoryMinimum + graphPartialOffload
|
|
||||||
|
|
||||||
if info.Library != "metal" {
|
|
||||||
if memoryRequiredPartial > memoryAvailable {
|
|
||||||
info.Library = "cpu"
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
var layers int
|
||||||
|
layers, estimatedVRAM, estimatedTotal = EstimateGPULayers(gpus, ggml, projectors, opts)
|
||||||
|
|
||||||
var layerCount int
|
if gpus[0].Library == "metal" && estimatedVRAM > systemMemory {
|
||||||
layers := ggml.Tensors().Layers()
|
// disable partial offloading when model is greater than total system memory as this
|
||||||
for i := 0; i < int(ggml.KV().BlockCount()); i++ {
|
// can lead to locking up the system
|
||||||
memoryLayer := layers[fmt.Sprintf("blk.%d", i)].size()
|
|
||||||
|
|
||||||
// KV is proportional to the number of layers
|
|
||||||
memoryLayer += kv / ggml.KV().BlockCount()
|
|
||||||
|
|
||||||
memoryRequiredTotal += memoryLayer
|
|
||||||
if memoryAvailable > memoryRequiredPartial+memoryLayer {
|
|
||||||
memoryRequiredPartial += memoryLayer
|
|
||||||
layerCount++
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
var memoryLayerOutput uint64
|
|
||||||
for k, v := range layers {
|
|
||||||
if !strings.HasPrefix(k, "blk.") {
|
|
||||||
memoryLayerOutput += v.size()
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
memoryRequiredTotal += memoryLayerOutput
|
|
||||||
|
|
||||||
if info.Library == "metal" && memoryRequiredTotal > info.TotalMemory {
|
|
||||||
// disable partial offloading when model is greater than total system memory
|
|
||||||
opts.NumGPU = 0
|
opts.NumGPU = 0
|
||||||
} else if memoryAvailable > memoryRequiredTotal {
|
} else if opts.NumGPU < 0 && layers > 0 && gpus[0].Library != "cpu" {
|
||||||
layerCount = int(ggml.KV().BlockCount()) + 1
|
opts.NumGPU = layers
|
||||||
memoryRequiredPartial = memoryRequiredTotal
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
if opts.NumGPU < 0 {
|
// Loop through potential servers
|
||||||
opts.NumGPU = layerCount
|
finalErr := fmt.Errorf("no suitable llama servers found")
|
||||||
}
|
|
||||||
|
|
||||||
memoryWeights := memoryRequiredTotal - memoryMinimum - graphFullOffload - kv
|
|
||||||
|
|
||||||
slog.Info(
|
|
||||||
"offload to gpu",
|
|
||||||
slog.Group(
|
|
||||||
"layers",
|
|
||||||
// actual number of layers offloaded
|
|
||||||
"real", opts.NumGPU,
|
|
||||||
// estimated number of layers that can be offloaded
|
|
||||||
"estimate", layerCount,
|
|
||||||
),
|
|
||||||
slog.Group(
|
|
||||||
"memory",
|
|
||||||
// memory available for offloading
|
|
||||||
"available", format.HumanBytes2(memoryAvailable),
|
|
||||||
slog.Group(
|
|
||||||
"required",
|
|
||||||
// memory required for full offloading
|
|
||||||
"full", format.HumanBytes2(memoryRequiredTotal),
|
|
||||||
// memory required to offload layers.estimate layers
|
|
||||||
"partial", format.HumanBytes2(memoryRequiredPartial),
|
|
||||||
// memory of KV cache
|
|
||||||
"kv", format.HumanBytes2(kv),
|
|
||||||
),
|
|
||||||
slog.Group(
|
|
||||||
"weights",
|
|
||||||
// memory of the weights
|
|
||||||
"total", format.HumanBytes2(memoryWeights),
|
|
||||||
// memory of repeating layers
|
|
||||||
"repeating", format.HumanBytes2(memoryWeights-memoryLayerOutput),
|
|
||||||
// memory of non-repeating layers
|
|
||||||
"nonrepeating", format.HumanBytes2(memoryLayerOutput),
|
|
||||||
),
|
|
||||||
slog.Group(
|
|
||||||
"graph",
|
|
||||||
// memory of graph when fully offloaded
|
|
||||||
"full", format.HumanBytes2(graphFullOffload),
|
|
||||||
// memory of graph when not fully offloaded
|
|
||||||
"partial", format.HumanBytes2(graphPartialOffload),
|
|
||||||
),
|
|
||||||
),
|
|
||||||
)
|
|
||||||
|
|
||||||
if len(adapters) > 1 {
|
if len(adapters) > 1 {
|
||||||
return nil, errors.New("ollama supports only one lora adapter, but multiple were provided")
|
return nil, errors.New("ollama supports only one lora adapter, but multiple were provided")
|
||||||
}
|
}
|
||||||
|
|
||||||
availableServers := availableServers()
|
availableServers := availableServers()
|
||||||
servers := serversForGpu(info)
|
var servers []string
|
||||||
|
if cpuRunner != "" {
|
||||||
demandLib := os.Getenv("OLLAMA_LLM_LIBRARY")
|
servers = []string{cpuRunner}
|
||||||
|
} else {
|
||||||
|
servers = serversForGpu(gpus[0]) // All GPUs in the list are matching Library and Variant
|
||||||
|
}
|
||||||
|
demandLib := envconfig.LLMLibrary
|
||||||
if demandLib != "" {
|
if demandLib != "" {
|
||||||
serverPath := availableServers[demandLib]
|
serverPath := availableServers[demandLib]
|
||||||
if serverPath == "" {
|
if serverPath == "" {
|
||||||
|
@ -188,11 +139,15 @@ func NewLlamaServer(model string, adapters, projectors []string, opts api.Option
|
||||||
} else {
|
} else {
|
||||||
slog.Info("user override", "OLLAMA_LLM_LIBRARY", demandLib, "path", serverPath)
|
slog.Info("user override", "OLLAMA_LLM_LIBRARY", demandLib, "path", serverPath)
|
||||||
servers = []string{demandLib}
|
servers = []string{demandLib}
|
||||||
|
if strings.HasPrefix(demandLib, "cpu") {
|
||||||
|
// Omit the GPU flag to silence the warning
|
||||||
|
opts.NumGPU = -1
|
||||||
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
if len(servers) == 0 {
|
if len(servers) == 0 {
|
||||||
return nil, fmt.Errorf("no servers found for %v", info)
|
return nil, fmt.Errorf("no servers found for %v", gpus)
|
||||||
}
|
}
|
||||||
|
|
||||||
params := []string{
|
params := []string{
|
||||||
|
@ -201,7 +156,7 @@ func NewLlamaServer(model string, adapters, projectors []string, opts api.Option
|
||||||
"--batch-size", fmt.Sprintf("%d", opts.NumBatch),
|
"--batch-size", fmt.Sprintf("%d", opts.NumBatch),
|
||||||
"--embedding",
|
"--embedding",
|
||||||
}
|
}
|
||||||
if debug := os.Getenv("OLLAMA_DEBUG"); debug != "" {
|
if envconfig.Debug {
|
||||||
params = append(params, "--log-format", "json")
|
params = append(params, "--log-format", "json")
|
||||||
} else {
|
} else {
|
||||||
params = append(params, "--log-disable")
|
params = append(params, "--log-disable")
|
||||||
|
@ -211,7 +166,7 @@ func NewLlamaServer(model string, adapters, projectors []string, opts api.Option
|
||||||
params = append(params, "--n-gpu-layers", fmt.Sprintf("%d", opts.NumGPU))
|
params = append(params, "--n-gpu-layers", fmt.Sprintf("%d", opts.NumGPU))
|
||||||
}
|
}
|
||||||
|
|
||||||
if debug := os.Getenv("OLLAMA_DEBUG"); debug != "" {
|
if envconfig.Debug {
|
||||||
params = append(params, "--verbose")
|
params = append(params, "--verbose")
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -249,10 +204,30 @@ func NewLlamaServer(model string, adapters, projectors []string, opts api.Option
|
||||||
params = append(params, "--numa")
|
params = append(params, "--numa")
|
||||||
}
|
}
|
||||||
|
|
||||||
// Loop through potential servers
|
numParallel := envconfig.NumParallel
|
||||||
var finalErr error
|
|
||||||
|
// TODO (jmorganca): multimodal models don't support parallel yet
|
||||||
|
// see https://github.com/ollama/ollama/issues/4165
|
||||||
|
if len(projectors) > 0 {
|
||||||
|
numParallel = 1
|
||||||
|
slog.Warn("multimodal models don't support parallel requests yet")
|
||||||
|
}
|
||||||
|
|
||||||
|
params = append(params, "--parallel", fmt.Sprintf("%d", numParallel))
|
||||||
|
|
||||||
for i := 0; i < len(servers); i++ {
|
for i := 0; i < len(servers); i++ {
|
||||||
dir := availableServers[servers[i]]
|
dir := availableServers[servers[i]]
|
||||||
|
if dir == "" {
|
||||||
|
// Shouldn't happen
|
||||||
|
finalErr = fmt.Errorf("[%d] server %s not listed in available servers %v", i, servers[i], availableServers)
|
||||||
|
slog.Error("sever list inconsistent", "error", finalErr)
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
if strings.HasPrefix(servers[i], "cpu") {
|
||||||
|
// TODO if we tried a gpu runner first, and it failed, record the error and bubble that back up
|
||||||
|
gpuCount = 0
|
||||||
|
}
|
||||||
|
|
||||||
// Find an availableServers port, retry on each iterration in case the failure was a port conflict race
|
// Find an availableServers port, retry on each iterration in case the failure was a port conflict race
|
||||||
port := 0
|
port := 0
|
||||||
|
@ -273,12 +248,21 @@ func NewLlamaServer(model string, adapters, projectors []string, opts api.Option
|
||||||
if runtime.GOOS == "windows" {
|
if runtime.GOOS == "windows" {
|
||||||
pathEnv = "PATH"
|
pathEnv = "PATH"
|
||||||
}
|
}
|
||||||
// append the server directory to LD_LIBRARY_PATH/PATH
|
// prepend the server directory to LD_LIBRARY_PATH/PATH
|
||||||
libraryPaths := []string{dir}
|
libraryPaths := []string{dir}
|
||||||
|
|
||||||
if libraryPath, ok := os.LookupEnv(pathEnv); ok {
|
if libraryPath, ok := os.LookupEnv(pathEnv); ok {
|
||||||
// Append our runner directory to the path
|
// Append our runner directory to the path
|
||||||
// This will favor system libraries over our bundled library dependencies
|
// This will favor system libraries over our bundled library dependencies
|
||||||
libraryPaths = append(filepath.SplitList(libraryPath), libraryPaths...)
|
libraryPaths = append(libraryPaths, filepath.SplitList(libraryPath)...)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Note: we always put the dependency path first
|
||||||
|
// since this was the exact version we verified for AMD GPUs
|
||||||
|
// and we favor what the user had in their path
|
||||||
|
if gpus[0].DependencyPath != "" {
|
||||||
|
// TODO refine for multi-gpu support
|
||||||
|
libraryPaths = append([]string{gpus[0].DependencyPath}, libraryPaths...)
|
||||||
}
|
}
|
||||||
|
|
||||||
server := filepath.Join(dir, "ollama_llama_server")
|
server := filepath.Join(dir, "ollama_llama_server")
|
||||||
|
@ -286,21 +270,66 @@ func NewLlamaServer(model string, adapters, projectors []string, opts api.Option
|
||||||
server = server + ".exe"
|
server = server + ".exe"
|
||||||
}
|
}
|
||||||
|
|
||||||
s := &LlamaServer{
|
// Detect tmp cleaners wiping out the file
|
||||||
|
_, err := os.Stat(server)
|
||||||
|
if errors.Is(err, os.ErrNotExist) {
|
||||||
|
slog.Warn("llama server disappeared, reinitializing payloads", "path", server, "error", err)
|
||||||
|
err = Init()
|
||||||
|
if err != nil {
|
||||||
|
slog.Warn("failed to reinitialize payloads", "error", err)
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
s := &llmServer{
|
||||||
port: port,
|
port: port,
|
||||||
cmd: exec.Command(server, finalParams...),
|
cmd: exec.Command(server, finalParams...),
|
||||||
status: NewStatusWriter(os.Stderr),
|
status: NewStatusWriter(os.Stderr),
|
||||||
options: opts,
|
options: opts,
|
||||||
|
estimatedVRAM: estimatedVRAM,
|
||||||
|
estimatedTotal: estimatedTotal,
|
||||||
|
sem: semaphore.NewWeighted(int64(numParallel)),
|
||||||
|
totalLayers: ggml.KV().BlockCount() + 1,
|
||||||
|
gpuCount: gpuCount,
|
||||||
}
|
}
|
||||||
libEnv := fmt.Sprintf("%s=%s", pathEnv, strings.Join(libraryPaths, string(filepath.ListSeparator)))
|
|
||||||
slog.Debug(libEnv)
|
s.cmd.Env = os.Environ()
|
||||||
s.cmd.Env = append(os.Environ(), libEnv)
|
|
||||||
s.cmd.Stdout = os.Stdout
|
s.cmd.Stdout = os.Stdout
|
||||||
s.cmd.Stderr = s.status
|
s.cmd.Stderr = s.status
|
||||||
|
|
||||||
|
visibleDevicesEnv, visibleDevicesEnvVal := gpu.GpuInfoList(gpus).GetVisibleDevicesEnv()
|
||||||
|
pathEnvVal := strings.Join(libraryPaths, string(filepath.ListSeparator))
|
||||||
|
|
||||||
|
// Update or add the path and visible devices variable with our adjusted version
|
||||||
|
pathNeeded := true
|
||||||
|
devicesNeeded := visibleDevicesEnv != ""
|
||||||
|
for i := range s.cmd.Env {
|
||||||
|
cmp := strings.SplitN(s.cmd.Env[i], "=", 2)
|
||||||
|
if strings.EqualFold(cmp[0], pathEnv) {
|
||||||
|
s.cmd.Env[i] = pathEnv + "=" + pathEnvVal
|
||||||
|
pathNeeded = false
|
||||||
|
} else if devicesNeeded && strings.EqualFold(cmp[0], visibleDevicesEnv) {
|
||||||
|
s.cmd.Env[i] = visibleDevicesEnv + "=" + visibleDevicesEnvVal
|
||||||
|
devicesNeeded = false
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if pathNeeded {
|
||||||
|
s.cmd.Env = append(s.cmd.Env, pathEnv+"="+pathEnvVal)
|
||||||
|
}
|
||||||
|
if devicesNeeded {
|
||||||
|
s.cmd.Env = append(s.cmd.Env, visibleDevicesEnv+"="+visibleDevicesEnvVal)
|
||||||
|
}
|
||||||
|
|
||||||
slog.Info("starting llama server", "cmd", s.cmd.String())
|
slog.Info("starting llama server", "cmd", s.cmd.String())
|
||||||
|
// Log at debug as the environment is inherited and might contain sensitive information
|
||||||
|
slog.Debug("subprocess", "environment", s.cmd.Env)
|
||||||
|
|
||||||
if err = s.cmd.Start(); err != nil {
|
if err = s.cmd.Start(); err != nil {
|
||||||
|
// Detect permission denied and augment them essage about noexec
|
||||||
|
if errors.Is(err, os.ErrPermission) {
|
||||||
|
finalErr = fmt.Errorf("unable to start server %w. %s may have noexec set. Set OLLAMA_TMPDIR for server to a writable executable directory", err, dir)
|
||||||
|
continue
|
||||||
|
}
|
||||||
msg := ""
|
msg := ""
|
||||||
if s.status != nil && s.status.LastErrMsg != "" {
|
if s.status != nil && s.status.LastErrMsg != "" {
|
||||||
msg = s.status.LastErrMsg
|
msg = s.status.LastErrMsg
|
||||||
|
@ -310,12 +339,6 @@ func NewLlamaServer(model string, adapters, projectors []string, opts api.Option
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
|
|
||||||
// reap subprocess when it exits
|
|
||||||
go func() {
|
|
||||||
// Exit status managed via getServerStatus
|
|
||||||
_ = s.cmd.Wait()
|
|
||||||
}()
|
|
||||||
|
|
||||||
return s, nil
|
return s, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -347,12 +370,27 @@ type ServerStatus int
|
||||||
|
|
||||||
const ( // iota is reset to 0
|
const ( // iota is reset to 0
|
||||||
ServerStatusReady ServerStatus = iota
|
ServerStatusReady ServerStatus = iota
|
||||||
ServerStatusNoSlotsAvaialble
|
ServerStatusNoSlotsAvailable
|
||||||
ServerStatusLoadingModel
|
ServerStatusLoadingModel
|
||||||
ServerStatusNotResponding
|
ServerStatusNotResponding
|
||||||
ServerStatusError
|
ServerStatusError
|
||||||
)
|
)
|
||||||
|
|
||||||
|
func (s ServerStatus) ToString() string {
|
||||||
|
switch s {
|
||||||
|
case ServerStatusReady:
|
||||||
|
return "llm server ready"
|
||||||
|
case ServerStatusNoSlotsAvailable:
|
||||||
|
return "llm busy - no slots available"
|
||||||
|
case ServerStatusLoadingModel:
|
||||||
|
return "llm server loading model"
|
||||||
|
case ServerStatusNotResponding:
|
||||||
|
return "llm server not responding"
|
||||||
|
default:
|
||||||
|
return "llm server error"
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
type ServerStatusResp struct {
|
type ServerStatusResp struct {
|
||||||
Status string `json:"status"`
|
Status string `json:"status"`
|
||||||
SlotsIdle int `json:"slots_idle"`
|
SlotsIdle int `json:"slots_idle"`
|
||||||
|
@ -360,13 +398,17 @@ type ServerStatusResp struct {
|
||||||
Error string `json:"error"`
|
Error string `json:"error"`
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *LlamaServer) getServerStatus(ctx context.Context) (ServerStatus, error) {
|
func (s *llmServer) getServerStatus(ctx context.Context) (ServerStatus, error) {
|
||||||
// Fail fast if its exited
|
// Fail fast if its exited
|
||||||
if s.cmd.ProcessState != nil {
|
if s.cmd.ProcessState != nil {
|
||||||
msg := ""
|
msg := ""
|
||||||
if s.status != nil && s.status.LastErrMsg != "" {
|
if s.status != nil && s.status.LastErrMsg != "" {
|
||||||
msg = s.status.LastErrMsg
|
msg = s.status.LastErrMsg
|
||||||
}
|
}
|
||||||
|
if s.cmd.ProcessState.ExitCode() == -1 {
|
||||||
|
// Most likely a signal killed it, log some more details to try to help troubleshoot
|
||||||
|
slog.Warn("llama runner process no longer running", "sys", s.cmd.ProcessState.Sys(), "string", s.cmd.ProcessState.String())
|
||||||
|
}
|
||||||
return ServerStatusError, fmt.Errorf("llama runner process no longer running: %d %s", s.cmd.ProcessState.ExitCode(), msg)
|
return ServerStatusError, fmt.Errorf("llama runner process no longer running: %d %s", s.cmd.ProcessState.ExitCode(), msg)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -399,7 +441,7 @@ func (s *LlamaServer) getServerStatus(ctx context.Context) (ServerStatus, error)
|
||||||
case "ok":
|
case "ok":
|
||||||
return ServerStatusReady, nil
|
return ServerStatusReady, nil
|
||||||
case "no slot available":
|
case "no slot available":
|
||||||
return ServerStatusNoSlotsAvaialble, nil
|
return ServerStatusNoSlotsAvailable, nil
|
||||||
case "loading model":
|
case "loading model":
|
||||||
return ServerStatusLoadingModel, nil
|
return ServerStatusLoadingModel, nil
|
||||||
default:
|
default:
|
||||||
|
@ -407,7 +449,30 @@ func (s *LlamaServer) getServerStatus(ctx context.Context) (ServerStatus, error)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *LlamaServer) Ping(ctx context.Context) error {
|
// getServerStatusRetry will retry if ServerStatusNoSlotsAvailable is received
|
||||||
|
func (s *llmServer) getServerStatusRetry(ctx context.Context) (ServerStatus, error) {
|
||||||
|
var retries int
|
||||||
|
for {
|
||||||
|
status, err := s.getServerStatus(ctx)
|
||||||
|
if err != nil {
|
||||||
|
return status, err
|
||||||
|
}
|
||||||
|
|
||||||
|
if status == ServerStatusNoSlotsAvailable {
|
||||||
|
if retries >= 10 {
|
||||||
|
return status, fmt.Errorf("no slots available after %d retries", retries)
|
||||||
|
}
|
||||||
|
|
||||||
|
time.Sleep(5 * time.Millisecond)
|
||||||
|
retries++
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
return status, nil
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *llmServer) Ping(ctx context.Context) error {
|
||||||
_, err := s.getServerStatus(ctx)
|
_, err := s.getServerStatus(ctx)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
slog.Debug("server unhealthy", "error", err)
|
slog.Debug("server unhealthy", "error", err)
|
||||||
|
@ -416,13 +481,25 @@ func (s *LlamaServer) Ping(ctx context.Context) error {
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *LlamaServer) WaitUntilRunning() error {
|
func (s *llmServer) WaitUntilRunning(ctx context.Context) error {
|
||||||
start := time.Now()
|
start := time.Now()
|
||||||
expiresAt := time.Now().Add(10 * time.Minute) // be generous with timeout, large models can take a while to load
|
expiresAt := time.Now().Add(10 * time.Minute) // be generous with timeout, large models can take a while to load
|
||||||
|
|
||||||
slog.Info("waiting for llama runner to start responding")
|
slog.Info("waiting for llama runner to start responding")
|
||||||
|
|
||||||
for {
|
for {
|
||||||
|
select {
|
||||||
|
case <-ctx.Done():
|
||||||
|
slog.Info("context expired before server started")
|
||||||
|
return fmt.Errorf("timed out waiting for llama runner to start: %w", ctx.Err())
|
||||||
|
case err := <-s.done:
|
||||||
|
msg := ""
|
||||||
|
if s.status != nil && s.status.LastErrMsg != "" {
|
||||||
|
msg = s.status.LastErrMsg
|
||||||
|
}
|
||||||
|
return fmt.Errorf("llama runner process has terminated: %v %s", err, msg)
|
||||||
|
default:
|
||||||
|
}
|
||||||
ctx, cancel := context.WithTimeout(context.Background(), 200*time.Millisecond)
|
ctx, cancel := context.WithTimeout(context.Background(), 200*time.Millisecond)
|
||||||
defer cancel()
|
defer cancel()
|
||||||
status, err := s.getServerStatus(ctx)
|
status, err := s.getServerStatus(ctx)
|
||||||
|
@ -487,7 +564,6 @@ ws ::= ([ \t\n] ws)?
|
||||||
`
|
`
|
||||||
|
|
||||||
const maxBufferSize = 512 * format.KiloByte
|
const maxBufferSize = 512 * format.KiloByte
|
||||||
const maxRetries = 3
|
|
||||||
|
|
||||||
type ImageData struct {
|
type ImageData struct {
|
||||||
Data []byte `json:"data"`
|
Data []byte `json:"data"`
|
||||||
|
@ -524,7 +600,19 @@ type CompletionResponse struct {
|
||||||
EvalDuration time.Duration
|
EvalDuration time.Duration
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *LlamaServer) Completion(ctx context.Context, req CompletionRequest, fn func(CompletionResponse)) error {
|
func (s *llmServer) Completion(ctx context.Context, req CompletionRequest, fn func(CompletionResponse)) error {
|
||||||
|
if err := s.sem.Acquire(ctx, 1); err != nil {
|
||||||
|
slog.Error("Failed to acquire semaphore", "error", err)
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
defer s.sem.Release(1)
|
||||||
|
|
||||||
|
// only allow maximum 10 "context shifts" to avoid infinite generation
|
||||||
|
if req.Options.NumPredict < 0 || req.Options.NumPredict > 10*s.options.NumCtx {
|
||||||
|
req.Options.NumPredict = 10 * s.options.NumCtx
|
||||||
|
slog.Debug("setting token limit to 10x num_ctx", "num_ctx", s.options.NumCtx, "num_predict", req.Options.NumPredict)
|
||||||
|
}
|
||||||
|
|
||||||
request := map[string]any{
|
request := map[string]any{
|
||||||
"prompt": req.Prompt,
|
"prompt": req.Prompt,
|
||||||
"stream": true,
|
"stream": true,
|
||||||
|
@ -551,11 +639,11 @@ func (s *LlamaServer) Completion(ctx context.Context, req CompletionRequest, fn
|
||||||
}
|
}
|
||||||
|
|
||||||
// Make sure the server is ready
|
// Make sure the server is ready
|
||||||
status, err := s.getServerStatus(ctx)
|
status, err := s.getServerStatusRetry(ctx)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
} else if status != ServerStatusReady {
|
} else if status != ServerStatusReady {
|
||||||
return fmt.Errorf("unexpected server status: %d", status)
|
return fmt.Errorf("unexpected server status: %s", status.ToString())
|
||||||
}
|
}
|
||||||
|
|
||||||
if req.Format == "json" {
|
if req.Format == "json" {
|
||||||
|
@ -565,13 +653,6 @@ func (s *LlamaServer) Completion(ctx context.Context, req CompletionRequest, fn
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
retryDelay := 100 * time.Microsecond
|
|
||||||
for retries := 0; retries < maxRetries; retries++ {
|
|
||||||
if retries > 0 {
|
|
||||||
time.Sleep(retryDelay) // wait before retrying
|
|
||||||
retryDelay *= 2 // exponential backoff
|
|
||||||
}
|
|
||||||
|
|
||||||
// Handling JSON marshaling with special characters unescaped.
|
// Handling JSON marshaling with special characters unescaped.
|
||||||
buffer := &bytes.Buffer{}
|
buffer := &bytes.Buffer{}
|
||||||
enc := json.NewEncoder(buffer)
|
enc := json.NewEncoder(buffer)
|
||||||
|
@ -582,20 +663,20 @@ func (s *LlamaServer) Completion(ctx context.Context, req CompletionRequest, fn
|
||||||
}
|
}
|
||||||
|
|
||||||
endpoint := fmt.Sprintf("http://127.0.0.1:%d/completion", s.port)
|
endpoint := fmt.Sprintf("http://127.0.0.1:%d/completion", s.port)
|
||||||
req, err := http.NewRequestWithContext(ctx, http.MethodPost, endpoint, buffer)
|
serverReq, err := http.NewRequestWithContext(ctx, http.MethodPost, endpoint, buffer)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return fmt.Errorf("error creating POST request: %v", err)
|
return fmt.Errorf("error creating POST request: %v", err)
|
||||||
}
|
}
|
||||||
req.Header.Set("Content-Type", "application/json")
|
serverReq.Header.Set("Content-Type", "application/json")
|
||||||
|
|
||||||
resp, err := http.DefaultClient.Do(req)
|
res, err := http.DefaultClient.Do(serverReq)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return fmt.Errorf("POST predict: %v", err)
|
return fmt.Errorf("POST predict: %v", err)
|
||||||
}
|
}
|
||||||
defer resp.Body.Close()
|
defer res.Body.Close()
|
||||||
|
|
||||||
if resp.StatusCode >= 400 {
|
if res.StatusCode >= 400 {
|
||||||
bodyBytes, err := io.ReadAll(resp.Body)
|
bodyBytes, err := io.ReadAll(res.Body)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return fmt.Errorf("failed reading llm error response: %w", err)
|
return fmt.Errorf("failed reading llm error response: %w", err)
|
||||||
}
|
}
|
||||||
|
@ -603,11 +684,10 @@ func (s *LlamaServer) Completion(ctx context.Context, req CompletionRequest, fn
|
||||||
return fmt.Errorf("%s", bodyBytes)
|
return fmt.Errorf("%s", bodyBytes)
|
||||||
}
|
}
|
||||||
|
|
||||||
scanner := bufio.NewScanner(resp.Body)
|
scanner := bufio.NewScanner(res.Body)
|
||||||
buf := make([]byte, 0, maxBufferSize)
|
buf := make([]byte, 0, maxBufferSize)
|
||||||
scanner.Buffer(buf, maxBufferSize)
|
scanner.Buffer(buf, maxBufferSize)
|
||||||
|
|
||||||
retryNeeded := false
|
|
||||||
// keep track of the last token generated, this is used to abort if the model starts looping
|
// keep track of the last token generated, this is used to abort if the model starts looping
|
||||||
var lastToken string
|
var lastToken string
|
||||||
var tokenRepeat int
|
var tokenRepeat int
|
||||||
|
@ -623,12 +703,6 @@ func (s *LlamaServer) Completion(ctx context.Context, req CompletionRequest, fn
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
|
|
||||||
// try again on slot unavailable
|
|
||||||
if bytes.Contains(line, []byte("slot unavailable")) {
|
|
||||||
retryNeeded = true
|
|
||||||
break
|
|
||||||
}
|
|
||||||
|
|
||||||
evt, ok := bytes.CutPrefix(line, []byte("data: "))
|
evt, ok := bytes.CutPrefix(line, []byte("data: "))
|
||||||
if !ok {
|
if !ok {
|
||||||
return fmt.Errorf("error parsing llm response stream: %s", line)
|
return fmt.Errorf("error parsing llm response stream: %s", line)
|
||||||
|
@ -679,19 +753,13 @@ func (s *LlamaServer) Completion(ctx context.Context, req CompletionRequest, fn
|
||||||
if s.status != nil && s.status.LastErrMsg != "" {
|
if s.status != nil && s.status.LastErrMsg != "" {
|
||||||
msg = s.status.LastErrMsg
|
msg = s.status.LastErrMsg
|
||||||
}
|
}
|
||||||
|
|
||||||
return fmt.Errorf("an unknown error was encountered while running the model %s", msg)
|
return fmt.Errorf("an unknown error was encountered while running the model %s", msg)
|
||||||
}
|
}
|
||||||
|
|
||||||
return fmt.Errorf("error reading llm response: %v", err)
|
return fmt.Errorf("error reading llm response: %v", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
if !retryNeeded {
|
return nil
|
||||||
return nil // success
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// should never reach here ideally
|
|
||||||
return fmt.Errorf("max retries exceeded")
|
|
||||||
}
|
}
|
||||||
|
|
||||||
type EmbeddingRequest struct {
|
type EmbeddingRequest struct {
|
||||||
|
@ -702,13 +770,19 @@ type EmbeddingResponse struct {
|
||||||
Embedding []float64 `json:"embedding"`
|
Embedding []float64 `json:"embedding"`
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *LlamaServer) Embedding(ctx context.Context, prompt string) ([]float64, error) {
|
func (s *llmServer) Embedding(ctx context.Context, prompt string) ([]float64, error) {
|
||||||
|
if err := s.sem.Acquire(ctx, 1); err != nil {
|
||||||
|
slog.Error("Failed to acquire semaphore", "error", err)
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
defer s.sem.Release(1)
|
||||||
|
|
||||||
// Make sure the server is ready
|
// Make sure the server is ready
|
||||||
status, err := s.getServerStatus(ctx)
|
status, err := s.getServerStatusRetry(ctx)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
} else if status != ServerStatusReady {
|
} else if status != ServerStatusReady {
|
||||||
return nil, fmt.Errorf("unexpected server status: %d", status)
|
return nil, fmt.Errorf("unexpected server status: %s", status.ToString())
|
||||||
}
|
}
|
||||||
|
|
||||||
data, err := json.Marshal(TokenizeRequest{Content: prompt})
|
data, err := json.Marshal(TokenizeRequest{Content: prompt})
|
||||||
|
@ -754,13 +828,13 @@ type TokenizeResponse struct {
|
||||||
Tokens []int `json:"tokens"`
|
Tokens []int `json:"tokens"`
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *LlamaServer) Tokenize(ctx context.Context, content string) ([]int, error) {
|
func (s *llmServer) Tokenize(ctx context.Context, content string) ([]int, error) {
|
||||||
// Make sure the server is ready
|
// Make sure the server is ready
|
||||||
status, err := s.getServerStatus(ctx)
|
status, err := s.getServerStatus(ctx)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
} else if status != ServerStatusReady {
|
} else if status != ServerStatusReady && status != ServerStatusNoSlotsAvailable {
|
||||||
return nil, fmt.Errorf("unexpected server status: %d", status)
|
return nil, fmt.Errorf("unexpected server status: %s", status.ToString())
|
||||||
}
|
}
|
||||||
|
|
||||||
data, err := json.Marshal(TokenizeRequest{Content: content})
|
data, err := json.Marshal(TokenizeRequest{Content: content})
|
||||||
|
@ -806,13 +880,13 @@ type DetokenizeResponse struct {
|
||||||
Content string `json:"content"`
|
Content string `json:"content"`
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *LlamaServer) Detokenize(ctx context.Context, tokens []int) (string, error) {
|
func (s *llmServer) Detokenize(ctx context.Context, tokens []int) (string, error) {
|
||||||
// Make sure the server is ready
|
// Make sure the server is ready
|
||||||
status, err := s.getServerStatus(ctx)
|
status, err := s.getServerStatus(ctx)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return "", err
|
return "", err
|
||||||
} else if status != ServerStatusReady {
|
} else if status != ServerStatusReady && status != ServerStatusNoSlotsAvailable {
|
||||||
return "", fmt.Errorf("unexpected server status: %d", status)
|
return "", fmt.Errorf("unexpected server status: %s", status.ToString())
|
||||||
}
|
}
|
||||||
|
|
||||||
data, err := json.Marshal(DetokenizeRequest{Tokens: tokens})
|
data, err := json.Marshal(DetokenizeRequest{Tokens: tokens})
|
||||||
|
@ -850,15 +924,25 @@ func (s *LlamaServer) Detokenize(ctx context.Context, tokens []int) (string, err
|
||||||
return decoded.Content, nil
|
return decoded.Content, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *LlamaServer) Close() error {
|
func (s *llmServer) Close() error {
|
||||||
if s.cmd != nil {
|
if s.cmd != nil {
|
||||||
slog.Debug("stopping llama server")
|
slog.Debug("stopping llama server")
|
||||||
return s.cmd.Process.Kill()
|
if err := s.cmd.Process.Kill(); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
_ = s.cmd.Wait()
|
||||||
|
|
||||||
|
slog.Debug("llama server stopped")
|
||||||
}
|
}
|
||||||
|
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (s *llmServer) EstimatedVRAM() uint64 {
|
||||||
|
return s.estimatedVRAM
|
||||||
|
}
|
||||||
|
|
||||||
func parseDurationMs(ms float64) time.Duration {
|
func parseDurationMs(ms float64) time.Duration {
|
||||||
dur, err := time.ParseDuration(fmt.Sprintf("%fms", ms))
|
dur, err := time.ParseDuration(fmt.Sprintf("%fms", ms))
|
||||||
if err != nil {
|
if err != nil {
|
||||||
|
|
|
@ -19,7 +19,7 @@ export default function () {
|
||||||
const [step, setStep] = useState<Step>(Step.WELCOME)
|
const [step, setStep] = useState<Step>(Step.WELCOME)
|
||||||
const [commandCopied, setCommandCopied] = useState<boolean>(false)
|
const [commandCopied, setCommandCopied] = useState<boolean>(false)
|
||||||
|
|
||||||
const command = 'ollama run llama2'
|
const command = 'ollama run llama3'
|
||||||
|
|
||||||
return (
|
return (
|
||||||
<div className='drag'>
|
<div className='drag'>
|
||||||
|
|
132
parser/parser.go
132
parser/parser.go
|
@ -1,132 +0,0 @@
|
||||||
package parser
|
|
||||||
|
|
||||||
import (
|
|
||||||
"bufio"
|
|
||||||
"bytes"
|
|
||||||
"errors"
|
|
||||||
"fmt"
|
|
||||||
"io"
|
|
||||||
"log/slog"
|
|
||||||
"slices"
|
|
||||||
)
|
|
||||||
|
|
||||||
type Command struct {
|
|
||||||
Name string
|
|
||||||
Args string
|
|
||||||
}
|
|
||||||
|
|
||||||
func (c *Command) Reset() {
|
|
||||||
c.Name = ""
|
|
||||||
c.Args = ""
|
|
||||||
}
|
|
||||||
|
|
||||||
func Parse(reader io.Reader) ([]Command, error) {
|
|
||||||
var commands []Command
|
|
||||||
var command, modelCommand Command
|
|
||||||
|
|
||||||
scanner := bufio.NewScanner(reader)
|
|
||||||
scanner.Buffer(make([]byte, 0, bufio.MaxScanTokenSize), bufio.MaxScanTokenSize)
|
|
||||||
scanner.Split(scanModelfile)
|
|
||||||
for scanner.Scan() {
|
|
||||||
line := scanner.Bytes()
|
|
||||||
|
|
||||||
fields := bytes.SplitN(line, []byte(" "), 2)
|
|
||||||
if len(fields) == 0 || len(fields[0]) == 0 {
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
|
|
||||||
switch string(bytes.ToUpper(fields[0])) {
|
|
||||||
case "FROM":
|
|
||||||
command.Name = "model"
|
|
||||||
command.Args = string(bytes.TrimSpace(fields[1]))
|
|
||||||
// copy command for validation
|
|
||||||
modelCommand = command
|
|
||||||
case "ADAPTER":
|
|
||||||
command.Name = string(bytes.ToLower(fields[0]))
|
|
||||||
command.Args = string(bytes.TrimSpace(fields[1]))
|
|
||||||
case "LICENSE", "TEMPLATE", "SYSTEM", "PROMPT":
|
|
||||||
command.Name = string(bytes.ToLower(fields[0]))
|
|
||||||
command.Args = string(fields[1])
|
|
||||||
case "PARAMETER":
|
|
||||||
fields = bytes.SplitN(fields[1], []byte(" "), 2)
|
|
||||||
if len(fields) < 2 {
|
|
||||||
return nil, fmt.Errorf("missing value for %s", fields)
|
|
||||||
}
|
|
||||||
|
|
||||||
command.Name = string(fields[0])
|
|
||||||
command.Args = string(bytes.TrimSpace(fields[1]))
|
|
||||||
case "EMBED":
|
|
||||||
return nil, fmt.Errorf("deprecated command: EMBED is no longer supported, use the /embed API endpoint instead")
|
|
||||||
case "MESSAGE":
|
|
||||||
command.Name = string(bytes.ToLower(fields[0]))
|
|
||||||
fields = bytes.SplitN(fields[1], []byte(" "), 2)
|
|
||||||
if len(fields) < 2 {
|
|
||||||
return nil, fmt.Errorf("should be in the format <role> <message>")
|
|
||||||
}
|
|
||||||
if !slices.Contains([]string{"system", "user", "assistant"}, string(bytes.ToLower(fields[0]))) {
|
|
||||||
return nil, fmt.Errorf("role must be one of \"system\", \"user\", or \"assistant\"")
|
|
||||||
}
|
|
||||||
command.Args = fmt.Sprintf("%s: %s", string(bytes.ToLower(fields[0])), string(fields[1]))
|
|
||||||
default:
|
|
||||||
if !bytes.HasPrefix(fields[0], []byte("#")) {
|
|
||||||
// log a warning for unknown commands
|
|
||||||
slog.Warn(fmt.Sprintf("Unknown command: %s", fields[0]))
|
|
||||||
}
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
|
|
||||||
commands = append(commands, command)
|
|
||||||
command.Reset()
|
|
||||||
}
|
|
||||||
|
|
||||||
if modelCommand.Args == "" {
|
|
||||||
return nil, errors.New("no FROM line for the model was specified")
|
|
||||||
}
|
|
||||||
|
|
||||||
return commands, scanner.Err()
|
|
||||||
}
|
|
||||||
|
|
||||||
func scanModelfile(data []byte, atEOF bool) (advance int, token []byte, err error) {
|
|
||||||
advance, token, err = scan([]byte(`"""`), []byte(`"""`), data, atEOF)
|
|
||||||
if err != nil {
|
|
||||||
return 0, nil, err
|
|
||||||
}
|
|
||||||
|
|
||||||
if advance > 0 && token != nil {
|
|
||||||
return advance, token, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
advance, token, err = scan([]byte(`"`), []byte(`"`), data, atEOF)
|
|
||||||
if err != nil {
|
|
||||||
return 0, nil, err
|
|
||||||
}
|
|
||||||
|
|
||||||
if advance > 0 && token != nil {
|
|
||||||
return advance, token, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
return bufio.ScanLines(data, atEOF)
|
|
||||||
}
|
|
||||||
|
|
||||||
func scan(openBytes, closeBytes, data []byte, atEOF bool) (advance int, token []byte, err error) {
|
|
||||||
newline := bytes.IndexByte(data, '\n')
|
|
||||||
|
|
||||||
if start := bytes.Index(data, openBytes); start >= 0 && start < newline {
|
|
||||||
end := bytes.Index(data[start+len(openBytes):], closeBytes)
|
|
||||||
if end < 0 {
|
|
||||||
if atEOF {
|
|
||||||
return 0, nil, fmt.Errorf("unterminated %s: expecting %s", openBytes, closeBytes)
|
|
||||||
} else {
|
|
||||||
return 0, nil, nil
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
n := start + len(openBytes) + end + len(closeBytes)
|
|
||||||
|
|
||||||
newData := data[:start]
|
|
||||||
newData = append(newData, data[start+len(openBytes):n-len(closeBytes)]...)
|
|
||||||
return n, newData, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
return 0, nil, nil
|
|
||||||
}
|
|
|
@ -1,98 +0,0 @@
|
||||||
package parser
|
|
||||||
|
|
||||||
import (
|
|
||||||
"strings"
|
|
||||||
"testing"
|
|
||||||
|
|
||||||
"github.com/stretchr/testify/assert"
|
|
||||||
)
|
|
||||||
|
|
||||||
func Test_Parser(t *testing.T) {
|
|
||||||
|
|
||||||
input := `
|
|
||||||
FROM model1
|
|
||||||
ADAPTER adapter1
|
|
||||||
LICENSE MIT
|
|
||||||
PARAMETER param1 value1
|
|
||||||
PARAMETER param2 value2
|
|
||||||
TEMPLATE template1
|
|
||||||
`
|
|
||||||
|
|
||||||
reader := strings.NewReader(input)
|
|
||||||
|
|
||||||
commands, err := Parse(reader)
|
|
||||||
assert.Nil(t, err)
|
|
||||||
|
|
||||||
expectedCommands := []Command{
|
|
||||||
{Name: "model", Args: "model1"},
|
|
||||||
{Name: "adapter", Args: "adapter1"},
|
|
||||||
{Name: "license", Args: "MIT"},
|
|
||||||
{Name: "param1", Args: "value1"},
|
|
||||||
{Name: "param2", Args: "value2"},
|
|
||||||
{Name: "template", Args: "template1"},
|
|
||||||
}
|
|
||||||
|
|
||||||
assert.Equal(t, expectedCommands, commands)
|
|
||||||
}
|
|
||||||
|
|
||||||
func Test_Parser_NoFromLine(t *testing.T) {
|
|
||||||
|
|
||||||
input := `
|
|
||||||
PARAMETER param1 value1
|
|
||||||
PARAMETER param2 value2
|
|
||||||
`
|
|
||||||
|
|
||||||
reader := strings.NewReader(input)
|
|
||||||
|
|
||||||
_, err := Parse(reader)
|
|
||||||
assert.ErrorContains(t, err, "no FROM line")
|
|
||||||
}
|
|
||||||
|
|
||||||
func Test_Parser_MissingValue(t *testing.T) {
|
|
||||||
|
|
||||||
input := `
|
|
||||||
FROM foo
|
|
||||||
PARAMETER param1
|
|
||||||
`
|
|
||||||
|
|
||||||
reader := strings.NewReader(input)
|
|
||||||
|
|
||||||
_, err := Parse(reader)
|
|
||||||
assert.ErrorContains(t, err, "missing value for [param1]")
|
|
||||||
|
|
||||||
}
|
|
||||||
|
|
||||||
func Test_Parser_Messages(t *testing.T) {
|
|
||||||
|
|
||||||
input := `
|
|
||||||
FROM foo
|
|
||||||
MESSAGE system You are a Parser. Always Parse things.
|
|
||||||
MESSAGE user Hey there!
|
|
||||||
MESSAGE assistant Hello, I want to parse all the things!
|
|
||||||
`
|
|
||||||
|
|
||||||
reader := strings.NewReader(input)
|
|
||||||
commands, err := Parse(reader)
|
|
||||||
assert.Nil(t, err)
|
|
||||||
|
|
||||||
expectedCommands := []Command{
|
|
||||||
{Name: "model", Args: "foo"},
|
|
||||||
{Name: "message", Args: "system: You are a Parser. Always Parse things."},
|
|
||||||
{Name: "message", Args: "user: Hey there!"},
|
|
||||||
{Name: "message", Args: "assistant: Hello, I want to parse all the things!"},
|
|
||||||
}
|
|
||||||
|
|
||||||
assert.Equal(t, expectedCommands, commands)
|
|
||||||
}
|
|
||||||
|
|
||||||
func Test_Parser_Messages_BadRole(t *testing.T) {
|
|
||||||
|
|
||||||
input := `
|
|
||||||
FROM foo
|
|
||||||
MESSAGE badguy I'm a bad guy!
|
|
||||||
`
|
|
||||||
|
|
||||||
reader := strings.NewReader(input)
|
|
||||||
_, err := Parse(reader)
|
|
||||||
assert.ErrorContains(t, err, "role must be one of \"system\", \"user\", or \"assistant\"")
|
|
||||||
}
|
|
|
@ -218,7 +218,7 @@ func (i *Instance) Readline() (string, error) {
|
||||||
case CharCtrlZ:
|
case CharCtrlZ:
|
||||||
fd := int(syscall.Stdin)
|
fd := int(syscall.Stdin)
|
||||||
return handleCharCtrlZ(fd, i.Terminal.termios)
|
return handleCharCtrlZ(fd, i.Terminal.termios)
|
||||||
case CharEnter:
|
case CharEnter, CharCtrlJ:
|
||||||
output := buf.String()
|
output := buf.String()
|
||||||
if output != "" {
|
if output != "" {
|
||||||
i.History.Add([]rune(output))
|
i.History.Add([]rune(output))
|
||||||
|
@ -232,7 +232,7 @@ func (i *Instance) Readline() (string, error) {
|
||||||
metaDel = false
|
metaDel = false
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
if r >= CharSpace || r == CharEnter {
|
if r >= CharSpace || r == CharEnter || r == CharCtrlJ {
|
||||||
buf.Add(r)
|
buf.Add(r)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
Some files were not shown because too many files have changed in this diff Show more
Loading…
Reference in a new issue