Merge remote-tracking branch 'upstream/main' into pr3702

This commit is contained in:
Daniel Hiltgen 2024-05-08 16:44:35 -07:00
commit 920a4b0794
132 changed files with 7701 additions and 4766 deletions

View file

@ -103,6 +103,7 @@ jobs:
path: | path: |
llm/build/**/bin/* llm/build/**/bin/*
llm/build/**/*.a llm/build/**/*.a
dist/windows-amd64/**
# ROCm generation step # ROCm generation step
generate-windows-rocm: generate-windows-rocm:
@ -173,7 +174,9 @@ jobs:
- uses: actions/upload-artifact@v4 - uses: actions/upload-artifact@v4
with: with:
name: generate-windows-rocm name: generate-windows-rocm
path: llm/build/**/bin/* path: |
llm/build/**/bin/*
dist/windows-amd64/**
- uses: actions/upload-artifact@v4 - uses: actions/upload-artifact@v4
with: with:
name: windows-rocm-deps name: windows-rocm-deps
@ -253,7 +256,9 @@ jobs:
- uses: actions/upload-artifact@v4 - uses: actions/upload-artifact@v4
with: with:
name: generate-windows-cuda name: generate-windows-cuda
path: llm/build/**/bin/* path: |
llm/build/**/bin/*
dist/windows-amd64/**
- uses: actions/upload-artifact@v4 - uses: actions/upload-artifact@v4
with: with:
name: windows-cuda-deps name: windows-cuda-deps
@ -306,23 +311,18 @@ jobs:
- uses: actions/download-artifact@v4 - uses: actions/download-artifact@v4
with: with:
name: generate-windows-cpu name: generate-windows-cpu
path: llm/build
- uses: actions/download-artifact@v4 - uses: actions/download-artifact@v4
with: with:
name: generate-windows-cuda name: generate-windows-cuda
path: llm/build
- uses: actions/download-artifact@v4 - uses: actions/download-artifact@v4
with: with:
name: windows-cuda-deps name: windows-cuda-deps
path: dist/deps
- uses: actions/download-artifact@v4 - uses: actions/download-artifact@v4
with: with:
name: windows-rocm-deps name: windows-rocm-deps
path: dist/deps
- uses: actions/download-artifact@v4 - uses: actions/download-artifact@v4
with: with:
name: generate-windows-rocm name: generate-windows-rocm
path: llm/build
- run: dir llm/build - run: dir llm/build
- run: | - run: |
$gopath=(get-command go).source | split-path -parent $gopath=(get-command go).source | split-path -parent
@ -331,13 +331,13 @@ jobs:
$env:CMAKE_SYSTEM_VERSION="10.0.22621.0" $env:CMAKE_SYSTEM_VERSION="10.0.22621.0"
$env:PATH="$gopath;$env:PATH" $env:PATH="$gopath;$env:PATH"
$env:OLLAMA_SKIP_GENERATE="1" $env:OLLAMA_SKIP_GENERATE="1"
$env:NVIDIA_DIR=$(resolve-path ".\dist\deps")
$env:HIP_PATH=$(resolve-path ".\dist\deps")
& .\scripts\build_windows.ps1 & .\scripts\build_windows.ps1
- uses: actions/upload-artifact@v4 - uses: actions/upload-artifact@v4
with: with:
name: dist-windows name: dist-windows
path: dist/*.exe path: |
dist/OllamaSetup.exe
dist/ollama-windows-*.zip
# Linux x86 assets built using the container based build # Linux x86 assets built using the container based build
build-linux-amd64: build-linux-amd64:

View file

@ -1,5 +1,15 @@
name: test name: test
concurrency:
# For PRs, later CI runs preempt previous ones. e.g. a force push on a PR
# cancels running CI jobs and starts all new ones.
#
# For non-PR pushes, concurrency.group needs to be unique for every distinct
# CI run we want to have happen. Use run_id, which in practice means all
# non-PR CI runs will be allowed to run without preempting each other.
group: ${{ github.workflow }}-$${{ github.pull_request.number || github.run_id }}
cancel-in-progress: true
on: on:
pull_request: pull_request:
paths: paths:
@ -21,7 +31,9 @@ jobs:
- id: changes - id: changes
run: | run: |
changed() { changed() {
git diff-tree -r --no-commit-id --name-only ${{ github.event.pull_request.base.sha }} ${{ github.event.pull_request.head.sha }} \ git diff-tree -r --no-commit-id --name-only \
$(git merge-base ${{ github.event.pull_request.base.sha }} ${{ github.event.pull_request.head.sha }}) \
${{ github.event.pull_request.head.sha }} \
| xargs python3 -c "import sys; print(any([x.startswith('$1') for x in sys.argv[1:]]))" | xargs python3 -c "import sys; print(any([x.startswith('$1') for x in sys.argv[1:]]))"
} }
@ -103,7 +115,9 @@ jobs:
- uses: actions/upload-artifact@v4 - uses: actions/upload-artifact@v4
with: with:
name: cuda-${{ matrix.cuda-version }}-libraries name: cuda-${{ matrix.cuda-version }}-libraries
path: llm/build/**/bin/* path: |
llm/build/**/bin/*
dist/windows-amd64/**
generate-rocm: generate-rocm:
needs: [changes] needs: [changes]
if: ${{ needs.changes.outputs.GENERATE_ROCM == 'True' }} if: ${{ needs.changes.outputs.GENERATE_ROCM == 'True' }}
@ -134,7 +148,9 @@ jobs:
- uses: actions/upload-artifact@v4 - uses: actions/upload-artifact@v4
with: with:
name: rocm-${{ matrix.rocm-version }}-libraries name: rocm-${{ matrix.rocm-version }}-libraries
path: llm/build/**/bin/* path: |
llm/build/**/bin/*
dist/windows-amd64/**
# ROCm generation step # ROCm generation step
generate-windows-rocm: generate-windows-rocm:
@ -253,14 +269,9 @@ jobs:
mkdir -p llm/build/darwin/$ARCH/stub/bin mkdir -p llm/build/darwin/$ARCH/stub/bin
touch llm/build/darwin/$ARCH/stub/bin/ollama_llama_server touch llm/build/darwin/$ARCH/stub/bin/ollama_llama_server
if: ${{ startsWith(matrix.os, 'macos-') }} if: ${{ startsWith(matrix.os, 'macos-') }}
- run: |
mkdir -p llm/build/windows/$ARCH/stub/bin
touch llm/build/windows/$ARCH/stub/bin/ollama_llama_server
if: ${{ startsWith(matrix.os, 'windows-') }}
shell: bash
- uses: golangci/golangci-lint-action@v4 - uses: golangci/golangci-lint-action@v4
with: with:
args: --timeout 8m0s args: --timeout 8m0s -v
test: test:
strategy: strategy:
matrix: matrix:
@ -284,7 +295,6 @@ jobs:
with: with:
go-version-file: go.mod go-version-file: go.mod
cache: true cache: true
- run: go get
- run: | - run: |
case ${{ matrix.arch }} in case ${{ matrix.arch }} in
amd64) echo ARCH=x86_64 ;; amd64) echo ARCH=x86_64 ;;
@ -299,10 +309,6 @@ jobs:
mkdir -p llm/build/darwin/$ARCH/stub/bin mkdir -p llm/build/darwin/$ARCH/stub/bin
touch llm/build/darwin/$ARCH/stub/bin/ollama_llama_server touch llm/build/darwin/$ARCH/stub/bin/ollama_llama_server
if: ${{ startsWith(matrix.os, 'macos-') }} if: ${{ startsWith(matrix.os, 'macos-') }}
- run: |
mkdir -p llm/build/windows/$ARCH/stub/bin
touch llm/build/windows/$ARCH/stub/bin/ollama_llama_server
if: ${{ startsWith(matrix.os, 'windows-') }}
shell: bash shell: bash
- run: go generate ./... - run: go generate ./...
- run: go build - run: go build

1
.gitignore vendored
View file

@ -12,3 +12,4 @@ ggml-metal.metal
test_data test_data
*.crt *.crt
llm/build llm/build
__debug_bin*

View file

@ -18,7 +18,7 @@ ENV PATH /opt/rh/devtoolset-10/root/usr/bin:$PATH
COPY --from=llm-code / /go/src/github.com/ollama/ollama/ COPY --from=llm-code / /go/src/github.com/ollama/ollama/
WORKDIR /go/src/github.com/ollama/ollama/llm/generate WORKDIR /go/src/github.com/ollama/ollama/llm/generate
ARG CGO_CFLAGS ARG CGO_CFLAGS
RUN OLLAMA_SKIP_CPU_GENERATE=1 sh gen_linux.sh RUN OLLAMA_SKIP_STATIC_GENERATE=1 OLLAMA_SKIP_CPU_GENERATE=1 sh gen_linux.sh
FROM --platform=linux/arm64 nvidia/cuda:$CUDA_VERSION-devel-rockylinux8 AS cuda-build-arm64 FROM --platform=linux/arm64 nvidia/cuda:$CUDA_VERSION-devel-rockylinux8 AS cuda-build-arm64
ARG CMAKE_VERSION ARG CMAKE_VERSION
@ -28,7 +28,7 @@ ENV PATH /opt/rh/gcc-toolset-10/root/usr/bin:$PATH
COPY --from=llm-code / /go/src/github.com/ollama/ollama/ COPY --from=llm-code / /go/src/github.com/ollama/ollama/
WORKDIR /go/src/github.com/ollama/ollama/llm/generate WORKDIR /go/src/github.com/ollama/ollama/llm/generate
ARG CGO_CFLAGS ARG CGO_CFLAGS
RUN OLLAMA_SKIP_CPU_GENERATE=1 sh gen_linux.sh RUN OLLAMA_SKIP_STATIC_GENERATE=1 OLLAMA_SKIP_CPU_GENERATE=1 sh gen_linux.sh
FROM --platform=linux/amd64 rocm/dev-centos-7:${ROCM_VERSION}-complete AS rocm-build-amd64 FROM --platform=linux/amd64 rocm/dev-centos-7:${ROCM_VERSION}-complete AS rocm-build-amd64
ARG CMAKE_VERSION ARG CMAKE_VERSION
@ -40,7 +40,7 @@ COPY --from=llm-code / /go/src/github.com/ollama/ollama/
WORKDIR /go/src/github.com/ollama/ollama/llm/generate WORKDIR /go/src/github.com/ollama/ollama/llm/generate
ARG CGO_CFLAGS ARG CGO_CFLAGS
ARG AMDGPU_TARGETS ARG AMDGPU_TARGETS
RUN OLLAMA_SKIP_CPU_GENERATE=1 sh gen_linux.sh RUN OLLAMA_SKIP_STATIC_GENERATE=1 OLLAMA_SKIP_CPU_GENERATE=1 sh gen_linux.sh
RUN mkdir /tmp/scratch && \ RUN mkdir /tmp/scratch && \
for dep in $(zcat /go/src/github.com/ollama/ollama/llm/build/linux/x86_64/rocm*/bin/deps.txt.gz) ; do \ for dep in $(zcat /go/src/github.com/ollama/ollama/llm/build/linux/x86_64/rocm*/bin/deps.txt.gz) ; do \
cp ${dep} /tmp/scratch/ || exit 1 ; \ cp ${dep} /tmp/scratch/ || exit 1 ; \
@ -64,11 +64,11 @@ WORKDIR /go/src/github.com/ollama/ollama/llm/generate
FROM --platform=linux/amd64 cpu-builder-amd64 AS static-build-amd64 FROM --platform=linux/amd64 cpu-builder-amd64 AS static-build-amd64
RUN OLLAMA_CPU_TARGET="static" sh gen_linux.sh RUN OLLAMA_CPU_TARGET="static" sh gen_linux.sh
FROM --platform=linux/amd64 cpu-builder-amd64 AS cpu-build-amd64 FROM --platform=linux/amd64 cpu-builder-amd64 AS cpu-build-amd64
RUN OLLAMA_CPU_TARGET="cpu" sh gen_linux.sh RUN OLLAMA_SKIP_STATIC_GENERATE=1 OLLAMA_CPU_TARGET="cpu" sh gen_linux.sh
FROM --platform=linux/amd64 cpu-builder-amd64 AS cpu_avx-build-amd64 FROM --platform=linux/amd64 cpu-builder-amd64 AS cpu_avx-build-amd64
RUN OLLAMA_CPU_TARGET="cpu_avx" sh gen_linux.sh RUN OLLAMA_SKIP_STATIC_GENERATE=1 OLLAMA_CPU_TARGET="cpu_avx" sh gen_linux.sh
FROM --platform=linux/amd64 cpu-builder-amd64 AS cpu_avx2-build-amd64 FROM --platform=linux/amd64 cpu-builder-amd64 AS cpu_avx2-build-amd64
RUN OLLAMA_CPU_TARGET="cpu_avx2" sh gen_linux.sh RUN OLLAMA_SKIP_STATIC_GENERATE=1 OLLAMA_CPU_TARGET="cpu_avx2" sh gen_linux.sh
FROM --platform=linux/arm64 centos:7 AS cpu-builder-arm64 FROM --platform=linux/arm64 centos:7 AS cpu-builder-arm64
ARG CMAKE_VERSION ARG CMAKE_VERSION
@ -84,7 +84,7 @@ WORKDIR /go/src/github.com/ollama/ollama/llm/generate
FROM --platform=linux/arm64 cpu-builder-arm64 AS static-build-arm64 FROM --platform=linux/arm64 cpu-builder-arm64 AS static-build-arm64
RUN OLLAMA_CPU_TARGET="static" sh gen_linux.sh RUN OLLAMA_CPU_TARGET="static" sh gen_linux.sh
FROM --platform=linux/arm64 cpu-builder-arm64 AS cpu-build-arm64 FROM --platform=linux/arm64 cpu-builder-arm64 AS cpu-build-arm64
RUN OLLAMA_CPU_TARGET="cpu" sh gen_linux.sh RUN OLLAMA_SKIP_STATIC_GENERATE=1 OLLAMA_CPU_TARGET="cpu" sh gen_linux.sh
# Intermediate stage used for ./scripts/build_linux.sh # Intermediate stage used for ./scripts/build_linux.sh

View file

@ -1,5 +1,5 @@
<div align="center"> <div align="center">
<img alt="ollama" height="200px" src="https://github.com/ollama/ollama/assets/3325447/0d0b44e2-8f4a-4e99-9b52-a5c1c741c8f7">  <img alt="ollama" height="200px" src="https://github.com/ollama/ollama/assets/3325447/0d0b44e2-8f4a-4e99-9b52-a5c1c741c8f7">
</div> </div>
# Ollama # Ollama
@ -35,10 +35,10 @@ The official [Ollama Docker image](https://hub.docker.com/r/ollama/ollama) `olla
## Quickstart ## Quickstart
To run and chat with [Llama 2](https://ollama.com/library/llama2): To run and chat with [Llama 3](https://ollama.com/library/llama3):
``` ```
ollama run llama2 ollama run llama3
``` ```
## Model library ## Model library
@ -49,17 +49,14 @@ Here are some example models that can be downloaded:
| Model | Parameters | Size | Download | | Model | Parameters | Size | Download |
| ------------------ | ---------- | ----- | ------------------------------ | | ------------------ | ---------- | ----- | ------------------------------ |
| Llama 2 | 7B | 3.8GB | `ollama run llama2` | | Llama 3 | 8B | 4.7GB | `ollama run llama3` |
| Llama 3 | 70B | 40GB | `ollama run llama3:70b` |
| Phi-3 | 3.8B | 2.3GB | `ollama run phi3` |
| Mistral | 7B | 4.1GB | `ollama run mistral` | | Mistral | 7B | 4.1GB | `ollama run mistral` |
| Dolphin Phi | 2.7B | 1.6GB | `ollama run dolphin-phi` |
| Phi-2 | 2.7B | 1.7GB | `ollama run phi` |
| Neural Chat | 7B | 4.1GB | `ollama run neural-chat` | | Neural Chat | 7B | 4.1GB | `ollama run neural-chat` |
| Starling | 7B | 4.1GB | `ollama run starling-lm` | | Starling | 7B | 4.1GB | `ollama run starling-lm` |
| Code Llama | 7B | 3.8GB | `ollama run codellama` | | Code Llama | 7B | 3.8GB | `ollama run codellama` |
| Llama 2 Uncensored | 7B | 3.8GB | `ollama run llama2-uncensored` | | Llama 2 Uncensored | 7B | 3.8GB | `ollama run llama2-uncensored` |
| Llama 2 13B | 13B | 7.3GB | `ollama run llama2:13b` |
| Llama 2 70B | 70B | 39GB | `ollama run llama2:70b` |
| Orca Mini | 3B | 1.9GB | `ollama run orca-mini` |
| LLaVA | 7B | 4.5GB | `ollama run llava` | | LLaVA | 7B | 4.5GB | `ollama run llava` |
| Gemma | 2B | 1.4GB | `ollama run gemma:2b` | | Gemma | 2B | 1.4GB | `ollama run gemma:2b` |
| Gemma | 7B | 4.8GB | `ollama run gemma:7b` | | Gemma | 7B | 4.8GB | `ollama run gemma:7b` |
@ -97,16 +94,16 @@ See the [guide](docs/import.md) on importing models for more information.
### Customize a prompt ### Customize a prompt
Models from the Ollama library can be customized with a prompt. For example, to customize the `llama2` model: Models from the Ollama library can be customized with a prompt. For example, to customize the `llama3` model:
``` ```
ollama pull llama2 ollama pull llama3
``` ```
Create a `Modelfile`: Create a `Modelfile`:
``` ```
FROM llama2 FROM llama3
# set the temperature to 1 [higher is more creative, lower is more coherent] # set the temperature to 1 [higher is more creative, lower is more coherent]
PARAMETER temperature 1 PARAMETER temperature 1
@ -141,7 +138,7 @@ ollama create mymodel -f ./Modelfile
### Pull a model ### Pull a model
``` ```
ollama pull llama2 ollama pull llama3
``` ```
> This command can also be used to update a local model. Only the diff will be pulled. > This command can also be used to update a local model. Only the diff will be pulled.
@ -149,13 +146,13 @@ ollama pull llama2
### Remove a model ### Remove a model
``` ```
ollama rm llama2 ollama rm llama3
``` ```
### Copy a model ### Copy a model
``` ```
ollama cp llama2 my-llama2 ollama cp llama3 my-model
``` ```
### Multiline input ### Multiline input
@ -176,10 +173,10 @@ I'm a basic program that prints the famous "Hello, world!" message to the consol
The image features a yellow smiley face, which is likely the central focus of the picture. The image features a yellow smiley face, which is likely the central focus of the picture.
``` ```
### Pass in prompt as arguments ### Pass the prompt as an argument
``` ```
$ ollama run llama2 "Summarize this file: $(cat README.md)" $ ollama run llama3 "Summarize this file: $(cat README.md)"
Ollama is a lightweight, extensible framework for building and running language models on the local machine. It provides a simple API for creating, running, and managing models, as well as a library of pre-built models that can be easily used in a variety of applications. Ollama is a lightweight, extensible framework for building and running language models on the local machine. It provides a simple API for creating, running, and managing models, as well as a library of pre-built models that can be easily used in a variety of applications.
``` ```
@ -226,7 +223,7 @@ Next, start the server:
Finally, in a separate shell, run a model: Finally, in a separate shell, run a model:
``` ```
./ollama run llama2 ./ollama run llama3
``` ```
## REST API ## REST API
@ -237,7 +234,7 @@ Ollama has a REST API for running and managing models.
``` ```
curl http://localhost:11434/api/generate -d '{ curl http://localhost:11434/api/generate -d '{
"model": "llama2", "model": "llama3",
"prompt":"Why is the sky blue?" "prompt":"Why is the sky blue?"
}' }'
``` ```
@ -246,7 +243,7 @@ curl http://localhost:11434/api/generate -d '{
``` ```
curl http://localhost:11434/api/chat -d '{ curl http://localhost:11434/api/chat -d '{
"model": "mistral", "model": "llama3",
"messages": [ "messages": [
{ "role": "user", "content": "why is the sky blue?" } { "role": "user", "content": "why is the sky blue?" }
] ]
@ -259,16 +256,18 @@ See the [API documentation](./docs/api.md) for all endpoints.
### Web & Desktop ### Web & Desktop
- [Open WebUI](https://github.com/open-webui/open-webui)
- [Enchanted (macOS native)](https://github.com/AugustDev/enchanted)
- [Hollama](https://github.com/fmaclen/hollama)
- [Lollms-Webui](https://github.com/ParisNeo/lollms-webui) - [Lollms-Webui](https://github.com/ParisNeo/lollms-webui)
- [LibreChat](https://github.com/danny-avila/LibreChat) - [LibreChat](https://github.com/danny-avila/LibreChat)
- [Bionic GPT](https://github.com/bionic-gpt/bionic-gpt) - [Bionic GPT](https://github.com/bionic-gpt/bionic-gpt)
- [Enchanted (macOS native)](https://github.com/AugustDev/enchanted)
- [HTML UI](https://github.com/rtcfirefly/ollama-ui) - [HTML UI](https://github.com/rtcfirefly/ollama-ui)
- [Saddle](https://github.com/jikkuatwork/saddle) - [Saddle](https://github.com/jikkuatwork/saddle)
- [Chatbot UI](https://github.com/ivanfioravanti/chatbot-ollama) - [Chatbot UI](https://github.com/ivanfioravanti/chatbot-ollama)
- [Chatbot UI v2](https://github.com/mckaywrigley/chatbot-ui)
- [Typescript UI](https://github.com/ollama-interface/Ollama-Gui?tab=readme-ov-file) - [Typescript UI](https://github.com/ollama-interface/Ollama-Gui?tab=readme-ov-file)
- [Minimalistic React UI for Ollama Models](https://github.com/richawo/minimal-llm-ui) - [Minimalistic React UI for Ollama Models](https://github.com/richawo/minimal-llm-ui)
- [Open WebUI](https://github.com/open-webui/open-webui)
- [Ollamac](https://github.com/kevinhermawan/Ollamac) - [Ollamac](https://github.com/kevinhermawan/Ollamac)
- [big-AGI](https://github.com/enricoros/big-AGI/blob/main/docs/config-local-ollama.md) - [big-AGI](https://github.com/enricoros/big-AGI/blob/main/docs/config-local-ollama.md)
- [Cheshire Cat assistant framework](https://github.com/cheshire-cat-ai/core) - [Cheshire Cat assistant framework](https://github.com/cheshire-cat-ai/core)
@ -286,13 +285,20 @@ See the [API documentation](./docs/api.md) for all endpoints.
- [OllamaGUI](https://github.com/enoch1118/ollamaGUI) - [OllamaGUI](https://github.com/enoch1118/ollamaGUI)
- [OpenAOE](https://github.com/InternLM/OpenAOE) - [OpenAOE](https://github.com/InternLM/OpenAOE)
- [Odin Runes](https://github.com/leonid20000/OdinRunes) - [Odin Runes](https://github.com/leonid20000/OdinRunes)
- [LLM-X: Progressive Web App](https://github.com/mrdjohnson/llm-x) - [LLM-X](https://github.com/mrdjohnson/llm-x) (Progressive Web App)
- [AnythingLLM (Docker + MacOs/Windows/Linux native app)](https://github.com/Mintplex-Labs/anything-llm) - [AnythingLLM (Docker + MacOs/Windows/Linux native app)](https://github.com/Mintplex-Labs/anything-llm)
- [Ollama Basic Chat: Uses HyperDiv Reactive UI](https://github.com/rapidarchitect/ollama_basic_chat) - [Ollama Basic Chat: Uses HyperDiv Reactive UI](https://github.com/rapidarchitect/ollama_basic_chat)
- [Ollama-chats RPG](https://github.com/drazdra/ollama-chats) - [Ollama-chats RPG](https://github.com/drazdra/ollama-chats)
- [ChatOllama: Open Source Chatbot based on Ollama with Knowledge Bases](https://github.com/sugarforever/chat-ollama) - [QA-Pilot](https://github.com/reid41/QA-Pilot) (Chat with Code Repository)
- [CRAG Ollama Chat: Simple Web Search with Corrective RAG](https://github.com/Nagi-ovo/CRAG-Ollama-Chat) - [ChatOllama](https://github.com/sugarforever/chat-ollama) (Open Source Chatbot based on Ollama with Knowledge Bases)
- [RAGFlow: Open-source Retrieval-Augmented Generation engine based on deep document understanding](https://github.com/infiniflow/ragflow) - [CRAG Ollama Chat](https://github.com/Nagi-ovo/CRAG-Ollama-Chat) (Simple Web Search with Corrective RAG)
- [RAGFlow](https://github.com/infiniflow/ragflow) (Open-source Retrieval-Augmented Generation engine based on deep document understanding)
- [StreamDeploy](https://github.com/StreamDeploy-DevRel/streamdeploy-llm-app-scaffold) (LLM Application Scaffold)
- [chat](https://github.com/swuecho/chat) (chat web app for teams)
- [Lobe Chat](https://github.com/lobehub/lobe-chat) with [Integrating Doc](https://lobehub.com/docs/self-hosting/examples/ollama)
- [Ollama RAG Chatbot](https://github.com/datvodinh/rag-chatbot.git) (Local Chat with multiple PDFs using Ollama and RAG)
- [BrainSoup](https://www.nurgo-software.com/products/brainsoup) (Flexible native client with RAG & multi-agent automation)
- [macai](https://github.com/Renset/macai) (macOS client for Ollama, ChatGPT, and other compatible API back-ends)
### Terminal ### Terminal
@ -308,11 +314,13 @@ See the [API documentation](./docs/api.md) for all endpoints.
- [Oatmeal](https://github.com/dustinblackman/oatmeal) - [Oatmeal](https://github.com/dustinblackman/oatmeal)
- [cmdh](https://github.com/pgibler/cmdh) - [cmdh](https://github.com/pgibler/cmdh)
- [ooo](https://github.com/npahlfer/ooo) - [ooo](https://github.com/npahlfer/ooo)
- [shell-pilot](https://github.com/reid41/shell-pilot)
- [tenere](https://github.com/pythops/tenere) - [tenere](https://github.com/pythops/tenere)
- [llm-ollama](https://github.com/taketwo/llm-ollama) for [Datasette's LLM CLI](https://llm.datasette.io/en/stable/). - [llm-ollama](https://github.com/taketwo/llm-ollama) for [Datasette's LLM CLI](https://llm.datasette.io/en/stable/).
- [typechat-cli](https://github.com/anaisbetts/typechat-cli) - [typechat-cli](https://github.com/anaisbetts/typechat-cli)
- [ShellOracle](https://github.com/djcopley/ShellOracle) - [ShellOracle](https://github.com/djcopley/ShellOracle)
- [tlm](https://github.com/yusufcanb/tlm) - [tlm](https://github.com/yusufcanb/tlm)
- [podman-ollama](https://github.com/ericcurtin/podman-ollama)
### Database ### Database
@ -344,9 +352,11 @@ See the [API documentation](./docs/api.md) for all endpoints.
- [Haystack](https://github.com/deepset-ai/haystack-integrations/blob/main/integrations/ollama.md) - [Haystack](https://github.com/deepset-ai/haystack-integrations/blob/main/integrations/ollama.md)
- [Elixir LangChain](https://github.com/brainlid/langchain) - [Elixir LangChain](https://github.com/brainlid/langchain)
- [Ollama for R - rollama](https://github.com/JBGruber/rollama) - [Ollama for R - rollama](https://github.com/JBGruber/rollama)
- [Ollama for R - ollama-r](https://github.com/hauselin/ollama-r)
- [Ollama-ex for Elixir](https://github.com/lebrunel/ollama-ex) - [Ollama-ex for Elixir](https://github.com/lebrunel/ollama-ex)
- [Ollama Connector for SAP ABAP](https://github.com/b-tocs/abap_btocs_ollama) - [Ollama Connector for SAP ABAP](https://github.com/b-tocs/abap_btocs_ollama)
- [Testcontainers](https://testcontainers.com/modules/ollama/) - [Testcontainers](https://testcontainers.com/modules/ollama/)
- [Portkey](https://portkey.ai/docs/welcome/integration-guides/ollama)
### Mobile ### Mobile
@ -366,17 +376,20 @@ See the [API documentation](./docs/api.md) for all endpoints.
- [Ollama Telegram Bot](https://github.com/ruecat/ollama-telegram) - [Ollama Telegram Bot](https://github.com/ruecat/ollama-telegram)
- [Hass Ollama Conversation](https://github.com/ej52/hass-ollama-conversation) - [Hass Ollama Conversation](https://github.com/ej52/hass-ollama-conversation)
- [Rivet plugin](https://github.com/abrenneke/rivet-plugin-ollama) - [Rivet plugin](https://github.com/abrenneke/rivet-plugin-ollama)
- [Llama Coder](https://github.com/ex3ndr/llama-coder) (Copilot alternative using Ollama)
- [Obsidian BMO Chatbot plugin](https://github.com/longy2k/obsidian-bmo-chatbot) - [Obsidian BMO Chatbot plugin](https://github.com/longy2k/obsidian-bmo-chatbot)
- [Cliobot](https://github.com/herval/cliobot) (Telegram bot with Ollama support) - [Cliobot](https://github.com/herval/cliobot) (Telegram bot with Ollama support)
- [Copilot for Obsidian plugin](https://github.com/logancyang/obsidian-copilot) - [Copilot for Obsidian plugin](https://github.com/logancyang/obsidian-copilot)
- [Obsidian Local GPT plugin](https://github.com/pfrankov/obsidian-local-gpt) - [Obsidian Local GPT plugin](https://github.com/pfrankov/obsidian-local-gpt)
- [Open Interpreter](https://docs.openinterpreter.com/language-model-setup/local-models/ollama) - [Open Interpreter](https://docs.openinterpreter.com/language-model-setup/local-models/ollama)
- [Llama Coder](https://github.com/ex3ndr/llama-coder) (Copilot alternative using Ollama)
- [Ollama Copilot](https://github.com/bernardo-bruning/ollama-copilot) (Proxy that allows you to use ollama as a copilot like Github copilot)
- [twinny](https://github.com/rjmacarthy/twinny) (Copilot and Copilot chat alternative using Ollama) - [twinny](https://github.com/rjmacarthy/twinny) (Copilot and Copilot chat alternative using Ollama)
- [Wingman-AI](https://github.com/RussellCanfield/wingman-ai) (Copilot code and chat alternative using Ollama and HuggingFace) - [Wingman-AI](https://github.com/RussellCanfield/wingman-ai) (Copilot code and chat alternative using Ollama and HuggingFace)
- [Page Assist](https://github.com/n4ze3m/page-assist) (Chrome Extension) - [Page Assist](https://github.com/n4ze3m/page-assist) (Chrome Extension)
- [AI Telegram Bot](https://github.com/tusharhero/aitelegrambot) (Telegram bot using Ollama in backend) - [AI Telegram Bot](https://github.com/tusharhero/aitelegrambot) (Telegram bot using Ollama in backend)
- [AI ST Completion](https://github.com/yaroslavyaroslav/OpenAI-sublime-text) (Sublime Text 4 AI assistant plugin with Ollama support) - [AI ST Completion](https://github.com/yaroslavyaroslav/OpenAI-sublime-text) (Sublime Text 4 AI assistant plugin with Ollama support)
- [Discord-Ollama Chat Bot](https://github.com/kevinthedang/discord-ollama) (Generalized TypeScript Discord Bot w/ Tuning Documentation)
### Supported backends ### Supported backends
- [llama.cpp](https://github.com/ggerganov/llama.cpp) project founded by Georgi Gerganov. - [llama.cpp](https://github.com/ggerganov/llama.cpp) project founded by Georgi Gerganov.

View file

@ -1,9 +1,16 @@
// Package api implements the client-side API for code wishing to interact // Package api implements the client-side API for code wishing to interact
// with the ollama service. The methods of the [Client] type correspond to // with the ollama service. The methods of the [Client] type correspond to
// the ollama REST API as described in https://github.com/ollama/ollama/blob/main/docs/api.md // the ollama REST API as described in [the API documentation].
//
// The ollama command-line client itself uses this package to interact with // The ollama command-line client itself uses this package to interact with
// the backend service. // the backend service.
//
// # Examples
//
// Several examples of using this package are available [in the GitHub
// repository].
//
// [the API documentation]: https://github.com/ollama/ollama/blob/main/docs/api.md
// [in the GitHub repository]: https://github.com/ollama/ollama/tree/main/examples
package api package api
import ( import (
@ -18,6 +25,7 @@ import (
"net/url" "net/url"
"os" "os"
"runtime" "runtime"
"strconv"
"strings" "strings"
"github.com/ollama/ollama/format" "github.com/ollama/ollama/format"
@ -57,12 +65,36 @@ func checkError(resp *http.Response, body []byte) error {
// If the variable is not specified, a default ollama host and port will be // If the variable is not specified, a default ollama host and port will be
// used. // used.
func ClientFromEnvironment() (*Client, error) { func ClientFromEnvironment() (*Client, error) {
ollamaHost, err := GetOllamaHost()
if err != nil {
return nil, err
}
return &Client{
base: &url.URL{
Scheme: ollamaHost.Scheme,
Host: net.JoinHostPort(ollamaHost.Host, ollamaHost.Port),
},
http: http.DefaultClient,
}, nil
}
type OllamaHost struct {
Scheme string
Host string
Port string
}
func GetOllamaHost() (OllamaHost, error) {
defaultPort := "11434" defaultPort := "11434"
scheme, hostport, ok := strings.Cut(os.Getenv("OLLAMA_HOST"), "://") hostVar := os.Getenv("OLLAMA_HOST")
hostVar = strings.TrimSpace(strings.Trim(strings.TrimSpace(hostVar), "\"'"))
scheme, hostport, ok := strings.Cut(hostVar, "://")
switch { switch {
case !ok: case !ok:
scheme, hostport = "http", os.Getenv("OLLAMA_HOST") scheme, hostport = "http", hostVar
case scheme == "http": case scheme == "http":
defaultPort = "80" defaultPort = "80"
case scheme == "https": case scheme == "https":
@ -82,15 +114,24 @@ func ClientFromEnvironment() (*Client, error) {
} }
} }
return &Client{ if portNum, err := strconv.ParseInt(port, 10, 32); err != nil || portNum > 65535 || portNum < 0 {
base: &url.URL{ return OllamaHost{}, ErrInvalidHostPort
}
return OllamaHost{
Scheme: scheme, Scheme: scheme,
Host: net.JoinHostPort(host, port), Host: host,
}, Port: port,
http: http.DefaultClient,
}, nil }, nil
} }
func NewClient(base *url.URL, http *http.Client) *Client {
return &Client{
base: base,
http: http,
}
}
func (c *Client) do(ctx context.Context, method, path string, reqData, respData any) error { func (c *Client) do(ctx context.Context, method, path string, reqData, respData any) error {
var reqBody io.Reader var reqBody io.Reader
var data []byte var data []byte
@ -265,8 +306,14 @@ func (c *Client) Pull(ctx context.Context, req *PullRequest, fn PullProgressFunc
}) })
} }
// PushProgressFunc is a function that [Client.Push] invokes when progress is
// made.
// It's similar to other progress function types like [PullProgressFunc].
type PushProgressFunc func(ProgressResponse) error type PushProgressFunc func(ProgressResponse) error
// Push uploads a model to the model library; requires registering for ollama.ai
// and adding a public key first. fn is called each time progress is made on
// the request and can be used to display a progress bar, etc.
func (c *Client) Push(ctx context.Context, req *PushRequest, fn PushProgressFunc) error { func (c *Client) Push(ctx context.Context, req *PushRequest, fn PushProgressFunc) error {
return c.stream(ctx, http.MethodPost, "/api/push", req, func(bts []byte) error { return c.stream(ctx, http.MethodPost, "/api/push", req, func(bts []byte) error {
var resp ProgressResponse var resp ProgressResponse
@ -278,8 +325,15 @@ func (c *Client) Push(ctx context.Context, req *PushRequest, fn PushProgressFunc
}) })
} }
// CreateProgressFunc is a function that [Client.Create] invokes when progress
// is made.
// It's similar to other progress function types like [PullProgressFunc].
type CreateProgressFunc func(ProgressResponse) error type CreateProgressFunc func(ProgressResponse) error
// Create creates a model from a [Modelfile]. fn is a progress function that
// behaves similarly to other methods (see [Client.Pull]).
//
// [Modelfile]: https://github.com/ollama/ollama/blob/main/docs/modelfile.md
func (c *Client) Create(ctx context.Context, req *CreateRequest, fn CreateProgressFunc) error { func (c *Client) Create(ctx context.Context, req *CreateRequest, fn CreateProgressFunc) error {
return c.stream(ctx, http.MethodPost, "/api/create", req, func(bts []byte) error { return c.stream(ctx, http.MethodPost, "/api/create", req, func(bts []byte) error {
var resp ProgressResponse var resp ProgressResponse
@ -291,6 +345,7 @@ func (c *Client) Create(ctx context.Context, req *CreateRequest, fn CreateProgre
}) })
} }
// List lists models that are available locally.
func (c *Client) List(ctx context.Context) (*ListResponse, error) { func (c *Client) List(ctx context.Context) (*ListResponse, error) {
var lr ListResponse var lr ListResponse
if err := c.do(ctx, http.MethodGet, "/api/tags", nil, &lr); err != nil { if err := c.do(ctx, http.MethodGet, "/api/tags", nil, &lr); err != nil {
@ -299,6 +354,8 @@ func (c *Client) List(ctx context.Context) (*ListResponse, error) {
return &lr, nil return &lr, nil
} }
// Copy copies a model - creating a model with another name from an existing
// model.
func (c *Client) Copy(ctx context.Context, req *CopyRequest) error { func (c *Client) Copy(ctx context.Context, req *CopyRequest) error {
if err := c.do(ctx, http.MethodPost, "/api/copy", req, nil); err != nil { if err := c.do(ctx, http.MethodPost, "/api/copy", req, nil); err != nil {
return err return err
@ -306,6 +363,7 @@ func (c *Client) Copy(ctx context.Context, req *CopyRequest) error {
return nil return nil
} }
// Delete deletes a model and its data.
func (c *Client) Delete(ctx context.Context, req *DeleteRequest) error { func (c *Client) Delete(ctx context.Context, req *DeleteRequest) error {
if err := c.do(ctx, http.MethodDelete, "/api/delete", req, nil); err != nil { if err := c.do(ctx, http.MethodDelete, "/api/delete", req, nil); err != nil {
return err return err
@ -313,6 +371,7 @@ func (c *Client) Delete(ctx context.Context, req *DeleteRequest) error {
return nil return nil
} }
// Show obtains model information, including details, modelfile, license etc.
func (c *Client) Show(ctx context.Context, req *ShowRequest) (*ShowResponse, error) { func (c *Client) Show(ctx context.Context, req *ShowRequest) (*ShowResponse, error) {
var resp ShowResponse var resp ShowResponse
if err := c.do(ctx, http.MethodPost, "/api/show", req, &resp); err != nil { if err := c.do(ctx, http.MethodPost, "/api/show", req, &resp); err != nil {
@ -321,12 +380,16 @@ func (c *Client) Show(ctx context.Context, req *ShowRequest) (*ShowResponse, err
return &resp, nil return &resp, nil
} }
// Hearbeat checks if the server has started and is responsive; if yes, it
// returns nil, otherwise an error.
func (c *Client) Heartbeat(ctx context.Context) error { func (c *Client) Heartbeat(ctx context.Context) error {
if err := c.do(ctx, http.MethodHead, "/", nil, nil); err != nil { if err := c.do(ctx, http.MethodHead, "/", nil, nil); err != nil {
return err return err
} }
return nil return nil
} }
// Embeddings generates embeddings from a model.
func (c *Client) Embeddings(ctx context.Context, req *EmbeddingRequest) (*EmbeddingResponse, error) { func (c *Client) Embeddings(ctx context.Context, req *EmbeddingRequest) (*EmbeddingResponse, error) {
var resp EmbeddingResponse var resp EmbeddingResponse
if err := c.do(ctx, http.MethodPost, "/api/embeddings", req, &resp); err != nil { if err := c.do(ctx, http.MethodPost, "/api/embeddings", req, &resp); err != nil {
@ -335,10 +398,13 @@ func (c *Client) Embeddings(ctx context.Context, req *EmbeddingRequest) (*Embedd
return &resp, nil return &resp, nil
} }
// CreateBlob creates a blob from a file on the server. digest is the
// expected SHA256 digest of the file, and r represents the file.
func (c *Client) CreateBlob(ctx context.Context, digest string, r io.Reader) error { func (c *Client) CreateBlob(ctx context.Context, digest string, r io.Reader) error {
return c.do(ctx, http.MethodPost, fmt.Sprintf("/api/blobs/%s", digest), r, nil) return c.do(ctx, http.MethodPost, fmt.Sprintf("/api/blobs/%s", digest), r, nil)
} }
// Version returns the Ollama server version as a string.
func (c *Client) Version(ctx context.Context) (string, error) { func (c *Client) Version(ctx context.Context) (string, error) {
var version struct { var version struct {
Version string `json:"version"` Version string `json:"version"`

View file

@ -1,6 +1,12 @@
package api package api
import "testing" import (
"fmt"
"net"
"testing"
"github.com/stretchr/testify/assert"
)
func TestClientFromEnvironment(t *testing.T) { func TestClientFromEnvironment(t *testing.T) {
type testCase struct { type testCase struct {
@ -40,4 +46,40 @@ func TestClientFromEnvironment(t *testing.T) {
} }
}) })
} }
hostTestCases := map[string]*testCase{
"empty": {value: "", expect: "127.0.0.1:11434"},
"only address": {value: "1.2.3.4", expect: "1.2.3.4:11434"},
"only port": {value: ":1234", expect: ":1234"},
"address and port": {value: "1.2.3.4:1234", expect: "1.2.3.4:1234"},
"hostname": {value: "example.com", expect: "example.com:11434"},
"hostname and port": {value: "example.com:1234", expect: "example.com:1234"},
"zero port": {value: ":0", expect: ":0"},
"too large port": {value: ":66000", err: ErrInvalidHostPort},
"too small port": {value: ":-1", err: ErrInvalidHostPort},
"ipv6 localhost": {value: "[::1]", expect: "[::1]:11434"},
"ipv6 world open": {value: "[::]", expect: "[::]:11434"},
"ipv6 no brackets": {value: "::1", expect: "[::1]:11434"},
"ipv6 + port": {value: "[::1]:1337", expect: "[::1]:1337"},
"extra space": {value: " 1.2.3.4 ", expect: "1.2.3.4:11434"},
"extra quotes": {value: "\"1.2.3.4\"", expect: "1.2.3.4:11434"},
"extra space+quotes": {value: " \" 1.2.3.4 \" ", expect: "1.2.3.4:11434"},
"extra single quotes": {value: "'1.2.3.4'", expect: "1.2.3.4:11434"},
}
for k, v := range hostTestCases {
t.Run(k, func(t *testing.T) {
t.Setenv("OLLAMA_HOST", v.value)
oh, err := GetOllamaHost()
if err != v.err {
t.Fatalf("expected %s, got %s", v.err, err)
}
if err == nil {
host := net.JoinHostPort(oh.Host, oh.Port)
assert.Equal(t, v.expect, host, fmt.Sprintf("%s: expected %s, got %s", k, v.expect, host))
}
})
}
} }

View file

@ -2,6 +2,7 @@ package api
import ( import (
"encoding/json" "encoding/json"
"errors"
"fmt" "fmt"
"math" "math"
"os" "os"
@ -11,6 +12,7 @@ import (
"time" "time"
) )
// StatusError is an error with and HTTP status code.
type StatusError struct { type StatusError struct {
StatusCode int StatusCode int
Status string Status string
@ -31,6 +33,7 @@ func (e StatusError) Error() string {
} }
} }
// ImageData represents the raw binary data of an image file.
type ImageData []byte type ImageData []byte
// GenerateRequest describes a request sent by [Client.Generate]. While you // GenerateRequest describes a request sent by [Client.Generate]. While you
@ -76,22 +79,39 @@ type GenerateRequest struct {
Options map[string]interface{} `json:"options"` Options map[string]interface{} `json:"options"`
} }
// ChatRequest describes a request sent by [Client.Chat].
type ChatRequest struct { type ChatRequest struct {
// Model is the model name, as in [GenerateRequest].
Model string `json:"model"` Model string `json:"model"`
// Messages is the messages of the chat - can be used to keep a chat memory.
Messages []Message `json:"messages"` Messages []Message `json:"messages"`
// Stream enable streaming of returned response; true by default.
Stream *bool `json:"stream,omitempty"` Stream *bool `json:"stream,omitempty"`
// Format is the format to return the response in (e.g. "json").
Format string `json:"format"` Format string `json:"format"`
// KeepAlive controls how long the model will stay loaded into memory
// followin the request.
KeepAlive *Duration `json:"keep_alive,omitempty"` KeepAlive *Duration `json:"keep_alive,omitempty"`
// Options lists model-specific options.
Options map[string]interface{} `json:"options"` Options map[string]interface{} `json:"options"`
} }
// Message is a single message in a chat sequence. The message contains the
// role ("system", "user", or "assistant"), the content and an optional list
// of images.
type Message struct { type Message struct {
Role string `json:"role"` // one of ["system", "user", "assistant"] Role string `json:"role"`
Content string `json:"content"` Content string `json:"content"`
Images []ImageData `json:"images,omitempty"` Images []ImageData `json:"images,omitempty"`
} }
// ChatResponse is the response returned by [Client.Chat]. Its fields are
// similar to [GenerateResponse].
type ChatResponse struct { type ChatResponse struct {
Model string `json:"model"` Model string `json:"model"`
CreatedAt time.Time `json:"created_at"` CreatedAt time.Time `json:"created_at"`
@ -111,7 +131,8 @@ type Metrics struct {
EvalDuration time.Duration `json:"eval_duration,omitempty"` EvalDuration time.Duration `json:"eval_duration,omitempty"`
} }
// Options specified in GenerateRequest, if you add a new option here add it to the API docs also // Options specified in [GenerateRequest], if you add a new option here add it
// to the API docs also.
type Options struct { type Options struct {
Runner Runner
@ -157,18 +178,28 @@ type Runner struct {
RopeFrequencyScale float32 `json:"rope_frequency_scale,omitempty"` RopeFrequencyScale float32 `json:"rope_frequency_scale,omitempty"`
} }
// EmbeddingRequest is the request passed to [Client.Embeddings].
type EmbeddingRequest struct { type EmbeddingRequest struct {
// Model is the model name.
Model string `json:"model"` Model string `json:"model"`
// Prompt is the textual prompt to embed.
Prompt string `json:"prompt"` Prompt string `json:"prompt"`
// KeepAlive controls how long the model will stay loaded in memory following
// this request.
KeepAlive *Duration `json:"keep_alive,omitempty"` KeepAlive *Duration `json:"keep_alive,omitempty"`
// Options lists model-specific options.
Options map[string]interface{} `json:"options"` Options map[string]interface{} `json:"options"`
} }
// EmbeddingResponse is the response from [Client.Embeddings].
type EmbeddingResponse struct { type EmbeddingResponse struct {
Embedding []float64 `json:"embedding"` Embedding []float64 `json:"embedding"`
} }
// CreateRequest is the request passed to [Client.Create].
type CreateRequest struct { type CreateRequest struct {
Model string `json:"model"` Model string `json:"model"`
Path string `json:"path"` Path string `json:"path"`
@ -180,6 +211,7 @@ type CreateRequest struct {
Name string `json:"name"` Name string `json:"name"`
} }
// DeleteRequest is the request passed to [Client.Delete].
type DeleteRequest struct { type DeleteRequest struct {
Model string `json:"model"` Model string `json:"model"`
@ -187,6 +219,7 @@ type DeleteRequest struct {
Name string `json:"name"` Name string `json:"name"`
} }
// ShowRequest is the request passed to [Client.Show].
type ShowRequest struct { type ShowRequest struct {
Model string `json:"model"` Model string `json:"model"`
System string `json:"system"` System string `json:"system"`
@ -198,6 +231,7 @@ type ShowRequest struct {
Name string `json:"name"` Name string `json:"name"`
} }
// ShowResponse is the response returned from [Client.Show].
type ShowResponse struct { type ShowResponse struct {
License string `json:"license,omitempty"` License string `json:"license,omitempty"`
Modelfile string `json:"modelfile,omitempty"` Modelfile string `json:"modelfile,omitempty"`
@ -208,11 +242,13 @@ type ShowResponse struct {
Messages []Message `json:"messages,omitempty"` Messages []Message `json:"messages,omitempty"`
} }
// CopyRequest is the request passed to [Client.Copy].
type CopyRequest struct { type CopyRequest struct {
Source string `json:"source"` Source string `json:"source"`
Destination string `json:"destination"` Destination string `json:"destination"`
} }
// PullRequest is the request passed to [Client.Pull].
type PullRequest struct { type PullRequest struct {
Model string `json:"model"` Model string `json:"model"`
Insecure bool `json:"insecure,omitempty"` Insecure bool `json:"insecure,omitempty"`
@ -224,6 +260,8 @@ type PullRequest struct {
Name string `json:"name"` Name string `json:"name"`
} }
// ProgressResponse is the response passed to progress functions like
// [PullProgressFunc] and [PushProgressFunc].
type ProgressResponse struct { type ProgressResponse struct {
Status string `json:"status"` Status string `json:"status"`
Digest string `json:"digest,omitempty"` Digest string `json:"digest,omitempty"`
@ -231,6 +269,7 @@ type ProgressResponse struct {
Completed int64 `json:"completed,omitempty"` Completed int64 `json:"completed,omitempty"`
} }
// PushRequest is the request passed to [Client.Push].
type PushRequest struct { type PushRequest struct {
Model string `json:"model"` Model string `json:"model"`
Insecure bool `json:"insecure,omitempty"` Insecure bool `json:"insecure,omitempty"`
@ -242,10 +281,12 @@ type PushRequest struct {
Name string `json:"name"` Name string `json:"name"`
} }
// ListResponse is the response from [Client.List].
type ListResponse struct { type ListResponse struct {
Models []ModelResponse `json:"models"` Models []ModelResponse `json:"models"`
} }
// ModelResponse is a single model description in [ListResponse].
type ModelResponse struct { type ModelResponse struct {
Name string `json:"name"` Name string `json:"name"`
Model string `json:"model"` Model string `json:"model"`
@ -259,17 +300,28 @@ type TokenResponse struct {
Token string `json:"token"` Token string `json:"token"`
} }
// GenerateResponse is the response passed into [GenerateResponseFunc].
type GenerateResponse struct { type GenerateResponse struct {
// Model is the model name that generated the response.
Model string `json:"model"` Model string `json:"model"`
//CreatedAt is the timestamp of the response.
CreatedAt time.Time `json:"created_at"` CreatedAt time.Time `json:"created_at"`
// Response is the textual response itself.
Response string `json:"response"` Response string `json:"response"`
// Done specifies if the response is complete.
Done bool `json:"done"` Done bool `json:"done"`
// Context is an encoding of the conversation used in this response; this
// can be sent in the next request to keep a conversational memory.
Context []int `json:"context,omitempty"` Context []int `json:"context,omitempty"`
Metrics Metrics
} }
// ModelDetails provides details about a model.
type ModelDetails struct { type ModelDetails struct {
ParentModel string `json:"parent_model"` ParentModel string `json:"parent_model"`
Format string `json:"format"` Format string `json:"format"`
@ -307,7 +359,9 @@ func (m *Metrics) Summary() {
} }
} }
var ErrInvalidOpts = fmt.Errorf("invalid options") // ErrInvalidOpts is returned when invalid options are passed to the client.
var ErrInvalidOpts = errors.New("invalid options")
var ErrInvalidHostPort = errors.New("invalid port specified in OLLAMA_HOST")
func (opts *Options) FromMap(m map[string]interface{}) error { func (opts *Options) FromMap(m map[string]interface{}) error {
valueOpts := reflect.ValueOf(opts).Elem() // names of the fields in the options struct valueOpts := reflect.ValueOf(opts).Elem() // names of the fields in the options struct
@ -392,11 +446,15 @@ func (opts *Options) FromMap(m map[string]interface{}) error {
return nil return nil
} }
// DefaultOptions is the default set of options for [GenerateRequest]; these
// values are used unless the user specifies other values explicitly.
func DefaultOptions() Options { func DefaultOptions() Options {
return Options{ return Options{
// options set on request to runner // options set on request to runner
NumPredict: -1, NumPredict: -1,
NumKeep: 0,
// set a minimal num_keep to avoid issues on context shifts
NumKeep: 4,
Temperature: 0.8, Temperature: 0.8,
TopK: 40, TopK: 40,
TopP: 0.9, TopP: 0.9,
@ -432,6 +490,13 @@ type Duration struct {
time.Duration time.Duration
} }
func (d Duration) MarshalJSON() ([]byte, error) {
if d.Duration < 0 {
return []byte("-1"), nil
}
return []byte("\"" + d.Duration.String() + "\""), nil
}
func (d *Duration) UnmarshalJSON(b []byte) (err error) { func (d *Duration) UnmarshalJSON(b []byte) (err error) {
var v any var v any
if err := json.Unmarshal(b, &v); err != nil { if err := json.Unmarshal(b, &v); err != nil {
@ -445,7 +510,7 @@ func (d *Duration) UnmarshalJSON(b []byte) (err error) {
if t < 0 { if t < 0 {
d.Duration = time.Duration(math.MaxInt64) d.Duration = time.Duration(math.MaxInt64)
} else { } else {
d.Duration = time.Duration(t * float64(time.Second)) d.Duration = time.Duration(int(t) * int(time.Second))
} }
case string: case string:
d.Duration, err = time.ParseDuration(t) d.Duration, err = time.ParseDuration(t)
@ -455,6 +520,8 @@ func (d *Duration) UnmarshalJSON(b []byte) (err error) {
if d.Duration < 0 { if d.Duration < 0 {
d.Duration = time.Duration(math.MaxInt64) d.Duration = time.Duration(math.MaxInt64)
} }
default:
return fmt.Errorf("Unsupported type: '%s'", reflect.TypeOf(v))
} }
return nil return nil

View file

@ -21,6 +21,11 @@ func TestKeepAliveParsingFromJSON(t *testing.T) {
req: `{ "keep_alive": 42 }`, req: `{ "keep_alive": 42 }`,
exp: &Duration{42 * time.Second}, exp: &Duration{42 * time.Second},
}, },
{
name: "Positive Float",
req: `{ "keep_alive": 42.5 }`,
exp: &Duration{42 * time.Second},
},
{ {
name: "Positive Integer String", name: "Positive Integer String",
req: `{ "keep_alive": "42m" }`, req: `{ "keep_alive": "42m" }`,
@ -31,6 +36,11 @@ func TestKeepAliveParsingFromJSON(t *testing.T) {
req: `{ "keep_alive": -1 }`, req: `{ "keep_alive": -1 }`,
exp: &Duration{math.MaxInt64}, exp: &Duration{math.MaxInt64},
}, },
{
name: "Negative Float",
req: `{ "keep_alive": -3.14 }`,
exp: &Duration{math.MaxInt64},
},
{ {
name: "Negative Integer String", name: "Negative Integer String",
req: `{ "keep_alive": "-1m" }`, req: `{ "keep_alive": "-1m" }`,
@ -48,3 +58,50 @@ func TestKeepAliveParsingFromJSON(t *testing.T) {
}) })
} }
} }
func TestDurationMarshalUnmarshal(t *testing.T) {
tests := []struct {
name string
input time.Duration
expected time.Duration
}{
{
"negative duration",
time.Duration(-1),
time.Duration(math.MaxInt64),
},
{
"positive duration",
time.Duration(42 * time.Second),
time.Duration(42 * time.Second),
},
{
"another positive duration",
time.Duration(42 * time.Minute),
time.Duration(42 * time.Minute),
},
{
"zero duration",
time.Duration(0),
time.Duration(0),
},
{
"max duration",
time.Duration(math.MaxInt64),
time.Duration(math.MaxInt64),
},
}
for _, test := range tests {
t.Run(test.name, func(t *testing.T) {
b, err := json.Marshal(Duration{test.input})
require.NoError(t, err)
var d Duration
err = json.Unmarshal(b, &d)
require.NoError(t, err)
assert.Equal(t, test.expected, d.Duration, "input %v, marshalled %v, got %v", test.input, string(b), d.Duration)
})
}
}

View file

@ -5,12 +5,14 @@ import (
"log/slog" "log/slog"
"os" "os"
"path/filepath" "path/filepath"
"github.com/ollama/ollama/server/envconfig"
) )
func InitLogging() { func InitLogging() {
level := slog.LevelInfo level := slog.LevelInfo
if debug := os.Getenv("OLLAMA_DEBUG"); debug != "" { if envconfig.Debug {
level = slog.LevelDebug level = slog.LevelDebug
} }

View file

@ -43,37 +43,36 @@ func getCLIFullPath(command string) string {
return command return command
} }
func SpawnServer(ctx context.Context, command string) (chan int, error) { func start(ctx context.Context, command string) (*exec.Cmd, error) {
done := make(chan int)
logDir := filepath.Dir(ServerLogFile)
_, err := os.Stat(logDir)
if errors.Is(err, os.ErrNotExist) {
if err := os.MkdirAll(logDir, 0o755); err != nil {
return done, fmt.Errorf("create ollama server log dir %s: %v", logDir, err)
}
}
cmd := getCmd(ctx, getCLIFullPath(command)) cmd := getCmd(ctx, getCLIFullPath(command))
// send stdout and stderr to a file
stdout, err := cmd.StdoutPipe() stdout, err := cmd.StdoutPipe()
if err != nil { if err != nil {
return done, fmt.Errorf("failed to spawn server stdout pipe %s", err) return nil, fmt.Errorf("failed to spawn server stdout pipe: %w", err)
} }
stderr, err := cmd.StderrPipe() stderr, err := cmd.StderrPipe()
if err != nil { if err != nil {
return done, fmt.Errorf("failed to spawn server stderr pipe %s", err) return nil, fmt.Errorf("failed to spawn server stderr pipe: %w", err)
}
stdin, err := cmd.StdinPipe()
if err != nil {
return done, fmt.Errorf("failed to spawn server stdin pipe %s", err)
} }
// TODO - rotation // TODO - rotation
logFile, err := os.OpenFile(ServerLogFile, os.O_APPEND|os.O_WRONLY|os.O_CREATE, 0755) logFile, err := os.OpenFile(ServerLogFile, os.O_APPEND|os.O_WRONLY|os.O_CREATE, 0755)
if err != nil { if err != nil {
return done, fmt.Errorf("failed to create server log %w", err) return nil, fmt.Errorf("failed to create server log: %w", err)
} }
logDir := filepath.Dir(ServerLogFile)
_, err = os.Stat(logDir)
if err != nil {
if !errors.Is(err, os.ErrNotExist) {
return nil, fmt.Errorf("stat ollama server log dir %s: %v", logDir, err)
}
if err := os.MkdirAll(logDir, 0o755); err != nil {
return nil, fmt.Errorf("create ollama server log dir %s: %v", logDir, err)
}
}
go func() { go func() {
defer logFile.Close() defer logFile.Close()
io.Copy(logFile, stdout) //nolint:errcheck io.Copy(logFile, stdout) //nolint:errcheck
@ -117,19 +116,33 @@ func SpawnServer(ctx context.Context, command string) (chan int, error) {
// run the command and wait for it to finish // run the command and wait for it to finish
if err := cmd.Start(); err != nil { if err := cmd.Start(); err != nil {
return done, fmt.Errorf("failed to start server %w", err) return nil, fmt.Errorf("failed to start server %w", err)
} }
if cmd.Process != nil { if cmd.Process != nil {
slog.Info(fmt.Sprintf("started ollama server with pid %d", cmd.Process.Pid)) slog.Info(fmt.Sprintf("started ollama server with pid %d", cmd.Process.Pid))
} }
slog.Info(fmt.Sprintf("ollama server logs %s", ServerLogFile)) slog.Info(fmt.Sprintf("ollama server logs %s", ServerLogFile))
return cmd, nil
}
func SpawnServer(ctx context.Context, command string) (chan int, error) {
done := make(chan int)
go func() { go func() {
// Keep the server running unless we're shuttind down the app // Keep the server running unless we're shuttind down the app
crashCount := 0 crashCount := 0
for { for {
slog.Info("starting server...")
cmd, err := start(ctx, command)
if err != nil {
crashCount++
slog.Error(fmt.Sprintf("failed to start server %s", err))
time.Sleep(500 * time.Millisecond * time.Duration(crashCount))
continue
}
cmd.Wait() //nolint:errcheck cmd.Wait() //nolint:errcheck
stdin.Close()
var code int var code int
if cmd.ProcessState != nil { if cmd.ProcessState != nil {
code = cmd.ProcessState.ExitCode() code = cmd.ProcessState.ExitCode()
@ -143,15 +156,12 @@ func SpawnServer(ctx context.Context, command string) (chan int, error) {
default: default:
crashCount++ crashCount++
slog.Warn(fmt.Sprintf("server crash %d - exit code %d - respawning", crashCount, code)) slog.Warn(fmt.Sprintf("server crash %d - exit code %d - respawning", crashCount, code))
time.Sleep(500 * time.Millisecond) time.Sleep(500 * time.Millisecond * time.Duration(crashCount))
if err := cmd.Start(); err != nil { break
slog.Error(fmt.Sprintf("failed to restart server %s", err))
// Keep trying, but back off if we keep failing
time.Sleep(time.Duration(crashCount) * time.Second)
}
} }
} }
}() }()
return done, nil return done, nil
} }

View file

@ -31,16 +31,13 @@ func DoUpgrade(cancel context.CancelFunc, done chan int) error {
"/LOG=" + filepath.Base(UpgradeLogFile), // Only relative seems reliable, so set pwd "/LOG=" + filepath.Base(UpgradeLogFile), // Only relative seems reliable, so set pwd
"/FORCECLOSEAPPLICATIONS", // Force close the tray app - might be needed "/FORCECLOSEAPPLICATIONS", // Force close the tray app - might be needed
} }
// When we're not in debug mode, make the upgrade as quiet as possible (no GUI, no prompts) // make the upgrade as quiet as possible (no GUI, no prompts)
// TODO - temporarily disable since we're pinning in debug mode for the preview
// if debug := os.Getenv("OLLAMA_DEBUG"); debug == "" {
installArgs = append(installArgs, installArgs = append(installArgs,
"/SP", // Skip the "This will install... Do you wish to continue" prompt "/SP", // Skip the "This will install... Do you wish to continue" prompt
"/SUPPRESSMSGBOXES", "/SUPPRESSMSGBOXES",
"/SILENT", "/SILENT",
"/VERYSILENT", "/VERYSILENT",
) )
// }
// Safeguard in case we have requests in flight that need to drain... // Safeguard in case we have requests in flight that need to drain...
slog.Info("Waiting for server to shutdown") slog.Info("Waiting for server to shutdown")

View file

@ -88,15 +88,12 @@ DialogFontSize=12
[Files] [Files]
Source: ".\app.exe"; DestDir: "{app}"; DestName: "{#MyAppExeName}" ; Flags: ignoreversion 64bit Source: ".\app.exe"; DestDir: "{app}"; DestName: "{#MyAppExeName}" ; Flags: ignoreversion 64bit
Source: "..\ollama.exe"; DestDir: "{app}"; Flags: ignoreversion 64bit Source: "..\ollama.exe"; DestDir: "{app}"; Flags: ignoreversion 64bit
Source: "..\dist\windeps\*.dll"; DestDir: "{app}"; Flags: ignoreversion 64bit Source: "..\dist\windows-{#ARCH}\*.dll"; DestDir: "{app}"; Flags: ignoreversion 64bit
Source: "..\dist\windows-{#ARCH}\ollama_runners\*"; DestDir: "{app}\ollama_runners"; Flags: ignoreversion 64bit recursesubdirs
Source: "..\dist\ollama_welcome.ps1"; DestDir: "{app}"; Flags: ignoreversion Source: "..\dist\ollama_welcome.ps1"; DestDir: "{app}"; Flags: ignoreversion
Source: ".\assets\app.ico"; DestDir: "{app}"; Flags: ignoreversion Source: ".\assets\app.ico"; DestDir: "{app}"; Flags: ignoreversion
; Assumes v5.7, may need adjustments for v6 #if DirExists("..\dist\windows-amd64\rocm")
#if GetEnv("HIP_PATH") != "" Source: "..\dist\windows-amd64\rocm\*"; DestDir: "{app}\rocm\"; Flags: ignoreversion recursesubdirs
Source: "{#GetEnv('HIP_PATH')}\bin\hipblas.dll"; DestDir: "{app}\rocm\"; Flags: ignoreversion
Source: "{#GetEnv('HIP_PATH')}\bin\rocblas.dll"; DestDir: "{app}\rocm\"; Flags: ignoreversion
; amdhip64.dll dependency comes from the driver and must be installed already
Source: "{#GetEnv('HIP_PATH')}\bin\rocblas\library\*"; DestDir: "{app}\rocm\rocblas\library\"; Flags: ignoreversion
#endif #endif
@ -132,7 +129,7 @@ SetupAppRunningError=Another Ollama installer is running.%n%nPlease cancel or fi
;FinishedHeadingLabel=Run your first model ;FinishedHeadingLabel=Run your first model
;FinishedLabel=%nRun this command in a PowerShell or cmd terminal.%n%n%n ollama run llama2 ;FinishedLabel=%nRun this command in a PowerShell or cmd terminal.%n%n%n ollama run llama3
;ClickFinish=%n ;ClickFinish=%n
[Registry] [Registry]

View file

@ -10,12 +10,44 @@ import (
"log/slog" "log/slog"
"os" "os"
"path/filepath" "path/filepath"
"strings"
"golang.org/x/crypto/ssh" "golang.org/x/crypto/ssh"
) )
const defaultPrivateKey = "id_ed25519" const defaultPrivateKey = "id_ed25519"
func keyPath() (string, error) {
home, err := os.UserHomeDir()
if err != nil {
return "", err
}
return filepath.Join(home, ".ollama", defaultPrivateKey), nil
}
func GetPublicKey() (string, error) {
keyPath, err := keyPath()
if err != nil {
return "", err
}
privateKeyFile, err := os.ReadFile(keyPath)
if err != nil {
slog.Info(fmt.Sprintf("Failed to load private key: %v", err))
return "", err
}
privateKey, err := ssh.ParsePrivateKey(privateKeyFile)
if err != nil {
return "", err
}
publicKey := ssh.MarshalAuthorizedKey(privateKey.PublicKey())
return strings.TrimSpace(string(publicKey)), nil
}
func NewNonce(r io.Reader, length int) (string, error) { func NewNonce(r io.Reader, length int) (string, error) {
nonce := make([]byte, length) nonce := make([]byte, length)
if _, err := io.ReadFull(r, nonce); err != nil { if _, err := io.ReadFull(r, nonce); err != nil {
@ -26,13 +58,11 @@ func NewNonce(r io.Reader, length int) (string, error) {
} }
func Sign(ctx context.Context, bts []byte) (string, error) { func Sign(ctx context.Context, bts []byte) (string, error) {
home, err := os.UserHomeDir() keyPath, err := keyPath()
if err != nil { if err != nil {
return "", err return "", err
} }
keyPath := filepath.Join(home, ".ollama", defaultPrivateKey)
privateKeyFile, err := os.ReadFile(keyPath) privateKeyFile, err := os.ReadFile(keyPath)
if err != nil { if err != nil {
slog.Info(fmt.Sprintf("Failed to load private key: %v", err)) slog.Info(fmt.Sprintf("Failed to load private key: %v", err))

View file

@ -17,6 +17,7 @@ import (
"os" "os"
"os/signal" "os/signal"
"path/filepath" "path/filepath"
"regexp"
"runtime" "runtime"
"strings" "strings"
"syscall" "syscall"
@ -31,10 +32,12 @@ import (
"golang.org/x/term" "golang.org/x/term"
"github.com/ollama/ollama/api" "github.com/ollama/ollama/api"
"github.com/ollama/ollama/auth"
"github.com/ollama/ollama/format" "github.com/ollama/ollama/format"
"github.com/ollama/ollama/parser"
"github.com/ollama/ollama/progress" "github.com/ollama/ollama/progress"
"github.com/ollama/ollama/server" "github.com/ollama/ollama/server"
"github.com/ollama/ollama/types/errtypes"
"github.com/ollama/ollama/types/model"
"github.com/ollama/ollama/version" "github.com/ollama/ollama/version"
) )
@ -53,14 +56,13 @@ func CreateHandler(cmd *cobra.Command, args []string) error {
p := progress.NewProgress(os.Stderr) p := progress.NewProgress(os.Stderr)
defer p.Stop() defer p.Stop()
bars := make(map[string]*progress.Bar) f, err := os.Open(filename)
modelfile, err := os.ReadFile(filename)
if err != nil { if err != nil {
return err return err
} }
defer f.Close()
commands, err := parser.Parse(bytes.NewReader(modelfile)) modelfile, err := model.ParseFile(f)
if err != nil { if err != nil {
return err return err
} }
@ -74,10 +76,10 @@ func CreateHandler(cmd *cobra.Command, args []string) error {
spinner := progress.NewSpinner(status) spinner := progress.NewSpinner(status)
p.Add(status, spinner) p.Add(status, spinner)
for _, c := range commands { for i := range modelfile.Commands {
switch c.Name { switch modelfile.Commands[i].Name {
case "model", "adapter": case "model", "adapter":
path := c.Args path := modelfile.Commands[i].Args
if path == "~" { if path == "~" {
path = home path = home
} else if strings.HasPrefix(path, "~/") { } else if strings.HasPrefix(path, "~/") {
@ -89,101 +91,22 @@ func CreateHandler(cmd *cobra.Command, args []string) error {
} }
fi, err := os.Stat(path) fi, err := os.Stat(path)
if errors.Is(err, os.ErrNotExist) && c.Name == "model" { if errors.Is(err, os.ErrNotExist) && modelfile.Commands[i].Name == "model" {
continue continue
} else if err != nil { } else if err != nil {
return err return err
} }
// TODO make this work w/ adapters
if fi.IsDir() { if fi.IsDir() {
tf, err := os.CreateTemp("", "ollama-tf") // this is likely a safetensors or pytorch directory
// TODO make this work w/ adapters
tempfile, err := tempZipFiles(path)
if err != nil { if err != nil {
return err return err
} }
defer os.RemoveAll(tf.Name()) defer os.RemoveAll(tempfile)
zf := zip.NewWriter(tf) path = tempfile
files := []string{}
tfiles, err := filepath.Glob(filepath.Join(path, "pytorch_model-*.bin"))
if err != nil {
return err
} else if len(tfiles) == 0 {
tfiles, err = filepath.Glob(filepath.Join(path, "model-*.safetensors"))
if err != nil {
return err
}
}
files = append(files, tfiles...)
if len(files) == 0 {
return fmt.Errorf("no models were found in '%s'", path)
}
// add the safetensor/torch config file + tokenizer
files = append(files, filepath.Join(path, "config.json"))
files = append(files, filepath.Join(path, "params.json"))
files = append(files, filepath.Join(path, "added_tokens.json"))
files = append(files, filepath.Join(path, "tokenizer.model"))
for _, fn := range files {
f, err := os.Open(fn)
// just skip whatever files aren't there
if os.IsNotExist(err) {
if strings.HasSuffix(fn, "tokenizer.model") {
// try the parent dir before giving up
parentDir := filepath.Dir(path)
newFn := filepath.Join(parentDir, "tokenizer.model")
f, err = os.Open(newFn)
if os.IsNotExist(err) {
continue
} else if err != nil {
return err
}
} else {
continue
}
} else if err != nil {
return err
}
fi, err := f.Stat()
if err != nil {
return err
}
h, err := zip.FileInfoHeader(fi)
if err != nil {
return err
}
h.Name = filepath.Base(fn)
h.Method = zip.Store
w, err := zf.CreateHeader(h)
if err != nil {
return err
}
_, err = io.Copy(w, f)
if err != nil {
return err
}
}
if err := zf.Close(); err != nil {
return err
}
if err := tf.Close(); err != nil {
return err
}
path = tf.Name()
} }
digest, err := createBlob(cmd, client, path) digest, err := createBlob(cmd, client, path)
@ -191,10 +114,11 @@ func CreateHandler(cmd *cobra.Command, args []string) error {
return err return err
} }
modelfile = bytes.ReplaceAll(modelfile, []byte(c.Args), []byte("@"+digest)) modelfile.Commands[i].Args = "@" + digest
} }
} }
bars := make(map[string]*progress.Bar)
fn := func(resp api.ProgressResponse) error { fn := func(resp api.ProgressResponse) error {
if resp.Digest != "" { if resp.Digest != "" {
spinner.Stop() spinner.Stop()
@ -220,7 +144,7 @@ func CreateHandler(cmd *cobra.Command, args []string) error {
quantization, _ := cmd.Flags().GetString("quantization") quantization, _ := cmd.Flags().GetString("quantization")
request := api.CreateRequest{Name: args[0], Modelfile: string(modelfile), Quantization: quantization} request := api.CreateRequest{Name: args[0], Modelfile: modelfile.String(), Quantization: quantization}
if err := client.Create(cmd.Context(), &request, fn); err != nil { if err := client.Create(cmd.Context(), &request, fn); err != nil {
return err return err
} }
@ -228,6 +152,114 @@ func CreateHandler(cmd *cobra.Command, args []string) error {
return nil return nil
} }
func tempZipFiles(path string) (string, error) {
tempfile, err := os.CreateTemp("", "ollama-tf")
if err != nil {
return "", err
}
defer tempfile.Close()
zipfile := zip.NewWriter(tempfile)
defer zipfile.Close()
detectContentType := func(path string) (string, error) {
f, err := os.Open(path)
if err != nil {
return "", err
}
defer f.Close()
var b bytes.Buffer
b.Grow(512)
if _, err := io.CopyN(&b, f, 512); err != nil && !errors.Is(err, io.EOF) {
return "", err
}
contentType, _, _ := strings.Cut(http.DetectContentType(b.Bytes()), ";")
return contentType, nil
}
glob := func(pattern, contentType string) ([]string, error) {
matches, err := filepath.Glob(pattern)
if err != nil {
return nil, err
}
for _, safetensor := range matches {
if ct, err := detectContentType(safetensor); err != nil {
return nil, err
} else if ct != contentType {
return nil, fmt.Errorf("invalid content type: expected %s for %s", ct, safetensor)
}
}
return matches, nil
}
var files []string
if st, _ := glob(filepath.Join(path, "model*.safetensors"), "application/octet-stream"); len(st) > 0 {
// safetensors files might be unresolved git lfs references; skip if they are
// covers model-x-of-y.safetensors, model.fp32-x-of-y.safetensors, model.safetensors
files = append(files, st...)
} else if pt, _ := glob(filepath.Join(path, "pytorch_model*.bin"), "application/zip"); len(pt) > 0 {
// pytorch files might also be unresolved git lfs references; skip if they are
// covers pytorch_model-x-of-y.bin, pytorch_model.fp32-x-of-y.bin, pytorch_model.bin
files = append(files, pt...)
} else if pt, _ := glob(filepath.Join(path, "consolidated*.pth"), "application/octet-stream"); len(pt) > 0 {
// pytorch files might also be unresolved git lfs references; skip if they are
// covers consolidated.x.pth, consolidated.pth
files = append(files, pt...)
} else {
return "", errors.New("no safetensors or torch files found")
}
// add configuration files, json files are detected as text/plain
js, err := glob(filepath.Join(path, "*.json"), "text/plain")
if err != nil {
return "", err
}
files = append(files, js...)
if tks, _ := glob(filepath.Join(path, "tokenizer.model"), "application/octet-stream"); len(tks) > 0 {
// add tokenizer.model if it exists, tokenizer.json is automatically picked up by the previous glob
// tokenizer.model might be a unresolved git lfs reference; error if it is
files = append(files, tks...)
} else if tks, _ := glob(filepath.Join(path, "**/tokenizer.model"), "text/plain"); len(tks) > 0 {
// some times tokenizer.model is in a subdirectory (e.g. meta-llama/Meta-Llama-3-8B)
files = append(files, tks...)
}
for _, file := range files {
f, err := os.Open(file)
if err != nil {
return "", err
}
defer f.Close()
fi, err := f.Stat()
if err != nil {
return "", err
}
zfi, err := zip.FileInfoHeader(fi)
if err != nil {
return "", err
}
zf, err := zipfile.CreateHeader(zfi)
if err != nil {
return "", err
}
if _, err := io.Copy(zf, f); err != nil {
return "", err
}
}
return tempfile.Name(), nil
}
func createBlob(cmd *cobra.Command, client *api.Client, path string) (string, error) { func createBlob(cmd *cobra.Command, client *api.Client, path string) (string, error) {
bin, err := os.Open(path) bin, err := os.Open(path)
if err != nil { if err != nil {
@ -322,6 +354,47 @@ func RunHandler(cmd *cobra.Command, args []string) error {
return generateInteractive(cmd, opts) return generateInteractive(cmd, opts)
} }
func errFromUnknownKey(unknownKeyErr error) error {
// find SSH public key in the error message
sshKeyPattern := `ssh-\w+ [^\s"]+`
re := regexp.MustCompile(sshKeyPattern)
matches := re.FindStringSubmatch(unknownKeyErr.Error())
if len(matches) > 0 {
serverPubKey := matches[0]
localPubKey, err := auth.GetPublicKey()
if err != nil {
return unknownKeyErr
}
if runtime.GOOS == "linux" && serverPubKey != localPubKey {
// try the ollama service public key
svcPubKey, err := os.ReadFile("/usr/share/ollama/.ollama/id_ed25519.pub")
if err != nil {
return unknownKeyErr
}
localPubKey = strings.TrimSpace(string(svcPubKey))
}
// check if the returned public key matches the local public key, this prevents adding a remote key to the user's account
if serverPubKey != localPubKey {
return unknownKeyErr
}
var msg strings.Builder
msg.WriteString(unknownKeyErr.Error())
msg.WriteString("\n\nYour ollama key is:\n")
msg.WriteString(localPubKey)
msg.WriteString("\nAdd your key at:\n")
msg.WriteString("https://ollama.com/settings/keys")
return errors.New(msg.String())
}
return unknownKeyErr
}
func PushHandler(cmd *cobra.Command, args []string) error { func PushHandler(cmd *cobra.Command, args []string) error {
client, err := api.ClientFromEnvironment() client, err := api.ClientFromEnvironment()
if err != nil { if err != nil {
@ -369,6 +442,20 @@ func PushHandler(cmd *cobra.Command, args []string) error {
request := api.PushRequest{Name: args[0], Insecure: insecure} request := api.PushRequest{Name: args[0], Insecure: insecure}
if err := client.Push(cmd.Context(), &request, fn); err != nil { if err := client.Push(cmd.Context(), &request, fn); err != nil {
if spinner != nil {
spinner.Stop()
}
if strings.Contains(err.Error(), "access denied") {
return errors.New("you are not authorized to push to this namespace, create the model under a namespace you own")
}
host := model.ParseName(args[0]).Host
isOllamaHost := strings.HasSuffix(host, ".ollama.ai") || strings.HasSuffix(host, ".ollama.com")
if strings.Contains(err.Error(), errtypes.UnknownOllamaKeyErrMsg) && isOllamaHost {
// the user has not added their ollama key to ollama.com
// re-throw an error with a more user-friendly message
return errFromUnknownKey(err)
}
return err return err
} }
@ -796,24 +883,27 @@ func generate(cmd *cobra.Command, opts runOptions) error {
} }
func RunServer(cmd *cobra.Command, _ []string) error { func RunServer(cmd *cobra.Command, _ []string) error {
host, port, err := net.SplitHostPort(strings.Trim(os.Getenv("OLLAMA_HOST"), "\"'")) // retrieve the OLLAMA_HOST environment variable
ollamaHost, err := api.GetOllamaHost()
if err != nil { if err != nil {
host, port = "127.0.0.1", "11434" return err
if ip := net.ParseIP(strings.Trim(os.Getenv("OLLAMA_HOST"), "[]")); ip != nil {
host = ip.String()
}
} }
if err := initializeKeypair(); err != nil { if err := initializeKeypair(); err != nil {
return err return err
} }
ln, err := net.Listen("tcp", net.JoinHostPort(host, port)) ln, err := net.Listen("tcp", net.JoinHostPort(ollamaHost.Host, ollamaHost.Port))
if err != nil { if err != nil {
return err return err
} }
return server.Serve(ln) err = server.Serve(ln)
if errors.Is(err, http.ErrServerClosed) {
return nil
}
return err
} }
func initializeKeypair() error { func initializeKeypair() error {
@ -1034,7 +1124,7 @@ Environment Variables:
RunE: ListHandler, RunE: ListHandler,
} }
copyCmd := &cobra.Command{ copyCmd := &cobra.Command{
Use: "cp SOURCE TARGET", Use: "cp SOURCE DESTINATION",
Short: "Copy a model", Short: "Copy a model",
Args: cobra.ExactArgs(2), Args: cobra.ExactArgs(2),
PreRunE: checkServerHeartbeat, PreRunE: checkServerHeartbeat,

View file

@ -94,6 +94,7 @@ func generateInteractive(cmd *cobra.Command, opts runOptions) error {
fmt.Fprintln(os.Stderr, " /show Show model information") fmt.Fprintln(os.Stderr, " /show Show model information")
fmt.Fprintln(os.Stderr, " /load <model> Load a session or model") fmt.Fprintln(os.Stderr, " /load <model> Load a session or model")
fmt.Fprintln(os.Stderr, " /save <model> Save your current session") fmt.Fprintln(os.Stderr, " /save <model> Save your current session")
fmt.Fprintln(os.Stderr, " /clear Clear session context")
fmt.Fprintln(os.Stderr, " /bye Exit") fmt.Fprintln(os.Stderr, " /bye Exit")
fmt.Fprintln(os.Stderr, " /?, /help Help for a command") fmt.Fprintln(os.Stderr, " /?, /help Help for a command")
fmt.Fprintln(os.Stderr, " /? shortcuts Help for keyboard shortcuts") fmt.Fprintln(os.Stderr, " /? shortcuts Help for keyboard shortcuts")
@ -161,7 +162,7 @@ func generateInteractive(cmd *cobra.Command, opts runOptions) error {
fmt.Fprintln(os.Stderr, " /set parameter repeat_penalty <float> How strongly to penalize repetitions") fmt.Fprintln(os.Stderr, " /set parameter repeat_penalty <float> How strongly to penalize repetitions")
fmt.Fprintln(os.Stderr, " /set parameter repeat_last_n <int> Set how far back to look for repetitions") fmt.Fprintln(os.Stderr, " /set parameter repeat_last_n <int> Set how far back to look for repetitions")
fmt.Fprintln(os.Stderr, " /set parameter num_gpu <int> The number of layers to send to the GPU") fmt.Fprintln(os.Stderr, " /set parameter num_gpu <int> The number of layers to send to the GPU")
fmt.Fprintln(os.Stderr, " /set parameter stop \"<string>\", ... Set the stop parameters") fmt.Fprintln(os.Stderr, " /set parameter stop <string> <string> ... Set the stop parameters")
fmt.Fprintln(os.Stderr, "") fmt.Fprintln(os.Stderr, "")
} }
@ -280,6 +281,10 @@ func generateInteractive(cmd *cobra.Command, opts runOptions) error {
} }
fmt.Printf("Created new model '%s'\n", args[1]) fmt.Printf("Created new model '%s'\n", args[1])
continue continue
case strings.HasPrefix(line, "/clear"):
opts.Messages = []api.Message{}
fmt.Println("Cleared session context")
continue
case strings.HasPrefix(line, "/set"): case strings.HasPrefix(line, "/set"):
args := strings.Fields(line) args := strings.Fields(line)
if len(args) > 1 { if len(args) > 1 {

View file

@ -5,6 +5,7 @@ import (
"encoding/binary" "encoding/binary"
"encoding/json" "encoding/json"
"fmt" "fmt"
"io"
"log/slog" "log/slog"
"os" "os"
"path/filepath" "path/filepath"
@ -31,6 +32,10 @@ type Params struct {
EoSTokenID int `json:"eos_token_id"` EoSTokenID int `json:"eos_token_id"`
HeadDimension int `json:"head_dim"` HeadDimension int `json:"head_dim"`
PaddingTokenID int `json:"pad_token_id"` PaddingTokenID int `json:"pad_token_id"`
RopeFrequencyBase float64 `json:"rope_theta"`
Experts int `json:"num_local_experts"`
ExpertsUsed int `json:"num_experts_per_tok"`
ByteOrder ByteOrder
} }
@ -43,7 +48,7 @@ type ByteOrder interface {
type ModelArch interface { type ModelArch interface {
GetTensors() error GetTensors() error
LoadVocab() error LoadVocab() error
WriteGGUF() (string, error) WriteGGUF(io.WriteSeeker) error
} }
type ModelFormat interface { type ModelFormat interface {

View file

@ -94,7 +94,7 @@ func (m *GemmaModel) LoadVocab() error {
return nil return nil
} }
func (m *GemmaModel) WriteGGUF() (string, error) { func (m *GemmaModel) WriteGGUF(ws io.WriteSeeker) error {
kv := llm.KV{ kv := llm.KV{
"general.architecture": "gemma", "general.architecture": "gemma",
"general.name": m.Name, "general.name": m.Name,
@ -122,16 +122,5 @@ func (m *GemmaModel) WriteGGUF() (string, error) {
"tokenizer.ggml.add_eos_token": false, "tokenizer.ggml.add_eos_token": false,
} }
f, err := os.CreateTemp("", "ollama-gguf") return llm.NewGGUFV3(m.Params.ByteOrder).Encode(ws, kv, m.Tensors)
if err != nil {
return "", err
}
defer f.Close()
mod := llm.NewGGUFV3(m.Params.ByteOrder)
if err := mod.Encode(f, kv, m.Tensors); err != nil {
return "", err
}
return f.Name(), nil
} }

View file

@ -5,7 +5,6 @@ import (
"fmt" "fmt"
"io" "io"
"log/slog" "log/slog"
"os"
"regexp" "regexp"
"strings" "strings"
@ -132,7 +131,7 @@ func (m *LlamaModel) LoadVocab() error {
return nil return nil
} }
func (m *LlamaModel) WriteGGUF() (string, error) { func (m *LlamaModel) WriteGGUF(ws io.WriteSeeker) error {
kv := llm.KV{ kv := llm.KV{
"general.architecture": "llama", "general.architecture": "llama",
"general.name": m.Name, "general.name": m.Name,
@ -159,18 +158,5 @@ func (m *LlamaModel) WriteGGUF() (string, error) {
"tokenizer.ggml.add_eos_token": false, "tokenizer.ggml.add_eos_token": false,
} }
f, err := os.CreateTemp("", "ollama-gguf") return llm.NewGGUFV3(m.Params.ByteOrder).Encode(ws, kv, m.Tensors)
if err != nil {
return "", err
}
defer f.Close()
mod := llm.NewGGUFV3(m.Params.ByteOrder)
if err := mod.Encode(f, kv, m.Tensors); err != nil {
return "", err
}
slog.Debug(fmt.Sprintf("gguf file = %s", f.Name()))
return f.Name(), nil
} }

View file

@ -132,7 +132,7 @@ func (m *MistralModel) LoadVocab() error {
return nil return nil
} }
func (m *MistralModel) WriteGGUF() (string, error) { func (m *MistralModel) WriteGGUF(ws io.WriteSeeker) error {
kv := llm.KV{ kv := llm.KV{
"general.architecture": "llama", "general.architecture": "llama",
"general.name": m.Name, "general.name": m.Name,
@ -158,16 +158,5 @@ func (m *MistralModel) WriteGGUF() (string, error) {
"tokenizer.ggml.unknown_token_id": uint32(0), "tokenizer.ggml.unknown_token_id": uint32(0),
} }
f, err := os.CreateTemp("", "ollama-gguf") return llm.NewGGUFV3(m.Params.ByteOrder).Encode(ws, kv, m.Tensors)
if err != nil {
return "", err
}
defer f.Close()
mod := llm.NewGGUFV3(m.Params.ByteOrder)
if err := mod.Encode(f, kv, m.Tensors); err != nil {
return "", err
}
return f.Name(), nil
} }

85
convert/mixtral.go Normal file
View file

@ -0,0 +1,85 @@
package convert
import (
"io"
"regexp"
"github.com/ollama/ollama/llm"
)
type MixtralModel struct {
ModelData
}
func (m *MixtralModel) GetTensors() error {
t, err := m.Format.GetTensors(m.Path, m.Params)
if err != nil {
return err
}
m.Tensors = []llm.Tensor{}
pattern := `^blk\.[0-9]+\.attn_(?P<layer>q|k)\.weight$`
re, err := regexp.Compile(pattern)
if err != nil {
return err
}
for _, l := range t {
matches := re.FindAllStringSubmatch(l.Name, -1)
if len(matches) > 0 {
wt := l.WriterTo.(safetensorWriterTo)
wt.handler = mistralLayerHandler
l.WriterTo = wt
}
m.Tensors = append(m.Tensors, l)
}
return nil
}
func (m *MixtralModel) LoadVocab() error {
v, err := LoadSentencePieceTokens(m.Path, m.Params)
if err != nil {
return err
}
m.Vocab = v
return nil
}
func (m *MixtralModel) WriteGGUF(ws io.WriteSeeker) error {
kv := llm.KV{
"general.architecture": "llama",
"general.name": m.Name,
"llama.block_count": uint32(m.Params.HiddenLayers),
"llama.context_length": uint32(m.Params.ContextSize),
"llama.embedding_length": uint32(m.Params.HiddenSize),
"llama.feed_forward_length": uint32(m.Params.IntermediateSize),
"llama.attention.head_count": uint32(m.Params.AttentionHeads),
"llama.attention.head_count_kv": uint32(m.Params.KeyValHeads),
"llama.rope.freq_base": float32(m.Params.RopeFrequencyBase),
"llama.attention.layer_norm_rms_epsilon": float32(m.Params.NormEPS),
"llama.expert_count": uint32(m.Params.Experts),
"llama.expert_used_count": uint32(m.Params.ExpertsUsed),
"llama.vocab_size": uint32(len(m.Vocab.Tokens)),
"llama.rope.dimension_count": uint32(m.Params.HiddenSize / m.Params.AttentionHeads),
"general.file_type": uint32(1),
"tokenizer.ggml.model": "llama",
"tokenizer.ggml.tokens": m.Vocab.Tokens,
"tokenizer.ggml.scores": m.Vocab.Scores,
"tokenizer.ggml.token_type": m.Vocab.Types,
"tokenizer.ggml.bos_token_id": uint32(m.Params.BoSTokenID),
"tokenizer.ggml.eos_token_id": uint32(m.Params.EoSTokenID),
"tokenizer.ggml.unknown_token_id": uint32(0),
"tokenizer.ggml.add_bos_token": true,
"tokenizer.ggml.add_eos_token": false,
}
return llm.NewGGUFV3(m.Params.ByteOrder).Encode(ws, kv, m.Tensors)
}

View file

@ -53,7 +53,7 @@ func (m *SafetensorFormat) GetTensors(dirpath string, params *Params) ([]llm.Ten
var err error var err error
t, offset, err = m.readTensors(f, offset, params) t, offset, err = m.readTensors(f, offset, params)
if err != nil { if err != nil {
slog.Error("%v", err) slog.Error(err.Error())
return nil, err return nil, err
} }
tensors = append(tensors, t...) tensors = append(tensors, t...)
@ -93,7 +93,6 @@ func (m *SafetensorFormat) readTensors(fn string, offset uint64, params *Params)
} }
slices.Sort(keys) slices.Sort(keys)
slog.Info("converting layers") slog.Info("converting layers")
var tensors []llm.Tensor var tensors []llm.Tensor
@ -105,7 +104,6 @@ func (m *SafetensorFormat) readTensors(fn string, offset uint64, params *Params)
return nil, 0, err return nil, 0, err
} }
slog.Debug(fmt.Sprintf("metadata = %#v", data))
var size uint64 var size uint64
var kind uint32 var kind uint32
switch len(data.Shape) { switch len(data.Shape) {
@ -124,7 +122,7 @@ func (m *SafetensorFormat) readTensors(fn string, offset uint64, params *Params)
ggufName, err := m.GetLayerName(k) ggufName, err := m.GetLayerName(k)
if err != nil { if err != nil {
slog.Error("%v", err) slog.Error(err.Error())
return nil, 0, err return nil, 0, err
} }
@ -150,11 +148,13 @@ func (m *SafetensorFormat) readTensors(fn string, offset uint64, params *Params)
padding: 8 + jsonSize, padding: 8 + jsonSize,
} }
tensors = append(tensors, t)
offset += size offset += size
tensors = append(tensors, t)
} }
slog.Debug(fmt.Sprintf("total tensors for file = %d", len(tensors))) slog.Debug(fmt.Sprintf("total tensors for file = %d", len(tensors)))
slog.Debug(fmt.Sprintf("offset = %d", offset)) slog.Debug(fmt.Sprintf("offset = %d", offset))
return tensors, offset, nil return tensors, offset, nil
} }
@ -194,6 +194,10 @@ func (m *SafetensorFormat) GetLayerName(n string) (string, error) {
"model.layers.(\\d+).self_attn.o_proj.weight": "blk.$1.attn_output.weight", "model.layers.(\\d+).self_attn.o_proj.weight": "blk.$1.attn_output.weight",
"model.layers.(\\d+).self_attn.q_proj.weight": "blk.$1.attn_q.weight", "model.layers.(\\d+).self_attn.q_proj.weight": "blk.$1.attn_q.weight",
"model.layers.(\\d+).self_attn.v_proj.weight": "blk.$1.attn_v.weight", "model.layers.(\\d+).self_attn.v_proj.weight": "blk.$1.attn_v.weight",
"model.layers.(\\d+).block_sparse_moe.gate.weight": "blk.$1.ffn_gate_inp.weight",
"model.layers.(\\d+).block_sparse_moe.experts.(\\d+).w1.weight": "blk.$1.ffn_gate.$2.weight",
"model.layers.(\\d+).block_sparse_moe.experts.(\\d+).w2.weight": "blk.$1.ffn_down.$2.weight",
"model.layers.(\\d+).block_sparse_moe.experts.(\\d+).w3.weight": "blk.$1.ffn_up.$2.weight",
} }
v, ok := directMap[n] v, ok := directMap[n]
@ -286,6 +290,15 @@ func (m *SafetensorFormat) GetModelArch(name, dirPath string, params *Params) (M
Format: m, Format: m,
}, },
}, nil }, nil
case "MixtralForCausalLM":
return &MixtralModel{
ModelData{
Name: name,
Path: dirPath,
Params: params,
Format: m,
},
}, nil
case "GemmaForCausalLM": case "GemmaForCausalLM":
return &GemmaModel{ return &GemmaModel{
ModelData{ ModelData{

View file

@ -74,7 +74,7 @@ func (tf *TorchFormat) GetTensors(dirpath string, params *Params) ([]llm.Tensor,
ggufName, err := tf.GetLayerName(k.(string)) ggufName, err := tf.GetLayerName(k.(string))
if err != nil { if err != nil {
slog.Error("%v", err) slog.Error(err.Error())
return nil, err return nil, err
} }
slog.Debug(fmt.Sprintf("finding name for '%s' -> '%s'", k.(string), ggufName)) slog.Debug(fmt.Sprintf("finding name for '%s' -> '%s'", k.(string), ggufName))

View file

@ -17,7 +17,7 @@
### Model names ### Model names
Model names follow a `model:tag` format, where `model` can have an optional namespace such as `example/model`. Some examples are `orca-mini:3b-q4_1` and `llama2:70b`. The tag is optional and, if not provided, will default to `latest`. The tag is used to identify a specific version. Model names follow a `model:tag` format, where `model` can have an optional namespace such as `example/model`. Some examples are `orca-mini:3b-q4_1` and `llama3:70b`. The tag is optional and, if not provided, will default to `latest`. The tag is used to identify a specific version.
### Durations ### Durations
@ -66,7 +66,7 @@ Enable JSON mode by setting the `format` parameter to `json`. This will structur
```shell ```shell
curl http://localhost:11434/api/generate -d '{ curl http://localhost:11434/api/generate -d '{
"model": "llama2", "model": "llama3",
"prompt": "Why is the sky blue?" "prompt": "Why is the sky blue?"
}' }'
``` ```
@ -77,7 +77,7 @@ A stream of JSON objects is returned:
```json ```json
{ {
"model": "llama2", "model": "llama3",
"created_at": "2023-08-04T08:52:19.385406455-07:00", "created_at": "2023-08-04T08:52:19.385406455-07:00",
"response": "The", "response": "The",
"done": false "done": false
@ -90,16 +90,16 @@ The final response in the stream also includes additional data about the generat
- `load_duration`: time spent in nanoseconds loading the model - `load_duration`: time spent in nanoseconds loading the model
- `prompt_eval_count`: number of tokens in the prompt - `prompt_eval_count`: number of tokens in the prompt
- `prompt_eval_duration`: time spent in nanoseconds evaluating the prompt - `prompt_eval_duration`: time spent in nanoseconds evaluating the prompt
- `eval_count`: number of tokens the response - `eval_count`: number of tokens in the response
- `eval_duration`: time in nanoseconds spent generating the response - `eval_duration`: time in nanoseconds spent generating the response
- `context`: an encoding of the conversation used in this response, this can be sent in the next request to keep a conversational memory - `context`: an encoding of the conversation used in this response, this can be sent in the next request to keep a conversational memory
- `response`: empty if the response was streamed, if not streamed, this will contain the full response - `response`: empty if the response was streamed, if not streamed, this will contain the full response
To calculate how fast the response is generated in tokens per second (token/s), divide `eval_count` / `eval_duration`. To calculate how fast the response is generated in tokens per second (token/s), divide `eval_count` / `eval_duration` * `10^9`.
```json ```json
{ {
"model": "llama2", "model": "llama3",
"created_at": "2023-08-04T19:22:45.499127Z", "created_at": "2023-08-04T19:22:45.499127Z",
"response": "", "response": "",
"done": true, "done": true,
@ -121,7 +121,7 @@ A response can be received in one reply when streaming is off.
```shell ```shell
curl http://localhost:11434/api/generate -d '{ curl http://localhost:11434/api/generate -d '{
"model": "llama2", "model": "llama3",
"prompt": "Why is the sky blue?", "prompt": "Why is the sky blue?",
"stream": false "stream": false
}' }'
@ -133,7 +133,7 @@ If `stream` is set to `false`, the response will be a single JSON object:
```json ```json
{ {
"model": "llama2", "model": "llama3",
"created_at": "2023-08-04T19:22:45.499127Z", "created_at": "2023-08-04T19:22:45.499127Z",
"response": "The sky is blue because it is the color of the sky.", "response": "The sky is blue because it is the color of the sky.",
"done": true, "done": true,
@ -155,7 +155,7 @@ If `stream` is set to `false`, the response will be a single JSON object:
```shell ```shell
curl http://localhost:11434/api/generate -d '{ curl http://localhost:11434/api/generate -d '{
"model": "llama2", "model": "llama3",
"prompt": "What color is the sky at different times of the day? Respond using JSON", "prompt": "What color is the sky at different times of the day? Respond using JSON",
"format": "json", "format": "json",
"stream": false "stream": false
@ -166,7 +166,7 @@ curl http://localhost:11434/api/generate -d '{
```json ```json
{ {
"model": "llama2", "model": "llama3",
"created_at": "2023-11-09T21:07:55.186497Z", "created_at": "2023-11-09T21:07:55.186497Z",
"response": "{\n\"morning\": {\n\"color\": \"blue\"\n},\n\"noon\": {\n\"color\": \"blue-gray\"\n},\n\"afternoon\": {\n\"color\": \"warm gray\"\n},\n\"evening\": {\n\"color\": \"orange\"\n}\n}\n", "response": "{\n\"morning\": {\n\"color\": \"blue\"\n},\n\"noon\": {\n\"color\": \"blue-gray\"\n},\n\"afternoon\": {\n\"color\": \"warm gray\"\n},\n\"evening\": {\n\"color\": \"orange\"\n}\n}\n",
"done": true, "done": true,
@ -289,7 +289,7 @@ If you want to set custom options for the model at runtime rather than in the Mo
```shell ```shell
curl http://localhost:11434/api/generate -d '{ curl http://localhost:11434/api/generate -d '{
"model": "llama2", "model": "llama3",
"prompt": "Why is the sky blue?", "prompt": "Why is the sky blue?",
"stream": false, "stream": false,
"options": { "options": {
@ -332,7 +332,7 @@ curl http://localhost:11434/api/generate -d '{
```json ```json
{ {
"model": "llama2", "model": "llama3",
"created_at": "2023-08-04T19:22:45.499127Z", "created_at": "2023-08-04T19:22:45.499127Z",
"response": "The sky is blue because it is the color of the sky.", "response": "The sky is blue because it is the color of the sky.",
"done": true, "done": true,
@ -354,7 +354,7 @@ If an empty prompt is provided, the model will be loaded into memory.
```shell ```shell
curl http://localhost:11434/api/generate -d '{ curl http://localhost:11434/api/generate -d '{
"model": "llama2" "model": "llama3"
}' }'
``` ```
@ -364,7 +364,7 @@ A single JSON object is returned:
```json ```json
{ {
"model": "llama2", "model": "llama3",
"created_at": "2023-12-18T19:52:07.071755Z", "created_at": "2023-12-18T19:52:07.071755Z",
"response": "", "response": "",
"done": true "done": true
@ -407,7 +407,7 @@ Send a chat message with a streaming response.
```shell ```shell
curl http://localhost:11434/api/chat -d '{ curl http://localhost:11434/api/chat -d '{
"model": "llama2", "model": "llama3",
"messages": [ "messages": [
{ {
"role": "user", "role": "user",
@ -423,7 +423,7 @@ A stream of JSON objects is returned:
```json ```json
{ {
"model": "llama2", "model": "llama3",
"created_at": "2023-08-04T08:52:19.385406455-07:00", "created_at": "2023-08-04T08:52:19.385406455-07:00",
"message": { "message": {
"role": "assistant", "role": "assistant",
@ -438,7 +438,7 @@ Final response:
```json ```json
{ {
"model": "llama2", "model": "llama3",
"created_at": "2023-08-04T19:22:45.499127Z", "created_at": "2023-08-04T19:22:45.499127Z",
"done": true, "done": true,
"total_duration": 4883583458, "total_duration": 4883583458,
@ -456,7 +456,7 @@ Final response:
```shell ```shell
curl http://localhost:11434/api/chat -d '{ curl http://localhost:11434/api/chat -d '{
"model": "llama2", "model": "llama3",
"messages": [ "messages": [
{ {
"role": "user", "role": "user",
@ -471,7 +471,7 @@ curl http://localhost:11434/api/chat -d '{
```json ```json
{ {
"model": "registry.ollama.ai/library/llama2:latest", "model": "registry.ollama.ai/library/llama3:latest",
"created_at": "2023-12-12T14:13:43.416799Z", "created_at": "2023-12-12T14:13:43.416799Z",
"message": { "message": {
"role": "assistant", "role": "assistant",
@ -495,7 +495,7 @@ Send a chat message with a conversation history. You can use this same approach
```shell ```shell
curl http://localhost:11434/api/chat -d '{ curl http://localhost:11434/api/chat -d '{
"model": "llama2", "model": "llama3",
"messages": [ "messages": [
{ {
"role": "user", "role": "user",
@ -519,7 +519,7 @@ A stream of JSON objects is returned:
```json ```json
{ {
"model": "llama2", "model": "llama3",
"created_at": "2023-08-04T08:52:19.385406455-07:00", "created_at": "2023-08-04T08:52:19.385406455-07:00",
"message": { "message": {
"role": "assistant", "role": "assistant",
@ -533,7 +533,7 @@ Final response:
```json ```json
{ {
"model": "llama2", "model": "llama3",
"created_at": "2023-08-04T19:22:45.499127Z", "created_at": "2023-08-04T19:22:45.499127Z",
"done": true, "done": true,
"total_duration": 8113331500, "total_duration": 8113331500,
@ -591,7 +591,7 @@ curl http://localhost:11434/api/chat -d '{
```shell ```shell
curl http://localhost:11434/api/chat -d '{ curl http://localhost:11434/api/chat -d '{
"model": "llama2", "model": "llama3",
"messages": [ "messages": [
{ {
"role": "user", "role": "user",
@ -609,7 +609,7 @@ curl http://localhost:11434/api/chat -d '{
```json ```json
{ {
"model": "registry.ollama.ai/library/llama2:latest", "model": "registry.ollama.ai/library/llama3:latest",
"created_at": "2023-12-12T14:13:43.416799Z", "created_at": "2023-12-12T14:13:43.416799Z",
"message": { "message": {
"role": "assistant", "role": "assistant",
@ -651,7 +651,7 @@ Create a new model from a `Modelfile`.
```shell ```shell
curl http://localhost:11434/api/create -d '{ curl http://localhost:11434/api/create -d '{
"name": "mario", "name": "mario",
"modelfile": "FROM llama2\nSYSTEM You are mario from Super Mario Bros." "modelfile": "FROM llama3\nSYSTEM You are mario from Super Mario Bros."
}' }'
``` ```
@ -758,7 +758,7 @@ A single JSON object will be returned.
} }
}, },
{ {
"name": "llama2:latest", "name": "llama3:latest",
"modified_at": "2023-12-07T09:32:18.757212583-08:00", "modified_at": "2023-12-07T09:32:18.757212583-08:00",
"size": 3825819519, "size": 3825819519,
"digest": "fe938a131f40e6f6d40083c9f0f430a515233eb2edaa6d72eb85c50d64f2300e", "digest": "fe938a131f40e6f6d40083c9f0f430a515233eb2edaa6d72eb85c50d64f2300e",
@ -792,7 +792,7 @@ Show information about a model including details, modelfile, template, parameter
```shell ```shell
curl http://localhost:11434/api/show -d '{ curl http://localhost:11434/api/show -d '{
"name": "llama2" "name": "llama3"
}' }'
``` ```
@ -827,8 +827,8 @@ Copy a model. Creates a model with another name from an existing model.
```shell ```shell
curl http://localhost:11434/api/copy -d '{ curl http://localhost:11434/api/copy -d '{
"source": "llama2", "source": "llama3",
"destination": "llama2-backup" "destination": "llama3-backup"
}' }'
``` ```
@ -854,7 +854,7 @@ Delete a model and its data.
```shell ```shell
curl -X DELETE http://localhost:11434/api/delete -d '{ curl -X DELETE http://localhost:11434/api/delete -d '{
"name": "llama2:13b" "name": "llama3:13b"
}' }'
``` ```
@ -882,7 +882,7 @@ Download a model from the ollama library. Cancelled pulls are resumed from where
```shell ```shell
curl http://localhost:11434/api/pull -d '{ curl http://localhost:11434/api/pull -d '{
"name": "llama2" "name": "llama3"
}' }'
``` ```

View file

@ -51,7 +51,7 @@ Typically the build scripts will auto-detect CUDA, however, if your Linux distro
or installation approach uses unusual paths, you can specify the location by or installation approach uses unusual paths, you can specify the location by
specifying an environment variable `CUDA_LIB_DIR` to the location of the shared specifying an environment variable `CUDA_LIB_DIR` to the location of the shared
libraries, and `CUDACXX` to the location of the nvcc compiler. You can customize libraries, and `CUDACXX` to the location of the nvcc compiler. You can customize
set set of target CUDA architectues by setting `CMAKE_CUDA_ARCHITECTURES` (e.g. "50;60;70") a set of target CUDA architectures by setting `CMAKE_CUDA_ARCHITECTURES` (e.g. "50;60;70")
Then generate dependencies: Then generate dependencies:

View file

@ -32,7 +32,7 @@ When using the API, specify the `num_ctx` parameter:
``` ```
curl http://localhost:11434/api/generate -d '{ curl http://localhost:11434/api/generate -d '{
"model": "llama2", "model": "llama3",
"prompt": "Why is the sky blue?", "prompt": "Why is the sky blue?",
"options": { "options": {
"num_ctx": 4096 "num_ctx": 4096
@ -140,7 +140,7 @@ Refer to the section [above](#how-do-i-configure-ollama-server) for how to set e
- macOS: `~/.ollama/models` - macOS: `~/.ollama/models`
- Linux: `/usr/share/ollama/.ollama/models` - Linux: `/usr/share/ollama/.ollama/models`
- Windows: `C:\Users\<username>\.ollama\models` - Windows: `C:\Users\%username%\.ollama\models`
### How do I set them to a different location? ### How do I set them to a different location?
@ -221,10 +221,20 @@ The `keep_alive` parameter can be set to:
For example, to preload a model and leave it in memory use: For example, to preload a model and leave it in memory use:
```shell ```shell
curl http://localhost:11434/api/generate -d '{"model": "llama2", "keep_alive": -1}' curl http://localhost:11434/api/generate -d '{"model": "llama3", "keep_alive": -1}'
``` ```
To unload the model and free up memory use: To unload the model and free up memory use:
```shell ```shell
curl http://localhost:11434/api/generate -d '{"model": "llama2", "keep_alive": 0}' curl http://localhost:11434/api/generate -d '{"model": "llama3", "keep_alive": 0}'
``` ```
Alternatively, you can change the amount of time all models are loaded into memory by setting the `OLLAMA_KEEP_ALIVE` environment variable when starting the Ollama server. The `OLLAMA_KEEP_ALIVE` variable uses the same parameter types as the `keep_alive` parameter types mentioned above. Refer to section explaining [how to configure the Ollama server](#how-do-i-configure-ollama-server) to correctly set the environment variable.
If you wish to override the `OLLAMA_KEEP_ALIVE` setting, use the `keep_alive` API parameter with the `/api/generate` or `/api/chat` API endpoints.
## How do I manage the maximum number of requests the server can queue
If too many requests are sent to the server, it will respond with a 503 error
indicating the server is overloaded. You can adjust how many requests may be
queue by setting `OLLAMA_MAX_QUEUE`

View file

@ -125,7 +125,7 @@ Publishing models is in early alpha. If you'd like to publish your model to shar
1. Create [an account](https://ollama.com/signup) 1. Create [an account](https://ollama.com/signup)
2. Copy your Ollama public key: 2. Copy your Ollama public key:
- macOS: `cat ~/.ollama/id_ed25519.pub` - macOS: `cat ~/.ollama/id_ed25519.pub | pbcopy`
- Windows: `type %USERPROFILE%\.ollama\id_ed25519.pub` - Windows: `type %USERPROFILE%\.ollama\id_ed25519.pub`
- Linux: `cat /usr/share/ollama/.ollama/id_ed25519.pub` - Linux: `cat /usr/share/ollama/.ollama/id_ed25519.pub`
3. Add your public key to your [Ollama account](https://ollama.com/settings/keys) 3. Add your public key to your [Ollama account](https://ollama.com/settings/keys)
@ -136,6 +136,8 @@ Next, copy your model to your username's namespace:
ollama cp example <your username>/example ollama cp example <your username>/example
``` ```
> Note: model names may only contain lowercase letters, digits, and the characters `.`, `-`, and `_`.
Then push the model: Then push the model:
``` ```

View file

@ -105,7 +105,7 @@ sudo chmod +x /usr/bin/ollama
To view logs of Ollama running as a startup service, run: To view logs of Ollama running as a startup service, run:
```bash ```bash
journalctl -u ollama journalctl -e -u ollama
``` ```
## Uninstall ## Uninstall

View file

@ -10,7 +10,7 @@ A model file is the blueprint to create and share models with Ollama.
- [Examples](#examples) - [Examples](#examples)
- [Instructions](#instructions) - [Instructions](#instructions)
- [FROM (Required)](#from-required) - [FROM (Required)](#from-required)
- [Build from llama2](#build-from-llama2) - [Build from llama3](#build-from-llama3)
- [Build from a bin file](#build-from-a-bin-file) - [Build from a bin file](#build-from-a-bin-file)
- [PARAMETER](#parameter) - [PARAMETER](#parameter)
- [Valid Parameters and Values](#valid-parameters-and-values) - [Valid Parameters and Values](#valid-parameters-and-values)
@ -48,7 +48,7 @@ INSTRUCTION arguments
An example of a `Modelfile` creating a mario blueprint: An example of a `Modelfile` creating a mario blueprint:
```modelfile ```modelfile
FROM llama2 FROM llama3
# sets the temperature to 1 [higher is more creative, lower is more coherent] # sets the temperature to 1 [higher is more creative, lower is more coherent]
PARAMETER temperature 1 PARAMETER temperature 1
# sets the context window size to 4096, this controls how many tokens the LLM can use as context to generate the next token # sets the context window size to 4096, this controls how many tokens the LLM can use as context to generate the next token
@ -67,33 +67,25 @@ To use this:
More examples are available in the [examples directory](../examples). More examples are available in the [examples directory](../examples).
### `Modelfile`s in [ollama.com/library][1] To view the Modelfile of a given model, use the `ollama show --modelfile` command.
There are two ways to view `Modelfile`s underlying the models in [ollama.com/library][1]:
- Option 1: view a details page from a model's tags page:
1. Go to a particular model's tags (e.g. https://ollama.com/library/llama2/tags)
2. Click on a tag (e.g. https://ollama.com/library/llama2:13b)
3. Scroll down to "Layers"
- Note: if the [`FROM` instruction](#from-required) is not present,
it means the model was created from a local file
- Option 2: use `ollama show` to print the `Modelfile` for any local models like so:
```bash ```bash
> ollama show --modelfile llama2:13b > ollama show --modelfile llama3
# Modelfile generated by "ollama show" # Modelfile generated by "ollama show"
# To build a new Modelfile based on this one, replace the FROM line with: # To build a new Modelfile based on this one, replace the FROM line with:
# FROM llama2:13b # FROM llama3:latest
FROM /Users/pdevine/.ollama/models/blobs/sha256-00e1317cbf74d901080d7100f57580ba8dd8de57203072dc6f668324ba545f29
TEMPLATE """{{ if .System }}<|start_header_id|>system<|end_header_id|>
FROM /root/.ollama/models/blobs/sha256:123abc {{ .System }}<|eot_id|>{{ end }}{{ if .Prompt }}<|start_header_id|>user<|end_header_id|>
TEMPLATE """[INST] {{ if .System }}<<SYS>>{{ .System }}<</SYS>>
{{ end }}{{ .Prompt }} [/INST] """ {{ .Prompt }}<|eot_id|>{{ end }}<|start_header_id|>assistant<|end_header_id|>
SYSTEM """"""
PARAMETER stop [INST] {{ .Response }}<|eot_id|>"""
PARAMETER stop [/INST] PARAMETER stop "<|start_header_id|>"
PARAMETER stop <<SYS>> PARAMETER stop "<|end_header_id|>"
PARAMETER stop <</SYS>> PARAMETER stop "<|eot_id|>"
PARAMETER stop "<|reserved_special_token"
``` ```
## Instructions ## Instructions
@ -106,10 +98,10 @@ The `FROM` instruction defines the base model to use when creating a model.
FROM <model name>:<tag> FROM <model name>:<tag>
``` ```
#### Build from llama2 #### Build from llama3
```modelfile ```modelfile
FROM llama2 FROM llama3
``` ```
A list of available base models: A list of available base models:

View file

@ -25,7 +25,7 @@ chat_completion = client.chat.completions.create(
'content': 'Say this is a test', 'content': 'Say this is a test',
} }
], ],
model='llama2', model='llama3',
) )
``` ```
@ -43,7 +43,7 @@ const openai = new OpenAI({
const chatCompletion = await openai.chat.completions.create({ const chatCompletion = await openai.chat.completions.create({
messages: [{ role: 'user', content: 'Say this is a test' }], messages: [{ role: 'user', content: 'Say this is a test' }],
model: 'llama2', model: 'llama3',
}) })
``` ```
@ -53,7 +53,7 @@ const chatCompletion = await openai.chat.completions.create({
curl http://localhost:11434/v1/chat/completions \ curl http://localhost:11434/v1/chat/completions \
-H "Content-Type: application/json" \ -H "Content-Type: application/json" \
-d '{ -d '{
"model": "llama2", "model": "llama3",
"messages": [ "messages": [
{ {
"role": "system", "role": "system",
@ -113,7 +113,7 @@ curl http://localhost:11434/v1/chat/completions \
Before using a model, pull it locally `ollama pull`: Before using a model, pull it locally `ollama pull`:
```shell ```shell
ollama pull llama2 ollama pull llama3
``` ```
### Default model names ### Default model names
@ -121,7 +121,7 @@ ollama pull llama2
For tooling that relies on default OpenAI model names such as `gpt-3.5-turbo`, use `ollama cp` to copy an existing model name to a temporary name: For tooling that relies on default OpenAI model names such as `gpt-3.5-turbo`, use `ollama cp` to copy an existing model name to a temporary name:
``` ```
ollama cp llama2 gpt-3.5-turbo ollama cp llama3 gpt-3.5-turbo
``` ```
Afterwards, this new model name can be specified the `model` field: Afterwards, this new model name can be specified the `model` field:

View file

@ -15,7 +15,7 @@ import { Ollama } from "langchain/llms/ollama";
const ollama = new Ollama({ const ollama = new Ollama({
baseUrl: "http://localhost:11434", baseUrl: "http://localhost:11434",
model: "llama2", model: "llama3",
}); });
const answer = await ollama.invoke(`why is the sky blue?`); const answer = await ollama.invoke(`why is the sky blue?`);
@ -23,7 +23,7 @@ const answer = await ollama.invoke(`why is the sky blue?`);
console.log(answer); console.log(answer);
``` ```
That will get us the same thing as if we ran `ollama run llama2 "why is the sky blue"` in the terminal. But we want to load a document from the web to ask a question against. **Cheerio** is a great library for ingesting a webpage, and **LangChain** uses it in their **CheerioWebBaseLoader**. So let's install **Cheerio** and build that part of the app. That will get us the same thing as if we ran `ollama run llama3 "why is the sky blue"` in the terminal. But we want to load a document from the web to ask a question against. **Cheerio** is a great library for ingesting a webpage, and **LangChain** uses it in their **CheerioWebBaseLoader**. So let's install **Cheerio** and build that part of the app.
```bash ```bash
npm install cheerio npm install cheerio

View file

@ -12,15 +12,17 @@ So let's figure out how we can use **LangChain** with Ollama to ask our question
Let's start by asking a simple question that we can get an answer to from the **Llama2** model using **Ollama**. First, we need to install the **LangChain** package: Let's start by asking a simple question that we can get an answer to from the **Llama2** model using **Ollama**. First, we need to install the **LangChain** package:
`pip install langchain` `pip install langchain_community`
Then we can create a model and ask the question: Then we can create a model and ask the question:
```python ```python
from langchain.llms import Ollama from langchain_community.llms import Ollama
ollama = Ollama(base_url='http://localhost:11434', ollama = Ollama(
model="llama2") base_url='http://localhost:11434',
print(ollama("why is the sky blue")) model="llama3"
)
print(ollama.invoke("why is the sky blue"))
``` ```
Notice that we are defining the model and the base URL for Ollama. Notice that we are defining the model and the base URL for Ollama.

View file

@ -1,38 +1,15 @@
# Running Ollama on NVIDIA Jetson Devices # Running Ollama on NVIDIA Jetson Devices
With some minor configuration, Ollama runs well on [NVIDIA Jetson Devices](https://www.nvidia.com/en-us/autonomous-machines/embedded-systems/). The following has been tested on [JetPack 5.1.2](https://developer.nvidia.com/embedded/jetpack). Ollama runs well on [NVIDIA Jetson Devices](https://www.nvidia.com/en-us/autonomous-machines/embedded-systems/) and should run out of the box with the standard installation instructions.
NVIDIA Jetson devices are Linux-based embedded AI computers that are purpose-built for AI applications. The following has been tested on [JetPack 5.1.2](https://developer.nvidia.com/embedded/jetpack), but should also work on JetPack 6.0.
Jetsons have an integrated GPU that is wired directly to the memory controller of the machine. For this reason, the `nvidia-smi` command is unrecognized, and Ollama proceeds to operate in "CPU only"
mode. This can be verified by using a monitoring tool like jtop.
In order to address this, we simply pass the path to the Jetson's pre-installed CUDA libraries into `ollama serve` (while in a tmux session). We then hardcode the num_gpu parameters into a cloned
version of our target model.
Prerequisites:
- curl
- tmux
Here are the steps:
- Install Ollama via standard Linux command (ignore the 404 error): `curl https://ollama.com/install.sh | sh` - Install Ollama via standard Linux command (ignore the 404 error): `curl https://ollama.com/install.sh | sh`
- Stop the Ollama service: `sudo systemctl stop ollama`
- Start Ollama serve in a tmux session called ollama_jetson and reference the CUDA libraries path: `tmux has-session -t ollama_jetson 2>/dev/null || tmux new-session -d -s ollama_jetson
'LD_LIBRARY_PATH=/usr/local/cuda/lib64 ollama serve'`
- Pull the model you want to use (e.g. mistral): `ollama pull mistral` - Pull the model you want to use (e.g. mistral): `ollama pull mistral`
- Create a new Modelfile specifically for enabling GPU support on the Jetson: `touch ModelfileMistralJetson` - Start an interactive session: `ollama run mistral`
- In the ModelfileMistralJetson file, specify the FROM model and the num_gpu PARAMETER as shown below:
```
FROM mistral
PARAMETER num_gpu 999
```
- Create a new model from your Modelfile: `ollama create mistral-jetson -f ./ModelfileMistralJetson`
- Run the new model: `ollama run mistral-jetson`
If you run a monitoring tool like jtop you should now see that Ollama is using the Jetson's integrated GPU.
And that's it! And that's it!
# Running Ollama in Docker
When running GPU accelerated applications in Docker, it is highly recommended to use [dusty-nv jetson-containers repo](https://github.com/dusty-nv/jetson-containers).

View file

@ -14,7 +14,7 @@ As this is a preview release, you should expect a few bugs here and there. If
you run into a problem you can reach out on you run into a problem you can reach out on
[Discord](https://discord.gg/ollama), or file an [Discord](https://discord.gg/ollama), or file an
[issue](https://github.com/ollama/ollama/issues). [issue](https://github.com/ollama/ollama/issues).
Logs will often be helpful in dianosing the problem (see Logs will often be helpful in diagnosing the problem (see
[Troubleshooting](#troubleshooting) below) [Troubleshooting](#troubleshooting) below)
## System Requirements ## System Requirements
@ -27,7 +27,7 @@ Logs will often be helpful in dianosing the problem (see
Here's a quick example showing API access from `powershell` Here's a quick example showing API access from `powershell`
```powershell ```powershell
(Invoke-WebRequest -method POST -Body '{"model":"llama2", "prompt":"Why is the sky blue?", "stream": false}' -uri http://localhost:11434/api/generate ).Content | ConvertFrom-json (Invoke-WebRequest -method POST -Body '{"model":"llama3", "prompt":"Why is the sky blue?", "stream": false}' -uri http://localhost:11434/api/generate ).Content | ConvertFrom-json
``` ```
## Troubleshooting ## Troubleshooting
@ -45,3 +45,17 @@ the explorer window by hitting `<cmd>+R` and type in:
- `explorer %LOCALAPPDATA%\Programs\Ollama` contains the binaries (The installer adds this to your user PATH) - `explorer %LOCALAPPDATA%\Programs\Ollama` contains the binaries (The installer adds this to your user PATH)
- `explorer %HOMEPATH%\.ollama` contains models and configuration - `explorer %HOMEPATH%\.ollama` contains models and configuration
- `explorer %TEMP%` contains temporary executable files in one or more `ollama*` directories - `explorer %TEMP%` contains temporary executable files in one or more `ollama*` directories
## Standalone CLI
The easiest way to install Ollama on Windows is to use the `OllamaSetup.exe`
installer. It installs in your account without requiring Administrator rights.
We update Ollama regularly to support the latest models, and this installer will
help you keep up to date.
If you'd like to install or integrate Ollama as a service, a standalone
`ollama-windows-amd64.zip` zip file is available containing only the Ollama CLI
and GPU library dependencies for Nvidia and AMD. This allows for embedding
Ollama in existing applications, or running it as a system service via `ollama
serve` with tools such as [NSSM](https://nssm.cc/).

View file

@ -2,7 +2,7 @@
When calling `ollama`, you can pass it a file to run all the prompts in the file, one after the other: When calling `ollama`, you can pass it a file to run all the prompts in the file, one after the other:
`ollama run llama2 < sourcequestions.txt` `ollama run llama3 < sourcequestions.txt`
This concept is used in the following example. This concept is used in the following example.

1
examples/flyio/.gitignore vendored Normal file
View file

@ -0,0 +1 @@
fly.toml

67
examples/flyio/README.md Normal file
View file

@ -0,0 +1,67 @@
# Deploy Ollama to Fly.io
> Note: this example exposes a public endpoint and does not configure authentication. Use with care.
## Prerequisites
- Ollama: https://ollama.com/download
- Fly.io account. Sign up for a free account: https://fly.io/app/sign-up
## Steps
1. Login to Fly.io
```bash
fly auth login
```
1. Create a new Fly app
```bash
fly launch --name <name> --image ollama/ollama --internal-port 11434 --vm-size shared-cpu-8x --now
```
1. Pull and run `orca-mini:3b`
```bash
OLLAMA_HOST=https://<name>.fly.dev ollama run orca-mini:3b
```
`shared-cpu-8x` is a free-tier eligible machine type. For better performance, switch to a `performance` or `dedicated` machine type or attach a GPU for hardware acceleration (see below).
## (Optional) Persistent Volume
By default Fly Machines use ephemeral storage which is problematic if you want to use the same model across restarts without pulling it again. Create and attach a persistent volume to store the downloaded models:
1. Create the Fly Volume
```bash
fly volume create ollama
```
1. Update `fly.toml` and add `[mounts]`
```toml
[mounts]
source = "ollama"
destination = "/mnt/ollama/models"
```
1. Update `fly.toml` and add `[env]`
```toml
[env]
OLLAMA_MODELS = "/mnt/ollama/models"
```
1. Deploy your app
```bash
fly deploy
```
## (Optional) Hardware Acceleration
Fly.io GPU is currently in waitlist. Sign up for the waitlist: https://fly.io/gpu
Once you've been accepted, create the app with the additional flags `--vm-gpu-kind a100-pcie-40gb` or `--vm-gpu-kind a100-pcie-80gb`.

View file

@ -35,7 +35,7 @@ func main() {
ctx := context.Background() ctx := context.Background()
req := &api.ChatRequest{ req := &api.ChatRequest{
Model: "llama2", Model: "llama3",
Messages: messages, Messages: messages,
} }

View file

@ -7,12 +7,24 @@
## Steps ## Steps
1. Create the Ollama namespace, daemon set, and service 1. Create the Ollama namespace, deployment, and service
```bash ```bash
kubectl apply -f cpu.yaml kubectl apply -f cpu.yaml
``` ```
## (Optional) Hardware Acceleration
Hardware acceleration in Kubernetes requires NVIDIA's [`k8s-device-plugin`](https://github.com/NVIDIA/k8s-device-plugin) which is deployed in Kubernetes in form of daemonset. Follow the link for more details.
Once configured, create a GPU enabled Ollama deployment.
```bash
kubectl apply -f gpu.yaml
```
## Test
1. Port forward the Ollama service to connect and use it locally 1. Port forward the Ollama service to connect and use it locally
```bash ```bash
@ -24,13 +36,3 @@
```bash ```bash
ollama run orca-mini:3b ollama run orca-mini:3b
``` ```
## (Optional) Hardware Acceleration
Hardware acceleration in Kubernetes requires NVIDIA's [`k8s-device-plugin`](https://github.com/NVIDIA/k8s-device-plugin). Follow the link for more details.
Once configured, create a GPU enabled Ollama deployment.
```bash
kubectl apply -f gpu.yaml
```

View file

@ -51,7 +51,7 @@ while True:
template=template, template=template,
) )
llm = Ollama(model="llama2:13b", callback_manager=CallbackManager([StreamingStdOutCallbackHandler()])) llm = Ollama(model="llama3:8b", callback_manager=CallbackManager([StreamingStdOutCallbackHandler()]))
qa_chain = RetrievalQA.from_chain_type( qa_chain = RetrievalQA.from_chain_type(
llm, llm,
retriever=vectorstore.as_retriever(), retriever=vectorstore.as_retriever(),

View file

@ -1,12 +1,12 @@
from langchain.llms import Ollama from langchain_community.llms import Ollama
from langchain.document_loaders import WebBaseLoader from langchain_community.document_loaders import WebBaseLoader
from langchain.chains.summarize import load_summarize_chain from langchain.chains.summarize import load_summarize_chain
loader = WebBaseLoader("https://ollama.com/blog/run-llama2-uncensored-locally") loader = WebBaseLoader("https://ollama.com/blog/run-llama2-uncensored-locally")
docs = loader.load() docs = loader.load()
llm = Ollama(model="llama2") llm = Ollama(model="llama3")
chain = load_summarize_chain(llm, chain_type="stuff") chain = load_summarize_chain(llm, chain_type="stuff")
result = chain.run(docs) result = chain.invoke(docs)
print(result) print(result)

View file

@ -4,10 +4,10 @@ This example is a basic "hello world" of using LangChain with Ollama.
## Running the Example ## Running the Example
1. Ensure you have the `llama2` model installed: 1. Ensure you have the `llama3` model installed:
```bash ```bash
ollama pull llama2 ollama pull llama3
``` ```
2. Install the Python Requirements. 2. Install the Python Requirements.
@ -21,4 +21,3 @@ This example is a basic "hello world" of using LangChain with Ollama.
```bash ```bash
python main.py python main.py
``` ```

View file

@ -1,6 +1,6 @@
from langchain.llms import Ollama from langchain.llms import Ollama
input = input("What is your question?") input = input("What is your question?")
llm = Ollama(model="llama2") llm = Ollama(model="llama3")
res = llm.predict(input) res = llm.predict(input)
print (res) print (res)

View file

@ -1,4 +1,4 @@
FROM llama2 FROM llama3
PARAMETER temperature 1 PARAMETER temperature 1
SYSTEM """ SYSTEM """
You are Mario from super mario bros, acting as an assistant. You are Mario from super mario bros, acting as an assistant.

View file

@ -2,12 +2,12 @@
# Example character: Mario # Example character: Mario
This example shows how to create a basic character using Llama2 as the base model. This example shows how to create a basic character using Llama3 as the base model.
To run this example: To run this example:
1. Download the Modelfile 1. Download the Modelfile
2. `ollama pull llama2` to get the base model used in the model file. 2. `ollama pull llama3` to get the base model used in the model file.
3. `ollama create NAME -f ./Modelfile` 3. `ollama create NAME -f ./Modelfile`
4. `ollama run NAME` 4. `ollama run NAME`
@ -18,7 +18,7 @@ Ask it some questions like "Who are you?" or "Is Peach in trouble again?"
What the model file looks like: What the model file looks like:
``` ```
FROM llama2 FROM llama3
PARAMETER temperature 1 PARAMETER temperature 1
SYSTEM """ SYSTEM """
You are Mario from Super Mario Bros, acting as an assistant. You are Mario from Super Mario Bros, acting as an assistant.

View file

@ -2,7 +2,7 @@ import requests
import json import json
import random import random
model = "llama2" model = "llama3"
template = { template = {
"firstName": "", "firstName": "",
"lastName": "", "lastName": "",

View file

@ -12,7 +12,7 @@ countries = [
"France", "France",
] ]
country = random.choice(countries) country = random.choice(countries)
model = "llama2" model = "llama3"
prompt = f"generate one realistically believable sample data set of a persons first name, last name, address in {country}, and phone number. Do not use common names. Respond using JSON. Key names should have no backslashes, values should use plain ascii with no special characters." prompt = f"generate one realistically believable sample data set of a persons first name, last name, address in {country}, and phone number. Do not use common names. Respond using JSON. Key names should have no backslashes, values should use plain ascii with no special characters."

View file

@ -6,10 +6,10 @@ There are two python scripts in this example. `randomaddresses.py` generates ran
## Running the Example ## Running the Example
1. Ensure you have the `llama2` model installed: 1. Ensure you have the `llama3` model installed:
```bash ```bash
ollama pull llama2 ollama pull llama3
``` ```
2. Install the Python Requirements. 2. Install the Python Requirements.

View file

@ -2,7 +2,7 @@ import json
import requests import requests
# NOTE: ollama must be running for this to work, start the ollama app or run `ollama serve` # NOTE: ollama must be running for this to work, start the ollama app or run `ollama serve`
model = "llama2" # TODO: update this for whatever model you wish to use model = "llama3" # TODO: update this for whatever model you wish to use
def chat(messages): def chat(messages):

View file

@ -4,10 +4,10 @@ The **chat** endpoint is one of two ways to generate text from an LLM with Ollam
## Running the Example ## Running the Example
1. Ensure you have the `llama2` model installed: 1. Ensure you have the `llama3` model installed:
```bash ```bash
ollama pull llama2 ollama pull llama3
``` ```
2. Install the Python Requirements. 2. Install the Python Requirements.

View file

@ -4,10 +4,10 @@ This example demonstrates how one would create a set of 'mentors' you can have a
## Usage ## Usage
1. Add llama2 to have the mentors ask your questions: 1. Add llama3 to have the mentors ask your questions:
```bash ```bash
ollama pull llama2 ollama pull llama3
``` ```
2. Install prerequisites: 2. Install prerequisites:

View file

@ -15,7 +15,7 @@ async function characterGenerator() {
ollama.setModel("stablebeluga2:70b-q4_K_M"); ollama.setModel("stablebeluga2:70b-q4_K_M");
const bio = await ollama.generate(`create a bio of ${character} in a single long paragraph. Instead of saying '${character} is...' or '${character} was...' use language like 'You are...' or 'You were...'. Then create a paragraph describing the speaking mannerisms and style of ${character}. Don't include anything about how ${character} looked or what they sounded like, just focus on the words they said. Instead of saying '${character} would say...' use language like 'You should say...'. If you use quotes, always use single quotes instead of double quotes. If there are any specific words or phrases you used a lot, show how you used them. `); const bio = await ollama.generate(`create a bio of ${character} in a single long paragraph. Instead of saying '${character} is...' or '${character} was...' use language like 'You are...' or 'You were...'. Then create a paragraph describing the speaking mannerisms and style of ${character}. Don't include anything about how ${character} looked or what they sounded like, just focus on the words they said. Instead of saying '${character} would say...' use language like 'You should say...'. If you use quotes, always use single quotes instead of double quotes. If there are any specific words or phrases you used a lot, show how you used them. `);
const thecontents = `FROM llama2\nSYSTEM """\n${bio.response.replace(/(\r\n|\n|\r)/gm, " ").replace('would', 'should')} All answers to questions should be related back to what you are most known for.\n"""`; const thecontents = `FROM llama3\nSYSTEM """\n${bio.response.replace(/(\r\n|\n|\r)/gm, " ").replace('would', 'should')} All answers to questions should be related back to what you are most known for.\n"""`;
fs.writeFile(path.join(directory, 'Modelfile'), thecontents, (err: any) => { fs.writeFile(path.join(directory, 'Modelfile'), thecontents, (err: any) => {
if (err) throw err; if (err) throw err;

View file

@ -1,6 +1,6 @@
import * as readline from "readline"; import * as readline from "readline";
const model = "llama2"; const model = "llama3";
type Message = { type Message = {
role: "assistant" | "user" | "system"; role: "assistant" | "user" | "system";
content: string; content: string;

View file

@ -15,6 +15,7 @@ const (
KibiByte = Byte * 1024 KibiByte = Byte * 1024
MebiByte = KibiByte * 1024 MebiByte = KibiByte * 1024
GibiByte = MebiByte * 1024
) )
func HumanBytes(b int64) string { func HumanBytes(b int64) string {
@ -52,6 +53,8 @@ func HumanBytes(b int64) string {
func HumanBytes2(b uint64) string { func HumanBytes2(b uint64) string {
switch { switch {
case b >= GibiByte:
return fmt.Sprintf("%.1f GiB", float64(b)/GibiByte)
case b >= MebiByte: case b >= MebiByte:
return fmt.Sprintf("%.1f MiB", float64(b)/MebiByte) return fmt.Sprintf("%.1f MiB", float64(b)/MebiByte)
case b >= KibiByte: case b >= KibiByte:

View file

@ -13,12 +13,20 @@ const (
func HumanNumber(b uint64) string { func HumanNumber(b uint64) string {
switch { switch {
case b > Billion: case b >= Billion:
return fmt.Sprintf("%.0fB", math.Round(float64(b)/Billion)) number := float64(b) / Billion
case b > Million: if number == math.Floor(number) {
return fmt.Sprintf("%.0fM", math.Round(float64(b)/Million)) return fmt.Sprintf("%.0fB", number) // no decimals if whole number
case b > Thousand: }
return fmt.Sprintf("%.0fK", math.Round(float64(b)/Thousand)) return fmt.Sprintf("%.1fB", number) // one decimal if not a whole number
case b >= Million:
number := float64(b) / Million
if number == math.Floor(number) {
return fmt.Sprintf("%.0fM", number) // no decimals if whole number
}
return fmt.Sprintf("%.2fM", number) // two decimals if not a whole number
case b >= Thousand:
return fmt.Sprintf("%.0fK", float64(b)/Thousand)
default: default:
return fmt.Sprintf("%d", b) return fmt.Sprintf("%d", b)
} }

34
format/format_test.go Normal file
View file

@ -0,0 +1,34 @@
package format
import (
"testing"
)
func TestHumanNumber(t *testing.T) {
type testCase struct {
input uint64
expected string
}
testCases := []testCase{
{0, "0"},
{1000000, "1M"},
{125000000, "125M"},
{500500000, "500.50M"},
{500550000, "500.55M"},
{1000000000, "1B"},
{2800000000, "2.8B"},
{2850000000, "2.9B"},
{1000000000000, "1000B"},
}
for _, tc := range testCases {
t.Run(tc.expected, func(t *testing.T) {
result := HumanNumber(tc.input)
if result != tc.expected {
t.Errorf("Expected %s, got %s", tc.expected, result)
}
})
}
}

View file

@ -7,7 +7,7 @@ import (
"log/slog" "log/slog"
"os" "os"
"path/filepath" "path/filepath"
"strconv" "runtime"
"strings" "strings"
) )
@ -35,22 +35,66 @@ func GetSupportedGFX(libDir string) ([]string, error) {
return ret, nil return ret, nil
} }
func amdSetVisibleDevices(ids []int, skip map[int]interface{}) { func rocmGetVisibleDevicesEnv(gpuInfo []GpuInfo) (string, string) {
// Set the visible devices if not already set ids := []string{}
// TODO - does sort order matter? for _, info := range gpuInfo {
devices := []string{} if info.Library != "rocm" {
for i := range ids { // TODO shouldn't happen if things are wired correctly...
if _, skipped := skip[i]; skipped { slog.Debug("rocmGetVisibleDevicesEnv skipping over non-rocm device", "library", info.Library)
continue continue
} }
devices = append(devices, strconv.Itoa(i)) ids = append(ids, info.ID)
}
return "HIP_VISIBLE_DEVICES", strings.Join(ids, ",")
} }
val := strings.Join(devices, ",") func commonAMDValidateLibDir() (string, error) {
err := os.Setenv("HIP_VISIBLE_DEVICES", val) // We try to favor system paths first, so that we can wire up the subprocess to use
// the system version. Only use our bundled version if the system version doesn't work
// This gives users a more recovery options if versions have subtle problems at runtime
// Prefer explicit HIP env var
hipPath := os.Getenv("HIP_PATH")
if hipPath != "" {
hipLibDir := filepath.Join(hipPath, "bin")
if rocmLibUsable(hipLibDir) {
slog.Debug("detected ROCM via HIP_PATH=" + hipPath)
return hipLibDir, nil
}
}
// Scan the LD_LIBRARY_PATH or PATH
pathEnv := "LD_LIBRARY_PATH"
if runtime.GOOS == "windows" {
pathEnv = "PATH"
}
paths := os.Getenv(pathEnv)
for _, path := range filepath.SplitList(paths) {
d, err := filepath.Abs(path)
if err != nil { if err != nil {
slog.Warn(fmt.Sprintf("failed to set env: %s", err)) continue
} else { }
slog.Info("Setting HIP_VISIBLE_DEVICES=" + val) if rocmLibUsable(d) {
return d, nil
} }
} }
// Well known location(s)
for _, path := range RocmStandardLocations {
if rocmLibUsable(path) {
return path, nil
}
}
// Installer payload location if we're running the installed binary
exe, err := os.Executable()
if err == nil {
rocmTargetDir := filepath.Join(filepath.Dir(exe), "rocm")
if rocmLibUsable(rocmTargetDir) {
slog.Debug("detected ROCM next to ollama executable " + rocmTargetDir)
return rocmTargetDir, nil
}
}
return "", fmt.Errorf("no suitable rocm found, falling back to CPU")
}

View file

@ -69,7 +69,7 @@ func NewHipLib() (*HipLib, error) {
func (hl *HipLib) Release() { func (hl *HipLib) Release() {
err := windows.FreeLibrary(hl.dll) err := windows.FreeLibrary(hl.dll)
if err != nil { if err != nil {
slog.Warn(fmt.Sprintf("failed to unload amdhip64.dll: %s", err)) slog.Warn("failed to unload amdhip64.dll", "error", err)
} }
hl.dll = 0 hl.dll = 0
} }
@ -98,7 +98,7 @@ func (hl *HipLib) HipGetDeviceCount() int {
return 0 return 0
} }
if status != hipSuccess { if status != hipSuccess {
slog.Warn(fmt.Sprintf("failed call to hipGetDeviceCount: %d %s", status, err)) slog.Warn("failed call to hipGetDeviceCount", "status", status, "error", err)
} }
return count return count
} }

View file

@ -11,6 +11,8 @@ import (
"slices" "slices"
"strconv" "strconv"
"strings" "strings"
"github.com/ollama/ollama/format"
) )
// Discovery logic for AMD/ROCm GPUs // Discovery logic for AMD/ROCm GPUs
@ -23,26 +25,20 @@ const (
// Prefix with the node dir // Prefix with the node dir
GPUTotalMemoryFileGlob = "mem_banks/*/properties" // size_in_bytes line GPUTotalMemoryFileGlob = "mem_banks/*/properties" // size_in_bytes line
GPUUsedMemoryFileGlob = "mem_banks/*/used_memory" GPUUsedMemoryFileGlob = "mem_banks/*/used_memory"
RocmStandardLocation = "/opt/rocm/lib"
// TODO find a better way to detect iGPU instead of minimum memory
IGPUMemLimit = 1024 * 1024 * 1024 // 512G is what they typically report, so anything less than 1G must be iGPU
) )
var ( var (
// Used to validate if the given ROCm lib is usable // Used to validate if the given ROCm lib is usable
ROCmLibGlobs = []string{"libhipblas.so.2*", "rocblas"} // TODO - probably include more coverage of files here... ROCmLibGlobs = []string{"libhipblas.so.2*", "rocblas"} // TODO - probably include more coverage of files here...
RocmStandardLocations = []string{"/opt/rocm/lib", "/usr/lib64"}
) )
// Gather GPU information from the amdgpu driver if any supported GPUs are detected // Gather GPU information from the amdgpu driver if any supported GPUs are detected
// HIP_VISIBLE_DEVICES will be set if we detect a mix of unsupported and supported devices func AMDGetGPUInfo() []GpuInfo {
// and the user hasn't already set this variable resp := []GpuInfo{}
func AMDGetGPUInfo(resp *GpuInfo) {
// TODO - DRY this out with windows
if !AMDDetected() { if !AMDDetected() {
return return resp
} }
skip := map[int]interface{}{}
// Opportunistic logging of driver version to aid in troubleshooting // Opportunistic logging of driver version to aid in troubleshooting
ver, err := AMDDriverVersion() ver, err := AMDDriverVersion()
@ -50,160 +46,117 @@ func AMDGetGPUInfo(resp *GpuInfo) {
slog.Info("AMD Driver: " + ver) slog.Info("AMD Driver: " + ver)
} else { } else {
// TODO - if we see users crash and burn with the upstreamed kernel this can be adjusted to hard-fail rocm support and fallback to CPU // TODO - if we see users crash and burn with the upstreamed kernel this can be adjusted to hard-fail rocm support and fallback to CPU
slog.Warn(fmt.Sprintf("ollama recommends running the https://www.amd.com/en/support/linux-drivers: %s", err)) slog.Warn("ollama recommends running the https://www.amd.com/en/support/linux-drivers", "error", err)
} }
// If the user has specified exactly which GPUs to use, look up their memory // Determine if the user has already pre-selected which GPUs to look at, then ignore the others
visibleDevices := os.Getenv("HIP_VISIBLE_DEVICES") var visibleDevices []string
if visibleDevices != "" { hipVD := os.Getenv("HIP_VISIBLE_DEVICES") // zero based index only
ids := []int{} rocrVD := os.Getenv("ROCR_VISIBLE_DEVICES") // zero based index or UUID, but consumer cards seem to not support UUID
for _, idStr := range strings.Split(visibleDevices, ",") { gpuDO := os.Getenv("GPU_DEVICE_ORDINAL") // zero based index
id, err := strconv.Atoi(idStr) switch {
if err != nil { // TODO is this priorty order right?
slog.Warn(fmt.Sprintf("malformed HIP_VISIBLE_DEVICES=%s %s", visibleDevices, err)) case hipVD != "":
} else { visibleDevices = strings.Split(hipVD, ",")
ids = append(ids, id) case rocrVD != "":
visibleDevices = strings.Split(rocrVD, ",")
// TODO - since we don't yet support UUIDs, consider detecting and reporting here
// all our test systems show GPU-XX indicating UUID is not supported
case gpuDO != "":
visibleDevices = strings.Split(gpuDO, ",")
} }
}
amdProcMemLookup(resp, nil, ids)
return
}
// Gather GFX version information from all detected cards
gfx := AMDGFXVersions()
verStrings := []string{}
for i, v := range gfx {
verStrings = append(verStrings, v.ToGFXString())
if v.Major == 0 {
// Silently skip CPUs
skip[i] = struct{}{}
continue
}
if v.Major < 9 {
// TODO consider this a build-time setting if we can support 8xx family GPUs
slog.Warn(fmt.Sprintf("amdgpu [%d] too old %s", i, v.ToGFXString()))
skip[i] = struct{}{}
}
}
slog.Info(fmt.Sprintf("detected amdgpu versions %v", verStrings))
// Abort if all GPUs are skipped
if len(skip) >= len(gfx) {
slog.Info("all detected amdgpus are skipped, falling back to CPU")
return
}
// If we got this far, then we have at least 1 GPU that's a ROCm candidate, so make sure we have a lib
libDir, err := AMDValidateLibDir()
if err != nil {
slog.Warn(fmt.Sprintf("unable to verify rocm library, will use cpu: %s", err))
return
}
updateLibPath(libDir)
gfxOverride := os.Getenv("HSA_OVERRIDE_GFX_VERSION") gfxOverride := os.Getenv("HSA_OVERRIDE_GFX_VERSION")
if gfxOverride == "" { var supported []string
supported, err := GetSupportedGFX(libDir) libDir := ""
// The amdgpu driver always exposes the host CPU(s) first, but we have to skip them and subtract
// from the other IDs to get alignment with the HIP libraries expectations (zero is the first GPU, not the CPU)
matches, _ := filepath.Glob(GPUPropertiesFileGlob)
cpuCount := 0
for _, match := range matches {
slog.Debug("evaluating amdgpu node " + match)
fp, err := os.Open(match)
if err != nil { if err != nil {
slog.Warn(fmt.Sprintf("failed to lookup supported GFX types, falling back to CPU mode: %s", err)) slog.Debug("failed to open sysfs node", "file", match, "error", err)
return
}
slog.Debug(fmt.Sprintf("rocm supported GPU types %v", supported))
for i, v := range gfx {
if !slices.Contains[[]string, string](supported, v.ToGFXString()) {
slog.Warn(fmt.Sprintf("amdgpu [%d] %s is not supported by %s %v", i, v.ToGFXString(), libDir, supported))
// TODO - consider discrete markdown just for ROCM troubleshooting?
slog.Warn("See https://github.com/ollama/ollama/blob/main/docs/gpu.md#overrides for HSA_OVERRIDE_GFX_VERSION usage")
skip[i] = struct{}{}
} else {
slog.Info(fmt.Sprintf("amdgpu [%d] %s is supported", i, v.ToGFXString()))
}
}
} else {
slog.Debug("skipping rocm gfx compatibility check with HSA_OVERRIDE_GFX_VERSION=" + gfxOverride)
}
if len(skip) >= len(gfx) {
slog.Info("all detected amdgpus are skipped, falling back to CPU")
return
}
ids := make([]int, len(gfx))
i := 0
for k := range gfx {
ids[i] = k
i++
}
amdProcMemLookup(resp, skip, ids)
if resp.memInfo.DeviceCount == 0 {
return
}
if len(skip) > 0 {
amdSetVisibleDevices(ids, skip)
}
}
func updateLibPath(libDir string) {
ldPaths := []string{}
if val, ok := os.LookupEnv("LD_LIBRARY_PATH"); ok {
ldPaths = strings.Split(val, ":")
}
for _, d := range ldPaths {
if d == libDir {
return
}
}
val := strings.Join(append(ldPaths, libDir), ":")
slog.Debug("updated lib path", "LD_LIBRARY_PATH", val)
os.Setenv("LD_LIBRARY_PATH", val)
}
// Walk the sysfs nodes for the available GPUs and gather information from them
// skipping over any devices in the skip map
func amdProcMemLookup(resp *GpuInfo, skip map[int]interface{}, ids []int) {
resp.memInfo.DeviceCount = 0
resp.memInfo.TotalMemory = 0
resp.memInfo.FreeMemory = 0
slog.Debug("discovering VRAM for amdgpu devices")
if len(ids) == 0 {
entries, err := os.ReadDir(AMDNodesSysfsDir)
if err != nil {
slog.Warn(fmt.Sprintf("failed to read amdgpu sysfs %s - %s", AMDNodesSysfsDir, err))
return
}
for _, node := range entries {
if !node.IsDir() {
continue continue
} }
id, err := strconv.Atoi(node.Name()) defer fp.Close()
nodeID, err := strconv.Atoi(filepath.Base(filepath.Dir(match)))
if err != nil { if err != nil {
slog.Warn("malformed amdgpu sysfs node id " + node.Name()) slog.Debug("failed to parse node ID", "error", err)
continue continue
} }
ids = append(ids, id)
}
}
slog.Debug(fmt.Sprintf("amdgpu devices %v", ids))
for _, id := range ids { scanner := bufio.NewScanner(fp)
if _, skipped := skip[id]; skipped { isCPU := false
var major, minor, patch uint64
for scanner.Scan() {
line := strings.TrimSpace(scanner.Text())
// Note: we could also use "cpu_cores_count X" where X is greater than zero to detect CPUs
if strings.HasPrefix(line, "gfx_target_version") {
ver := strings.Fields(line)
// Detect CPUs
if len(ver) == 2 && ver[1] == "0" {
slog.Debug("detected CPU " + match)
isCPU = true
break
}
if len(ver) != 2 || len(ver[1]) < 5 {
slog.Warn("malformed "+match, "gfx_target_version", line)
// If this winds up being a CPU, our offsets may be wrong
continue continue
} }
l := len(ver[1])
var err1, err2, err3 error
patch, err1 = strconv.ParseUint(ver[1][l-2:l], 10, 32)
minor, err2 = strconv.ParseUint(ver[1][l-4:l-2], 10, 32)
major, err3 = strconv.ParseUint(ver[1][:l-4], 10, 32)
if err1 != nil || err2 != nil || err3 != nil {
slog.Debug("malformed int " + line)
continue
}
}
// TODO - any other properties we want to extract and record?
// vendor_id + device_id -> pci lookup for "Name"
// Other metrics that may help us understand relative performance between multiple GPUs
}
if isCPU {
cpuCount++
continue
}
// CPUs are always first in the list
gpuID := nodeID - cpuCount
// Shouldn't happen, but just in case...
if gpuID < 0 {
slog.Error("unexpected amdgpu sysfs data resulted in negative GPU ID, please set OLLAMA_DEBUG=1 and report an issue")
return []GpuInfo{}
}
if int(major) < RocmComputeMin {
slog.Warn(fmt.Sprintf("amdgpu too old gfx%d%d%x", major, minor, patch), "gpu", gpuID)
continue
}
// Look up the memory for the current node
totalMemory := uint64(0) totalMemory := uint64(0)
usedMemory := uint64(0) usedMemory := uint64(0)
// Adjust for sysfs vs HIP ids propGlob := filepath.Join(AMDNodesSysfsDir, strconv.Itoa(nodeID), GPUTotalMemoryFileGlob)
propGlob := filepath.Join(AMDNodesSysfsDir, strconv.Itoa(id+1), GPUTotalMemoryFileGlob)
propFiles, err := filepath.Glob(propGlob) propFiles, err := filepath.Glob(propGlob)
if err != nil { if err != nil {
slog.Warn(fmt.Sprintf("error looking up total GPU memory: %s %s", propGlob, err)) slog.Warn("error looking up total GPU memory", "glob", propGlob, "error", err)
} }
// 1 or more memory banks - sum the values of all of them // 1 or more memory banks - sum the values of all of them
for _, propFile := range propFiles { for _, propFile := range propFiles {
fp, err := os.Open(propFile) fp, err := os.Open(propFile)
if err != nil { if err != nil {
slog.Warn(fmt.Sprintf("failed to open sysfs node file %s: %s", propFile, err)) slog.Warn("failed to open sysfs node", "file", propFile, "erroir", err)
continue continue
} }
defer fp.Close() defer fp.Close()
@ -226,49 +179,113 @@ func amdProcMemLookup(resp *GpuInfo, skip map[int]interface{}, ids []int) {
} }
} }
if totalMemory == 0 { if totalMemory == 0 {
slog.Warn(fmt.Sprintf("amdgpu [%d] reports zero total memory, skipping", id)) slog.Warn("amdgpu reports zero total memory", "gpu", gpuID)
skip[id] = struct{}{}
continue continue
} }
if totalMemory < IGPUMemLimit { usedGlob := filepath.Join(AMDNodesSysfsDir, strconv.Itoa(nodeID), GPUUsedMemoryFileGlob)
slog.Info(fmt.Sprintf("amdgpu [%d] appears to be an iGPU with %dM reported total memory, skipping", id, totalMemory/1024/1024))
skip[id] = struct{}{}
continue
}
usedGlob := filepath.Join(AMDNodesSysfsDir, strconv.Itoa(id), GPUUsedMemoryFileGlob)
usedFiles, err := filepath.Glob(usedGlob) usedFiles, err := filepath.Glob(usedGlob)
if err != nil { if err != nil {
slog.Warn(fmt.Sprintf("error looking up used GPU memory: %s %s", usedGlob, err)) slog.Warn("error looking up used GPU memory", "glob", usedGlob, "error", err)
continue continue
} }
for _, usedFile := range usedFiles { for _, usedFile := range usedFiles {
fp, err := os.Open(usedFile) fp, err := os.Open(usedFile)
if err != nil { if err != nil {
slog.Warn(fmt.Sprintf("failed to open sysfs node file %s: %s", usedFile, err)) slog.Warn("failed to open sysfs node", "file", usedFile, "error", err)
continue continue
} }
defer fp.Close() defer fp.Close()
data, err := io.ReadAll(fp) data, err := io.ReadAll(fp)
if err != nil { if err != nil {
slog.Warn(fmt.Sprintf("failed to read sysfs node file %s: %s", usedFile, err)) slog.Warn("failed to read sysfs node", "file", usedFile, "error", err)
continue continue
} }
used, err := strconv.ParseUint(strings.TrimSpace(string(data)), 10, 64) used, err := strconv.ParseUint(strings.TrimSpace(string(data)), 10, 64)
if err != nil { if err != nil {
slog.Warn(fmt.Sprintf("malformed used memory %s: %s", string(data), err)) slog.Warn("malformed used memory", "data", string(data), "error", err)
continue continue
} }
usedMemory += used usedMemory += used
} }
slog.Info(fmt.Sprintf("[%d] amdgpu totalMemory %dM", id, totalMemory/1024/1024))
slog.Info(fmt.Sprintf("[%d] amdgpu freeMemory %dM", id, (totalMemory-usedMemory)/1024/1024)) // iGPU detection, remove this check once we can support an iGPU variant of the rocm library
resp.memInfo.DeviceCount++ if totalMemory < IGPUMemLimit {
resp.memInfo.TotalMemory += totalMemory slog.Info("amdgpu appears to be an iGPU, skipping", "gpu", gpuID, "total", format.HumanBytes2(totalMemory))
resp.memInfo.FreeMemory += (totalMemory - usedMemory) continue
} }
if resp.memInfo.DeviceCount > 0 {
resp.Library = "rocm" slog.Info("amdgpu memory", "gpu", gpuID, "total", format.HumanBytes2(totalMemory))
slog.Info("amdgpu memory", "gpu", gpuID, "available", format.HumanBytes2(totalMemory-usedMemory))
gpuInfo := GpuInfo{
Library: "rocm",
memInfo: memInfo{
TotalMemory: totalMemory,
FreeMemory: (totalMemory - usedMemory),
},
ID: fmt.Sprintf("%d", gpuID),
// Name: not exposed in sysfs directly, would require pci device id lookup
Major: int(major),
Minor: int(minor),
Patch: int(patch),
MinimumMemory: rocmMinimumMemory,
} }
// If the user wants to filter to a subset of devices, filter out if we aren't a match
if len(visibleDevices) > 0 {
include := false
for _, visible := range visibleDevices {
if visible == gpuInfo.ID {
include = true
break
}
}
if !include {
slog.Info("filtering out device per user request", "id", gpuInfo.ID, "visible_devices", visibleDevices)
continue
}
}
// Final validation is gfx compatibility - load the library if we haven't already loaded it
// even if the user overrides, we still need to validate the library
if libDir == "" {
libDir, err = AMDValidateLibDir()
if err != nil {
slog.Warn("unable to verify rocm library, will use cpu", "error", err)
return []GpuInfo{}
}
}
gpuInfo.DependencyPath = libDir
if gfxOverride == "" {
// Only load supported list once
if len(supported) == 0 {
supported, err = GetSupportedGFX(libDir)
if err != nil {
slog.Warn("failed to lookup supported GFX types, falling back to CPU mode", "error", err)
return []GpuInfo{}
}
slog.Debug("rocm supported GPUs", "types", supported)
}
gfx := fmt.Sprintf("gfx%d%d%x", gpuInfo.Major, gpuInfo.Minor, gpuInfo.Patch)
if !slices.Contains[[]string, string](supported, gfx) {
slog.Warn("amdgpu is not supported", "gpu", gpuInfo.ID, "gpu_type", gfx, "library", libDir, "supported_types", supported)
// TODO - consider discrete markdown just for ROCM troubleshooting?
slog.Warn("See https://github.com/ollama/ollama/blob/main/docs/gpu.md#overrides for HSA_OVERRIDE_GFX_VERSION usage")
continue
} else {
slog.Info("amdgpu is supported", "gpu", gpuInfo.ID, "gpu_type", gfx)
}
} else {
slog.Debug("skipping rocm gfx compatibility check with HSA_OVERRIDE_GFX_VERSION=" + gfxOverride)
}
// The GPU has passed all the verification steps and is supported
resp = append(resp, gpuInfo)
}
if len(resp) == 0 {
slog.Info("no compatible amdgpu devices detected")
}
return resp
} }
// Quick check for AMD driver so we can skip amdgpu discovery if not present // Quick check for AMD driver so we can skip amdgpu discovery if not present
@ -280,87 +297,24 @@ func AMDDetected() bool {
slog.Debug("amdgpu driver not detected " + sysfsDir) slog.Debug("amdgpu driver not detected " + sysfsDir)
return false return false
} else if err != nil { } else if err != nil {
slog.Debug(fmt.Sprintf("error looking up amd driver %s %s", sysfsDir, err)) slog.Debug("error looking up amd driver", "path", sysfsDir, "error", err)
return false return false
} }
return true return true
} }
func setupLink(source, target string) error {
if err := os.RemoveAll(target); err != nil {
return fmt.Errorf("failed to remove old rocm directory %s %w", target, err)
}
if err := os.Symlink(source, target); err != nil {
return fmt.Errorf("failed to create link %s => %s %w", source, target, err)
}
slog.Debug(fmt.Sprintf("host rocm linked %s => %s", source, target))
return nil
}
// Ensure the AMD rocm lib dir is wired up
// Prefer to use host installed ROCm, as long as it meets our minimum requirements // Prefer to use host installed ROCm, as long as it meets our minimum requirements
// failing that, tell the user how to download it on their own // failing that, tell the user how to download it on their own
func AMDValidateLibDir() (string, error) { func AMDValidateLibDir() (string, error) {
// We rely on the rpath compiled into our library to find rocm libDir, err := commonAMDValidateLibDir()
// so we establish a symlink to wherever we find it on the system
// to <payloads>/rocm
payloadsDir, err := PayloadsDir()
if err != nil {
return "", err
}
// If we already have a rocm dependency wired, nothing more to do
rocmTargetDir := filepath.Clean(filepath.Join(payloadsDir, "..", "rocm"))
if rocmLibUsable(rocmTargetDir) {
return rocmTargetDir, nil
}
// next to the running binary
exe, err := os.Executable()
if err == nil { if err == nil {
peerDir := filepath.Dir(exe) return libDir, nil
if rocmLibUsable(peerDir) {
slog.Debug("detected ROCM next to ollama executable " + peerDir)
return rocmTargetDir, setupLink(peerDir, rocmTargetDir)
}
peerDir = filepath.Join(filepath.Dir(exe), "rocm")
if rocmLibUsable(peerDir) {
slog.Debug("detected ROCM next to ollama executable " + peerDir)
return rocmTargetDir, setupLink(peerDir, rocmTargetDir)
}
} }
// Well known ollama installer path // Well known ollama installer path
installedRocmDir := "/usr/share/ollama/lib/rocm" installedRocmDir := "/usr/share/ollama/lib/rocm"
if rocmLibUsable(installedRocmDir) { if rocmLibUsable(installedRocmDir) {
return rocmTargetDir, setupLink(installedRocmDir, rocmTargetDir) return installedRocmDir, nil
}
// Prefer explicit HIP env var
hipPath := os.Getenv("HIP_PATH")
if hipPath != "" {
hipLibDir := filepath.Join(hipPath, "lib")
if rocmLibUsable(hipLibDir) {
slog.Debug("detected ROCM via HIP_PATH=" + hipPath)
return rocmTargetDir, setupLink(hipLibDir, rocmTargetDir)
}
}
// Scan the library path for potential matches
ldPaths := strings.Split(os.Getenv("LD_LIBRARY_PATH"), ":")
for _, ldPath := range ldPaths {
d, err := filepath.Abs(ldPath)
if err != nil {
continue
}
if rocmLibUsable(d) {
return rocmTargetDir, setupLink(d, rocmTargetDir)
}
}
// Well known location(s)
if rocmLibUsable("/opt/rocm/lib") {
return rocmTargetDir, setupLink("/opt/rocm/lib", rocmTargetDir)
} }
// If we still haven't found a usable rocm, the user will have to install it on their own // If we still haven't found a usable rocm, the user will have to install it on their own
@ -384,68 +338,3 @@ func AMDDriverVersion() (string, error) {
} }
return strings.TrimSpace(string(verString)), nil return strings.TrimSpace(string(verString)), nil
} }
func AMDGFXVersions() map[int]Version {
// The amdgpu driver always exposes the host CPU as node 0, but we have to skip that and subtract one
// from the other IDs to get alignment with the HIP libraries expectations (zero is the first GPU, not the CPU)
res := map[int]Version{}
matches, _ := filepath.Glob(GPUPropertiesFileGlob)
for _, match := range matches {
fp, err := os.Open(match)
if err != nil {
slog.Debug(fmt.Sprintf("failed to open sysfs node file %s: %s", match, err))
continue
}
defer fp.Close()
i, err := strconv.Atoi(filepath.Base(filepath.Dir(match)))
if err != nil {
slog.Debug(fmt.Sprintf("failed to parse node ID %s", err))
continue
}
if i == 0 {
// Skipping the CPU
continue
}
// Align with HIP IDs (zero is first GPU, not CPU)
i -= 1
scanner := bufio.NewScanner(fp)
for scanner.Scan() {
line := strings.TrimSpace(scanner.Text())
if strings.HasPrefix(line, "gfx_target_version") {
ver := strings.Fields(line)
if len(ver) != 2 || len(ver[1]) < 5 {
if ver[1] != "0" {
slog.Debug("malformed " + line)
}
res[i] = Version{
Major: 0,
Minor: 0,
Patch: 0,
}
continue
}
l := len(ver[1])
patch, err1 := strconv.ParseUint(ver[1][l-2:l], 10, 32)
minor, err2 := strconv.ParseUint(ver[1][l-4:l-2], 10, 32)
major, err3 := strconv.ParseUint(ver[1][:l-4], 10, 32)
if err1 != nil || err2 != nil || err3 != nil {
slog.Debug("malformed int " + line)
continue
}
res[i] = Version{
Major: uint(major),
Minor: uint(minor),
Patch: uint(patch),
}
}
}
}
return res
}
func (v Version) ToGFXString() string {
return fmt.Sprintf("gfx%d%d%d", v.Major, v.Minor, v.Patch)
}

View file

@ -7,11 +7,13 @@ import (
"os" "os"
"path/filepath" "path/filepath"
"slices" "slices"
"strconv"
"strings" "strings"
"github.com/ollama/ollama/format"
) )
const ( const (
RocmStandardLocation = "C:\\Program Files\\AMD\\ROCm\\5.7\\bin" // TODO glob?
// TODO We're lookinng for this exact name to detect iGPUs since hipGetDeviceProperties never reports integrated==true // TODO We're lookinng for this exact name to detect iGPUs since hipGetDeviceProperties never reports integrated==true
iGPUName = "AMD Radeon(TM) Graphics" iGPUName = "AMD Radeon(TM) Graphics"
@ -20,38 +22,35 @@ const (
var ( var (
// Used to validate if the given ROCm lib is usable // Used to validate if the given ROCm lib is usable
ROCmLibGlobs = []string{"hipblas.dll", "rocblas"} // TODO - probably include more coverage of files here... ROCmLibGlobs = []string{"hipblas.dll", "rocblas"} // TODO - probably include more coverage of files here...
RocmStandardLocations = []string{"C:\\Program Files\\AMD\\ROCm\\5.7\\bin"} // TODO glob?
) )
func AMDGetGPUInfo(resp *GpuInfo) { func AMDGetGPUInfo() []GpuInfo {
resp := []GpuInfo{}
hl, err := NewHipLib() hl, err := NewHipLib()
if err != nil { if err != nil {
slog.Debug(err.Error()) slog.Debug(err.Error())
return return nil
} }
defer hl.Release() defer hl.Release()
skip := map[int]interface{}{}
ids := []int{}
resp.memInfo.DeviceCount = 0
resp.memInfo.TotalMemory = 0
resp.memInfo.FreeMemory = 0
ver, err := hl.AMDDriverVersion() ver, err := hl.AMDDriverVersion()
if err == nil { if err == nil {
slog.Info("AMD Driver: " + ver) slog.Info("AMD Driver: " + ver)
} else { } else {
// For now this is benign, but we may eventually need to fail compatibility checks // For now this is benign, but we may eventually need to fail compatibility checks
slog.Debug(fmt.Sprintf("error looking up amd driver version: %s", err)) slog.Debug("error looking up amd driver version", "error", err)
} }
// Note: the HIP library automatically handles HIP_VISIBLE_DEVICES // Note: the HIP library automatically handles subsetting to any HIP_VISIBLE_DEVICES the user specified
count := hl.HipGetDeviceCount() count := hl.HipGetDeviceCount()
if count == 0 { if count == 0 {
return return nil
} }
libDir, err := AMDValidateLibDir() libDir, err := AMDValidateLibDir()
if err != nil { if err != nil {
slog.Warn(fmt.Sprintf("unable to verify rocm library, will use cpu: %s", err)) slog.Warn("unable to verify rocm library, will use cpu", "error", err)
return return nil
} }
var supported []string var supported []string
@ -59,95 +58,120 @@ func AMDGetGPUInfo(resp *GpuInfo) {
if gfxOverride == "" { if gfxOverride == "" {
supported, err = GetSupportedGFX(libDir) supported, err = GetSupportedGFX(libDir)
if err != nil { if err != nil {
slog.Warn(fmt.Sprintf("failed to lookup supported GFX types, falling back to CPU mode: %s", err)) slog.Warn("failed to lookup supported GFX types, falling back to CPU mode", "error", err)
return return nil
} }
} else { } else {
slog.Debug("skipping rocm gfx compatibility check with HSA_OVERRIDE_GFX_VERSION=" + gfxOverride) slog.Debug("skipping rocm gfx compatibility check with HSA_OVERRIDE_GFX_VERSION=" + gfxOverride)
} }
slog.Info(fmt.Sprintf("detected %d hip devices", count)) slog.Info("detected hip devices", "count", count)
// TODO how to determine the underlying device ID when visible devices is causing this to subset?
for i := 0; i < count; i++ { for i := 0; i < count; i++ {
ids = append(ids, i)
err = hl.HipSetDevice(i) err = hl.HipSetDevice(i)
if err != nil { if err != nil {
slog.Warn(fmt.Sprintf("[%d] %s", i, err)) slog.Warn("set device", "id", i, "error", err)
skip[i] = struct{}{}
continue continue
} }
props, err := hl.HipGetDeviceProperties(i) props, err := hl.HipGetDeviceProperties(i)
if err != nil { if err != nil {
slog.Warn(fmt.Sprintf("[%d] %s", i, err)) slog.Warn("get properties", "id", i, "error", err)
skip[i] = struct{}{}
continue continue
} }
n := bytes.IndexByte(props.Name[:], 0) n := bytes.IndexByte(props.Name[:], 0)
name := string(props.Name[:n]) name := string(props.Name[:n])
slog.Info(fmt.Sprintf("[%d] Name: %s", i, name)) // TODO is UUID actually populated on windows?
// Can luid be used on windows for setting visible devices (and is it actually set?)
n = bytes.IndexByte(props.GcnArchName[:], 0) n = bytes.IndexByte(props.GcnArchName[:], 0)
gfx := string(props.GcnArchName[:n]) gfx := string(props.GcnArchName[:n])
slog.Info(fmt.Sprintf("[%d] GcnArchName: %s", i, gfx)) slog.Info("hip device", "id", i, "name", name, "gfx", gfx)
var major, minor, patch string
switch len(gfx) {
case 6:
major, minor, patch = gfx[3:4], gfx[4:5], gfx[5:]
case 7:
major, minor, patch = gfx[3:5], gfx[5:6], gfx[6:]
}
//slog.Info(fmt.Sprintf("[%d] Integrated: %d", i, props.iGPU)) // DOESN'T REPORT CORRECTLY! Always 0 //slog.Info(fmt.Sprintf("[%d] Integrated: %d", i, props.iGPU)) // DOESN'T REPORT CORRECTLY! Always 0
// TODO Why isn't props.iGPU accurate!? // TODO Why isn't props.iGPU accurate!?
if strings.EqualFold(name, iGPUName) { if strings.EqualFold(name, iGPUName) {
slog.Info(fmt.Sprintf("iGPU detected [%d] skipping", i)) slog.Info("iGPU detected skipping", "id", i)
skip[i] = struct{}{}
continue continue
} }
if gfxOverride == "" { if gfxOverride == "" {
if !slices.Contains[[]string, string](supported, gfx) { if !slices.Contains[[]string, string](supported, gfx) {
slog.Warn(fmt.Sprintf("amdgpu [%d] %s is not supported by %s %v", i, gfx, libDir, supported)) slog.Warn("amdgpu is not supported", "gpu", i, "gpu_type", gfx, "library", libDir, "supported_types", supported)
// TODO - consider discrete markdown just for ROCM troubleshooting? // TODO - consider discrete markdown just for ROCM troubleshooting?
slog.Warn("See https://github.com/ollama/ollama/blob/main/docs/troubleshooting.md for HSA_OVERRIDE_GFX_VERSION usage") slog.Warn("See https://github.com/ollama/ollama/blob/main/docs/troubleshooting.md for HSA_OVERRIDE_GFX_VERSION usage")
skip[i] = struct{}{}
continue continue
} else { } else {
slog.Info(fmt.Sprintf("amdgpu [%d] %s is supported", i, gfx)) slog.Info("amdgpu is supported", "gpu", i, "gpu_type", gfx)
} }
} }
totalMemory, freeMemory, err := hl.HipMemGetInfo() freeMemory, totalMemory, err := hl.HipMemGetInfo()
if err != nil { if err != nil {
slog.Warn(fmt.Sprintf("[%d] %s", i, err)) slog.Warn("get mem info", "id", i, "error", err)
continue continue
} }
// TODO according to docs, freeMem may lie on windows! // iGPU detection, remove this check once we can support an iGPU variant of the rocm library
slog.Info(fmt.Sprintf("[%d] Total Mem: %d", i, totalMemory)) if totalMemory < IGPUMemLimit {
slog.Info(fmt.Sprintf("[%d] Free Mem: %d", i, freeMemory)) slog.Info("amdgpu appears to be an iGPU, skipping", "gpu", i, "total", format.HumanBytes2(totalMemory))
resp.memInfo.DeviceCount++ continue
resp.memInfo.TotalMemory += totalMemory
resp.memInfo.FreeMemory += freeMemory
} }
if resp.memInfo.DeviceCount > 0 {
resp.Library = "rocm" // TODO revisit this once ROCm v6 is available on windows.
// v5.7 only reports VRAM used by this process, so it's completely wrong and unusable
slog.Info("amdgpu memory", "gpu", i, "total", format.HumanBytes2(totalMemory))
slog.Info("amdgpu memory", "gpu", i, "available", format.HumanBytes2(freeMemory))
gpuInfo := GpuInfo{
Library: "rocm",
memInfo: memInfo{
TotalMemory: totalMemory,
FreeMemory: freeMemory,
},
ID: fmt.Sprintf("%d", i), // TODO this is probably wrong if we specify visible devices
DependencyPath: libDir,
MinimumMemory: rocmMinimumMemory,
} }
// Abort if all GPUs are skipped if major != "" {
if len(skip) >= count { gpuInfo.Major, err = strconv.Atoi(major)
slog.Info("all detected amdgpus are skipped, falling back to CPU") if err != nil {
return slog.Info("failed to parse version", "version", gfx, "error", err)
} }
if len(skip) > 0 {
amdSetVisibleDevices(ids, skip)
} }
UpdatePath(libDir) if minor != "" {
gpuInfo.Minor, err = strconv.Atoi(minor)
if err != nil {
slog.Info("failed to parse version", "version", gfx, "error", err)
}
}
if patch != "" {
// Patch rev is hex; e.g. gfx90a
p, err := strconv.ParseInt(patch, 16, 0)
if err != nil {
slog.Info("failed to parse version", "version", gfx, "error", err)
} else {
gpuInfo.Patch = int(p)
}
}
if gpuInfo.Major < RocmComputeMin {
slog.Warn(fmt.Sprintf("amdgpu [%s] too old gfx%d%d%x", gpuInfo.ID, gpuInfo.Major, gpuInfo.Minor, gpuInfo.Patch))
continue
}
resp = append(resp, gpuInfo)
}
return resp
} }
func AMDValidateLibDir() (string, error) { func AMDValidateLibDir() (string, error) {
// On windows non-admins typically can't create links libDir, err := commonAMDValidateLibDir()
// so instead of trying to rely on rpath and a link in
// $LibDir/rocm, we instead rely on setting PATH to point
// to the location of the ROCm library
// Installer payload location if we're running the installed binary
exe, err := os.Executable()
if err == nil { if err == nil {
rocmTargetDir := filepath.Join(filepath.Dir(exe), "rocm") return libDir, nil
if rocmLibUsable(rocmTargetDir) {
slog.Debug("detected ROCM next to ollama executable " + rocmTargetDir)
return rocmTargetDir, nil
}
} }
// Installer payload (if we're running from some other location) // Installer payload (if we're running from some other location)
@ -159,21 +183,6 @@ func AMDValidateLibDir() (string, error) {
return rocmTargetDir, nil return rocmTargetDir, nil
} }
// Prefer explicit HIP env var
hipPath := os.Getenv("HIP_PATH")
if hipPath != "" {
hipLibDir := filepath.Join(hipPath, "bin")
if rocmLibUsable(hipLibDir) {
slog.Debug("detected ROCM via HIP_PATH=" + hipPath)
return hipLibDir, nil
}
}
// Well known location(s)
if rocmLibUsable(RocmStandardLocation) {
return RocmStandardLocation, nil
}
// Should not happen on windows since we include it in the installer, but stand-alone binary might hit this // Should not happen on windows since we include it in the installer, but stand-alone binary might hit this
slog.Warn("amdgpu detected, but no compatible rocm library found. Please install ROCm") slog.Warn("amdgpu detected, but no compatible rocm library found. Please install ROCm")
return "", fmt.Errorf("no suitable rocm found, falling back to CPU") return "", fmt.Errorf("no suitable rocm found, falling back to CPU")

View file

@ -12,6 +12,8 @@ import (
"sync" "sync"
"syscall" "syscall"
"time" "time"
"github.com/ollama/ollama/server/envconfig"
) )
var ( var (
@ -24,8 +26,16 @@ func PayloadsDir() (string, error) {
defer lock.Unlock() defer lock.Unlock()
var err error var err error
if payloadsDir == "" { if payloadsDir == "" {
runnersDir := envconfig.RunnersDir
if runnersDir != "" {
payloadsDir = runnersDir
return payloadsDir, nil
}
// The remainder only applies on non-windows where we still carry payloads in the main executable
cleanupTmpDirs() cleanupTmpDirs()
tmpDir := os.Getenv("OLLAMA_TMPDIR") tmpDir := envconfig.TmpDir
if tmpDir == "" { if tmpDir == "" {
tmpDir, err = os.MkdirTemp("", "ollama") tmpDir, err = os.MkdirTemp("", "ollama")
if err != nil { if err != nil {
@ -80,7 +90,7 @@ func cleanupTmpDirs() {
} }
err = os.RemoveAll(d) err = os.RemoveAll(d)
if err != nil { if err != nil {
slog.Debug(fmt.Sprintf("unable to cleanup stale tmpdir %s: %s", d, err)) slog.Debug("unable to cleanup stale tmpdir", "path", d, "error", err)
} }
} }
} }
@ -88,7 +98,8 @@ func cleanupTmpDirs() {
func Cleanup() { func Cleanup() {
lock.Lock() lock.Lock()
defer lock.Unlock() defer lock.Unlock()
if payloadsDir != "" { runnersDir := envconfig.RunnersDir
if payloadsDir != "" && runnersDir == "" && runtime.GOOS != "windows" {
// We want to fully clean up the tmpdir parent of the payloads dir // We want to fully clean up the tmpdir parent of the payloads dir
tmpDir := filepath.Clean(filepath.Join(payloadsDir, "..")) tmpDir := filepath.Clean(filepath.Join(payloadsDir, ".."))
slog.Debug("cleaning up", "dir", tmpDir) slog.Debug("cleaning up", "dir", tmpDir)
@ -120,7 +131,7 @@ func UpdatePath(dir string) {
} }
} }
newPath := strings.Join(append([]string{dir}, pathComponents...), ";") newPath := strings.Join(append([]string{dir}, pathComponents...), ";")
slog.Info(fmt.Sprintf("Updating PATH to %s", newPath)) slog.Info("updating", "PATH", newPath)
os.Setenv("PATH", newPath) os.Setenv("PATH", newPath)
} }
// linux and darwin rely on rpath // linux and darwin rely on rpath

22
gpu/cuda_common.go Normal file
View file

@ -0,0 +1,22 @@
//go:build linux || windows
package gpu
import (
"log/slog"
"strings"
)
func cudaGetVisibleDevicesEnv(gpuInfo []GpuInfo) (string, string) {
ids := []string{}
for _, info := range gpuInfo {
if info.Library != "cuda" {
// TODO shouldn't happen if things are wired correctly...
slog.Debug("cudaGetVisibleDevicesEnv skipping over non-cuda device", "library", info.Library)
continue
}
ids = append(ids, info.ID)
}
return "CUDA_VISIBLE_DEVICES", strings.Join(ids, ",")
}

View file

@ -16,22 +16,23 @@ import (
"os" "os"
"path/filepath" "path/filepath"
"runtime" "runtime"
"strconv"
"strings" "strings"
"sync" "sync"
"unsafe" "unsafe"
"github.com/ollama/ollama/format" "github.com/ollama/ollama/format"
"github.com/ollama/ollama/server/envconfig"
) )
type handles struct { type handles struct {
nvml *C.nvml_handle_t deviceCount int
cudart *C.cudart_handle_t cudart *C.cudart_handle_t
nvcuda *C.nvcuda_handle_t
} }
const ( const (
cudaMinimumMemory = 457 * format.MebiByte cudaMinimumMemory = 256 * format.MebiByte
rocmMinimumMemory = 457 * format.MebiByte rocmMinimumMemory = 256 * format.MebiByte
) )
var gpuMutex sync.Mutex var gpuMutex sync.Mutex
@ -39,26 +40,10 @@ var gpuMutex sync.Mutex
// With our current CUDA compile flags, older than 5.0 will not work properly // With our current CUDA compile flags, older than 5.0 will not work properly
var CudaComputeMin = [2]C.int{5, 0} var CudaComputeMin = [2]C.int{5, 0}
// Possible locations for the nvidia-ml library var RocmComputeMin = 9
var NvmlLinuxGlobs = []string{
"/usr/local/cuda/lib64/libnvidia-ml.so*",
"/usr/lib/x86_64-linux-gnu/nvidia/current/libnvidia-ml.so*",
"/usr/lib/x86_64-linux-gnu/libnvidia-ml.so*",
"/usr/lib/wsl/lib/libnvidia-ml.so*",
"/usr/lib/wsl/drivers/*/libnvidia-ml.so*",
"/opt/cuda/lib64/libnvidia-ml.so*",
"/usr/lib*/libnvidia-ml.so*",
"/usr/lib/aarch64-linux-gnu/nvidia/current/libnvidia-ml.so*",
"/usr/lib/aarch64-linux-gnu/libnvidia-ml.so*",
"/usr/local/lib*/libnvidia-ml.so*",
// TODO: are these stubs ever valid? // TODO find a better way to detect iGPU instead of minimum memory
"/opt/cuda/targets/x86_64-linux/lib/stubs/libnvidia-ml.so*", const IGPUMemLimit = 1 * format.GibiByte // 512G is what they typically report, so anything less than 1G must be iGPU
}
var NvmlWindowsGlobs = []string{
"c:\\Windows\\System32\\nvml.dll",
}
var CudartLinuxGlobs = []string{ var CudartLinuxGlobs = []string{
"/usr/local/cuda/lib64/libcudart.so*", "/usr/local/cuda/lib64/libcudart.so*",
@ -79,6 +64,22 @@ var CudartWindowsGlobs = []string{
"c:\\Program Files\\NVIDIA GPU Computing Toolkit\\CUDA\\v*\\bin\\cudart64_*.dll", "c:\\Program Files\\NVIDIA GPU Computing Toolkit\\CUDA\\v*\\bin\\cudart64_*.dll",
} }
var NvcudaLinuxGlobs = []string{
"/usr/local/cuda*/targets/*/lib/libcuda.so*",
"/usr/lib/*-linux-gnu/nvidia/current/libcuda.so*",
"/usr/lib/*-linux-gnu/libcuda.so*",
"/usr/lib/wsl/lib/libcuda.so*",
"/usr/lib/wsl/drivers/*/libcuda.so*",
"/opt/cuda/lib*/libcuda.so*",
"/usr/local/cuda/lib*/libcuda.so*",
"/usr/lib*/libcuda.so*",
"/usr/local/lib*/libcuda.so*",
}
var NvcudaWindowsGlobs = []string{
"c:\\windows\\system*\\nvcuda.dll",
}
// Jetson devices have JETSON_JETPACK="x.y.z" factory set to the Jetpack version installed. // Jetson devices have JETSON_JETPACK="x.y.z" factory set to the Jetpack version installed.
// Included to drive logic for reducing Ollama-allocated overhead on L4T/Jetson devices. // Included to drive logic for reducing Ollama-allocated overhead on L4T/Jetson devices.
var CudaTegra string = os.Getenv("JETSON_JETPACK") var CudaTegra string = os.Getenv("JETSON_JETPACK")
@ -88,61 +89,62 @@ func initGPUHandles() *handles {
// TODO - if the ollama build is CPU only, don't do these checks as they're irrelevant and confusing // TODO - if the ollama build is CPU only, don't do these checks as they're irrelevant and confusing
gpuHandles := &handles{nil, nil} gpuHandles := &handles{}
var nvmlMgmtName string
var nvmlMgmtPatterns []string
var cudartMgmtName string var cudartMgmtName string
var cudartMgmtPatterns []string var cudartMgmtPatterns []string
var nvcudaMgmtName string
var nvcudaMgmtPatterns []string
tmpDir, _ := PayloadsDir() tmpDir, _ := PayloadsDir()
switch runtime.GOOS { switch runtime.GOOS {
case "windows": case "windows":
nvmlMgmtName = "nvml.dll"
nvmlMgmtPatterns = make([]string, len(NvmlWindowsGlobs))
copy(nvmlMgmtPatterns, NvmlWindowsGlobs)
cudartMgmtName = "cudart64_*.dll" cudartMgmtName = "cudart64_*.dll"
localAppData := os.Getenv("LOCALAPPDATA") localAppData := os.Getenv("LOCALAPPDATA")
cudartMgmtPatterns = []string{filepath.Join(localAppData, "Programs", "Ollama", cudartMgmtName)} cudartMgmtPatterns = []string{filepath.Join(localAppData, "Programs", "Ollama", cudartMgmtName)}
cudartMgmtPatterns = append(cudartMgmtPatterns, CudartWindowsGlobs...) cudartMgmtPatterns = append(cudartMgmtPatterns, CudartWindowsGlobs...)
// Aligned with driver, we can't carry as payloads
nvcudaMgmtName = "nvcuda.dll"
nvcudaMgmtPatterns = NvcudaWindowsGlobs
case "linux": case "linux":
nvmlMgmtName = "libnvidia-ml.so"
nvmlMgmtPatterns = make([]string, len(NvmlLinuxGlobs))
copy(nvmlMgmtPatterns, NvmlLinuxGlobs)
cudartMgmtName = "libcudart.so*" cudartMgmtName = "libcudart.so*"
if tmpDir != "" { if tmpDir != "" {
// TODO - add "payloads" for subprocess // TODO - add "payloads" for subprocess
cudartMgmtPatterns = []string{filepath.Join(tmpDir, "cuda*", cudartMgmtName)} cudartMgmtPatterns = []string{filepath.Join(tmpDir, "cuda*", cudartMgmtName)}
} }
cudartMgmtPatterns = append(cudartMgmtPatterns, CudartLinuxGlobs...) cudartMgmtPatterns = append(cudartMgmtPatterns, CudartLinuxGlobs...)
// Aligned with driver, we can't carry as payloads
nvcudaMgmtName = "libcuda.so*"
nvcudaMgmtPatterns = NvcudaLinuxGlobs
default: default:
return gpuHandles return gpuHandles
} }
slog.Info("Detecting GPU type") slog.Info("Detecting GPUs")
nvcudaLibPaths := FindGPULibs(nvcudaMgmtName, nvcudaMgmtPatterns)
if len(nvcudaLibPaths) > 0 {
deviceCount, nvcuda, libPath := LoadNVCUDAMgmt(nvcudaLibPaths)
if nvcuda != nil {
slog.Info("detected GPUs", "count", deviceCount, "library", libPath)
gpuHandles.nvcuda = nvcuda
gpuHandles.deviceCount = deviceCount
return gpuHandles
}
}
cudartLibPaths := FindGPULibs(cudartMgmtName, cudartMgmtPatterns) cudartLibPaths := FindGPULibs(cudartMgmtName, cudartMgmtPatterns)
if len(cudartLibPaths) > 0 { if len(cudartLibPaths) > 0 {
cudart := LoadCUDARTMgmt(cudartLibPaths) deviceCount, cudart, libPath := LoadCUDARTMgmt(cudartLibPaths)
if cudart != nil { if cudart != nil {
slog.Info("Nvidia GPU detected via cudart") slog.Info("detected GPUs", "library", libPath, "count", deviceCount)
gpuHandles.cudart = cudart gpuHandles.cudart = cudart
return gpuHandles gpuHandles.deviceCount = deviceCount
}
}
// TODO once we build confidence, remove this and the gpu_info_nvml.[ch] files
nvmlLibPaths := FindGPULibs(nvmlMgmtName, nvmlMgmtPatterns)
if len(nvmlLibPaths) > 0 {
nvml := LoadNVMLMgmt(nvmlLibPaths)
if nvml != nil {
slog.Info("Nvidia GPU detected via nvidia-ml")
gpuHandles.nvml = nvml
return gpuHandles return gpuHandles
} }
} }
return gpuHandles return gpuHandles
} }
func GetGPUInfo() GpuInfo { func GetGPUInfo() GpuInfoList {
// TODO - consider exploring lspci (and equivalent on windows) to check for // TODO - consider exploring lspci (and equivalent on windows) to check for
// GPUs so we can report warnings if we see Nvidia/AMD but fail to load the libraries // GPUs so we can report warnings if we see Nvidia/AMD but fail to load the libraries
gpuMutex.Lock() gpuMutex.Lock()
@ -150,12 +152,12 @@ func GetGPUInfo() GpuInfo {
gpuHandles := initGPUHandles() gpuHandles := initGPUHandles()
defer func() { defer func() {
if gpuHandles.nvml != nil {
C.nvml_release(*gpuHandles.nvml)
}
if gpuHandles.cudart != nil { if gpuHandles.cudart != nil {
C.cudart_release(*gpuHandles.cudart) C.cudart_release(*gpuHandles.cudart)
} }
if gpuHandles.nvcuda != nil {
C.nvcuda_release(*gpuHandles.nvcuda)
}
}() }()
// All our GPU builds on x86 have AVX enabled, so fallback to CPU if we don't detect at least AVX // All our GPU builds on x86 have AVX enabled, so fallback to CPU if we don't detect at least AVX
@ -164,73 +166,75 @@ func GetGPUInfo() GpuInfo {
slog.Warn("CPU does not have AVX or AVX2, disabling GPU support.") slog.Warn("CPU does not have AVX or AVX2, disabling GPU support.")
} }
// On windows we bundle the nvidia library one level above the runner dir
depPath := ""
if runtime.GOOS == "windows" && envconfig.RunnersDir != "" {
depPath = filepath.Dir(envconfig.RunnersDir)
}
var memInfo C.mem_info_t var memInfo C.mem_info_t
resp := GpuInfo{} resp := []GpuInfo{}
if gpuHandles.nvml != nil && (cpuVariant != "" || runtime.GOARCH != "amd64") {
C.nvml_check_vram(*gpuHandles.nvml, &memInfo) // NVIDIA first
for i := 0; i < gpuHandles.deviceCount; i++ {
// TODO once we support CPU compilation variants of GPU libraries refine this...
if cpuVariant == "" && runtime.GOARCH == "amd64" {
continue
}
gpuInfo := GpuInfo{
Library: "cuda",
}
if gpuHandles.cudart != nil {
C.cudart_check_vram(*gpuHandles.cudart, C.int(i), &memInfo)
} else {
C.nvcuda_check_vram(*gpuHandles.nvcuda, C.int(i), &memInfo)
}
if memInfo.err != nil { if memInfo.err != nil {
slog.Info(fmt.Sprintf("[nvidia-ml] error looking up NVML GPU memory: %s", C.GoString(memInfo.err))) slog.Info("error looking up nvidia GPU memory", "error", C.GoString(memInfo.err))
C.free(unsafe.Pointer(memInfo.err)) C.free(unsafe.Pointer(memInfo.err))
} else if memInfo.count > 0 { continue
// Verify minimum compute capability
var cc C.nvml_compute_capability_t
C.nvml_compute_capability(*gpuHandles.nvml, &cc)
if cc.err != nil {
slog.Info(fmt.Sprintf("[nvidia-ml] error looking up NVML GPU compute capability: %s", C.GoString(cc.err)))
C.free(unsafe.Pointer(cc.err))
} else if cc.major > CudaComputeMin[0] || (cc.major == CudaComputeMin[0] && cc.minor >= CudaComputeMin[1]) {
slog.Info(fmt.Sprintf("[nvidia-ml] NVML CUDA Compute Capability detected: %d.%d", cc.major, cc.minor))
resp.Library = "cuda"
resp.MinimumMemory = cudaMinimumMemory
} else {
slog.Info(fmt.Sprintf("[nvidia-ml] CUDA GPU is too old. Falling back to CPU mode. Compute Capability detected: %d.%d", cc.major, cc.minor))
} }
if memInfo.major < CudaComputeMin[0] || (memInfo.major == CudaComputeMin[0] && memInfo.minor < CudaComputeMin[1]) {
slog.Info(fmt.Sprintf("[%d] CUDA GPU is too old. Compute Capability detected: %d.%d", i, memInfo.major, memInfo.minor))
continue
} }
} else if gpuHandles.cudart != nil && (cpuVariant != "" || runtime.GOARCH != "amd64") { gpuInfo.TotalMemory = uint64(memInfo.total)
C.cudart_check_vram(*gpuHandles.cudart, &memInfo) gpuInfo.FreeMemory = uint64(memInfo.free)
if memInfo.err != nil { gpuInfo.ID = C.GoString(&memInfo.gpu_id[0])
slog.Info(fmt.Sprintf("[cudart] error looking up CUDART GPU memory: %s", C.GoString(memInfo.err))) gpuInfo.Major = int(memInfo.major)
C.free(unsafe.Pointer(memInfo.err)) gpuInfo.Minor = int(memInfo.minor)
} else if memInfo.count > 0 { gpuInfo.MinimumMemory = cudaMinimumMemory
// Verify minimum compute capability gpuInfo.DependencyPath = depPath
var cc C.cudart_compute_capability_t
C.cudart_compute_capability(*gpuHandles.cudart, &cc) // TODO potentially sort on our own algorithm instead of what the underlying GPU library does...
if cc.err != nil { resp = append(resp, gpuInfo)
slog.Info(fmt.Sprintf("[cudart] error looking up CUDA compute capability: %s", C.GoString(cc.err)))
C.free(unsafe.Pointer(cc.err))
} else if cc.major > CudaComputeMin[0] || (cc.major == CudaComputeMin[0] && cc.minor >= CudaComputeMin[1]) {
slog.Info(fmt.Sprintf("[cudart] CUDART CUDA Compute Capability detected: %d.%d", cc.major, cc.minor))
resp.Library = "cuda"
resp.MinimumMemory = cudaMinimumMemory
} else {
slog.Info(fmt.Sprintf("[cudart] CUDA GPU is too old. Falling back to CPU mode. Compute Capability detected: %d.%d", cc.major, cc.minor))
} }
}
} else { // Then AMD
AMDGetGPUInfo(&resp) resp = append(resp, AMDGetGPUInfo()...)
if resp.Library != "" {
resp.MinimumMemory = rocmMinimumMemory if len(resp) == 0 {
return resp
}
}
if resp.Library == "" {
C.cpu_check_ram(&memInfo) C.cpu_check_ram(&memInfo)
resp.Library = "cpu"
resp.Variant = cpuVariant
}
if memInfo.err != nil { if memInfo.err != nil {
slog.Info(fmt.Sprintf("error looking up CPU memory: %s", C.GoString(memInfo.err))) slog.Info("error looking up CPU memory", "error", C.GoString(memInfo.err))
C.free(unsafe.Pointer(memInfo.err)) C.free(unsafe.Pointer(memInfo.err))
return resp return resp
} }
gpuInfo := GpuInfo{
Library: "cpu",
Variant: cpuVariant,
}
gpuInfo.TotalMemory = uint64(memInfo.total)
gpuInfo.FreeMemory = uint64(memInfo.free)
gpuInfo.ID = C.GoString(&memInfo.gpu_id[0])
resp = append(resp, gpuInfo)
}
resp.DeviceCount = uint32(memInfo.count)
resp.FreeMemory = uint64(memInfo.free)
resp.TotalMemory = uint64(memInfo.total)
return resp return resp
} }
func getCPUMem() (memInfo, error) { func GetCPUMem() (memInfo, error) {
var ret memInfo var ret memInfo
var info C.mem_info_t var info C.mem_info_t
C.cpu_check_ram(&info) C.cpu_check_ram(&info)
@ -243,29 +247,12 @@ func getCPUMem() (memInfo, error) {
return ret, nil return ret, nil
} }
func CheckVRAM() (uint64, error) { func FindGPULibs(baseLibName string, defaultPatterns []string) []string {
userLimit := os.Getenv("OLLAMA_MAX_VRAM")
if userLimit != "" {
avail, err := strconv.ParseInt(userLimit, 10, 64)
if err != nil {
return 0, fmt.Errorf("Invalid OLLAMA_MAX_VRAM setting %s: %s", userLimit, err)
}
slog.Info(fmt.Sprintf("user override OLLAMA_MAX_VRAM=%d", avail))
return uint64(avail), nil
}
gpuInfo := GetGPUInfo()
if gpuInfo.FreeMemory > 0 && (gpuInfo.Library == "cuda" || gpuInfo.Library == "rocm") {
return gpuInfo.FreeMemory, nil
}
return 0, fmt.Errorf("no GPU detected") // TODO - better handling of CPU based memory determiniation
}
func FindGPULibs(baseLibName string, patterns []string) []string {
// Multiple GPU libraries may exist, and some may not work, so keep trying until we exhaust them // Multiple GPU libraries may exist, and some may not work, so keep trying until we exhaust them
var ldPaths []string var ldPaths []string
var patterns []string
gpuLibPaths := []string{} gpuLibPaths := []string{}
slog.Info(fmt.Sprintf("Searching for GPU management library %s", baseLibName)) slog.Debug("Searching for GPU library", "name", baseLibName)
switch runtime.GOOS { switch runtime.GOOS {
case "windows": case "windows":
@ -283,8 +270,14 @@ func FindGPULibs(baseLibName string, patterns []string) []string {
} }
patterns = append(patterns, filepath.Join(d, baseLibName+"*")) patterns = append(patterns, filepath.Join(d, baseLibName+"*"))
} }
slog.Debug(fmt.Sprintf("gpu management search paths: %v", patterns)) patterns = append(patterns, defaultPatterns...)
slog.Debug("gpu library search", "globs", patterns)
for _, pattern := range patterns { for _, pattern := range patterns {
// Nvidia PhysX known to return bogus results
if strings.Contains(pattern, "PhysX") {
slog.Debug("skipping PhysX cuda library path", "path", pattern)
}
// Ignore glob discovery errors // Ignore glob discovery errors
matches, _ := filepath.Glob(pattern) matches, _ := filepath.Glob(pattern)
for _, match := range matches { for _, match := range matches {
@ -311,28 +304,11 @@ func FindGPULibs(baseLibName string, patterns []string) []string {
} }
} }
} }
slog.Info(fmt.Sprintf("Discovered GPU libraries: %v", gpuLibPaths)) slog.Debug("discovered GPU libraries", "paths", gpuLibPaths)
return gpuLibPaths return gpuLibPaths
} }
func LoadNVMLMgmt(nvmlLibPaths []string) *C.nvml_handle_t { func LoadCUDARTMgmt(cudartLibPaths []string) (int, *C.cudart_handle_t, string) {
var resp C.nvml_init_resp_t
resp.ch.verbose = getVerboseState()
for _, libPath := range nvmlLibPaths {
lib := C.CString(libPath)
defer C.free(unsafe.Pointer(lib))
C.nvml_init(lib, &resp)
if resp.err != nil {
slog.Info(fmt.Sprintf("Unable to load NVML management library %s: %s", libPath, C.GoString(resp.err)))
C.free(unsafe.Pointer(resp.err))
} else {
return &resp.ch
}
}
return nil
}
func LoadCUDARTMgmt(cudartLibPaths []string) *C.cudart_handle_t {
var resp C.cudart_init_resp_t var resp C.cudart_init_resp_t
resp.ch.verbose = getVerboseState() resp.ch.verbose = getVerboseState()
for _, libPath := range cudartLibPaths { for _, libPath := range cudartLibPaths {
@ -340,18 +316,54 @@ func LoadCUDARTMgmt(cudartLibPaths []string) *C.cudart_handle_t {
defer C.free(unsafe.Pointer(lib)) defer C.free(unsafe.Pointer(lib))
C.cudart_init(lib, &resp) C.cudart_init(lib, &resp)
if resp.err != nil { if resp.err != nil {
slog.Info(fmt.Sprintf("Unable to load cudart CUDA management library %s: %s", libPath, C.GoString(resp.err))) slog.Debug("Unable to load cudart", "library", libPath, "error", C.GoString(resp.err))
C.free(unsafe.Pointer(resp.err)) C.free(unsafe.Pointer(resp.err))
} else { } else {
return &resp.ch return int(resp.num_devices), &resp.ch, libPath
} }
} }
return nil return 0, nil, ""
}
func LoadNVCUDAMgmt(nvcudaLibPaths []string) (int, *C.nvcuda_handle_t, string) {
var resp C.nvcuda_init_resp_t
resp.ch.verbose = getVerboseState()
for _, libPath := range nvcudaLibPaths {
lib := C.CString(libPath)
defer C.free(unsafe.Pointer(lib))
C.nvcuda_init(lib, &resp)
if resp.err != nil {
slog.Debug("Unable to load nvcuda", "library", libPath, "error", C.GoString(resp.err))
C.free(unsafe.Pointer(resp.err))
} else {
return int(resp.num_devices), &resp.ch, libPath
}
}
return 0, nil, ""
} }
func getVerboseState() C.uint16_t { func getVerboseState() C.uint16_t {
if debug := os.Getenv("OLLAMA_DEBUG"); debug != "" { if envconfig.Debug {
return C.uint16_t(1) return C.uint16_t(1)
} }
return C.uint16_t(0) return C.uint16_t(0)
} }
// Given the list of GPUs this instantiation is targeted for,
// figure out the visible devices environment variable
//
// If different libraries are detected, the first one is what we use
func (l GpuInfoList) GetVisibleDevicesEnv() (string, string) {
if len(l) == 0 {
return "", ""
}
switch l[0].Library {
case "cuda":
return cudaGetVisibleDevicesEnv(l)
case "rocm":
return rocmGetVisibleDevicesEnv(l)
default:
slog.Debug("no filter required for library " + l[0].Library)
return "", ""
}
}

View file

@ -9,52 +9,47 @@ package gpu
*/ */
import "C" import "C"
import ( import (
"fmt"
"log/slog"
"os"
"runtime" "runtime"
"strconv"
"github.com/ollama/ollama/format"
) )
// CheckVRAM returns the free VRAM in bytes on Linux machines with NVIDIA GPUs const (
func CheckVRAM() (uint64, error) { metalMinimumMemory = 384 * format.MebiByte
userLimit := os.Getenv("OLLAMA_MAX_VRAM") )
if userLimit != "" {
avail, err := strconv.ParseInt(userLimit, 10, 64)
if err != nil {
return 0, fmt.Errorf("Invalid OLLAMA_MAX_VRAM setting %s: %s", userLimit, err)
}
slog.Info(fmt.Sprintf("user override OLLAMA_MAX_VRAM=%d", avail))
return uint64(avail), nil
}
func GetGPUInfo() GpuInfoList {
mem, _ := GetCPUMem()
if runtime.GOARCH == "amd64" { if runtime.GOARCH == "amd64" {
// gpu not supported, this may not be metal return []GpuInfo{
return 0, nil {
}
return uint64(C.getRecommendedMaxVRAM()), nil
}
func GetGPUInfo() GpuInfo {
mem, _ := getCPUMem()
if runtime.GOARCH == "amd64" {
return GpuInfo{
Library: "cpu", Library: "cpu",
Variant: GetCPUVariant(), Variant: GetCPUVariant(),
memInfo: mem, memInfo: mem,
},
} }
} }
return GpuInfo{ info := GpuInfo{
Library: "metal", Library: "metal",
memInfo: mem, ID: "0",
} }
info.TotalMemory = uint64(C.getRecommendedMaxVRAM())
// TODO is there a way to gather actual allocated video memory? (currentAllocatedSize doesn't work)
info.FreeMemory = info.TotalMemory
info.MinimumMemory = metalMinimumMemory
return []GpuInfo{info}
} }
func getCPUMem() (memInfo, error) { func GetCPUMem() (memInfo, error) {
return memInfo{ return memInfo{
TotalMemory: uint64(C.getPhysicalMemory()), TotalMemory: uint64(C.getPhysicalMemory()),
FreeMemory: 0, FreeMemory: 0,
DeviceCount: 1,
}, nil }, nil
} }
func (l GpuInfoList) GetVisibleDevicesEnv() (string, string) {
// No-op on darwin
return "", ""
}

View file

@ -38,12 +38,17 @@
extern "C" { extern "C" {
#endif #endif
#define GPU_ID_LEN 64
typedef struct mem_info { typedef struct mem_info {
char *err; // If non-nill, caller responsible for freeing
char gpu_id[GPU_ID_LEN];
uint64_t total; uint64_t total;
uint64_t free; uint64_t free;
unsigned int count;
int igpu_index; // If >= 0, we detected an integrated GPU to ignore // Compute Capability
char *err; // If non-nill, caller responsible for freeing int major;
int minor;
} mem_info_t; } mem_info_t;
void cpu_check_ram(mem_info_t *resp); void cpu_check_ram(mem_info_t *resp);
@ -52,8 +57,8 @@ void cpu_check_ram(mem_info_t *resp);
} }
#endif #endif
#include "gpu_info_nvml.h"
#include "gpu_info_cudart.h" #include "gpu_info_cudart.h"
#include "gpu_info_nvcuda.h"
#endif // __GPU_INFO_H__ #endif // __GPU_INFO_H__
#endif // __APPLE__ #endif // __APPLE__

View file

@ -8,9 +8,11 @@ void cpu_check_ram(mem_info_t *resp) {
MEMORYSTATUSEX info; MEMORYSTATUSEX info;
info.dwLength = sizeof(info); info.dwLength = sizeof(info);
if (GlobalMemoryStatusEx(&info) != 0) { if (GlobalMemoryStatusEx(&info) != 0) {
resp->count = 1;
resp->total = info.ullTotalPhys; resp->total = info.ullTotalPhys;
resp->free = info.ullAvailPhys; resp->free = info.ullAvailPhys;
resp->major = 0;
resp->minor = 0;
snprintf(&resp->gpu_id[0], GPU_ID_LEN, "0");
} else { } else {
resp->err = LOAD_ERR(); resp->err = LOAD_ERR();
} }
@ -27,9 +29,11 @@ void cpu_check_ram(mem_info_t *resp) {
if (sysinfo(&info) != 0) { if (sysinfo(&info) != 0) {
resp->err = strdup(strerror(errno)); resp->err = strdup(strerror(errno));
} else { } else {
resp->count = 1;
resp->total = info.totalram * info.mem_unit; resp->total = info.totalram * info.mem_unit;
resp->free = info.freeram * info.mem_unit; resp->free = info.freeram * info.mem_unit;
resp->major = 0;
resp->minor = 0;
snprintf(&resp->gpu_id[0], GPU_ID_LEN, "0");
} }
return; return;
} }

View file

@ -6,6 +6,7 @@
void cudart_init(char *cudart_lib_path, cudart_init_resp_t *resp) { void cudart_init(char *cudart_lib_path, cudart_init_resp_t *resp) {
cudartReturn_t ret; cudartReturn_t ret;
resp->err = NULL; resp->err = NULL;
resp->num_devices = 0;
const int buflen = 256; const int buflen = 256;
char buf[buflen + 1]; char buf[buflen + 1];
int i; int i;
@ -21,6 +22,7 @@ void cudart_init(char *cudart_lib_path, cudart_init_resp_t *resp) {
{"cudaGetDeviceCount", (void *)&resp->ch.cudaGetDeviceCount}, {"cudaGetDeviceCount", (void *)&resp->ch.cudaGetDeviceCount},
{"cudaDeviceGetAttribute", (void *)&resp->ch.cudaDeviceGetAttribute}, {"cudaDeviceGetAttribute", (void *)&resp->ch.cudaDeviceGetAttribute},
{"cudaDriverGetVersion", (void *)&resp->ch.cudaDriverGetVersion}, {"cudaDriverGetVersion", (void *)&resp->ch.cudaDriverGetVersion},
{"cudaGetDeviceProperties", (void *)&resp->ch.cudaGetDeviceProperties},
{NULL, NULL}, {NULL, NULL},
}; };
@ -36,13 +38,7 @@ void cudart_init(char *cudart_lib_path, cudart_init_resp_t *resp) {
return; return;
} }
// TODO once we've squashed the remaining corner cases remove this log
LOG(resp->ch.verbose, "wiring cudart library functions in %s\n", cudart_lib_path);
for (i = 0; l[i].s != NULL; i++) { for (i = 0; l[i].s != NULL; i++) {
// TODO once we've squashed the remaining corner cases remove this log
LOG(resp->ch.verbose, "dlsym: %s\n", l[i].s);
*l[i].p = LOAD_SYMBOL(resp->ch.handle, l[i].s); *l[i].p = LOAD_SYMBOL(resp->ch.handle, l[i].s);
if (!l[i].p) { if (!l[i].p) {
char *msg = LOAD_ERR(); char *msg = LOAD_ERR();
@ -63,7 +59,7 @@ void cudart_init(char *cudart_lib_path, cudart_init_resp_t *resp) {
UNLOAD_LIBRARY(resp->ch.handle); UNLOAD_LIBRARY(resp->ch.handle);
resp->ch.handle = NULL; resp->ch.handle = NULL;
if (ret == CUDA_ERROR_INSUFFICIENT_DRIVER) { if (ret == CUDA_ERROR_INSUFFICIENT_DRIVER) {
resp->err = strdup("your nvidia driver is too old or missing, please upgrade to run ollama"); resp->err = strdup("your nvidia driver is too old or missing. If you have a CUDA GPU please upgrade to run ollama");
return; return;
} }
snprintf(buf, buflen, "cudart init failure: %d", ret); snprintf(buf, buflen, "cudart init failure: %d", ret);
@ -85,42 +81,82 @@ void cudart_init(char *cudart_lib_path, cudart_init_resp_t *resp) {
driverVersion.minor = (version - (driverVersion.major * 1000)) / 10; driverVersion.minor = (version - (driverVersion.major * 1000)) / 10;
LOG(resp->ch.verbose, "CUDA driver version: %d-%d\n", driverVersion.major, driverVersion.minor); LOG(resp->ch.verbose, "CUDA driver version: %d-%d\n", driverVersion.major, driverVersion.minor);
} }
ret = (*resp->ch.cudaGetDeviceCount)(&resp->num_devices);
if (ret != CUDART_SUCCESS) {
LOG(resp->ch.verbose, "cudaGetDeviceCount err: %d\n", ret);
UNLOAD_LIBRARY(resp->ch.handle);
resp->ch.handle = NULL;
snprintf(buf, buflen, "unable to get device count: %d", ret);
resp->err = strdup(buf);
return;
}
} }
void cudart_check_vram(cudart_handle_t h, mem_info_t *resp) { void cudart_check_vram(cudart_handle_t h, int i, mem_info_t *resp) {
resp->err = NULL; resp->err = NULL;
cudartMemory_t memInfo = {0,0,0}; cudartMemory_t memInfo = {0,0,0};
cudartReturn_t ret; cudartReturn_t ret;
const int buflen = 256; const int buflen = 256;
char buf[buflen + 1]; char buf[buflen + 1];
int i;
if (h.handle == NULL) { if (h.handle == NULL) {
resp->err = strdup("cudart handle isn't initialized"); resp->err = strdup("cudart handle isn't initialized");
return; return;
} }
// cudaGetDeviceCount takes int type, resp-> count is uint
int deviceCount;
ret = (*h.cudaGetDeviceCount)(&deviceCount);
if (ret != CUDART_SUCCESS) {
snprintf(buf, buflen, "unable to get device count: %d", ret);
resp->err = strdup(buf);
return;
} else {
resp->count = (unsigned int)deviceCount;
}
resp->total = 0;
resp->free = 0;
for (i = 0; i < resp-> count; i++) {
ret = (*h.cudaSetDevice)(i); ret = (*h.cudaSetDevice)(i);
if (ret != CUDART_SUCCESS) { if (ret != CUDART_SUCCESS) {
snprintf(buf, buflen, "cudart device failed to initialize"); snprintf(buf, buflen, "cudart device failed to initialize");
resp->err = strdup(buf); resp->err = strdup(buf);
return; return;
} }
cudaDeviceProp_t props;
ret = (*h.cudaGetDeviceProperties)(&props, i);
if (ret != CUDART_SUCCESS) {
LOG(h.verbose, "[%d] device properties lookup failure: %d\n", i, ret);
snprintf(&resp->gpu_id[0], GPU_ID_LEN, "%d", i);
resp->major = 0;
resp->minor = 0;
} else {
int allNull = 1;
for (int j = 0; j < 16; j++) {
if (props.uuid.bytes[j] != 0) {
allNull = 0;
break;
}
}
if (allNull != 0) {
snprintf(&resp->gpu_id[0], GPU_ID_LEN, "%d", i);
} else {
// GPU-d110a105-ac29-1d54-7b49-9c90440f215b
snprintf(&resp->gpu_id[0], GPU_ID_LEN,
"GPU-%02x%02x%02x%02x-%02x%02x-%02x%02x-%02x%02x-%02x%02x%02x%02x%02x%02x",
props.uuid.bytes[0],
props.uuid.bytes[1],
props.uuid.bytes[2],
props.uuid.bytes[3],
props.uuid.bytes[4],
props.uuid.bytes[5],
props.uuid.bytes[6],
props.uuid.bytes[7],
props.uuid.bytes[8],
props.uuid.bytes[9],
props.uuid.bytes[10],
props.uuid.bytes[11],
props.uuid.bytes[12],
props.uuid.bytes[13],
props.uuid.bytes[14],
props.uuid.bytes[15]
);
}
resp->major = props.major;
resp->minor = props.minor;
// TODO add other useful properties from props
}
ret = (*h.cudaMemGetInfo)(&memInfo.free, &memInfo.total); ret = (*h.cudaMemGetInfo)(&memInfo.free, &memInfo.total);
if (ret != CUDART_SUCCESS) { if (ret != CUDART_SUCCESS) {
snprintf(buf, buflen, "cudart device memory info lookup failure %d", ret); snprintf(buf, buflen, "cudart device memory info lookup failure %d", ret);
@ -128,67 +164,12 @@ void cudart_check_vram(cudart_handle_t h, mem_info_t *resp) {
return; return;
} }
LOG(h.verbose, "[%d] CUDA totalMem %lu\n", i, memInfo.total); resp->total = memInfo.total;
LOG(h.verbose, "[%d] CUDA freeMem %lu\n", i, memInfo.free); resp->free = memInfo.free;
resp->total += memInfo.total; LOG(h.verbose, "[%s] CUDA totalMem %lu\n", resp->gpu_id, resp->total);
resp->free += memInfo.free; LOG(h.verbose, "[%s] CUDA freeMem %lu\n", resp->gpu_id, resp->free);
} LOG(h.verbose, "[%s] Compute Capability %d.%d\n", resp->gpu_id, resp->major, resp->minor);
}
void cudart_compute_capability(cudart_handle_t h, cudart_compute_capability_t *resp) {
resp->err = NULL;
resp->major = 0;
resp->minor = 0;
int major = 0;
int minor = 0;
cudartReturn_t ret;
const int buflen = 256;
char buf[buflen + 1];
int i;
if (h.handle == NULL) {
resp->err = strdup("cudart handle not initialized");
return;
}
int devices;
ret = (*h.cudaGetDeviceCount)(&devices);
if (ret != CUDART_SUCCESS) {
snprintf(buf, buflen, "unable to get cudart device count: %d", ret);
resp->err = strdup(buf);
return;
}
for (i = 0; i < devices; i++) {
ret = (*h.cudaSetDevice)(i);
if (ret != CUDART_SUCCESS) {
snprintf(buf, buflen, "cudart device failed to initialize");
resp->err = strdup(buf);
return;
}
ret = (*h.cudaDeviceGetAttribute)(&major, cudartDevAttrComputeCapabilityMajor, i);
if (ret != CUDART_SUCCESS) {
snprintf(buf, buflen, "device compute capability lookup failure %d: %d", i, ret);
resp->err = strdup(buf);
return;
}
ret = (*h.cudaDeviceGetAttribute)(&minor, cudartDevAttrComputeCapabilityMinor, i);
if (ret != CUDART_SUCCESS) {
snprintf(buf, buflen, "device compute capability lookup failure %d: %d", i, ret);
resp->err = strdup(buf);
return;
}
// Report the lowest major.minor we detect as that limits our compatibility
if (resp->major == 0 || resp->major > major ) {
resp->major = major;
resp->minor = minor;
} else if ( resp->major == major && resp->minor > minor ) {
resp->minor = minor;
}
}
} }
void cudart_release(cudart_handle_t h) { void cudart_release(cudart_handle_t h) {

View file

@ -6,14 +6,20 @@
// Just enough typedef's to dlopen/dlsym for memory information // Just enough typedef's to dlopen/dlsym for memory information
typedef enum cudartReturn_enum { typedef enum cudartReturn_enum {
CUDART_SUCCESS = 0, CUDART_SUCCESS = 0,
CUDART_UNSUPPORTED = 1, CUDART_ERROR_INVALID_VALUE = 1,
CUDA_ERROR_INSUFFICIENT_DRIVER = 35, CUDART_ERROR_MEMORY_ALLOCATION = 2,
CUDART_ERROR_INSUFFICIENT_DRIVER = 35,
// Other values omitted for now... // Other values omitted for now...
} cudartReturn_t; } cudartReturn_t;
typedef enum cudartDeviceAttr_enum { typedef enum cudartDeviceAttr_enum {
cudartDevAttrComputeCapabilityMajor = 75, cudartDevAttrComputeCapabilityMajor = 75,
cudartDevAttrComputeCapabilityMinor = 76, cudartDevAttrComputeCapabilityMinor = 76,
// TODO - not yet wired up but may be useful for Jetson or other
// integrated GPU scenarios with shared memory
cudaDevAttrIntegrated = 18
} cudartDeviceAttr_t; } cudartDeviceAttr_t;
typedef void *cudartDevice_t; // Opaque is sufficient typedef void *cudartDevice_t; // Opaque is sufficient
@ -28,6 +34,92 @@ typedef struct cudartDriverVersion {
int minor; int minor;
} cudartDriverVersion_t; } cudartDriverVersion_t;
typedef struct cudaUUID {
unsigned char bytes[16];
} cudaUUID_t;
typedef struct cudaDeviceProp {
char name[256]; /**< ASCII string identifying device */
cudaUUID_t uuid; /**< 16-byte unique identifier */
char luid[8]; /**< 8-byte locally unique identifier. Value is undefined on TCC and non-Windows platforms */
unsigned int luidDeviceNodeMask; /**< LUID device node mask. Value is undefined on TCC and non-Windows platforms */
size_t totalGlobalMem; /**< Global memory available on device in bytes */
size_t sharedMemPerBlock; /**< Shared memory available per block in bytes */
int regsPerBlock; /**< 32-bit registers available per block */
int warpSize; /**< Warp size in threads */
size_t memPitch; /**< Maximum pitch in bytes allowed by memory copies */
int maxThreadsPerBlock; /**< Maximum number of threads per block */
int maxThreadsDim[3]; /**< Maximum size of each dimension of a block */
int maxGridSize[3]; /**< Maximum size of each dimension of a grid */
int clockRate; /**< Clock frequency in kilohertz */
size_t totalConstMem; /**< Constant memory available on device in bytes */
int major; /**< Major compute capability */
int minor; /**< Minor compute capability */
size_t textureAlignment; /**< Alignment requirement for textures */
size_t texturePitchAlignment; /**< Pitch alignment requirement for texture references bound to pitched memory */
int deviceOverlap; /**< Device can concurrently copy memory and execute a kernel. Deprecated. Use instead asyncEngineCount. */
int multiProcessorCount; /**< Number of multiprocessors on device */
int kernelExecTimeoutEnabled; /**< Specified whether there is a run time limit on kernels */
int integrated; /**< Device is integrated as opposed to discrete */
int canMapHostMemory; /**< Device can map host memory with cudaHostAlloc/cudaHostGetDevicePointer */
int computeMode; /**< Compute mode (See ::cudaComputeMode) */
int maxTexture1D; /**< Maximum 1D texture size */
int maxTexture1DMipmap; /**< Maximum 1D mipmapped texture size */
int maxTexture1DLinear; /**< Deprecated, do not use. Use cudaDeviceGetTexture1DLinearMaxWidth() or cuDeviceGetTexture1DLinearMaxWidth() instead. */
int maxTexture2D[2]; /**< Maximum 2D texture dimensions */
int maxTexture2DMipmap[2]; /**< Maximum 2D mipmapped texture dimensions */
int maxTexture2DLinear[3]; /**< Maximum dimensions (width, height, pitch) for 2D textures bound to pitched memory */
int maxTexture2DGather[2]; /**< Maximum 2D texture dimensions if texture gather operations have to be performed */
int maxTexture3D[3]; /**< Maximum 3D texture dimensions */
int maxTexture3DAlt[3]; /**< Maximum alternate 3D texture dimensions */
int maxTextureCubemap; /**< Maximum Cubemap texture dimensions */
int maxTexture1DLayered[2]; /**< Maximum 1D layered texture dimensions */
int maxTexture2DLayered[3]; /**< Maximum 2D layered texture dimensions */
int maxTextureCubemapLayered[2];/**< Maximum Cubemap layered texture dimensions */
int maxSurface1D; /**< Maximum 1D surface size */
int maxSurface2D[2]; /**< Maximum 2D surface dimensions */
int maxSurface3D[3]; /**< Maximum 3D surface dimensions */
int maxSurface1DLayered[2]; /**< Maximum 1D layered surface dimensions */
int maxSurface2DLayered[3]; /**< Maximum 2D layered surface dimensions */
int maxSurfaceCubemap; /**< Maximum Cubemap surface dimensions */
int maxSurfaceCubemapLayered[2];/**< Maximum Cubemap layered surface dimensions */
size_t surfaceAlignment; /**< Alignment requirements for surfaces */
int concurrentKernels; /**< Device can possibly execute multiple kernels concurrently */
int ECCEnabled; /**< Device has ECC support enabled */
int pciBusID; /**< PCI bus ID of the device */
int pciDeviceID; /**< PCI device ID of the device */
int pciDomainID; /**< PCI domain ID of the device */
int tccDriver; /**< 1 if device is a Tesla device using TCC driver, 0 otherwise */
int asyncEngineCount; /**< Number of asynchronous engines */
int unifiedAddressing; /**< Device shares a unified address space with the host */
int memoryClockRate; /**< Peak memory clock frequency in kilohertz */
int memoryBusWidth; /**< Global memory bus width in bits */
int l2CacheSize; /**< Size of L2 cache in bytes */
int persistingL2CacheMaxSize; /**< Device's maximum l2 persisting lines capacity setting in bytes */
int maxThreadsPerMultiProcessor;/**< Maximum resident threads per multiprocessor */
int streamPrioritiesSupported; /**< Device supports stream priorities */
int globalL1CacheSupported; /**< Device supports caching globals in L1 */
int localL1CacheSupported; /**< Device supports caching locals in L1 */
size_t sharedMemPerMultiprocessor; /**< Shared memory available per multiprocessor in bytes */
int regsPerMultiprocessor; /**< 32-bit registers available per multiprocessor */
int managedMemory; /**< Device supports allocating managed memory on this system */
int isMultiGpuBoard; /**< Device is on a multi-GPU board */
int multiGpuBoardGroupID; /**< Unique identifier for a group of devices on the same multi-GPU board */
int hostNativeAtomicSupported; /**< Link between the device and the host supports native atomic operations */
int singleToDoublePrecisionPerfRatio; /**< Ratio of single precision performance (in floating-point operations per second) to double precision performance */
int pageableMemoryAccess; /**< Device supports coherently accessing pageable memory without calling cudaHostRegister on it */
int concurrentManagedAccess; /**< Device can coherently access managed memory concurrently with the CPU */
int computePreemptionSupported; /**< Device supports Compute Preemption */
int canUseHostPointerForRegisteredMem; /**< Device can access host registered memory at the same virtual address as the CPU */
int cooperativeLaunch; /**< Device supports launching cooperative kernels via ::cudaLaunchCooperativeKernel */
int cooperativeMultiDeviceLaunch; /**< Deprecated, cudaLaunchCooperativeKernelMultiDevice is deprecated. */
size_t sharedMemPerBlockOptin; /**< Per device maximum shared memory per block usable by special opt in */
int pageableMemoryAccessUsesHostPageTables; /**< Device accesses pageable memory via the host's page tables */
int directManagedMemAccessFromHost; /**< Host can directly access managed memory on the device without migration. */
int maxBlocksPerMultiProcessor; /**< Maximum number of resident blocks per multiprocessor */
int accessPolicyMaxWindowSize; /**< The maximum value of ::cudaAccessPolicyWindow::num_bytes. */
size_t reservedSharedMemPerBlock; /**< Shared memory reserved by CUDA driver per block in bytes */
} cudaDeviceProp_t;
typedef struct cudart_handle { typedef struct cudart_handle {
void *handle; void *handle;
uint16_t verbose; uint16_t verbose;
@ -38,23 +130,17 @@ typedef struct cudart_handle {
cudartReturn_t (*cudaGetDeviceCount)(int *); cudartReturn_t (*cudaGetDeviceCount)(int *);
cudartReturn_t (*cudaDeviceGetAttribute)(int* value, cudartDeviceAttr_t attr, int device); cudartReturn_t (*cudaDeviceGetAttribute)(int* value, cudartDeviceAttr_t attr, int device);
cudartReturn_t (*cudaDriverGetVersion) (int *driverVersion); cudartReturn_t (*cudaDriverGetVersion) (int *driverVersion);
cudartReturn_t (*cudaGetDeviceProperties) (cudaDeviceProp_t* prop, int device);
} cudart_handle_t; } cudart_handle_t;
typedef struct cudart_init_resp { typedef struct cudart_init_resp {
char *err; // If err is non-null handle is invalid char *err; // If err is non-null handle is invalid
cudart_handle_t ch; cudart_handle_t ch;
int num_devices;
} cudart_init_resp_t; } cudart_init_resp_t;
typedef struct cudart_compute_capability {
char *err;
int major;
int minor;
} cudart_compute_capability_t;
void cudart_init(char *cudart_lib_path, cudart_init_resp_t *resp); void cudart_init(char *cudart_lib_path, cudart_init_resp_t *resp);
void cudart_check_vram(cudart_handle_t ch, mem_info_t *resp); void cudart_check_vram(cudart_handle_t ch, int device_id, mem_info_t *resp);
void cudart_compute_capability(cudart_handle_t th, cudart_compute_capability_t *cc);
void cudart_release(cudart_handle_t ch); void cudart_release(cudart_handle_t ch);
#endif // __GPU_INFO_CUDART_H__ #endif // __GPU_INFO_CUDART_H__

203
gpu/gpu_info_nvcuda.c Normal file
View file

@ -0,0 +1,203 @@
#ifndef __APPLE__ // TODO - maybe consider nvidia support on intel macs?
#include <string.h>
#include "gpu_info_nvcuda.h"
void nvcuda_init(char *nvcuda_lib_path, nvcuda_init_resp_t *resp) {
CUresult ret;
resp->err = NULL;
resp->num_devices = 0;
const int buflen = 256;
char buf[buflen + 1];
int i;
struct lookup {
char *s;
void **p;
} l[] = {
{"cuInit", (void *)&resp->ch.cuInit},
{"cuDriverGetVersion", (void *)&resp->ch.cuDriverGetVersion},
{"cuDeviceGetCount", (void *)&resp->ch.cuDeviceGetCount},
{"cuDeviceGet", (void *)&resp->ch.cuDeviceGet},
{"cuDeviceGetAttribute", (void *)&resp->ch.cuDeviceGetAttribute},
{"cuDeviceGetUuid", (void *)&resp->ch.cuDeviceGetUuid},
{"cuCtxCreate_v3", (void *)&resp->ch.cuCtxCreate_v3},
{"cuMemGetInfo_v2", (void *)&resp->ch.cuMemGetInfo_v2},
{"cuCtxDestroy", (void *)&resp->ch.cuCtxDestroy},
{NULL, NULL},
};
resp->ch.handle = LOAD_LIBRARY(nvcuda_lib_path, RTLD_LAZY);
if (!resp->ch.handle) {
char *msg = LOAD_ERR();
LOG(resp->ch.verbose, "library %s load err: %s\n", nvcuda_lib_path, msg);
snprintf(buf, buflen,
"Unable to load %s library to query for Nvidia GPUs: %s",
nvcuda_lib_path, msg);
free(msg);
resp->err = strdup(buf);
return;
}
for (i = 0; l[i].s != NULL; i++) {
*l[i].p = LOAD_SYMBOL(resp->ch.handle, l[i].s);
if (!*l[i].p) {
char *msg = LOAD_ERR();
LOG(resp->ch.verbose, "dlerr: %s\n", msg);
UNLOAD_LIBRARY(resp->ch.handle);
resp->ch.handle = NULL;
snprintf(buf, buflen, "symbol lookup for %s failed: %s", l[i].s,
msg);
free(msg);
resp->err = strdup(buf);
return;
}
}
ret = (*resp->ch.cuInit)(0);
if (ret != CUDA_SUCCESS) {
LOG(resp->ch.verbose, "cuInit err: %d\n", ret);
UNLOAD_LIBRARY(resp->ch.handle);
resp->ch.handle = NULL;
if (ret == CUDA_ERROR_INSUFFICIENT_DRIVER) {
resp->err = strdup("your nvidia driver is too old or missing. If you have a CUDA GPU please upgrade to run ollama");
return;
}
snprintf(buf, buflen, "nvcuda init failure: %d", ret);
resp->err = strdup(buf);
return;
}
int version = 0;
nvcudaDriverVersion_t driverVersion;
driverVersion.major = 0;
driverVersion.minor = 0;
// Report driver version if we're in verbose mode, ignore errors
ret = (*resp->ch.cuDriverGetVersion)(&version);
if (ret != CUDA_SUCCESS) {
LOG(resp->ch.verbose, "cuDriverGetVersion failed: %d\n", ret);
} else {
driverVersion.major = version / 1000;
driverVersion.minor = (version - (driverVersion.major * 1000)) / 10;
LOG(resp->ch.verbose, "CUDA driver version: %d-%d\n", driverVersion.major, driverVersion.minor);
}
ret = (*resp->ch.cuDeviceGetCount)(&resp->num_devices);
if (ret != CUDA_SUCCESS) {
LOG(resp->ch.verbose, "cuDeviceGetCount err: %d\n", ret);
UNLOAD_LIBRARY(resp->ch.handle);
resp->ch.handle = NULL;
snprintf(buf, buflen, "unable to get device count: %d", ret);
resp->err = strdup(buf);
return;
}
}
const int buflen = 256;
void nvcuda_check_vram(nvcuda_handle_t h, int i, mem_info_t *resp) {
resp->err = NULL;
nvcudaMemory_t memInfo = {0,0};
CUresult ret;
CUdevice device = -1;
CUcontext ctx = NULL;
char buf[buflen + 1];
CUuuid uuid = {0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0};
if (h.handle == NULL) {
resp->err = strdup("nvcuda handle isn't initialized");
return;
}
ret = (*h.cuDeviceGet)(&device, i);
if (ret != CUDA_SUCCESS) {
snprintf(buf, buflen, "nvcuda device failed to initialize");
resp->err = strdup(buf);
return;
}
resp->major = 0;
resp->minor = 0;
int major = 0;
int minor = 0;
ret = (*h.cuDeviceGetAttribute)(&major, CU_DEVICE_ATTRIBUTE_COMPUTE_CAPABILITY_MAJOR, device);
if (ret != CUDA_SUCCESS) {
LOG(h.verbose, "[%d] device major lookup failure: %d\n", i, ret);
} else {
ret = (*h.cuDeviceGetAttribute)(&minor, CU_DEVICE_ATTRIBUTE_COMPUTE_CAPABILITY_MINOR, device);
if (ret != CUDA_SUCCESS) {
LOG(h.verbose, "[%d] device minor lookup failure: %d\n", i, ret);
} else {
resp->minor = minor;
resp->major = major;
}
}
ret = (*h.cuDeviceGetUuid)(&uuid, device);
if (ret != CUDA_SUCCESS) {
LOG(h.verbose, "[%d] device uuid lookup failure: %d\n", i, ret);
snprintf(&resp->gpu_id[0], GPU_ID_LEN, "%d", i);
} else {
// GPU-d110a105-ac29-1d54-7b49-9c90440f215b
snprintf(&resp->gpu_id[0], GPU_ID_LEN,
"GPU-%02x%02x%02x%02x-%02x%02x-%02x%02x-%02x%02x-%02x%02x%02x%02x%02x%02x",
uuid.bytes[0],
uuid.bytes[1],
uuid.bytes[2],
uuid.bytes[3],
uuid.bytes[4],
uuid.bytes[5],
uuid.bytes[6],
uuid.bytes[7],
uuid.bytes[8],
uuid.bytes[9],
uuid.bytes[10],
uuid.bytes[11],
uuid.bytes[12],
uuid.bytes[13],
uuid.bytes[14],
uuid.bytes[15]
);
}
// To get memory we have to set (and release) a context
ret = (*h.cuCtxCreate_v3)(&ctx, NULL, 0, 0, device);
if (ret != CUDA_SUCCESS) {
snprintf(buf, buflen, "nvcuda failed to get primary device context %d", ret);
resp->err = strdup(buf);
return;
}
ret = (*h.cuMemGetInfo_v2)(&memInfo.free, &memInfo.total);
if (ret != CUDA_SUCCESS) {
snprintf(buf, buflen, "nvcuda device memory info lookup failure %d", ret);
resp->err = strdup(buf);
// Best effort on failure...
(*h.cuCtxDestroy)(ctx);
return;
}
resp->total = memInfo.total;
resp->free = memInfo.free;
LOG(h.verbose, "[%s] CUDA totalMem %lu mb\n", resp->gpu_id, resp->total / 1024 / 1024);
LOG(h.verbose, "[%s] CUDA freeMem %lu mb\n", resp->gpu_id, resp->free / 1024 / 1024);
LOG(h.verbose, "[%s] Compute Capability %d.%d\n", resp->gpu_id, resp->major, resp->minor);
ret = (*h.cuCtxDestroy)(ctx);
if (ret != CUDA_SUCCESS) {
LOG(1, "nvcuda failed to release primary device context %d", ret);
}
}
void nvcuda_release(nvcuda_handle_t h) {
LOG(h.verbose, "releasing nvcuda library\n");
UNLOAD_LIBRARY(h.handle);
// TODO and other context release logic?
h.handle = NULL;
}
#endif // __APPLE__

71
gpu/gpu_info_nvcuda.h Normal file
View file

@ -0,0 +1,71 @@
#ifndef __APPLE__
#ifndef __GPU_INFO_NVCUDA_H__
#define __GPU_INFO_NVCUDA_H__
#include "gpu_info.h"
// Just enough typedef's to dlopen/dlsym for memory information
typedef enum cudaError_enum {
CUDA_SUCCESS = 0,
CUDA_ERROR_INVALID_VALUE = 1,
CUDA_ERROR_MEMORY_ALLOCATION = 2,
CUDA_ERROR_NOT_INITIALIZED = 3,
CUDA_ERROR_INSUFFICIENT_DRIVER = 35,
// Other values omitted for now...
} CUresult;
typedef enum CUdevice_attribute_enum {
CU_DEVICE_ATTRIBUTE_COMPUTE_CAPABILITY_MAJOR = 75,
CU_DEVICE_ATTRIBUTE_COMPUTE_CAPABILITY_MINOR = 76,
// TODO - not yet wired up but may be useful for Jetson or other
// integrated GPU scenarios with shared memory
CU_DEVICE_ATTRIBUTE_INTEGRATED = 18
} CUdevice_attribute;
typedef void *nvcudaDevice_t; // Opaque is sufficient
typedef struct nvcudaMemory_st {
uint64_t total;
uint64_t free;
} nvcudaMemory_t;
typedef struct nvcudaDriverVersion {
int major;
int minor;
} nvcudaDriverVersion_t;
typedef struct CUuuid_st {
unsigned char bytes[16];
} CUuuid;
typedef int CUdevice;
typedef void* CUcontext;
typedef struct nvcuda_handle {
void *handle;
uint16_t verbose;
CUresult (*cuInit)(unsigned int Flags);
CUresult (*cuDriverGetVersion)(int *driverVersion);
CUresult (*cuDeviceGetCount)(int *);
CUresult (*cuDeviceGet)(CUdevice* device, int ordinal);
CUresult (*cuDeviceGetAttribute)(int* pi, CUdevice_attribute attrib, CUdevice dev);
CUresult (*cuDeviceGetUuid)(CUuuid* uuid, CUdevice dev); // signature compatible with cuDeviceGetUuid_v2
// Context specific aspects
CUresult (*cuCtxCreate_v3)(CUcontext* pctx, void *params, int len, unsigned int flags, CUdevice dev);
CUresult (*cuMemGetInfo_v2)(uint64_t* free, uint64_t* total);
CUresult (*cuCtxDestroy)(CUcontext ctx);
} nvcuda_handle_t;
typedef struct nvcuda_init_resp {
char *err; // If err is non-null handle is invalid
nvcuda_handle_t ch;
int num_devices;
} nvcuda_init_resp_t;
void nvcuda_init(char *nvcuda_lib_path, nvcuda_init_resp_t *resp);
void nvcuda_check_vram(nvcuda_handle_t ch, int device_id, mem_info_t *resp);
void nvcuda_release(nvcuda_handle_t ch);
#endif // __GPU_INFO_NVCUDA_H__
#endif // __APPLE__

View file

@ -1,221 +0,0 @@
#ifndef __APPLE__ // TODO - maybe consider nvidia support on intel macs?
#include <string.h>
#include "gpu_info_nvml.h"
void nvml_init(char *nvml_lib_path, nvml_init_resp_t *resp) {
nvmlReturn_t ret;
resp->err = NULL;
const int buflen = 256;
char buf[buflen + 1];
int i;
struct lookup {
char *s;
void **p;
} l[] = {
{"nvmlInit_v2", (void *)&resp->ch.nvmlInit_v2},
{"nvmlShutdown", (void *)&resp->ch.nvmlShutdown},
{"nvmlDeviceGetHandleByIndex", (void *)&resp->ch.nvmlDeviceGetHandleByIndex},
{"nvmlDeviceGetMemoryInfo", (void *)&resp->ch.nvmlDeviceGetMemoryInfo},
{"nvmlDeviceGetCount_v2", (void *)&resp->ch.nvmlDeviceGetCount_v2},
{"nvmlDeviceGetCudaComputeCapability", (void *)&resp->ch.nvmlDeviceGetCudaComputeCapability},
{"nvmlSystemGetDriverVersion", (void *)&resp->ch.nvmlSystemGetDriverVersion},
{"nvmlDeviceGetName", (void *)&resp->ch.nvmlDeviceGetName},
{"nvmlDeviceGetSerial", (void *)&resp->ch.nvmlDeviceGetSerial},
{"nvmlDeviceGetVbiosVersion", (void *)&resp->ch.nvmlDeviceGetVbiosVersion},
{"nvmlDeviceGetBoardPartNumber", (void *)&resp->ch.nvmlDeviceGetBoardPartNumber},
{"nvmlDeviceGetBrand", (void *)&resp->ch.nvmlDeviceGetBrand},
{NULL, NULL},
};
resp->ch.handle = LOAD_LIBRARY(nvml_lib_path, RTLD_LAZY);
if (!resp->ch.handle) {
char *msg = LOAD_ERR();
LOG(resp->ch.verbose, "library %s load err: %s\n", nvml_lib_path, msg);
snprintf(buf, buflen,
"Unable to load %s library to query for Nvidia GPUs: %s",
nvml_lib_path, msg);
free(msg);
resp->err = strdup(buf);
return;
}
// TODO once we've squashed the remaining corner cases remove this log
LOG(resp->ch.verbose, "wiring nvidia management library functions in %s\n", nvml_lib_path);
for (i = 0; l[i].s != NULL; i++) {
// TODO once we've squashed the remaining corner cases remove this log
LOG(resp->ch.verbose, "dlsym: %s\n", l[i].s);
*l[i].p = LOAD_SYMBOL(resp->ch.handle, l[i].s);
if (!l[i].p) {
resp->ch.handle = NULL;
char *msg = LOAD_ERR();
LOG(resp->ch.verbose, "dlerr: %s\n", msg);
UNLOAD_LIBRARY(resp->ch.handle);
snprintf(buf, buflen, "symbol lookup for %s failed: %s", l[i].s,
msg);
free(msg);
resp->err = strdup(buf);
return;
}
}
ret = (*resp->ch.nvmlInit_v2)();
if (ret != NVML_SUCCESS) {
LOG(resp->ch.verbose, "nvmlInit_v2 err: %d\n", ret);
UNLOAD_LIBRARY(resp->ch.handle);
resp->ch.handle = NULL;
snprintf(buf, buflen, "nvml vram init failure: %d", ret);
resp->err = strdup(buf);
return;
}
// Report driver version if we're in verbose mode, ignore errors
ret = (*resp->ch.nvmlSystemGetDriverVersion)(buf, buflen);
if (ret != NVML_SUCCESS) {
LOG(resp->ch.verbose, "nvmlSystemGetDriverVersion failed: %d\n", ret);
} else {
LOG(resp->ch.verbose, "CUDA driver version: %s\n", buf);
}
}
void nvml_check_vram(nvml_handle_t h, mem_info_t *resp) {
resp->err = NULL;
nvmlDevice_t device;
nvmlMemory_t memInfo = {0};
nvmlReturn_t ret;
const int buflen = 256;
char buf[buflen + 1];
int i;
if (h.handle == NULL) {
resp->err = strdup("nvml handle isn't initialized");
return;
}
ret = (*h.nvmlDeviceGetCount_v2)(&resp->count);
if (ret != NVML_SUCCESS) {
snprintf(buf, buflen, "unable to get device count: %d", ret);
resp->err = strdup(buf);
return;
}
resp->total = 0;
resp->free = 0;
for (i = 0; i < resp->count; i++) {
ret = (*h.nvmlDeviceGetHandleByIndex)(i, &device);
if (ret != NVML_SUCCESS) {
snprintf(buf, buflen, "unable to get device handle %d: %d", i, ret);
resp->err = strdup(buf);
return;
}
ret = (*h.nvmlDeviceGetMemoryInfo)(device, &memInfo);
if (ret != NVML_SUCCESS) {
snprintf(buf, buflen, "device memory info lookup failure %d: %d", i, ret);
resp->err = strdup(buf);
return;
}
if (h.verbose) {
nvmlBrandType_t brand = 0;
// When in verbose mode, report more information about
// the card we discover, but don't fail on error
ret = (*h.nvmlDeviceGetName)(device, buf, buflen);
if (ret != NVML_SUCCESS) {
LOG(h.verbose, "nvmlDeviceGetName failed: %d\n", ret);
} else {
LOG(h.verbose, "[%d] CUDA device name: %s\n", i, buf);
}
ret = (*h.nvmlDeviceGetBoardPartNumber)(device, buf, buflen);
if (ret != NVML_SUCCESS) {
LOG(h.verbose, "nvmlDeviceGetBoardPartNumber failed: %d\n", ret);
} else {
LOG(h.verbose, "[%d] CUDA part number: %s\n", i, buf);
}
ret = (*h.nvmlDeviceGetSerial)(device, buf, buflen);
if (ret != NVML_SUCCESS) {
LOG(h.verbose, "nvmlDeviceGetSerial failed: %d\n", ret);
} else {
LOG(h.verbose, "[%d] CUDA S/N: %s\n", i, buf);
}
ret = (*h.nvmlDeviceGetVbiosVersion)(device, buf, buflen);
if (ret != NVML_SUCCESS) {
LOG(h.verbose, "nvmlDeviceGetVbiosVersion failed: %d\n", ret);
} else {
LOG(h.verbose, "[%d] CUDA vbios version: %s\n", i, buf);
}
ret = (*h.nvmlDeviceGetBrand)(device, &brand);
if (ret != NVML_SUCCESS) {
LOG(h.verbose, "nvmlDeviceGetBrand failed: %d\n", ret);
} else {
LOG(h.verbose, "[%d] CUDA brand: %d\n", i, brand);
}
}
LOG(h.verbose, "[%d] CUDA totalMem %ld\n", i, memInfo.total);
LOG(h.verbose, "[%d] CUDA freeMem %ld\n", i, memInfo.free);
resp->total += memInfo.total;
resp->free += memInfo.free;
}
}
void nvml_compute_capability(nvml_handle_t h, nvml_compute_capability_t *resp) {
resp->err = NULL;
resp->major = 0;
resp->minor = 0;
nvmlDevice_t device;
int major = 0;
int minor = 0;
nvmlReturn_t ret;
const int buflen = 256;
char buf[buflen + 1];
int i;
if (h.handle == NULL) {
resp->err = strdup("nvml handle not initialized");
return;
}
unsigned int devices;
ret = (*h.nvmlDeviceGetCount_v2)(&devices);
if (ret != NVML_SUCCESS) {
snprintf(buf, buflen, "unable to get device count: %d", ret);
resp->err = strdup(buf);
return;
}
for (i = 0; i < devices; i++) {
ret = (*h.nvmlDeviceGetHandleByIndex)(i, &device);
if (ret != NVML_SUCCESS) {
snprintf(buf, buflen, "unable to get device handle %d: %d", i, ret);
resp->err = strdup(buf);
return;
}
ret = (*h.nvmlDeviceGetCudaComputeCapability)(device, &major, &minor);
if (ret != NVML_SUCCESS) {
snprintf(buf, buflen, "device compute capability lookup failure %d: %d", i, ret);
resp->err = strdup(buf);
return;
}
// Report the lowest major.minor we detect as that limits our compatibility
if (resp->major == 0 || resp->major > major ) {
resp->major = major;
resp->minor = minor;
} else if ( resp->major == major && resp->minor > minor ) {
resp->minor = minor;
}
}
}
void nvml_release(nvml_handle_t h) {
LOG(h.verbose, "releasing nvml library\n");
UNLOAD_LIBRARY(h.handle);
h.handle = NULL;
}
#endif // __APPLE__

View file

@ -1,57 +0,0 @@
#ifndef __APPLE__
#ifndef __GPU_INFO_NVML_H__
#define __GPU_INFO_NVML_H__
#include "gpu_info.h"
// Just enough typedef's to dlopen/dlsym for memory information
typedef enum nvmlReturn_enum {
NVML_SUCCESS = 0,
// Other values omitted for now...
} nvmlReturn_t;
typedef void *nvmlDevice_t; // Opaque is sufficient
typedef struct nvmlMemory_st {
unsigned long long total;
unsigned long long free;
unsigned long long used;
} nvmlMemory_t;
typedef enum nvmlBrandType_enum
{
NVML_BRAND_UNKNOWN = 0,
} nvmlBrandType_t;
typedef struct nvml_handle {
void *handle;
uint16_t verbose;
nvmlReturn_t (*nvmlInit_v2)(void);
nvmlReturn_t (*nvmlShutdown)(void);
nvmlReturn_t (*nvmlDeviceGetHandleByIndex)(unsigned int, nvmlDevice_t *);
nvmlReturn_t (*nvmlDeviceGetMemoryInfo)(nvmlDevice_t, nvmlMemory_t *);
nvmlReturn_t (*nvmlDeviceGetCount_v2)(unsigned int *);
nvmlReturn_t (*nvmlDeviceGetCudaComputeCapability)(nvmlDevice_t, int* major, int* minor);
nvmlReturn_t (*nvmlSystemGetDriverVersion) (char* version, unsigned int length);
nvmlReturn_t (*nvmlDeviceGetName) (nvmlDevice_t device, char* name, unsigned int length);
nvmlReturn_t (*nvmlDeviceGetSerial) (nvmlDevice_t device, char* serial, unsigned int length);
nvmlReturn_t (*nvmlDeviceGetVbiosVersion) (nvmlDevice_t device, char* version, unsigned int length);
nvmlReturn_t (*nvmlDeviceGetBoardPartNumber) (nvmlDevice_t device, char* partNumber, unsigned int length);
nvmlReturn_t (*nvmlDeviceGetBrand) (nvmlDevice_t device, nvmlBrandType_t* type);
} nvml_handle_t;
typedef struct nvml_init_resp {
char *err; // If err is non-null handle is invalid
nvml_handle_t ch;
} nvml_init_resp_t;
typedef struct nvml_compute_capability {
char *err;
int major;
int minor;
} nvml_compute_capability_t;
void nvml_init(char *nvml_lib_path, nvml_init_resp_t *resp);
void nvml_check_vram(nvml_handle_t ch, mem_info_t *resp);
void nvml_compute_capability(nvml_handle_t ch, nvml_compute_capability_t *cc);
void nvml_release(nvml_handle_t ch);
#endif // __GPU_INFO_NVML_H__
#endif // __APPLE__

View file

@ -9,23 +9,16 @@ import (
func TestBasicGetGPUInfo(t *testing.T) { func TestBasicGetGPUInfo(t *testing.T) {
info := GetGPUInfo() info := GetGPUInfo()
assert.Contains(t, "cuda rocm cpu metal", info.Library) assert.Greater(t, len(info), 0)
assert.Contains(t, "cuda rocm cpu metal", info[0].Library)
switch runtime.GOOS { if info[0].Library != "cpu" {
case "darwin": assert.Greater(t, info[0].TotalMemory, uint64(0))
// TODO - remove this once MacOS returns some size for CPU assert.Greater(t, info[0].FreeMemory, uint64(0))
return
case "linux", "windows":
assert.Greater(t, info.TotalMemory, uint64(0))
assert.Greater(t, info.FreeMemory, uint64(0))
assert.Greater(t, info.DeviceCount, uint32(0))
default:
return
} }
} }
func TestCPUMemInfo(t *testing.T) { func TestCPUMemInfo(t *testing.T) {
info, err := getCPUMem() info, err := GetCPUMem()
assert.NoError(t, err) assert.NoError(t, err)
switch runtime.GOOS { switch runtime.GOOS {
case "darwin": case "darwin":

View file

@ -3,7 +3,6 @@ package gpu
type memInfo struct { type memInfo struct {
TotalMemory uint64 `json:"total_memory,omitempty"` TotalMemory uint64 `json:"total_memory,omitempty"`
FreeMemory uint64 `json:"free_memory,omitempty"` FreeMemory uint64 `json:"free_memory,omitempty"`
DeviceCount uint32 `json:"device_count,omitempty"`
} }
// Beginning of an `ollama info` command // Beginning of an `ollama info` command
@ -17,11 +16,49 @@ type GpuInfo struct {
// MinimumMemory represents the minimum memory required to use the GPU // MinimumMemory represents the minimum memory required to use the GPU
MinimumMemory uint64 `json:"-"` MinimumMemory uint64 `json:"-"`
// TODO add other useful attributes about the card here for discovery information // Any extra PATH/LD_LIBRARY_PATH dependencies required for the Library to operate properly
DependencyPath string `json:"lib_path,omitempty"`
// GPU information
ID string `json:"gpu_id"` // string to use for selection of this specific GPU
Name string `json:"name"` // user friendly name if available
Major int `json:"major,omitempty"` // Major compatibility version (CC or gfx)
Minor int `json:"minor,omitempty"` // Minor compatibility version (CC or gfx)
Patch int `json:"patch,omitempty"` // Patch compatibility only matters on AMD
// TODO other performance capability info to help in scheduling decisions
} }
type Version struct { type GpuInfoList []GpuInfo
Major uint
Minor uint // Split up the set of gpu info's by Library and variant
Patch uint func (l GpuInfoList) ByLibrary() []GpuInfoList {
resp := []GpuInfoList{}
libs := []string{}
for _, info := range l {
found := false
requested := info.Library
if info.Variant != "" {
requested += "_" + info.Variant
} }
for i, lib := range libs {
if lib == requested {
resp[i] = append(resp[i], info)
found = true
break
}
}
if !found {
libs = append(libs, info.Library)
resp = append(resp, []GpuInfo{info})
}
}
return resp
}
// Sort by Free Space
type ByFreeMemory []GpuInfo
func (a ByFreeMemory) Len() int { return len(a) }
func (a ByFreeMemory) Swap(i, j int) { a[i], a[j] = a[j], a[i] }
func (a ByFreeMemory) Less(i, j int) bool { return a[i].FreeMemory < a[j].FreeMemory }

View file

@ -4,11 +4,14 @@ package integration
import ( import (
"context" "context"
"net/http" "log/slog"
"os"
"runtime"
"testing" "testing"
"time" "time"
"github.com/ollama/ollama/api" "github.com/ollama/ollama/api"
"github.com/stretchr/testify/require"
) )
func TestOrcaMiniBlueSky(t *testing.T) { func TestOrcaMiniBlueSky(t *testing.T) {
@ -24,5 +27,44 @@ func TestOrcaMiniBlueSky(t *testing.T) {
"seed": 123, "seed": 123,
}, },
} }
GenerateTestHelper(ctx, t, &http.Client{}, req, []string{"rayleigh", "scattering"}) GenerateTestHelper(ctx, t, req, []string{"rayleigh", "scattering"})
}
func TestUnicodeModelDir(t *testing.T) {
// This is only useful for Windows with utf-16 characters, so skip this test for other platforms
if runtime.GOOS != "windows" {
t.Skip("Unicode test only applicable to windows")
}
// Only works for local testing
if os.Getenv("OLLAMA_TEST_EXISTING") != "" {
t.Skip("TestUnicodeModelDir only works for local testing, skipping")
}
modelDir, err := os.MkdirTemp("", "ollama_埃")
require.NoError(t, err)
defer os.RemoveAll(modelDir)
slog.Info("unicode", "OLLAMA_MODELS", modelDir)
oldModelsDir := os.Getenv("OLLAMA_MODELS")
if oldModelsDir == "" {
defer os.Unsetenv("OLLAMA_MODELS")
} else {
defer os.Setenv("OLLAMA_MODELS", oldModelsDir)
}
err = os.Setenv("OLLAMA_MODELS", modelDir)
require.NoError(t, err)
ctx, cancel := context.WithTimeout(context.Background(), 2*time.Minute)
defer cancel()
req := api.GenerateRequest{
Model: "orca-mini",
Prompt: "why is the sky blue?",
Stream: &stream,
Options: map[string]interface{}{
"temperature": 0,
"seed": 123,
},
}
GenerateTestHelper(ctx, t, req, []string{"rayleigh", "scattering"})
} }

View file

@ -0,0 +1,225 @@
//go:build integration
package integration
import (
"context"
"log/slog"
"os"
"strconv"
"sync"
"testing"
"time"
"github.com/ollama/ollama/api"
"github.com/stretchr/testify/require"
)
func TestMultiModelConcurrency(t *testing.T) {
var (
req = [2]api.GenerateRequest{
{
Model: "orca-mini",
Prompt: "why is the ocean blue?",
Stream: &stream,
Options: map[string]interface{}{
"seed": 42,
"temperature": 0.0,
},
}, {
Model: "tinydolphin",
Prompt: "what is the origin of the us thanksgiving holiday?",
Stream: &stream,
Options: map[string]interface{}{
"seed": 42,
"temperature": 0.0,
},
},
}
resp = [2][]string{
[]string{"sunlight"},
[]string{"england", "english", "massachusetts", "pilgrims"},
}
)
var wg sync.WaitGroup
wg.Add(len(req))
ctx, cancel := context.WithTimeout(context.Background(), time.Second*120)
defer cancel()
for i := 0; i < len(req); i++ {
go func(i int) {
defer wg.Done()
GenerateTestHelper(ctx, t, req[i], resp[i])
}(i)
}
wg.Wait()
}
func TestIntegrationConcurrentPredictOrcaMini(t *testing.T) {
ctx, cancel := context.WithTimeout(context.Background(), 10*time.Minute) // GTX 750 2G card takes ~9 minutes
defer cancel()
client, _, cleanup := InitServerConnection(ctx, t)
defer cleanup()
req, resp := GenerateRequests()
// Get the server running (if applicable) warm the model up with a single initial request
DoGenerate(ctx, t, client, req[0], resp[0], 60*time.Second, 5*time.Second)
var wg sync.WaitGroup
wg.Add(len(req))
for i := 0; i < len(req); i++ {
go func(i int) {
defer wg.Done()
for j := 0; j < 5; j++ {
slog.Info("Starting", "req", i, "iter", j)
// On slower GPUs it can take a while to process the 4 concurrent requests
// so we allow a much longer initial timeout
DoGenerate(ctx, t, client, req[i], resp[i], 90*time.Second, 5*time.Second)
}
}(i)
}
wg.Wait()
}
// Stress the system if we know how much VRAM it has, and attempt to load more models than will fit
func TestMultiModelStress(t *testing.T) {
vram := os.Getenv("OLLAMA_MAX_VRAM")
if vram == "" {
t.Skip("OLLAMA_MAX_VRAM not specified, can't pick the right models for the stress test")
}
max, err := strconv.ParseUint(vram, 10, 64)
require.NoError(t, err)
const MB = uint64(1024 * 1024)
type model struct {
name string
size uint64 // Approximate amount of VRAM they typically use when fully loaded in VRAM
}
smallModels := []model{
{
name: "orca-mini",
size: 2992 * MB,
},
{
name: "phi",
size: 2616 * MB,
},
{
name: "gemma:2b",
size: 2364 * MB,
},
{
name: "stable-code:3b",
size: 2608 * MB,
},
{
name: "starcoder2:3b",
size: 2166 * MB,
},
}
mediumModels := []model{
{
name: "llama2",
size: 5118 * MB,
},
{
name: "mistral",
size: 4620 * MB,
},
{
name: "orca-mini:7b",
size: 5118 * MB,
},
{
name: "dolphin-mistral",
size: 4620 * MB,
},
{
name: "gemma:7b",
size: 5000 * MB,
},
// TODO - uncomment this once #3565 is merged and this is rebased on it
// {
// name: "codellama:7b",
// size: 5118 * MB,
// },
}
// These seem to be too slow to be useful...
// largeModels := []model{
// {
// name: "llama2:13b",
// size: 7400 * MB,
// },
// {
// name: "codellama:13b",
// size: 7400 * MB,
// },
// {
// name: "orca-mini:13b",
// size: 7400 * MB,
// },
// {
// name: "gemma:7b",
// size: 5000 * MB,
// },
// {
// name: "starcoder2:15b",
// size: 9100 * MB,
// },
// }
var chosenModels []model
switch {
case max < 10000*MB:
slog.Info("selecting small models")
chosenModels = smallModels
// case max < 30000*MB:
default:
slog.Info("selecting medium models")
chosenModels = mediumModels
// default:
// slog.Info("selecting large models")
// chosenModels = largModels
}
req, resp := GenerateRequests()
for i := range req {
if i > len(chosenModels) {
break
}
req[i].Model = chosenModels[i].name
}
ctx, cancel := context.WithTimeout(context.Background(), 15*time.Minute) // TODO baseline -- 10m too short
defer cancel()
client, _, cleanup := InitServerConnection(ctx, t)
defer cleanup()
// Make sure all the models are pulled before we get started
for _, r := range req {
require.NoError(t, PullIfMissing(ctx, client, r.Model))
}
var wg sync.WaitGroup
consumed := uint64(256 * MB) // Assume some baseline usage
for i := 0; i < len(req); i++ {
// Always get at least 2 models, but dont' overshoot VRAM too much or we'll take too long
if i > 1 && consumed > max {
slog.Info("achieved target vram exhaustion", "count", i, "vramMB", max/1024/1024, "modelsMB", consumed/1024/1024)
break
}
consumed += chosenModels[i].size
slog.Info("target vram", "count", i, "vramMB", max/1024/1024, "modelsMB", consumed/1024/1024)
wg.Add(1)
go func(i int) {
defer wg.Done()
for j := 0; j < 3; j++ {
slog.Info("Starting", "req", i, "iter", j, "model", req[i].Model)
DoGenerate(ctx, t, client, req[i], resp[i], 90*time.Second, 5*time.Second)
}
}(i)
}
wg.Wait()
}

View file

@ -4,7 +4,6 @@ package integration
import ( import (
"context" "context"
"net/http"
"testing" "testing"
"time" "time"
@ -25,5 +24,5 @@ func TestContextExhaustion(t *testing.T) {
"num_ctx": 128, "num_ctx": 128,
}, },
} }
GenerateTestHelper(ctx, t, &http.Client{}, req, []string{"once", "upon", "lived"}) GenerateTestHelper(ctx, t, req, []string{"once", "upon", "lived"})
} }

View file

@ -5,7 +5,6 @@ package integration
import ( import (
"context" "context"
"encoding/base64" "encoding/base64"
"net/http"
"testing" "testing"
"time" "time"
@ -29,10 +28,11 @@ func TestIntegrationMultimodal(t *testing.T) {
}, },
} }
resp := "the ollamas" // Note: sometimes it returns "the ollamas" sometimes "the ollams"
resp := "the ollam"
ctx, cancel := context.WithTimeout(context.Background(), 3*time.Minute) ctx, cancel := context.WithTimeout(context.Background(), 3*time.Minute)
defer cancel() defer cancel()
GenerateTestHelper(ctx, t, &http.Client{}, req, []string{resp}) GenerateTestHelper(ctx, t, req, []string{resp})
} }
const imageEncoding = `iVBORw0KGgoAAAANSUhEUgAAANIAAAB4CAYAAACHHqzKAAAAAXNSR0IArs4c6QAAAIRlWElmTU0AKgAAAAgABQESAAMAAAABAAEAAAEaAAUAAAABAAAASgEb const imageEncoding = `iVBORw0KGgoAAAANSUhEUgAAANIAAAB4CAYAAACHHqzKAAAAAXNSR0IArs4c6QAAAIRlWElmTU0AKgAAAAgABQESAAMAAAABAAEAAAEaAAUAAAABAAAASgEb

View file

@ -4,8 +4,6 @@ package integration
import ( import (
"context" "context"
"net/http"
"sync"
"testing" "testing"
"time" "time"
@ -45,25 +43,5 @@ var (
func TestIntegrationSimpleOrcaMini(t *testing.T) { func TestIntegrationSimpleOrcaMini(t *testing.T) {
ctx, cancel := context.WithTimeout(context.Background(), time.Second*120) ctx, cancel := context.WithTimeout(context.Background(), time.Second*120)
defer cancel() defer cancel()
GenerateTestHelper(ctx, t, &http.Client{}, req[0], resp[0]) GenerateTestHelper(ctx, t, req[0], resp[0])
} }
// TODO
// The server always loads a new runner and closes the old one, which forces serial execution
// At present this test case fails with concurrency problems. Eventually we should try to
// get true concurrency working with n_parallel support in the backend
func TestIntegrationConcurrentPredictOrcaMini(t *testing.T) {
var wg sync.WaitGroup
wg.Add(len(req))
ctx, cancel := context.WithTimeout(context.Background(), time.Second*120)
defer cancel()
for i := 0; i < len(req); i++ {
go func(i int) {
defer wg.Done()
GenerateTestHelper(ctx, t, &http.Client{}, req[i], resp[i])
}(i)
}
wg.Wait()
}
// TODO - create a parallel test with 2 different models once we support concurrency

View file

@ -0,0 +1,117 @@
//go:build integration
package integration
import (
"context"
"errors"
"fmt"
"log/slog"
"os"
"strconv"
"strings"
"sync"
"testing"
"time"
"github.com/ollama/ollama/api"
"github.com/stretchr/testify/require"
)
func TestMaxQueue(t *testing.T) {
// Note: This test can be quite slow when running in CPU mode, so keep the threadCount low unless your on GPU
// Also note that by default Darwin can't sustain > ~128 connections without adjusting limits
threadCount := 32
mq := os.Getenv("OLLAMA_MAX_QUEUE")
if mq != "" {
var err error
threadCount, err = strconv.Atoi(mq)
require.NoError(t, err)
} else {
os.Setenv("OLLAMA_MAX_QUEUE", fmt.Sprintf("%d", threadCount))
}
req := api.GenerateRequest{
Model: "orca-mini",
Prompt: "write a long historical fiction story about christopher columbus. use at least 10 facts from his actual journey",
Options: map[string]interface{}{
"seed": 42,
"temperature": 0.0,
},
}
resp := []string{"explore", "discover", "ocean"}
// CPU mode takes much longer at the limit with a large queue setting
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Minute)
defer cancel()
client, _, cleanup := InitServerConnection(ctx, t)
defer cleanup()
require.NoError(t, PullIfMissing(ctx, client, req.Model))
// Context for the worker threads so we can shut them down
// embedCtx, embedCancel := context.WithCancel(ctx)
embedCtx := ctx
var genwg sync.WaitGroup
go func() {
genwg.Add(1)
defer genwg.Done()
slog.Info("Starting generate request")
DoGenerate(ctx, t, client, req, resp, 45*time.Second, 5*time.Second)
slog.Info("generate completed")
}()
// Give the generate a chance to get started before we start hammering on embed requests
time.Sleep(5 * time.Millisecond)
threadCount += 10 // Add a few extra to ensure we push the queue past its limit
busyCount := 0
resetByPeerCount := 0
canceledCount := 0
succesCount := 0
counterMu := sync.Mutex{}
var embedwg sync.WaitGroup
for i := 0; i < threadCount; i++ {
go func(i int) {
embedwg.Add(1)
defer embedwg.Done()
slog.Info("embed started", "id", i)
embedReq := api.EmbeddingRequest{
Model: req.Model,
Prompt: req.Prompt,
Options: req.Options,
}
// Fresh client for every request
client, _ = GetTestEndpoint()
resp, genErr := client.Embeddings(embedCtx, &embedReq)
counterMu.Lock()
defer counterMu.Unlock()
switch {
case genErr == nil:
succesCount++
require.Greater(t, len(resp.Embedding), 5) // somewhat arbitrary, but sufficient to be reasonable
case errors.Is(genErr, context.Canceled):
canceledCount++
case strings.Contains(genErr.Error(), "busy"):
busyCount++
case strings.Contains(genErr.Error(), "connection reset by peer"):
resetByPeerCount++
default:
require.NoError(t, genErr, "%d request failed", i)
}
slog.Info("embed finished", "id", i)
}(i)
}
genwg.Wait()
slog.Info("generate done, waiting for embeds")
embedwg.Wait()
require.Equal(t, resetByPeerCount, 0, "Connections reset by peer, have you updated your fd and socket limits?")
require.True(t, busyCount > 0, "no requests hit busy error but some should have")
require.True(t, canceledCount == 0, "no requests should have been canceled due to timeout")
slog.Info("embeds completed", "success", succesCount, "busy", busyCount, "reset", resetByPeerCount, "canceled", canceledCount)
}

View file

@ -5,13 +5,14 @@ package integration
import ( import (
"bytes" "bytes"
"context" "context"
"encoding/json" "errors"
"fmt" "fmt"
"io" "io"
"log/slog" "log/slog"
"math/rand" "math/rand"
"net" "net"
"net/http" "net/http"
"net/url"
"os" "os"
"path/filepath" "path/filepath"
"runtime" "runtime"
@ -23,9 +24,13 @@ import (
"github.com/ollama/ollama/api" "github.com/ollama/ollama/api"
"github.com/ollama/ollama/app/lifecycle" "github.com/ollama/ollama/app/lifecycle"
"github.com/stretchr/testify/assert" "github.com/stretchr/testify/require"
) )
func Init() {
lifecycle.InitLogging()
}
func FindPort() string { func FindPort() string {
port := 0 port := 0
if a, err := net.ResolveTCPAddr("tcp", "localhost:0"); err == nil { if a, err := net.ResolveTCPAddr("tcp", "localhost:0"); err == nil {
@ -41,7 +46,7 @@ func FindPort() string {
return strconv.Itoa(port) return strconv.Itoa(port)
} }
func GetTestEndpoint() (string, string) { func GetTestEndpoint() (*api.Client, string) {
defaultPort := "11434" defaultPort := "11434"
ollamaHost := os.Getenv("OLLAMA_HOST") ollamaHost := os.Getenv("OLLAMA_HOST")
@ -67,16 +72,20 @@ func GetTestEndpoint() (string, string) {
port = FindPort() port = FindPort()
} }
url := fmt.Sprintf("%s:%s", host, port) slog.Info("server connection", "host", host, "port", port)
slog.Info("server connection", "url", url)
return scheme, url return api.NewClient(
&url.URL{
Scheme: scheme,
Host: net.JoinHostPort(host, port),
},
http.DefaultClient), fmt.Sprintf("%s:%s", host, port)
} }
// TODO make fanicier, grab logs, etc.
var serverMutex sync.Mutex var serverMutex sync.Mutex
var serverReady bool var serverReady bool
func StartServer(ctx context.Context, ollamaHost string) error { func startServer(ctx context.Context, ollamaHost string) error {
// Make sure the server has been built // Make sure the server has been built
CLIName, err := filepath.Abs("../ollama") CLIName, err := filepath.Abs("../ollama")
if err != nil { if err != nil {
@ -98,7 +107,7 @@ func StartServer(ctx context.Context, ollamaHost string) error {
if tmp := os.Getenv("OLLAMA_HOST"); tmp != ollamaHost { if tmp := os.Getenv("OLLAMA_HOST"); tmp != ollamaHost {
slog.Info("setting env", "OLLAMA_HOST", ollamaHost) slog.Info("setting env", "OLLAMA_HOST", ollamaHost)
os.Setenv("OLLAMA_HOST", ollamaHost) t.Setenv("OLLAMA_HOST", ollamaHost)
} }
slog.Info("starting server", "url", ollamaHost) slog.Info("starting server", "url", ollamaHost)
@ -125,67 +134,76 @@ func StartServer(ctx context.Context, ollamaHost string) error {
return nil return nil
} }
func PullIfMissing(ctx context.Context, client *http.Client, scheme, testEndpoint, modelName string) error { func PullIfMissing(ctx context.Context, client *api.Client, modelName string) error {
slog.Info("checking status of model", "model", modelName) slog.Info("checking status of model", "model", modelName)
showReq := &api.ShowRequest{Name: modelName} showReq := &api.ShowRequest{Name: modelName}
requestJSON, err := json.Marshal(showReq)
if err != nil {
return err
}
req, err := http.NewRequest("POST", scheme+"://"+testEndpoint+"/api/show", bytes.NewReader(requestJSON)) showCtx, cancel := context.WithDeadlineCause(
if err != nil { ctx,
time.Now().Add(5*time.Second),
fmt.Errorf("show for existing model %s took too long", modelName),
)
defer cancel()
_, err := client.Show(showCtx, showReq)
var statusError api.StatusError
switch {
case errors.As(err, &statusError) && statusError.StatusCode == http.StatusNotFound:
break
case err != nil:
return err return err
} default:
// Make the request with the HTTP client
response, err := client.Do(req.WithContext(ctx))
if err != nil {
return err
}
defer response.Body.Close()
if response.StatusCode == 200 {
slog.Info("model already present", "model", modelName) slog.Info("model already present", "model", modelName)
return nil return nil
} }
slog.Info("model missing", "status", response.StatusCode) slog.Info("model missing", "model", modelName)
pullReq := &api.PullRequest{Name: modelName, Stream: &stream} stallDuration := 30 * time.Second // This includes checksum verification, which can take a while on larger models
requestJSON, err = json.Marshal(pullReq) stallTimer := time.NewTimer(stallDuration)
if err != nil { fn := func(resp api.ProgressResponse) error {
return err // fmt.Print(".")
if !stallTimer.Reset(stallDuration) {
return fmt.Errorf("stall was detected, aborting status reporting")
} }
req, err = http.NewRequest("POST", scheme+"://"+testEndpoint+"/api/pull", bytes.NewReader(requestJSON))
if err != nil {
return err
}
slog.Info("pulling", "model", modelName)
response, err = client.Do(req.WithContext(ctx))
if err != nil {
return err
}
defer response.Body.Close()
if response.StatusCode != 200 {
return fmt.Errorf("failed to pull model") // TODO more details perhaps
}
slog.Info("model pulled", "model", modelName)
return nil return nil
} }
stream := true
pullReq := &api.PullRequest{Name: modelName, Stream: &stream}
var pullError error
done := make(chan int)
go func() {
pullError = client.Pull(ctx, pullReq, fn)
done <- 0
}()
select {
case <-stallTimer.C:
return fmt.Errorf("download stalled")
case <-done:
return pullError
}
}
var serverProcMutex sync.Mutex var serverProcMutex sync.Mutex
func GenerateTestHelper(ctx context.Context, t *testing.T, client *http.Client, genReq api.GenerateRequest, anyResp []string) { // Returns an Client, the testEndpoint, and a cleanup function, fails the test on errors
// Starts the server if needed
// TODO maybe stuff in an init routine? func InitServerConnection(ctx context.Context, t *testing.T) (*api.Client, string, func()) {
lifecycle.InitLogging() client, testEndpoint := GetTestEndpoint()
if os.Getenv("OLLAMA_TEST_EXISTING") == "" {
requestJSON, err := json.Marshal(genReq) serverProcMutex.Lock()
fp, err := os.CreateTemp("", "ollama-server-*.log")
if err != nil { if err != nil {
t.Fatalf("Error serializing request: %v", err) t.Fatalf("failed to generate log file: %s", err)
} }
defer func() { lifecycle.ServerLogFile = fp.Name()
fp.Close()
require.NoError(t, startServer(ctx, testEndpoint))
}
return client, testEndpoint, func() {
if os.Getenv("OLLAMA_TEST_EXISTING") == "" { if os.Getenv("OLLAMA_TEST_EXISTING") == "" {
defer serverProcMutex.Unlock() defer serverProcMutex.Unlock()
if t.Failed() { if t.Failed() {
@ -203,63 +221,118 @@ func GenerateTestHelper(ctx context.Context, t *testing.T, client *http.Client,
os.Stderr.Write(data) os.Stderr.Write(data)
slog.Warn("END OF SERVER") slog.Warn("END OF SERVER")
} }
err = os.Remove(lifecycle.ServerLogFile) err := os.Remove(lifecycle.ServerLogFile)
if err != nil && !os.IsNotExist(err) { if err != nil && !os.IsNotExist(err) {
slog.Warn("failed to cleanup", "logfile", lifecycle.ServerLogFile, "error", err) slog.Warn("failed to cleanup", "logfile", lifecycle.ServerLogFile, "error", err)
} }
} }
}
}
func GenerateTestHelper(ctx context.Context, t *testing.T, genReq api.GenerateRequest, anyResp []string) {
client, _, cleanup := InitServerConnection(ctx, t)
defer cleanup()
require.NoError(t, PullIfMissing(ctx, client, genReq.Model))
DoGenerate(ctx, t, client, genReq, anyResp, 30*time.Second, 10*time.Second)
}
func DoGenerate(ctx context.Context, t *testing.T, client *api.Client, genReq api.GenerateRequest, anyResp []string, initialTimeout, streamTimeout time.Duration) {
stallTimer := time.NewTimer(initialTimeout)
var buf bytes.Buffer
fn := func(response api.GenerateResponse) error {
// fmt.Print(".")
buf.Write([]byte(response.Response))
if !stallTimer.Reset(streamTimeout) {
return fmt.Errorf("stall was detected while streaming response, aborting")
}
return nil
}
stream := true
genReq.Stream = &stream
done := make(chan int)
var genErr error
go func() {
genErr = client.Generate(ctx, &genReq, fn)
done <- 0
}() }()
scheme, testEndpoint := GetTestEndpoint()
if os.Getenv("OLLAMA_TEST_EXISTING") == "" { select {
serverProcMutex.Lock() case <-stallTimer.C:
fp, err := os.CreateTemp("", "ollama-server-*.log") if buf.Len() == 0 {
if err != nil { t.Errorf("generate never started. Timed out after :%s", initialTimeout.String())
t.Fatalf("failed to generate log file: %s", err) } else {
t.Errorf("generate stalled. Response so far:%s", buf.String())
} }
lifecycle.ServerLogFile = fp.Name() case <-done:
fp.Close() require.NoError(t, genErr, "failed with %s request prompt %s ", genReq.Model, genReq.Prompt)
assert.NoError(t, StartServer(ctx, testEndpoint))
}
err = PullIfMissing(ctx, client, scheme, testEndpoint, genReq.Model)
if err != nil {
t.Fatalf("Error pulling model: %v", err)
}
// Make the request and get the response
req, err := http.NewRequest("POST", scheme+"://"+testEndpoint+"/api/generate", bytes.NewReader(requestJSON))
if err != nil {
t.Fatalf("Error creating request: %v", err)
}
// Set the content type for the request
req.Header.Set("Content-Type", "application/json")
// Make the request with the HTTP client
response, err := client.Do(req.WithContext(ctx))
if err != nil {
t.Fatalf("Error making request: %v", err)
}
defer response.Body.Close()
body, err := io.ReadAll(response.Body)
assert.NoError(t, err)
assert.Equal(t, response.StatusCode, 200, string(body))
// Verify the response is valid JSON
var payload api.GenerateResponse
err = json.Unmarshal(body, &payload)
if err != nil {
assert.NoError(t, err, body)
}
// Verify the response contains the expected data // Verify the response contains the expected data
response := buf.String()
atLeastOne := false atLeastOne := false
for _, resp := range anyResp { for _, resp := range anyResp {
if strings.Contains(strings.ToLower(payload.Response), resp) { if strings.Contains(strings.ToLower(response), resp) {
atLeastOne = true atLeastOne = true
break break
} }
} }
assert.True(t, atLeastOne, "none of %v found in %s", anyResp, payload.Response) require.True(t, atLeastOne, "none of %v found in %s", anyResp, response)
slog.Info("test pass", "model", genReq.Model, "prompt", genReq.Prompt, "contains", anyResp, "response", response)
case <-ctx.Done():
t.Error("outer test context done while waiting for generate")
}
}
// Generate a set of requests
// By default each request uses orca-mini as the model
func GenerateRequests() ([]api.GenerateRequest, [][]string) {
return []api.GenerateRequest{
{
Model: "orca-mini",
Prompt: "why is the ocean blue?",
Stream: &stream,
Options: map[string]interface{}{
"seed": 42,
"temperature": 0.0,
},
}, {
Model: "orca-mini",
Prompt: "why is the color of dirt brown?",
Stream: &stream,
Options: map[string]interface{}{
"seed": 42,
"temperature": 0.0,
},
}, {
Model: "orca-mini",
Prompt: "what is the origin of the us thanksgiving holiday?",
Stream: &stream,
Options: map[string]interface{}{
"seed": 42,
"temperature": 0.0,
},
}, {
Model: "orca-mini",
Prompt: "what is the origin of independence day?",
Stream: &stream,
Options: map[string]interface{}{
"seed": 42,
"temperature": 0.0,
},
}, {
Model: "orca-mini",
Prompt: "what is the composition of air?",
Stream: &stream,
Options: map[string]interface{}{
"seed": 42,
"temperature": 0.0,
},
},
},
[][]string{
[]string{"sunlight"},
[]string{"soil", "organic", "earth", "black", "tan"},
[]string{"england", "english", "massachusetts", "pilgrims"},
[]string{"fourth", "july", "declaration", "independence"},
[]string{"nitrogen", "oxygen", "carbon", "dioxide"},
}
} }

View file

@ -1032,7 +1032,7 @@ struct llama_server_context
slot.has_next_token = false; slot.has_next_token = false;
} }
if (!slot.cache_tokens.empty() && result.tok == llama_token_eos(model)) if (!slot.cache_tokens.empty() && llama_token_is_eog(model, result.tok))
{ {
slot.stopped_eos = true; slot.stopped_eos = true;
slot.has_next_token = false; slot.has_next_token = false;
@ -1144,12 +1144,15 @@ struct llama_server_context
res.result_json = json res.result_json = json
{ {
{"content", tkn.text_to_send},
{"stop", false}, {"stop", false},
{"slot_id", slot.id}, {"slot_id", slot.id},
{"multimodal", multimodal} {"multimodal", multimodal}
}; };
if (!llama_token_is_eog(model, tkn.tok)) {
res.result_json["content"] = tkn.text_to_send;
}
if (slot.sparams.n_probs > 0) if (slot.sparams.n_probs > 0)
{ {
std::vector<completion_token_output> probs_output = {}; std::vector<completion_token_output> probs_output = {};
@ -1183,8 +1186,6 @@ struct llama_server_context
{"model", params.model_alias}, {"model", params.model_alias},
{"tokens_predicted", slot.n_decoded}, {"tokens_predicted", slot.n_decoded},
{"tokens_evaluated", slot.n_prompt_tokens}, {"tokens_evaluated", slot.n_prompt_tokens},
{"generation_settings", get_formated_generation(slot)},
{"prompt", slot.prompt},
{"truncated", slot.truncated}, {"truncated", slot.truncated},
{"stopped_eos", slot.stopped_eos}, {"stopped_eos", slot.stopped_eos},
{"stopped_word", slot.stopped_word}, {"stopped_word", slot.stopped_word},
@ -2644,18 +2645,18 @@ static void server_params_parse(int argc, char **argv, server_params &sparams,
if (strncmp(sep, "int:", 4) == 0) { if (strncmp(sep, "int:", 4) == 0) {
sep += 4; sep += 4;
kvo.tag = LLAMA_KV_OVERRIDE_TYPE_INT; kvo.tag = LLAMA_KV_OVERRIDE_TYPE_INT;
kvo.int_value = std::atol(sep); kvo.val_i64 = std::atol(sep);
} else if (strncmp(sep, "float:", 6) == 0) { } else if (strncmp(sep, "float:", 6) == 0) {
sep += 6; sep += 6;
kvo.tag = LLAMA_KV_OVERRIDE_TYPE_FLOAT; kvo.tag = LLAMA_KV_OVERRIDE_TYPE_FLOAT;
kvo.float_value = std::atof(sep); kvo.val_f64 = std::atof(sep);
} else if (strncmp(sep, "bool:", 5) == 0) { } else if (strncmp(sep, "bool:", 5) == 0) {
sep += 5; sep += 5;
kvo.tag = LLAMA_KV_OVERRIDE_TYPE_BOOL; kvo.tag = LLAMA_KV_OVERRIDE_TYPE_BOOL;
if (std::strcmp(sep, "true") == 0) { if (std::strcmp(sep, "true") == 0) {
kvo.bool_value = true; kvo.val_bool = true;
} else if (std::strcmp(sep, "false") == 0) { } else if (std::strcmp(sep, "false") == 0) {
kvo.bool_value = false; kvo.val_bool = false;
} else { } else {
fprintf(stderr, "error: Invalid boolean value for KV override: %s\n", argv[i]); fprintf(stderr, "error: Invalid boolean value for KV override: %s\n", argv[i]);
invalid_param = true; invalid_param = true;

140
llm/filetype.go Normal file
View file

@ -0,0 +1,140 @@
package llm
import "fmt"
type fileType uint32
const (
fileTypeF32 fileType = iota
fileTypeF16
fileTypeQ4_0
fileTypeQ4_1
fileTypeQ4_1_F16
fileTypeQ4_2 // unused
fileTypeQ4_3 // unused
fileTypeQ8_0
fileTypeQ5_0
fileTypeQ5_1
fileTypeQ2_K
fileTypeQ3_K_S
fileTypeQ3_K_M
fileTypeQ3_K_L
fileTypeQ4_K_S
fileTypeQ4_K_M
fileTypeQ5_K_S
fileTypeQ5_K_M
fileTypeQ6_K
fileTypeIQ2_XXS
fileTypeIQ2_XS
fileTypeQ2_K_S
fileTypeQ3_K_XS
fileTypeIQ3_XXS
fileTypeUnknown
)
func ParseFileType(s string) (fileType, error) {
switch s {
case "F32":
return fileTypeF32, nil
case "F16":
return fileTypeF16, nil
case "Q4_0":
return fileTypeQ4_0, nil
case "Q4_1":
return fileTypeQ4_1, nil
case "Q4_1_F16":
return fileTypeQ4_1_F16, nil
case "Q8_0":
return fileTypeQ8_0, nil
case "Q5_0":
return fileTypeQ5_0, nil
case "Q5_1":
return fileTypeQ5_1, nil
case "Q2_K":
return fileTypeQ2_K, nil
case "Q3_K_S":
return fileTypeQ3_K_S, nil
case "Q3_K_M":
return fileTypeQ3_K_M, nil
case "Q3_K_L":
return fileTypeQ3_K_L, nil
case "Q4_K_S":
return fileTypeQ4_K_S, nil
case "Q4_K_M":
return fileTypeQ4_K_M, nil
case "Q5_K_S":
return fileTypeQ5_K_S, nil
case "Q5_K_M":
return fileTypeQ5_K_M, nil
case "Q6_K":
return fileTypeQ6_K, nil
case "IQ2_XXS":
return fileTypeIQ2_XXS, nil
case "IQ2_XS":
return fileTypeIQ2_XS, nil
case "Q2_K_S":
return fileTypeQ2_K_S, nil
case "Q3_K_XS":
return fileTypeQ3_K_XS, nil
case "IQ3_XXS":
return fileTypeIQ3_XXS, nil
default:
return fileTypeUnknown, fmt.Errorf("unknown fileType: %s", s)
}
}
func (t fileType) String() string {
switch t {
case fileTypeF32:
return "F32"
case fileTypeF16:
return "F16"
case fileTypeQ4_0:
return "Q4_0"
case fileTypeQ4_1:
return "Q4_1"
case fileTypeQ4_1_F16:
return "Q4_1_F16"
case fileTypeQ8_0:
return "Q8_0"
case fileTypeQ5_0:
return "Q5_0"
case fileTypeQ5_1:
return "Q5_1"
case fileTypeQ2_K:
return "Q2_K"
case fileTypeQ3_K_S:
return "Q3_K_S"
case fileTypeQ3_K_M:
return "Q3_K_M"
case fileTypeQ3_K_L:
return "Q3_K_L"
case fileTypeQ4_K_S:
return "Q4_K_S"
case fileTypeQ4_K_M:
return "Q4_K_M"
case fileTypeQ5_K_S:
return "Q5_K_S"
case fileTypeQ5_K_M:
return "Q5_K_M"
case fileTypeQ6_K:
return "Q6_K"
case fileTypeIQ2_XXS:
return "IQ2_XXS"
case fileTypeIQ2_XS:
return "IQ2_XS"
case fileTypeQ2_K_S:
return "Q2_K_S"
case fileTypeQ3_K_XS:
return "Q3_K_XS"
case fileTypeIQ3_XXS:
return "IQ3_XXS"
default:
return "unknown"
}
}
func (t fileType) Value() uint32 {
return uint32(t)
}

View file

@ -57,11 +57,10 @@ init_vars
git_module_setup git_module_setup
apply_patches apply_patches
init_vars init_vars
if [ -z "${OLLAMA_SKIP_CPU_GENERATE}" ]; then if [ -z "${OLLAMA_SKIP_STATIC_GENERATE}" -o "${OLLAMA_CPU_TARGET}" = "static" ]; then
# Builds by default, allows skipping, forces build if OLLAMA_CPU_TARGET="static"
if [ -z "${OLLAMA_CPU_TARGET}" -o "${OLLAMA_CPU_TARGET}" = "static" ]; then # Enables optimized Dockerfile builds using a blanket skip and targeted overrides
# Static build for linking into the Go binary # Static build for linking into the Go binary
init_vars init_vars
CMAKE_TARGETS="--target llama --target ggml" CMAKE_TARGETS="--target llama --target ggml"
@ -71,7 +70,8 @@ if [ -z "${OLLAMA_SKIP_CPU_GENERATE}" ]; then
build build
fi fi
init_vars
if [ -z "${OLLAMA_SKIP_CPU_GENERATE}" ]; then
# Users building from source can tune the exact flags we pass to cmake for configuring # Users building from source can tune the exact flags we pass to cmake for configuring
# llama.cpp, and we'll build only 1 CPU variant in that case as the default. # llama.cpp, and we'll build only 1 CPU variant in that case as the default.
if [ -n "${OLLAMA_CUSTOM_CPU_DEFS}" ]; then if [ -n "${OLLAMA_CUSTOM_CPU_DEFS}" ]; then
@ -172,7 +172,15 @@ if [ -d "${CUDA_LIB_DIR}" ]; then
# Disabling has minimal performance effect while maintaining compatibility. # Disabling has minimal performance effect while maintaining compatibility.
ARM64_DEFS="-DLLAMA_AVX=off -DLLAMA_AVX2=off -DLLAMA_AVX512=off -DLLAMA_CUDA_F16=off" ARM64_DEFS="-DLLAMA_AVX=off -DLLAMA_AVX2=off -DLLAMA_AVX512=off -DLLAMA_CUDA_F16=off"
fi fi
CMAKE_DEFS="-DLLAMA_CUDA=on -DLLAMA_CUDA_FORCE_MMQ=on -DCMAKE_CUDA_ARCHITECTURES=${CMAKE_CUDA_ARCHITECTURES} ${COMMON_CMAKE_DEFS} ${CMAKE_DEFS} ${ARM64_DEFS}" # Users building from source can tune the exact flags we pass to cmake for configuring llama.cpp
if [ -n "${OLLAMA_CUSTOM_CUDA_DEFS}" ]; then
echo "OLLAMA_CUSTOM_CUDA_DEFS=\"${OLLAMA_CUSTOM_CUDA_DEFS}\""
CMAKE_CUDA_DEFS="-DLLAMA_CUDA=on -DCMAKE_CUDA_ARCHITECTURES=${CMAKE_CUDA_ARCHITECTURES} ${OLLAMA_CUSTOM_CUDA_DEFS}"
echo "Building custom CUDA GPU"
else
CMAKE_CUDA_DEFS="-DLLAMA_CUDA=on -DLLAMA_CUDA_FORCE_MMQ=on -DCMAKE_CUDA_ARCHITECTURES=${CMAKE_CUDA_ARCHITECTURES}"
fi
CMAKE_DEFS="${COMMON_CMAKE_DEFS} ${CMAKE_DEFS} ${ARM64_DEFS} ${CMAKE_CUDA_DEFS}"
BUILD_DIR="../build/linux/${ARCH}/cuda${CUDA_VARIANT}" BUILD_DIR="../build/linux/${ARCH}/cuda${CUDA_VARIANT}"
EXTRA_LIBS="-L${CUDA_LIB_DIR} -lcudart -lcublas -lcublasLt -lcuda" EXTRA_LIBS="-L${CUDA_LIB_DIR} -lcudart -lcublas -lcublasLt -lcuda"
build build
@ -217,6 +225,12 @@ if [ -d "${ROCM_PATH}" ]; then
fi fi
init_vars init_vars
CMAKE_DEFS="${COMMON_CMAKE_DEFS} ${CMAKE_DEFS} -DLLAMA_HIPBLAS=on -DCMAKE_C_COMPILER=$ROCM_PATH/llvm/bin/clang -DCMAKE_CXX_COMPILER=$ROCM_PATH/llvm/bin/clang++ -DAMDGPU_TARGETS=$(amdGPUs) -DGPU_TARGETS=$(amdGPUs)" CMAKE_DEFS="${COMMON_CMAKE_DEFS} ${CMAKE_DEFS} -DLLAMA_HIPBLAS=on -DCMAKE_C_COMPILER=$ROCM_PATH/llvm/bin/clang -DCMAKE_CXX_COMPILER=$ROCM_PATH/llvm/bin/clang++ -DAMDGPU_TARGETS=$(amdGPUs) -DGPU_TARGETS=$(amdGPUs)"
# Users building from source can tune the exact flags we pass to cmake for configuring llama.cpp
if [ -n "${OLLAMA_CUSTOM_ROCM_DEFS}" ]; then
echo "OLLAMA_CUSTOM_ROCM_DEFS=\"${OLLAMA_CUSTOM_ROCM_DEFS}\""
CMAKE_DEFS="${CMAKE_DEFS} ${OLLAMA_CUSTOM_ROCM_DEFS}"
echo "Building custom ROCM GPU"
fi
BUILD_DIR="../build/linux/${ARCH}/rocm${ROCM_VARIANT}" BUILD_DIR="../build/linux/${ARCH}/rocm${ROCM_VARIANT}"
EXTRA_LIBS="-L${ROCM_PATH}/lib -L/opt/amdgpu/lib/x86_64-linux-gnu/ -Wl,-rpath,\$ORIGIN/../../rocm/ -lhipblas -lrocblas -lamdhip64 -lrocsolver -lamd_comgr -lhsa-runtime64 -lrocsparse -ldrm -ldrm_amdgpu" EXTRA_LIBS="-L${ROCM_PATH}/lib -L/opt/amdgpu/lib/x86_64-linux-gnu/ -Wl,-rpath,\$ORIGIN/../../rocm/ -lhipblas -lrocblas -lamdhip64 -lrocsolver -lamd_comgr -lhsa-runtime64 -lrocsparse -ldrm -ldrm_amdgpu"
build build

View file

@ -26,15 +26,25 @@ function amdGPUs {
$GPU_LIST -join ';' $GPU_LIST -join ';'
} }
function init_vars { function init_vars {
if (!$script:SRC_DIR) {
$script:SRC_DIR = $(resolve-path "..\..\") $script:SRC_DIR = $(resolve-path "..\..\")
}
if (!$script:llamacppDir) {
$script:llamacppDir = "../llama.cpp" $script:llamacppDir = "../llama.cpp"
}
if (!$script:cmakeTargets) {
$script:cmakeTargets = @("ollama_llama_server")
}
$script:cmakeDefs = @( $script:cmakeDefs = @(
"-DBUILD_SHARED_LIBS=on", "-DBUILD_SHARED_LIBS=on",
"-DLLAMA_NATIVE=off" "-DLLAMA_NATIVE=off"
) )
$script:cmakeTargets = @("ollama_llama_server") $script:commonCpuDefs = @("-DCMAKE_POSITION_INDEPENDENT_CODE=on")
$script:ARCH = "amd64" # arm not yet supported. $script:ARCH = $Env:PROCESSOR_ARCHITECTURE.ToLower()
$script:DIST_BASE = "${script:SRC_DIR}\dist\windows-${script:ARCH}\ollama_runners"
md "$script:DIST_BASE" -ea 0 > $null
if ($env:CGO_CFLAGS -contains "-g") { if ($env:CGO_CFLAGS -contains "-g") {
$script:cmakeDefs += @("-DCMAKE_VERBOSE_MAKEFILE=on", "-DLLAMA_SERVER_VERBOSE=on", "-DCMAKE_BUILD_TYPE=RelWithDebInfo") $script:cmakeDefs += @("-DCMAKE_VERBOSE_MAKEFILE=on", "-DLLAMA_SERVER_VERBOSE=on", "-DCMAKE_BUILD_TYPE=RelWithDebInfo")
$script:config = "RelWithDebInfo" $script:config = "RelWithDebInfo"
@ -55,7 +65,6 @@ function init_vars {
} else { } else {
$script:CUDA_LIB_DIR=$env:CUDA_LIB_DIR $script:CUDA_LIB_DIR=$env:CUDA_LIB_DIR
} }
$script:GZIP=(get-command -ea 'silentlycontinue' gzip).path
$script:DUMPBIN=(get-command -ea 'silentlycontinue' dumpbin).path $script:DUMPBIN=(get-command -ea 'silentlycontinue' dumpbin).path
if ($null -eq $env:CMAKE_CUDA_ARCHITECTURES) { if ($null -eq $env:CMAKE_CUDA_ARCHITECTURES) {
$script:CMAKE_CUDA_ARCHITECTURES="50;52;61;70;75;80" $script:CMAKE_CUDA_ARCHITECTURES="50;52;61;70;75;80"
@ -134,21 +143,18 @@ function sign {
} }
} }
function compress { function install {
if ($script:GZIP -eq $null) { write-host "Installing binaries to dist dir ${script:distDir}"
write-host "gzip not installed, not compressing files" mkdir ${script:distDir} -ErrorAction SilentlyContinue
return
}
write-host "Compressing binaries..."
$binaries = dir "${script:buildDir}/bin/*.exe" $binaries = dir "${script:buildDir}/bin/*.exe"
foreach ($file in $binaries) { foreach ($file in $binaries) {
& "$script:GZIP" --best -f $file copy-item -Path $file -Destination ${script:distDir} -Force
} }
write-host "Compressing dlls..." write-host "Installing dlls to dist dir ${script:distDir}"
$dlls = dir "${script:buildDir}/bin/*.dll" $dlls = dir "${script:buildDir}/bin/*.dll"
foreach ($file in $dlls) { foreach ($file in $dlls) {
& "$script:GZIP" --best -f $file copy-item -Path $file -Destination ${script:distDir} -Force
} }
} }
@ -169,18 +175,14 @@ function cleanup {
} }
} }
init_vars
git_module_setup
apply_patches
# -DLLAMA_AVX -- 2011 Intel Sandy Bridge & AMD Bulldozer # -DLLAMA_AVX -- 2011 Intel Sandy Bridge & AMD Bulldozer
# -DLLAMA_AVX2 -- 2013 Intel Haswell & 2015 AMD Excavator / 2017 AMD Zen # -DLLAMA_AVX2 -- 2013 Intel Haswell & 2015 AMD Excavator / 2017 AMD Zen
# -DLLAMA_FMA (FMA3) -- 2013 Intel Haswell & 2012 AMD Piledriver # -DLLAMA_FMA (FMA3) -- 2013 Intel Haswell & 2012 AMD Piledriver
$script:commonCpuDefs = @("-DCMAKE_POSITION_INDEPENDENT_CODE=on")
if ($null -eq ${env:OLLAMA_SKIP_CPU_GENERATE}) {
function build_static() {
if ((-not "${env:OLLAMA_SKIP_STATIC_GENERATE}") -and ((-not "${env:OLLAMA_CPU_TARGET}") -or ("${env:OLLAMA_CPU_TARGET}" -eq "static"))) {
# GCC build for direct linking into the Go binary # GCC build for direct linking into the Go binary
init_vars init_vars
# cmake will silently fallback to msvc compilers if mingw isn't in the path, so detect and fail fast # cmake will silently fallback to msvc compilers if mingw isn't in the path, so detect and fail fast
@ -189,6 +191,7 @@ write-host "Checking for MinGW..."
# error action ensures we exit on failure # error action ensures we exit on failure
get-command gcc get-command gcc
get-command mingw32-make get-command mingw32-make
$oldTargets = $script:cmakeTargets
$script:cmakeTargets = @("llama", "ggml") $script:cmakeTargets = @("llama", "ggml")
$script:cmakeDefs = @( $script:cmakeDefs = @(
"-G", "MinGW Makefiles" "-G", "MinGW Makefiles"
@ -204,36 +207,60 @@ $script:cmakeDefs = @(
$script:buildDir="../build/windows/${script:ARCH}_static" $script:buildDir="../build/windows/${script:ARCH}_static"
write-host "Building static library" write-host "Building static library"
build build
$script:cmakeTargets = $oldTargets
# remaining llama.cpp builds use MSVC
init_vars
$script:cmakeDefs = $script:commonCpuDefs + @("-A", "x64", "-DLLAMA_AVX=off", "-DLLAMA_AVX2=off", "-DLLAMA_AVX512=off", "-DLLAMA_FMA=off", "-DLLAMA_F16C=off") + $script:cmakeDefs
$script:buildDir="../build/windows/${script:ARCH}/cpu"
write-host "Building LCD CPU"
build
sign
compress
init_vars
$script:cmakeDefs = $script:commonCpuDefs + @("-A", "x64", "-DLLAMA_AVX=on", "-DLLAMA_AVX2=off", "-DLLAMA_AVX512=off", "-DLLAMA_FMA=off", "-DLLAMA_F16C=off") + $script:cmakeDefs
$script:buildDir="../build/windows/${script:ARCH}/cpu_avx"
write-host "Building AVX CPU"
build
sign
compress
init_vars
$script:cmakeDefs = $script:commonCpuDefs + @("-A", "x64", "-DLLAMA_AVX=on", "-DLLAMA_AVX2=on", "-DLLAMA_AVX512=off", "-DLLAMA_FMA=on", "-DLLAMA_F16C=on") + $script:cmakeDefs
$script:buildDir="../build/windows/${script:ARCH}/cpu_avx2"
write-host "Building AVX2 CPU"
build
sign
compress
} else { } else {
write-host "Skipping CPU generation step as requested" write-host "Skipping CPU generation step as requested"
} }
}
if ($null -ne $script:CUDA_LIB_DIR) { function build_cpu($gen_arch) {
if ((-not "${env:OLLAMA_SKIP_CPU_GENERATE}" ) -and ((-not "${env:OLLAMA_CPU_TARGET}") -or ("${env:OLLAMA_CPU_TARGET}" -eq "cpu"))) {
# remaining llama.cpp builds use MSVC
init_vars
$script:cmakeDefs = $script:commonCpuDefs + @("-A", $gen_arch, "-DLLAMA_AVX=off", "-DLLAMA_AVX2=off", "-DLLAMA_AVX512=off", "-DLLAMA_FMA=off", "-DLLAMA_F16C=off") + $script:cmakeDefs
$script:buildDir="../build/windows/${script:ARCH}/cpu"
$script:distDir="$script:DIST_BASE\cpu"
write-host "Building LCD CPU"
build
sign
install
} else {
write-host "Skipping CPU generation step as requested"
}
}
function build_cpu_avx() {
if ((-not "${env:OLLAMA_SKIP_CPU_GENERATE}" ) -and ((-not "${env:OLLAMA_CPU_TARGET}") -or ("${env:OLLAMA_CPU_TARGET}" -eq "cpu_avx"))) {
init_vars
$script:cmakeDefs = $script:commonCpuDefs + @("-A", "x64", "-DLLAMA_AVX=on", "-DLLAMA_AVX2=off", "-DLLAMA_AVX512=off", "-DLLAMA_FMA=off", "-DLLAMA_F16C=off") + $script:cmakeDefs
$script:buildDir="../build/windows/${script:ARCH}/cpu_avx"
$script:distDir="$script:DIST_BASE\cpu_avx"
write-host "Building AVX CPU"
build
sign
install
} else {
write-host "Skipping CPU AVX generation step as requested"
}
}
function build_cpu_avx2() {
if ((-not "${env:OLLAMA_SKIP_CPU_GENERATE}" ) -and ((-not "${env:OLLAMA_CPU_TARGET}") -or ("${env:OLLAMA_CPU_TARGET}" -eq "cpu_avx2"))) {
init_vars
$script:cmakeDefs = $script:commonCpuDefs + @("-A", "x64", "-DLLAMA_AVX=on", "-DLLAMA_AVX2=on", "-DLLAMA_AVX512=off", "-DLLAMA_FMA=on", "-DLLAMA_F16C=on") + $script:cmakeDefs
$script:buildDir="../build/windows/${script:ARCH}/cpu_avx2"
$script:distDir="$script:DIST_BASE\cpu_avx2"
write-host "Building AVX2 CPU"
build
sign
install
} else {
write-host "Skipping CPU AVX2 generation step as requested"
}
}
function build_cuda() {
if ((-not "${env:OLLAMA_SKIP_CUDA_GENERATE}") -and ("${script:CUDA_LIB_DIR}")) {
# Then build cuda as a dynamically loaded library # Then build cuda as a dynamically loaded library
$nvcc = "$script:CUDA_LIB_DIR\nvcc.exe" $nvcc = "$script:CUDA_LIB_DIR\nvcc.exe"
$script:CUDA_VERSION=(get-item ($nvcc | split-path | split-path)).Basename $script:CUDA_VERSION=(get-item ($nvcc | split-path | split-path)).Basename
@ -242,13 +269,28 @@ if ($null -ne $script:CUDA_LIB_DIR) {
} }
init_vars init_vars
$script:buildDir="../build/windows/${script:ARCH}/cuda$script:CUDA_VARIANT" $script:buildDir="../build/windows/${script:ARCH}/cuda$script:CUDA_VARIANT"
$script:distDir="$script:DIST_BASE\cuda$script:CUDA_VARIANT"
$script:cmakeDefs += @("-A", "x64", "-DLLAMA_CUDA=ON", "-DLLAMA_AVX=on", "-DLLAMA_AVX2=off", "-DCUDAToolkit_INCLUDE_DIR=$script:CUDA_INCLUDE_DIR", "-DCMAKE_CUDA_ARCHITECTURES=${script:CMAKE_CUDA_ARCHITECTURES}") $script:cmakeDefs += @("-A", "x64", "-DLLAMA_CUDA=ON", "-DLLAMA_AVX=on", "-DLLAMA_AVX2=off", "-DCUDAToolkit_INCLUDE_DIR=$script:CUDA_INCLUDE_DIR", "-DCMAKE_CUDA_ARCHITECTURES=${script:CMAKE_CUDA_ARCHITECTURES}")
if ($null -ne $env:OLLAMA_CUSTOM_CUDA_DEFS) {
write-host "OLLAMA_CUSTOM_CUDA_DEFS=`"${env:OLLAMA_CUSTOM_CUDA_DEFS}`""
$script:cmakeDefs +=@("${env:OLLAMA_CUSTOM_CUDA_DEFS}")
write-host "building custom CUDA GPU"
}
build build
sign sign
compress install
write-host "copying CUDA dependencies to ${script:SRC_DIR}\dist\windows-${script:ARCH}\"
cp "${script:CUDA_LIB_DIR}\cudart64_*.dll" "${script:SRC_DIR}\dist\windows-${script:ARCH}\"
cp "${script:CUDA_LIB_DIR}\cublas64_*.dll" "${script:SRC_DIR}\dist\windows-${script:ARCH}\"
cp "${script:CUDA_LIB_DIR}\cublasLt64_*.dll" "${script:SRC_DIR}\dist\windows-${script:ARCH}\"
} else {
write-host "Skipping CUDA generation step"
}
} }
if ($null -ne $env:HIP_PATH) { function build_rocm() {
if ((-not "${env:OLLAMA_SKIP_ROCM_GENERATE}") -and ("${env:HIP_PATH}")) {
$script:ROCM_VERSION=(get-item $env:HIP_PATH).Basename $script:ROCM_VERSION=(get-item $env:HIP_PATH).Basename
if ($null -ne $script:ROCM_VERSION) { if ($null -ne $script:ROCM_VERSION) {
$script:ROCM_VARIANT="_v"+$script:ROCM_VERSION $script:ROCM_VARIANT="_v"+$script:ROCM_VERSION
@ -256,6 +298,7 @@ if ($null -ne $env:HIP_PATH) {
init_vars init_vars
$script:buildDir="../build/windows/${script:ARCH}/rocm$script:ROCM_VARIANT" $script:buildDir="../build/windows/${script:ARCH}/rocm$script:ROCM_VARIANT"
$script:distDir="$script:DIST_BASE\rocm$script:ROCM_VARIANT"
$script:cmakeDefs += @( $script:cmakeDefs += @(
"-G", "Ninja", "-G", "Ninja",
"-DCMAKE_C_COMPILER=clang.exe", "-DCMAKE_C_COMPILER=clang.exe",
@ -274,7 +317,11 @@ if ($null -ne $env:HIP_PATH) {
# We have to clobber the LIB var from the developer shell for clang to work properly # We have to clobber the LIB var from the developer shell for clang to work properly
$env:LIB="" $env:LIB=""
if ($null -ne $env:OLLAMA_CUSTOM_ROCM_DEFS) {
write-host "OLLAMA_CUSTOM_ROCM_DEFS=`"${env:OLLAMA_CUSTOM_ROCM_DEFS}`""
$script:cmakeDefs += @("${env:OLLAMA_CUSTOM_ROCM_DEFS}")
write-host "building custom ROCM GPU"
}
write-host "Building ROCm" write-host "Building ROCm"
build build
# Ninja doesn't prefix with config name # Ninja doesn't prefix with config name
@ -283,9 +330,40 @@ if ($null -ne $env:HIP_PATH) {
& "$script:DUMPBIN" /dependents "${script:buildDir}/bin/ollama_llama_server.exe" | select-string ".dll" & "$script:DUMPBIN" /dependents "${script:buildDir}/bin/ollama_llama_server.exe" | select-string ".dll"
} }
sign sign
compress install
# Assumes v5.7, may need adjustments for v6
rm -ea 0 -recurse -force -path "${script:SRC_DIR}\dist\windows-${script:ARCH}\rocm\"
md "${script:SRC_DIR}\dist\windows-${script:ARCH}\rocm\rocblas\library\" -ea 0 > $null
cp "${env:HIP_PATH}\bin\hipblas.dll" "${script:SRC_DIR}\dist\windows-${script:ARCH}\rocm\"
cp "${env:HIP_PATH}\bin\rocblas.dll" "${script:SRC_DIR}\dist\windows-${script:ARCH}\rocm\"
# amdhip64.dll dependency comes from the driver and must be installed on the host to use AMD GPUs
cp "${env:HIP_PATH}\bin\rocblas\library\*" "${script:SRC_DIR}\dist\windows-${script:ARCH}\rocm\rocblas\library\"
} else {
write-host "Skipping ROCm generation step"
}
} }
init_vars
if ($($args.count) -eq 0) {
git_module_setup
apply_patches
build_static
if ($script:ARCH -eq "arm64") {
build_cpu("ARM64")
} else { # amd64
build_cpu("x64")
build_cpu_avx
build_cpu_avx2
build_cuda
build_rocm
}
cleanup cleanup
write-host "`ngo generate completed. LLM runners: $(get-childitem -path ${script:SRC_DIR}\llm\build\windows\${script:ARCH})" write-host "`ngo generate completed. LLM runners: $(get-childitem -path $script:DIST_BASE)"
} else {
for ( $i = 0; $i -lt $args.count; $i++ ) {
write-host "performing $($args[$i])"
& $($args[$i])
}
}

View file

@ -13,82 +13,6 @@ type GGML struct {
model model
} }
const (
fileTypeF32 uint32 = iota
fileTypeF16
fileTypeQ4_0
fileTypeQ4_1
fileTypeQ4_1_F16
fileTypeQ8_0 uint32 = iota + 2
fileTypeQ5_0
fileTypeQ5_1
fileTypeQ2_K
fileTypeQ3_K_S
fileTypeQ3_K_M
fileTypeQ3_K_L
fileTypeQ4_K_S
fileTypeQ4_K_M
fileTypeQ5_K_S
fileTypeQ5_K_M
fileTypeQ6_K
fileTypeIQ2_XXS
fileTypeIQ2_XS
fileTypeQ2_K_S
fileTypeQ3_K_XS
fileTypeIQ3_XXS
)
func fileType(fileType uint32) string {
switch fileType {
case fileTypeF32:
return "F32"
case fileTypeF16:
return "F16"
case fileTypeQ4_0:
return "Q4_0"
case fileTypeQ4_1:
return "Q4_1"
case fileTypeQ4_1_F16:
return "Q4_1_F16"
case fileTypeQ8_0:
return "Q8_0"
case fileTypeQ5_0:
return "Q5_0"
case fileTypeQ5_1:
return "Q5_1"
case fileTypeQ2_K:
return "Q2_K"
case fileTypeQ3_K_S:
return "Q3_K_S"
case fileTypeQ3_K_M:
return "Q3_K_M"
case fileTypeQ3_K_L:
return "Q3_K_L"
case fileTypeQ4_K_S:
return "Q4_K_S"
case fileTypeQ4_K_M:
return "Q4_K_M"
case fileTypeQ5_K_S:
return "Q5_K_S"
case fileTypeQ5_K_M:
return "Q5_K_M"
case fileTypeQ6_K:
return "Q6_K"
case fileTypeIQ2_XXS:
return "IQ2_XXS"
case fileTypeIQ2_XS:
return "IQ2_XS"
case fileTypeQ2_K_S:
return "Q2_K_S"
case fileTypeQ3_K_XS:
return "Q3_K_XS"
case fileTypeIQ3_XXS:
return "IQ3_XXS"
default:
return "unknown"
}
}
type model interface { type model interface {
KV() KV KV() KV
Tensors() Tensors Tensors() Tensors
@ -121,12 +45,12 @@ func (kv KV) ParameterCount() uint64 {
return kv.u64("general.parameter_count") return kv.u64("general.parameter_count")
} }
func (kv KV) FileType() string { func (kv KV) FileType() fileType {
if u64 := kv.u64("general.file_type"); u64 > 0 { if u64 := kv.u64("general.file_type"); u64 > 0 {
return fileType(uint32(u64)) return fileType(uint32(u64))
} }
return "unknown" return fileTypeUnknown
} }
func (kv KV) BlockCount() uint64 { func (kv KV) BlockCount() uint64 {
@ -286,6 +210,23 @@ const (
var ErrUnsupportedFormat = errors.New("unsupported model format") var ErrUnsupportedFormat = errors.New("unsupported model format")
func DetectGGMLType(b []byte) string {
switch binary.LittleEndian.Uint32(b[:4]) {
case FILE_MAGIC_GGML:
return "ggml"
case FILE_MAGIC_GGMF:
return "ggmf"
case FILE_MAGIC_GGJT:
return "ggjt"
case FILE_MAGIC_GGLA:
return "ggla"
case FILE_MAGIC_GGUF_LE, FILE_MAGIC_GGUF_BE:
return "gguf"
default:
return ""
}
}
func DecodeGGML(rs io.ReadSeeker) (*GGML, int64, error) { func DecodeGGML(rs io.ReadSeeker) (*GGML, int64, error) {
var magic uint32 var magic uint32
if err := binary.Read(rs, binary.LittleEndian, &magic); err != nil { if err := binary.Read(rs, binary.LittleEndian, &magic); err != nil {
@ -343,7 +284,15 @@ func (llm GGML) GraphSize(context, batch uint64) (partialOffload, fullOffload ui
4*batch*(embedding+vocab)+embedding*vocab*105/128, 4*batch*(embedding+vocab)+embedding*vocab*105/128,
) )
if ffnGateWeight, ok := layers["0"]["ffn_gate.0.weight"]; ok { if ffnGateExpsWeight, ok := layers["blk.0"]["ffn_gate_exps.weight"]; ok {
// mixtral 8x22b
ff := uint64(llm.KV()["llama.feed_forward_length"].(uint32))
partialOffload = max(
3*ffnGateExpsWeight.size()+4*batch*(2*ff+headsKV+embedding+context+embedding/heads*headsKV),
4*(context*batch*heads+context*embedding/heads*headsKV+batch*1024+embedding/heads*headsKV*batch),
)
} else if ffnGateWeight, ok := layers["blk.0"]["ffn_gate.0.weight"]; ok {
// mixtral 8x7b
ffnGateWeight1 := ffnGateWeight.Shape[1] ffnGateWeight1 := ffnGateWeight.Shape[1]
fullOffload = 4 * batch * (2 + 3*embedding + context*(1+heads) + 2*headsKV + ffnGateWeight1) fullOffload = 4 * batch * (2 + 3*embedding + context*(1+heads) + 2*headsKV + ffnGateWeight1)
partialOffload = max( partialOffload = max(

View file

@ -190,8 +190,6 @@ func (llm *gguf) Decode(rs io.ReadSeeker) error {
llm.kv[k] = v llm.kv[k] = v
} }
slog.Debug(fmt.Sprintf("general.architecture = %s", llm.kv["general.architecture"]))
// decode tensors // decode tensors
for i := 0; uint64(i) < llm.numTensor(); i++ { for i := 0; uint64(i) < llm.numTensor(); i++ {
name, err := readGGUFString(llm, rs) name, err := readGGUFString(llm, rs)
@ -465,11 +463,13 @@ var ggufKVOrder = map[string][]string{
"llama.embedding_length", "llama.embedding_length",
"llama.block_count", "llama.block_count",
"llama.feed_forward_length", "llama.feed_forward_length",
"llama.rope.dimension_count",
"llama.attention.head_count", "llama.attention.head_count",
"llama.attention.head_count_kv", "llama.attention.head_count_kv",
"llama.attention.layer_norm_rms_epsilon", "llama.attention.layer_norm_rms_epsilon",
"llama.rope.freq_base", "llama.rope.freq_base",
"llama.rope.dimension_count",
"llama.expert_count",
"llama.expert_used_count",
"gemma.context_length", "gemma.context_length",
"gemma.embedding_length", "gemma.embedding_length",
"gemma.block_count", "gemma.block_count",
@ -577,6 +577,8 @@ func (llm *gguf) Encode(ws io.WriteSeeker, kv KV, tensors []Tensor) error {
return err return err
} }
} }
default:
return fmt.Errorf("improper type for '%s'", k)
} }
if err != nil { if err != nil {
return err return err
@ -598,9 +600,11 @@ func (llm *gguf) Encode(ws io.WriteSeeker, kv KV, tensors []Tensor) error {
return err return err
} }
dims := 1 dims := 0
if tensor.Shape[1] > 0 { for cnt := 0; cnt < len(tensor.Shape); cnt++ {
dims = 2 if tensor.Shape[cnt] > 0 {
dims++
}
} }
if err := binary.Write(ws, llm.ByteOrder, uint32(dims)); err != nil { if err := binary.Write(ws, llm.ByteOrder, uint32(dims)); err != nil {

@ -1 +1 @@
Subproject commit 7593639ce335e8d7f89aa9a54d616951f273af60 Subproject commit 952d03dbead16e4dbdd1d3458486340673cc2465

View file

@ -4,6 +4,7 @@ package llm
// #cgo darwin,arm64 LDFLAGS: ${SRCDIR}/build/darwin/arm64_static/libllama.a -lstdc++ // #cgo darwin,arm64 LDFLAGS: ${SRCDIR}/build/darwin/arm64_static/libllama.a -lstdc++
// #cgo darwin,amd64 LDFLAGS: ${SRCDIR}/build/darwin/x86_64_static/libllama.a -lstdc++ // #cgo darwin,amd64 LDFLAGS: ${SRCDIR}/build/darwin/x86_64_static/libllama.a -lstdc++
// #cgo windows,amd64 LDFLAGS: ${SRCDIR}/build/windows/amd64_static/libllama.a -static -lstdc++ // #cgo windows,amd64 LDFLAGS: ${SRCDIR}/build/windows/amd64_static/libllama.a -static -lstdc++
// #cgo windows,arm64 LDFLAGS: ${SRCDIR}/build/windows/arm64_static/libllama.a -static -lstdc++
// #cgo linux,amd64 LDFLAGS: ${SRCDIR}/build/linux/x86_64_static/libllama.a -lstdc++ // #cgo linux,amd64 LDFLAGS: ${SRCDIR}/build/linux/x86_64_static/libllama.a -lstdc++
// #cgo linux,arm64 LDFLAGS: ${SRCDIR}/build/linux/arm64_static/libllama.a -lstdc++ // #cgo linux,arm64 LDFLAGS: ${SRCDIR}/build/linux/arm64_static/libllama.a -lstdc++
// #include <stdlib.h> // #include <stdlib.h>
@ -19,7 +20,7 @@ func SystemInfo() string {
return C.GoString(C.llama_print_system_info()) return C.GoString(C.llama_print_system_info())
} }
func Quantize(infile, outfile, filetype string) error { func Quantize(infile, outfile string, ftype fileType) error {
cinfile := C.CString(infile) cinfile := C.CString(infile)
defer C.free(unsafe.Pointer(cinfile)) defer C.free(unsafe.Pointer(cinfile))
@ -28,58 +29,10 @@ func Quantize(infile, outfile, filetype string) error {
params := C.llama_model_quantize_default_params() params := C.llama_model_quantize_default_params()
params.nthread = -1 params.nthread = -1
params.ftype = ftype.Value()
switch filetype { if rc := C.llama_model_quantize(cinfile, coutfile, &params); rc != 0 {
case "F32": return fmt.Errorf("llama_model_quantize: %d", rc)
params.ftype = fileTypeF32
case "F16":
params.ftype = fileTypeF16
case "Q4_0":
params.ftype = fileTypeQ4_0
case "Q4_1":
params.ftype = fileTypeQ4_1
case "Q4_1_F16":
params.ftype = fileTypeQ4_1_F16
case "Q8_0":
params.ftype = fileTypeQ8_0
case "Q5_0":
params.ftype = fileTypeQ5_0
case "Q5_1":
params.ftype = fileTypeQ5_1
case "Q2_K":
params.ftype = fileTypeQ2_K
case "Q3_K_S":
params.ftype = fileTypeQ3_K_S
case "Q3_K_M":
params.ftype = fileTypeQ3_K_M
case "Q3_K_L":
params.ftype = fileTypeQ3_K_L
case "Q4_K_S":
params.ftype = fileTypeQ4_K_S
case "Q4_K_M":
params.ftype = fileTypeQ4_K_M
case "Q5_K_S":
params.ftype = fileTypeQ5_K_S
case "Q5_K_M":
params.ftype = fileTypeQ5_K_M
case "Q6_K":
params.ftype = fileTypeQ6_K
case "IQ2_XXS":
params.ftype = fileTypeIQ2_XXS
case "IQ2_XS":
params.ftype = fileTypeIQ2_XS
case "Q2_K_S":
params.ftype = fileTypeQ2_K_S
case "Q3_K_XS":
params.ftype = fileTypeQ3_K_XS
case "IQ3_XXS":
params.ftype = fileTypeIQ3_XXS
default:
return fmt.Errorf("unknown filetype: %s", filetype)
}
if retval := C.llama_model_quantize(cinfile, coutfile, &params); retval != 0 {
return fmt.Errorf("llama_model_quantize: %d", retval)
} }
return nil return nil

View file

@ -2,5 +2,5 @@ package llm
import "embed" import "embed"
//go:embed build/windows/*/*/bin/* // unused on windows
var libEmbed embed.FS var libEmbed embed.FS

185
llm/memory.go Normal file
View file

@ -0,0 +1,185 @@
package llm
import (
"fmt"
"log/slog"
"github.com/ollama/ollama/api"
"github.com/ollama/ollama/format"
"github.com/ollama/ollama/gpu"
"github.com/ollama/ollama/server/envconfig"
)
// This algorithm looks for a complete fit to determine if we need to unload other models
func PredictServerFit(allGpus gpu.GpuInfoList, ggml *GGML, adapters, projectors []string, opts api.Options) (bool, uint64) {
var estimatedVRAM uint64
if opts.NumCtx > int(ggml.KV().ContextLength()) {
slog.Warn("requested context length is greater than model max context length", "requested", opts.NumCtx, "model", ggml.KV().ContextLength())
opts.NumCtx = int(ggml.KV().ContextLength())
}
if opts.NumCtx < 4 {
opts.NumCtx = 4
}
// Split up the GPUs by type and try them
for _, gpus := range allGpus.ByLibrary() {
var layerCount int
layerCount, estimatedVRAM, _ = EstimateGPULayers(gpus, ggml, projectors, opts)
if opts.NumGPU < 0 {
if layerCount > 0 && layerCount >= int(ggml.KV().BlockCount()+1) {
return true, estimatedVRAM
}
} else {
if layerCount > 0 && layerCount >= opts.NumGPU {
return true, estimatedVRAM
}
}
}
return false, estimatedVRAM
}
// Given a model and one or more GPU targets, predict how many layers and bytes we can load, and the total size
// The GPUs provided must all be the same Library
func EstimateGPULayers(gpus []gpu.GpuInfo, ggml *GGML, projectors []string, opts api.Options) (int, uint64, uint64) {
var memoryAvailable uint64
for _, info := range gpus {
memoryAvailable += info.FreeMemory
}
if envconfig.MaxVRAM > 0 {
memoryAvailable = envconfig.MaxVRAM
}
slog.Debug("evaluating", "library", gpus[0].Library, "gpu_count", len(gpus), "available", format.HumanBytes2(memoryAvailable))
// TODO - this is probably wrong, first GPU vs secondaries will have different overheads
memoryMinimum := gpus[0].MinimumMemory
for _, projector := range projectors {
memoryMinimum += projectorMemoryRequirements(projector)
// multimodal models require at least 2048 context
opts.NumCtx = max(opts.NumCtx, 2048)
}
// fp16 k,v = (1 (k) + 1 (v)) * sizeof(float16) * n_ctx * n_layer * n_embd / n_head * n_head_kv
var kv uint64 = 2 * 2 * uint64(opts.NumCtx) * ggml.KV().BlockCount() * ggml.KV().EmbeddingLength() / ggml.KV().HeadCount() * ggml.KV().HeadCountKV()
graphPartialOffload, graphFullOffload := ggml.GraphSize(uint64(opts.NumCtx), uint64(min(opts.NumCtx, opts.NumBatch)))
if graphPartialOffload == 0 {
graphPartialOffload = ggml.KV().GQA() * kv / 6
}
if graphFullOffload == 0 {
graphFullOffload = graphPartialOffload
}
graphFullOffload *= uint64(len(gpus))
graphPartialOffload *= uint64(len(gpus))
// on metal there's no partial offload overhead
if gpus[0].Library == "metal" {
graphPartialOffload = graphFullOffload
}
layers := ggml.Tensors().Layers()
// memoryRequiredTotal represents the memory required for full GPU offloading (all layers)
memoryRequiredTotal := memoryMinimum + graphFullOffload + layers["blk.0"].size()
// memoryRequiredPartial represents the memory required for partial GPU offloading (n > 0, n < layers)
memoryRequiredPartial := memoryMinimum + graphPartialOffload + layers["blk.0"].size()
var memoryLayerOutput uint64
if layer, ok := layers["output_norm"]; ok {
memoryLayerOutput += layer.size()
}
if layer, ok := layers["output"]; ok {
memoryLayerOutput += layer.size()
} else if layer, ok := layers["token_embd"]; ok {
memoryLayerOutput += layer.size()
}
if gpus[0].Library == "metal" && opts.UseMMap {
// memory is preallocated for output tensors
memoryRequiredTotal += memoryLayerOutput
memoryRequiredPartial += memoryLayerOutput
}
var layerCount int
for i := 0; i < int(ggml.KV().BlockCount()); i++ {
memoryLayer := layers[fmt.Sprintf("blk.%d", i)].size()
// KV is proportional to the number of layers
memoryLayer += kv / ggml.KV().BlockCount()
memoryRequiredTotal += memoryLayer
if memoryAvailable > memoryRequiredPartial+memoryLayer {
memoryRequiredPartial += memoryLayer
layerCount++
}
}
if gpus[0].Library != "metal" || !opts.UseMMap {
// memory was not preallocated for output tensors
memoryRequiredTotal += memoryLayerOutput
}
if memoryAvailable > memoryRequiredTotal {
layerCount = int(ggml.KV().BlockCount()) + 1
memoryRequiredPartial = memoryRequiredTotal
}
memoryWeights := memoryRequiredTotal - memoryMinimum - graphFullOffload - kv
slog.Info(
"offload to gpu",
slog.Group(
"layers",
// actual number of layers offloaded
"real", opts.NumGPU,
// estimated number of layers that can be offloaded
"estimate", layerCount,
),
slog.Group(
"memory",
// memory available for offloading
"available", format.HumanBytes2(memoryAvailable),
slog.Group(
"required",
// memory required for full offloading
"full", format.HumanBytes2(memoryRequiredTotal),
// memory required to offload layers.estimate layers
"partial", format.HumanBytes2(memoryRequiredPartial),
// memory of KV cache
"kv", format.HumanBytes2(kv),
),
slog.Group(
"weights",
// memory of the weights
"total", format.HumanBytes2(memoryWeights),
// memory of repeating layers
"repeating", format.HumanBytes2(memoryWeights-memoryLayerOutput),
// memory of non-repeating layers
"nonrepeating", format.HumanBytes2(memoryLayerOutput),
),
slog.Group(
"graph",
// memory of graph when fully offloaded
"full", format.HumanBytes2(graphFullOffload),
// memory of graph when not fully offloaded
"partial", format.HumanBytes2(graphPartialOffload),
),
),
)
if gpus[0].Library == "cpu" {
return 0, 0, memoryRequiredTotal
}
if memoryRequiredPartial > memoryAvailable {
slog.Debug("insufficient VRAM to load any model layers")
return 0, 0, memoryRequiredTotal
}
return layerCount, memoryRequiredPartial, memoryRequiredTotal
}

View file

@ -0,0 +1,12 @@
diff --git a/examples/llava/clip.cpp b/examples/llava/clip.cpp
index e431c7f7..f077e688 100644
--- a/examples/llava/clip.cpp
+++ b/examples/llava/clip.cpp
@@ -3,6 +3,7 @@
// I'll gradually clean and extend it
// Note: Even when using identical normalized image inputs (see normalize_image_u8_to_f32()) we have a significant difference in resulting embeddings compared to pytorch
#include "clip.h"
+#include "common.h"
#include "log.h"
#include "ggml.h"
#include "ggml-alloc.h"

45
llm/patches/04-metal.diff Normal file
View file

@ -0,0 +1,45 @@
diff --git a/ggml-metal.m b/ggml-metal.m
index 0207b787..b5e9884b 100644
--- a/ggml-metal.m
+++ b/ggml-metal.m
@@ -1396,27 +1396,23 @@ static enum ggml_status ggml_metal_graph_compute(
// to the matrix-vector kernel
int ne11_mm_min = 1;
-#if 0
// the numbers below are measured on M2 Ultra for 7B and 13B models
// these numbers do not translate to other devices or model sizes
// TODO: need to find a better approach
- if ([ctx->device.name isEqualToString:@"Apple M2 Ultra"]) {
- switch (src0t) {
- case GGML_TYPE_F16: ne11_mm_min = 2; break;
- case GGML_TYPE_Q8_0: ne11_mm_min = 7; break;
- case GGML_TYPE_Q2_K: ne11_mm_min = 15; break;
- case GGML_TYPE_Q3_K: ne11_mm_min = 7; break;
- case GGML_TYPE_Q4_0:
- case GGML_TYPE_Q4_1: ne11_mm_min = 15; break;
- case GGML_TYPE_Q4_K: ne11_mm_min = 11; break;
- case GGML_TYPE_Q5_0: // not tested yet
- case GGML_TYPE_Q5_1: ne11_mm_min = 13; break; // not tested yet
- case GGML_TYPE_Q5_K: ne11_mm_min = 7; break;
- case GGML_TYPE_Q6_K: ne11_mm_min = 7; break;
- default: ne11_mm_min = 1; break;
- }
+ switch (src0t) {
+ case GGML_TYPE_F16: ne11_mm_min = 2; break;
+ case GGML_TYPE_Q8_0: ne11_mm_min = 7; break;
+ case GGML_TYPE_Q2_K: ne11_mm_min = 15; break;
+ case GGML_TYPE_Q3_K: ne11_mm_min = 7; break;
+ case GGML_TYPE_Q4_0:
+ case GGML_TYPE_Q4_1: ne11_mm_min = 15; break;
+ case GGML_TYPE_Q4_K: ne11_mm_min = 11; break;
+ case GGML_TYPE_Q5_0: // not tested yet
+ case GGML_TYPE_Q5_1: ne11_mm_min = 13; break; // not tested yet
+ case GGML_TYPE_Q5_K: ne11_mm_min = 7; break;
+ case GGML_TYPE_Q6_K: ne11_mm_min = 7; break;
+ default: ne11_mm_min = 1; break;
}
-#endif
// for now the matrix-matrix multiplication kernel only works on A14+/M1+ SoCs
// AMD GPU and older A-chips will reuse matrix-vector multiplication kernel

View file

@ -0,0 +1,24 @@
diff --git a/examples/llava/clip.cpp b/examples/llava/clip.cpp
index e3c9bcd4..b43f892d 100644
--- a/examples/llava/clip.cpp
+++ b/examples/llava/clip.cpp
@@ -573,14 +573,16 @@ static ggml_cgraph * clip_image_build_graph(clip_ctx * ctx, const clip_image_f32
struct ggml_tensor * embeddings = inp;
if (ctx->has_class_embedding) {
embeddings = ggml_new_tensor_3d(ctx0, GGML_TYPE_F32, hidden_size, num_positions, batch_size);
+ }
+ ggml_set_name(embeddings, "embeddings");
+ ggml_set_input(embeddings);
+
+ if (ctx->has_class_embedding) {
embeddings = ggml_acc(ctx0, embeddings, model.class_embedding,
embeddings->nb[1], embeddings->nb[2], embeddings->nb[3], 0);
embeddings = ggml_acc(ctx0, embeddings, inp,
embeddings->nb[1], embeddings->nb[2], embeddings->nb[3], model.class_embedding->nb[1]);
}
- ggml_set_name(embeddings, "embeddings");
- ggml_set_input(embeddings);
-
struct ggml_tensor * positions = ggml_new_tensor_1d(ctx0, GGML_TYPE_I32, num_positions);
ggml_set_name(positions, "positions");

View file

@ -9,6 +9,7 @@ import (
"log/slog" "log/slog"
"os" "os"
"path/filepath" "path/filepath"
"runtime"
"strings" "strings"
"golang.org/x/exp/slices" "golang.org/x/exp/slices"
@ -17,7 +18,7 @@ import (
"github.com/ollama/ollama/gpu" "github.com/ollama/ollama/gpu"
) )
var errPayloadMissing = fmt.Errorf("expected payloads not included in this build of ollama") var errPayloadMissing = errors.New("expected payloads not included in this build of ollama")
func Init() error { func Init() error {
payloadsDir, err := gpu.PayloadsDir() payloadsDir, err := gpu.PayloadsDir()
@ -25,6 +26,7 @@ func Init() error {
return err return err
} }
if runtime.GOOS != "windows" {
slog.Info("extracting embedded files", "dir", payloadsDir) slog.Info("extracting embedded files", "dir", payloadsDir)
binGlob := "build/*/*/*/bin/*" binGlob := "build/*/*/*/bin/*"
@ -33,6 +35,7 @@ func Init() error {
if err != nil { if err != nil {
return fmt.Errorf("extract binaries: %v", err) return fmt.Errorf("extract binaries: %v", err)
} }
}
var variants []string var variants []string
for v := range availableServers() { for v := range availableServers() {
@ -138,6 +141,23 @@ func serversForGpu(info gpu.GpuInfo) []string {
return servers return servers
} }
// Return the optimal server for this CPU architecture
func serverForCpu() string {
if runtime.GOOS == "darwin" && runtime.GOARCH == "arm64" {
return "metal"
}
variant := gpu.GetCPUVariant()
availableServers := availableServers()
if variant != "" {
for cmp := range availableServers {
if cmp == "cpu_"+variant {
return cmp
}
}
}
return "cpu"
}
// extract extracts the embedded files to the target directory // extract extracts the embedded files to the target directory
func extractFiles(targetDir string, glob string) error { func extractFiles(targetDir string, glob string) error {
files, err := fs.Glob(libEmbed, glob) files, err := fs.Glob(libEmbed, glob)

View file

@ -21,21 +21,47 @@ import (
"strings" "strings"
"time" "time"
"golang.org/x/sync/semaphore"
"github.com/ollama/ollama/api" "github.com/ollama/ollama/api"
"github.com/ollama/ollama/format" "github.com/ollama/ollama/format"
"github.com/ollama/ollama/gpu" "github.com/ollama/ollama/gpu"
"github.com/ollama/ollama/server/envconfig"
) )
// LlamaServer is an instance of the llama.cpp server type LlamaServer interface {
type LlamaServer struct { Ping(ctx context.Context) error
WaitUntilRunning(ctx context.Context) error
Completion(ctx context.Context, req CompletionRequest, fn func(CompletionResponse)) error
Embedding(ctx context.Context, prompt string) ([]float64, error)
Tokenize(ctx context.Context, content string) ([]int, error)
Detokenize(ctx context.Context, tokens []int) (string, error)
Close() error
EstimatedVRAM() uint64
}
// llmServer is an instance of the llama.cpp server
type llmServer struct {
port int port int
cmd *exec.Cmd cmd *exec.Cmd
done chan error // Channel to signal when the process exits done chan error // Channel to signal when the process exits
status *StatusWriter status *StatusWriter
options api.Options options api.Options
// TODO - this should be broken down by GPU
estimatedVRAM uint64 // Estimated usage of VRAM by the loaded model
estimatedTotal uint64 // Total size of model
totalLayers uint64
gpuCount int
sem *semaphore.Weighted
}
func LoadModel(model string) (*GGML, error) {
if _, err := os.Stat(model); err != nil {
return nil, err
} }
func NewLlamaServer(model string, adapters, projectors []string, opts api.Options) (*LlamaServer, error) {
f, err := os.Open(model) f, err := os.Open(model)
if err != nil { if err != nil {
return nil, err return nil, err
@ -43,144 +69,69 @@ func NewLlamaServer(model string, adapters, projectors []string, opts api.Option
defer f.Close() defer f.Close()
ggml, _, err := DecodeGGML(f) ggml, _, err := DecodeGGML(f)
if err != nil { return ggml, err
return nil, err
} }
// NewLlamaServer will run a server for the given GPUs
// The gpu list must be a single family.
func NewLlamaServer(gpus gpu.GpuInfoList, model string, ggml *GGML, adapters, projectors []string, opts api.Options) (LlamaServer, error) {
var err error
if opts.NumCtx > int(ggml.KV().ContextLength()) { if opts.NumCtx > int(ggml.KV().ContextLength()) {
slog.Warn("requested context length is greater than model max context length", "requested", opts.NumCtx, "model", ggml.KV().ContextLength()) slog.Warn("requested context length is greater than the model's training context window size", "requested", opts.NumCtx, "training size", ggml.KV().ContextLength())
opts.NumCtx = int(ggml.KV().ContextLength())
} }
if opts.NumCtx < 4 { if opts.NumCtx < 4 {
opts.NumCtx = 4 opts.NumCtx = 4
} }
memoryAvailable, _ := gpu.CheckVRAM() cpuRunner := ""
info := gpu.GetGPUInfo() var estimatedVRAM uint64
var estimatedTotal uint64
var systemMemory uint64
gpuCount := len(gpus)
if (len(gpus) == 1 && gpus[0].Library == "cpu") || opts.NumGPU == 0 {
memoryMinimum := info.MinimumMemory // TODO evaluate system memory to see if we should block the load, or force an unload of another CPU runner
for _, projector := range projectors {
memoryMinimum += projectorMemoryRequirements(projector)
// multimodal models require at least 2048 context cpuRunner = serverForCpu()
opts.NumCtx = max(opts.NumCtx, 2048) gpuCount = 0
} } else {
if gpus[0].Library == "metal" {
// fp16 k,v = (1 (k) + 1 (v)) * sizeof(float16) * n_ctx * n_layer * n_embd / n_head * n_head_kv memInfo, err := gpu.GetCPUMem()
var kv uint64 = 2 * 2 * uint64(opts.NumCtx) * ggml.KV().BlockCount() * ggml.KV().EmbeddingLength() / ggml.KV().HeadCount() * ggml.KV().HeadCountKV() if err != nil {
slog.Error("failed to lookup system memory", "error", err)
graphPartialOffload, graphFullOffload := ggml.GraphSize(uint64(opts.NumCtx), uint64(min(opts.NumCtx, opts.NumBatch))) } else {
if graphPartialOffload == 0 { systemMemory = memInfo.TotalMemory
graphPartialOffload = ggml.KV().GQA() * kv / 6 slog.Debug("system memory", "total", format.HumanBytes2(systemMemory))
}
if graphFullOffload == 0 {
graphFullOffload = graphPartialOffload
}
graphFullOffload *= uint64(info.DeviceCount)
graphPartialOffload *= uint64(info.DeviceCount)
// memoryRequiredTotal represents the memory required for full GPU offloading (all layers)
memoryRequiredTotal := memoryMinimum + graphFullOffload
// memoryRequiredPartial represents the memory required for partial GPU offloading (n > 0, n < layers)
memoryRequiredPartial := memoryMinimum + graphPartialOffload
if info.Library != "metal" {
if memoryRequiredPartial > memoryAvailable {
info.Library = "cpu"
} }
} }
var layers int
layers, estimatedVRAM, estimatedTotal = EstimateGPULayers(gpus, ggml, projectors, opts)
var layerCount int if gpus[0].Library == "metal" && estimatedVRAM > systemMemory {
layers := ggml.Tensors().Layers() // disable partial offloading when model is greater than total system memory as this
for i := 0; i < int(ggml.KV().BlockCount()); i++ { // can lead to locking up the system
memoryLayer := layers[fmt.Sprintf("blk.%d", i)].size()
// KV is proportional to the number of layers
memoryLayer += kv / ggml.KV().BlockCount()
memoryRequiredTotal += memoryLayer
if memoryAvailable > memoryRequiredPartial+memoryLayer {
memoryRequiredPartial += memoryLayer
layerCount++
}
}
var memoryLayerOutput uint64
for k, v := range layers {
if !strings.HasPrefix(k, "blk.") {
memoryLayerOutput += v.size()
}
}
memoryRequiredTotal += memoryLayerOutput
if info.Library == "metal" && memoryRequiredTotal > info.TotalMemory {
// disable partial offloading when model is greater than total system memory
opts.NumGPU = 0 opts.NumGPU = 0
} else if memoryAvailable > memoryRequiredTotal { } else if opts.NumGPU < 0 && layers > 0 && gpus[0].Library != "cpu" {
layerCount = int(ggml.KV().BlockCount()) + 1 opts.NumGPU = layers
memoryRequiredPartial = memoryRequiredTotal }
} }
if opts.NumGPU < 0 { // Loop through potential servers
opts.NumGPU = layerCount finalErr := fmt.Errorf("no suitable llama servers found")
}
memoryWeights := memoryRequiredTotal - memoryMinimum - graphFullOffload - kv
slog.Info(
"offload to gpu",
slog.Group(
"layers",
// actual number of layers offloaded
"real", opts.NumGPU,
// estimated number of layers that can be offloaded
"estimate", layerCount,
),
slog.Group(
"memory",
// memory available for offloading
"available", format.HumanBytes2(memoryAvailable),
slog.Group(
"required",
// memory required for full offloading
"full", format.HumanBytes2(memoryRequiredTotal),
// memory required to offload layers.estimate layers
"partial", format.HumanBytes2(memoryRequiredPartial),
// memory of KV cache
"kv", format.HumanBytes2(kv),
),
slog.Group(
"weights",
// memory of the weights
"total", format.HumanBytes2(memoryWeights),
// memory of repeating layers
"repeating", format.HumanBytes2(memoryWeights-memoryLayerOutput),
// memory of non-repeating layers
"nonrepeating", format.HumanBytes2(memoryLayerOutput),
),
slog.Group(
"graph",
// memory of graph when fully offloaded
"full", format.HumanBytes2(graphFullOffload),
// memory of graph when not fully offloaded
"partial", format.HumanBytes2(graphPartialOffload),
),
),
)
if len(adapters) > 1 { if len(adapters) > 1 {
return nil, errors.New("ollama supports only one lora adapter, but multiple were provided") return nil, errors.New("ollama supports only one lora adapter, but multiple were provided")
} }
availableServers := availableServers() availableServers := availableServers()
servers := serversForGpu(info) var servers []string
if cpuRunner != "" {
demandLib := os.Getenv("OLLAMA_LLM_LIBRARY") servers = []string{cpuRunner}
} else {
servers = serversForGpu(gpus[0]) // All GPUs in the list are matching Library and Variant
}
demandLib := envconfig.LLMLibrary
if demandLib != "" { if demandLib != "" {
serverPath := availableServers[demandLib] serverPath := availableServers[demandLib]
if serverPath == "" { if serverPath == "" {
@ -188,11 +139,15 @@ func NewLlamaServer(model string, adapters, projectors []string, opts api.Option
} else { } else {
slog.Info("user override", "OLLAMA_LLM_LIBRARY", demandLib, "path", serverPath) slog.Info("user override", "OLLAMA_LLM_LIBRARY", demandLib, "path", serverPath)
servers = []string{demandLib} servers = []string{demandLib}
if strings.HasPrefix(demandLib, "cpu") {
// Omit the GPU flag to silence the warning
opts.NumGPU = -1
}
} }
} }
if len(servers) == 0 { if len(servers) == 0 {
return nil, fmt.Errorf("no servers found for %v", info) return nil, fmt.Errorf("no servers found for %v", gpus)
} }
params := []string{ params := []string{
@ -201,7 +156,7 @@ func NewLlamaServer(model string, adapters, projectors []string, opts api.Option
"--batch-size", fmt.Sprintf("%d", opts.NumBatch), "--batch-size", fmt.Sprintf("%d", opts.NumBatch),
"--embedding", "--embedding",
} }
if debug := os.Getenv("OLLAMA_DEBUG"); debug != "" { if envconfig.Debug {
params = append(params, "--log-format", "json") params = append(params, "--log-format", "json")
} else { } else {
params = append(params, "--log-disable") params = append(params, "--log-disable")
@ -211,7 +166,7 @@ func NewLlamaServer(model string, adapters, projectors []string, opts api.Option
params = append(params, "--n-gpu-layers", fmt.Sprintf("%d", opts.NumGPU)) params = append(params, "--n-gpu-layers", fmt.Sprintf("%d", opts.NumGPU))
} }
if debug := os.Getenv("OLLAMA_DEBUG"); debug != "" { if envconfig.Debug {
params = append(params, "--verbose") params = append(params, "--verbose")
} }
@ -249,10 +204,30 @@ func NewLlamaServer(model string, adapters, projectors []string, opts api.Option
params = append(params, "--numa") params = append(params, "--numa")
} }
// Loop through potential servers numParallel := envconfig.NumParallel
var finalErr error
// TODO (jmorganca): multimodal models don't support parallel yet
// see https://github.com/ollama/ollama/issues/4165
if len(projectors) > 0 {
numParallel = 1
slog.Warn("multimodal models don't support parallel requests yet")
}
params = append(params, "--parallel", fmt.Sprintf("%d", numParallel))
for i := 0; i < len(servers); i++ { for i := 0; i < len(servers); i++ {
dir := availableServers[servers[i]] dir := availableServers[servers[i]]
if dir == "" {
// Shouldn't happen
finalErr = fmt.Errorf("[%d] server %s not listed in available servers %v", i, servers[i], availableServers)
slog.Error("sever list inconsistent", "error", finalErr)
continue
}
if strings.HasPrefix(servers[i], "cpu") {
// TODO if we tried a gpu runner first, and it failed, record the error and bubble that back up
gpuCount = 0
}
// Find an availableServers port, retry on each iterration in case the failure was a port conflict race // Find an availableServers port, retry on each iterration in case the failure was a port conflict race
port := 0 port := 0
@ -273,12 +248,21 @@ func NewLlamaServer(model string, adapters, projectors []string, opts api.Option
if runtime.GOOS == "windows" { if runtime.GOOS == "windows" {
pathEnv = "PATH" pathEnv = "PATH"
} }
// append the server directory to LD_LIBRARY_PATH/PATH // prepend the server directory to LD_LIBRARY_PATH/PATH
libraryPaths := []string{dir} libraryPaths := []string{dir}
if libraryPath, ok := os.LookupEnv(pathEnv); ok { if libraryPath, ok := os.LookupEnv(pathEnv); ok {
// Append our runner directory to the path // Append our runner directory to the path
// This will favor system libraries over our bundled library dependencies // This will favor system libraries over our bundled library dependencies
libraryPaths = append(filepath.SplitList(libraryPath), libraryPaths...) libraryPaths = append(libraryPaths, filepath.SplitList(libraryPath)...)
}
// Note: we always put the dependency path first
// since this was the exact version we verified for AMD GPUs
// and we favor what the user had in their path
if gpus[0].DependencyPath != "" {
// TODO refine for multi-gpu support
libraryPaths = append([]string{gpus[0].DependencyPath}, libraryPaths...)
} }
server := filepath.Join(dir, "ollama_llama_server") server := filepath.Join(dir, "ollama_llama_server")
@ -286,21 +270,66 @@ func NewLlamaServer(model string, adapters, projectors []string, opts api.Option
server = server + ".exe" server = server + ".exe"
} }
s := &LlamaServer{ // Detect tmp cleaners wiping out the file
_, err := os.Stat(server)
if errors.Is(err, os.ErrNotExist) {
slog.Warn("llama server disappeared, reinitializing payloads", "path", server, "error", err)
err = Init()
if err != nil {
slog.Warn("failed to reinitialize payloads", "error", err)
return nil, err
}
}
s := &llmServer{
port: port, port: port,
cmd: exec.Command(server, finalParams...), cmd: exec.Command(server, finalParams...),
status: NewStatusWriter(os.Stderr), status: NewStatusWriter(os.Stderr),
options: opts, options: opts,
estimatedVRAM: estimatedVRAM,
estimatedTotal: estimatedTotal,
sem: semaphore.NewWeighted(int64(numParallel)),
totalLayers: ggml.KV().BlockCount() + 1,
gpuCount: gpuCount,
} }
libEnv := fmt.Sprintf("%s=%s", pathEnv, strings.Join(libraryPaths, string(filepath.ListSeparator)))
slog.Debug(libEnv) s.cmd.Env = os.Environ()
s.cmd.Env = append(os.Environ(), libEnv)
s.cmd.Stdout = os.Stdout s.cmd.Stdout = os.Stdout
s.cmd.Stderr = s.status s.cmd.Stderr = s.status
visibleDevicesEnv, visibleDevicesEnvVal := gpu.GpuInfoList(gpus).GetVisibleDevicesEnv()
pathEnvVal := strings.Join(libraryPaths, string(filepath.ListSeparator))
// Update or add the path and visible devices variable with our adjusted version
pathNeeded := true
devicesNeeded := visibleDevicesEnv != ""
for i := range s.cmd.Env {
cmp := strings.SplitN(s.cmd.Env[i], "=", 2)
if strings.EqualFold(cmp[0], pathEnv) {
s.cmd.Env[i] = pathEnv + "=" + pathEnvVal
pathNeeded = false
} else if devicesNeeded && strings.EqualFold(cmp[0], visibleDevicesEnv) {
s.cmd.Env[i] = visibleDevicesEnv + "=" + visibleDevicesEnvVal
devicesNeeded = false
}
}
if pathNeeded {
s.cmd.Env = append(s.cmd.Env, pathEnv+"="+pathEnvVal)
}
if devicesNeeded {
s.cmd.Env = append(s.cmd.Env, visibleDevicesEnv+"="+visibleDevicesEnvVal)
}
slog.Info("starting llama server", "cmd", s.cmd.String()) slog.Info("starting llama server", "cmd", s.cmd.String())
// Log at debug as the environment is inherited and might contain sensitive information
slog.Debug("subprocess", "environment", s.cmd.Env)
if err = s.cmd.Start(); err != nil { if err = s.cmd.Start(); err != nil {
// Detect permission denied and augment them essage about noexec
if errors.Is(err, os.ErrPermission) {
finalErr = fmt.Errorf("unable to start server %w. %s may have noexec set. Set OLLAMA_TMPDIR for server to a writable executable directory", err, dir)
continue
}
msg := "" msg := ""
if s.status != nil && s.status.LastErrMsg != "" { if s.status != nil && s.status.LastErrMsg != "" {
msg = s.status.LastErrMsg msg = s.status.LastErrMsg
@ -310,12 +339,6 @@ func NewLlamaServer(model string, adapters, projectors []string, opts api.Option
continue continue
} }
// reap subprocess when it exits
go func() {
// Exit status managed via getServerStatus
_ = s.cmd.Wait()
}()
return s, nil return s, nil
} }
@ -347,12 +370,27 @@ type ServerStatus int
const ( // iota is reset to 0 const ( // iota is reset to 0
ServerStatusReady ServerStatus = iota ServerStatusReady ServerStatus = iota
ServerStatusNoSlotsAvaialble ServerStatusNoSlotsAvailable
ServerStatusLoadingModel ServerStatusLoadingModel
ServerStatusNotResponding ServerStatusNotResponding
ServerStatusError ServerStatusError
) )
func (s ServerStatus) ToString() string {
switch s {
case ServerStatusReady:
return "llm server ready"
case ServerStatusNoSlotsAvailable:
return "llm busy - no slots available"
case ServerStatusLoadingModel:
return "llm server loading model"
case ServerStatusNotResponding:
return "llm server not responding"
default:
return "llm server error"
}
}
type ServerStatusResp struct { type ServerStatusResp struct {
Status string `json:"status"` Status string `json:"status"`
SlotsIdle int `json:"slots_idle"` SlotsIdle int `json:"slots_idle"`
@ -360,13 +398,17 @@ type ServerStatusResp struct {
Error string `json:"error"` Error string `json:"error"`
} }
func (s *LlamaServer) getServerStatus(ctx context.Context) (ServerStatus, error) { func (s *llmServer) getServerStatus(ctx context.Context) (ServerStatus, error) {
// Fail fast if its exited // Fail fast if its exited
if s.cmd.ProcessState != nil { if s.cmd.ProcessState != nil {
msg := "" msg := ""
if s.status != nil && s.status.LastErrMsg != "" { if s.status != nil && s.status.LastErrMsg != "" {
msg = s.status.LastErrMsg msg = s.status.LastErrMsg
} }
if s.cmd.ProcessState.ExitCode() == -1 {
// Most likely a signal killed it, log some more details to try to help troubleshoot
slog.Warn("llama runner process no longer running", "sys", s.cmd.ProcessState.Sys(), "string", s.cmd.ProcessState.String())
}
return ServerStatusError, fmt.Errorf("llama runner process no longer running: %d %s", s.cmd.ProcessState.ExitCode(), msg) return ServerStatusError, fmt.Errorf("llama runner process no longer running: %d %s", s.cmd.ProcessState.ExitCode(), msg)
} }
@ -399,7 +441,7 @@ func (s *LlamaServer) getServerStatus(ctx context.Context) (ServerStatus, error)
case "ok": case "ok":
return ServerStatusReady, nil return ServerStatusReady, nil
case "no slot available": case "no slot available":
return ServerStatusNoSlotsAvaialble, nil return ServerStatusNoSlotsAvailable, nil
case "loading model": case "loading model":
return ServerStatusLoadingModel, nil return ServerStatusLoadingModel, nil
default: default:
@ -407,7 +449,30 @@ func (s *LlamaServer) getServerStatus(ctx context.Context) (ServerStatus, error)
} }
} }
func (s *LlamaServer) Ping(ctx context.Context) error { // getServerStatusRetry will retry if ServerStatusNoSlotsAvailable is received
func (s *llmServer) getServerStatusRetry(ctx context.Context) (ServerStatus, error) {
var retries int
for {
status, err := s.getServerStatus(ctx)
if err != nil {
return status, err
}
if status == ServerStatusNoSlotsAvailable {
if retries >= 10 {
return status, fmt.Errorf("no slots available after %d retries", retries)
}
time.Sleep(5 * time.Millisecond)
retries++
continue
}
return status, nil
}
}
func (s *llmServer) Ping(ctx context.Context) error {
_, err := s.getServerStatus(ctx) _, err := s.getServerStatus(ctx)
if err != nil { if err != nil {
slog.Debug("server unhealthy", "error", err) slog.Debug("server unhealthy", "error", err)
@ -416,13 +481,25 @@ func (s *LlamaServer) Ping(ctx context.Context) error {
return nil return nil
} }
func (s *LlamaServer) WaitUntilRunning() error { func (s *llmServer) WaitUntilRunning(ctx context.Context) error {
start := time.Now() start := time.Now()
expiresAt := time.Now().Add(10 * time.Minute) // be generous with timeout, large models can take a while to load expiresAt := time.Now().Add(10 * time.Minute) // be generous with timeout, large models can take a while to load
slog.Info("waiting for llama runner to start responding") slog.Info("waiting for llama runner to start responding")
for { for {
select {
case <-ctx.Done():
slog.Info("context expired before server started")
return fmt.Errorf("timed out waiting for llama runner to start: %w", ctx.Err())
case err := <-s.done:
msg := ""
if s.status != nil && s.status.LastErrMsg != "" {
msg = s.status.LastErrMsg
}
return fmt.Errorf("llama runner process has terminated: %v %s", err, msg)
default:
}
ctx, cancel := context.WithTimeout(context.Background(), 200*time.Millisecond) ctx, cancel := context.WithTimeout(context.Background(), 200*time.Millisecond)
defer cancel() defer cancel()
status, err := s.getServerStatus(ctx) status, err := s.getServerStatus(ctx)
@ -487,7 +564,6 @@ ws ::= ([ \t\n] ws)?
` `
const maxBufferSize = 512 * format.KiloByte const maxBufferSize = 512 * format.KiloByte
const maxRetries = 3
type ImageData struct { type ImageData struct {
Data []byte `json:"data"` Data []byte `json:"data"`
@ -524,7 +600,19 @@ type CompletionResponse struct {
EvalDuration time.Duration EvalDuration time.Duration
} }
func (s *LlamaServer) Completion(ctx context.Context, req CompletionRequest, fn func(CompletionResponse)) error { func (s *llmServer) Completion(ctx context.Context, req CompletionRequest, fn func(CompletionResponse)) error {
if err := s.sem.Acquire(ctx, 1); err != nil {
slog.Error("Failed to acquire semaphore", "error", err)
return err
}
defer s.sem.Release(1)
// only allow maximum 10 "context shifts" to avoid infinite generation
if req.Options.NumPredict < 0 || req.Options.NumPredict > 10*s.options.NumCtx {
req.Options.NumPredict = 10 * s.options.NumCtx
slog.Debug("setting token limit to 10x num_ctx", "num_ctx", s.options.NumCtx, "num_predict", req.Options.NumPredict)
}
request := map[string]any{ request := map[string]any{
"prompt": req.Prompt, "prompt": req.Prompt,
"stream": true, "stream": true,
@ -551,11 +639,11 @@ func (s *LlamaServer) Completion(ctx context.Context, req CompletionRequest, fn
} }
// Make sure the server is ready // Make sure the server is ready
status, err := s.getServerStatus(ctx) status, err := s.getServerStatusRetry(ctx)
if err != nil { if err != nil {
return err return err
} else if status != ServerStatusReady { } else if status != ServerStatusReady {
return fmt.Errorf("unexpected server status: %d", status) return fmt.Errorf("unexpected server status: %s", status.ToString())
} }
if req.Format == "json" { if req.Format == "json" {
@ -565,13 +653,6 @@ func (s *LlamaServer) Completion(ctx context.Context, req CompletionRequest, fn
} }
} }
retryDelay := 100 * time.Microsecond
for retries := 0; retries < maxRetries; retries++ {
if retries > 0 {
time.Sleep(retryDelay) // wait before retrying
retryDelay *= 2 // exponential backoff
}
// Handling JSON marshaling with special characters unescaped. // Handling JSON marshaling with special characters unescaped.
buffer := &bytes.Buffer{} buffer := &bytes.Buffer{}
enc := json.NewEncoder(buffer) enc := json.NewEncoder(buffer)
@ -582,20 +663,20 @@ func (s *LlamaServer) Completion(ctx context.Context, req CompletionRequest, fn
} }
endpoint := fmt.Sprintf("http://127.0.0.1:%d/completion", s.port) endpoint := fmt.Sprintf("http://127.0.0.1:%d/completion", s.port)
req, err := http.NewRequestWithContext(ctx, http.MethodPost, endpoint, buffer) serverReq, err := http.NewRequestWithContext(ctx, http.MethodPost, endpoint, buffer)
if err != nil { if err != nil {
return fmt.Errorf("error creating POST request: %v", err) return fmt.Errorf("error creating POST request: %v", err)
} }
req.Header.Set("Content-Type", "application/json") serverReq.Header.Set("Content-Type", "application/json")
resp, err := http.DefaultClient.Do(req) res, err := http.DefaultClient.Do(serverReq)
if err != nil { if err != nil {
return fmt.Errorf("POST predict: %v", err) return fmt.Errorf("POST predict: %v", err)
} }
defer resp.Body.Close() defer res.Body.Close()
if resp.StatusCode >= 400 { if res.StatusCode >= 400 {
bodyBytes, err := io.ReadAll(resp.Body) bodyBytes, err := io.ReadAll(res.Body)
if err != nil { if err != nil {
return fmt.Errorf("failed reading llm error response: %w", err) return fmt.Errorf("failed reading llm error response: %w", err)
} }
@ -603,11 +684,10 @@ func (s *LlamaServer) Completion(ctx context.Context, req CompletionRequest, fn
return fmt.Errorf("%s", bodyBytes) return fmt.Errorf("%s", bodyBytes)
} }
scanner := bufio.NewScanner(resp.Body) scanner := bufio.NewScanner(res.Body)
buf := make([]byte, 0, maxBufferSize) buf := make([]byte, 0, maxBufferSize)
scanner.Buffer(buf, maxBufferSize) scanner.Buffer(buf, maxBufferSize)
retryNeeded := false
// keep track of the last token generated, this is used to abort if the model starts looping // keep track of the last token generated, this is used to abort if the model starts looping
var lastToken string var lastToken string
var tokenRepeat int var tokenRepeat int
@ -623,12 +703,6 @@ func (s *LlamaServer) Completion(ctx context.Context, req CompletionRequest, fn
continue continue
} }
// try again on slot unavailable
if bytes.Contains(line, []byte("slot unavailable")) {
retryNeeded = true
break
}
evt, ok := bytes.CutPrefix(line, []byte("data: ")) evt, ok := bytes.CutPrefix(line, []byte("data: "))
if !ok { if !ok {
return fmt.Errorf("error parsing llm response stream: %s", line) return fmt.Errorf("error parsing llm response stream: %s", line)
@ -679,19 +753,13 @@ func (s *LlamaServer) Completion(ctx context.Context, req CompletionRequest, fn
if s.status != nil && s.status.LastErrMsg != "" { if s.status != nil && s.status.LastErrMsg != "" {
msg = s.status.LastErrMsg msg = s.status.LastErrMsg
} }
return fmt.Errorf("an unknown error was encountered while running the model %s", msg) return fmt.Errorf("an unknown error was encountered while running the model %s", msg)
} }
return fmt.Errorf("error reading llm response: %v", err) return fmt.Errorf("error reading llm response: %v", err)
} }
if !retryNeeded { return nil
return nil // success
}
}
// should never reach here ideally
return fmt.Errorf("max retries exceeded")
} }
type EmbeddingRequest struct { type EmbeddingRequest struct {
@ -702,13 +770,19 @@ type EmbeddingResponse struct {
Embedding []float64 `json:"embedding"` Embedding []float64 `json:"embedding"`
} }
func (s *LlamaServer) Embedding(ctx context.Context, prompt string) ([]float64, error) { func (s *llmServer) Embedding(ctx context.Context, prompt string) ([]float64, error) {
if err := s.sem.Acquire(ctx, 1); err != nil {
slog.Error("Failed to acquire semaphore", "error", err)
return nil, err
}
defer s.sem.Release(1)
// Make sure the server is ready // Make sure the server is ready
status, err := s.getServerStatus(ctx) status, err := s.getServerStatusRetry(ctx)
if err != nil { if err != nil {
return nil, err return nil, err
} else if status != ServerStatusReady { } else if status != ServerStatusReady {
return nil, fmt.Errorf("unexpected server status: %d", status) return nil, fmt.Errorf("unexpected server status: %s", status.ToString())
} }
data, err := json.Marshal(TokenizeRequest{Content: prompt}) data, err := json.Marshal(TokenizeRequest{Content: prompt})
@ -754,13 +828,13 @@ type TokenizeResponse struct {
Tokens []int `json:"tokens"` Tokens []int `json:"tokens"`
} }
func (s *LlamaServer) Tokenize(ctx context.Context, content string) ([]int, error) { func (s *llmServer) Tokenize(ctx context.Context, content string) ([]int, error) {
// Make sure the server is ready // Make sure the server is ready
status, err := s.getServerStatus(ctx) status, err := s.getServerStatus(ctx)
if err != nil { if err != nil {
return nil, err return nil, err
} else if status != ServerStatusReady { } else if status != ServerStatusReady && status != ServerStatusNoSlotsAvailable {
return nil, fmt.Errorf("unexpected server status: %d", status) return nil, fmt.Errorf("unexpected server status: %s", status.ToString())
} }
data, err := json.Marshal(TokenizeRequest{Content: content}) data, err := json.Marshal(TokenizeRequest{Content: content})
@ -806,13 +880,13 @@ type DetokenizeResponse struct {
Content string `json:"content"` Content string `json:"content"`
} }
func (s *LlamaServer) Detokenize(ctx context.Context, tokens []int) (string, error) { func (s *llmServer) Detokenize(ctx context.Context, tokens []int) (string, error) {
// Make sure the server is ready // Make sure the server is ready
status, err := s.getServerStatus(ctx) status, err := s.getServerStatus(ctx)
if err != nil { if err != nil {
return "", err return "", err
} else if status != ServerStatusReady { } else if status != ServerStatusReady && status != ServerStatusNoSlotsAvailable {
return "", fmt.Errorf("unexpected server status: %d", status) return "", fmt.Errorf("unexpected server status: %s", status.ToString())
} }
data, err := json.Marshal(DetokenizeRequest{Tokens: tokens}) data, err := json.Marshal(DetokenizeRequest{Tokens: tokens})
@ -850,15 +924,25 @@ func (s *LlamaServer) Detokenize(ctx context.Context, tokens []int) (string, err
return decoded.Content, nil return decoded.Content, nil
} }
func (s *LlamaServer) Close() error { func (s *llmServer) Close() error {
if s.cmd != nil { if s.cmd != nil {
slog.Debug("stopping llama server") slog.Debug("stopping llama server")
return s.cmd.Process.Kill() if err := s.cmd.Process.Kill(); err != nil {
return err
}
_ = s.cmd.Wait()
slog.Debug("llama server stopped")
} }
return nil return nil
} }
func (s *llmServer) EstimatedVRAM() uint64 {
return s.estimatedVRAM
}
func parseDurationMs(ms float64) time.Duration { func parseDurationMs(ms float64) time.Duration {
dur, err := time.ParseDuration(fmt.Sprintf("%fms", ms)) dur, err := time.ParseDuration(fmt.Sprintf("%fms", ms))
if err != nil { if err != nil {

View file

@ -19,7 +19,7 @@ export default function () {
const [step, setStep] = useState<Step>(Step.WELCOME) const [step, setStep] = useState<Step>(Step.WELCOME)
const [commandCopied, setCommandCopied] = useState<boolean>(false) const [commandCopied, setCommandCopied] = useState<boolean>(false)
const command = 'ollama run llama2' const command = 'ollama run llama3'
return ( return (
<div className='drag'> <div className='drag'>

View file

@ -1,132 +0,0 @@
package parser
import (
"bufio"
"bytes"
"errors"
"fmt"
"io"
"log/slog"
"slices"
)
type Command struct {
Name string
Args string
}
func (c *Command) Reset() {
c.Name = ""
c.Args = ""
}
func Parse(reader io.Reader) ([]Command, error) {
var commands []Command
var command, modelCommand Command
scanner := bufio.NewScanner(reader)
scanner.Buffer(make([]byte, 0, bufio.MaxScanTokenSize), bufio.MaxScanTokenSize)
scanner.Split(scanModelfile)
for scanner.Scan() {
line := scanner.Bytes()
fields := bytes.SplitN(line, []byte(" "), 2)
if len(fields) == 0 || len(fields[0]) == 0 {
continue
}
switch string(bytes.ToUpper(fields[0])) {
case "FROM":
command.Name = "model"
command.Args = string(bytes.TrimSpace(fields[1]))
// copy command for validation
modelCommand = command
case "ADAPTER":
command.Name = string(bytes.ToLower(fields[0]))
command.Args = string(bytes.TrimSpace(fields[1]))
case "LICENSE", "TEMPLATE", "SYSTEM", "PROMPT":
command.Name = string(bytes.ToLower(fields[0]))
command.Args = string(fields[1])
case "PARAMETER":
fields = bytes.SplitN(fields[1], []byte(" "), 2)
if len(fields) < 2 {
return nil, fmt.Errorf("missing value for %s", fields)
}
command.Name = string(fields[0])
command.Args = string(bytes.TrimSpace(fields[1]))
case "EMBED":
return nil, fmt.Errorf("deprecated command: EMBED is no longer supported, use the /embed API endpoint instead")
case "MESSAGE":
command.Name = string(bytes.ToLower(fields[0]))
fields = bytes.SplitN(fields[1], []byte(" "), 2)
if len(fields) < 2 {
return nil, fmt.Errorf("should be in the format <role> <message>")
}
if !slices.Contains([]string{"system", "user", "assistant"}, string(bytes.ToLower(fields[0]))) {
return nil, fmt.Errorf("role must be one of \"system\", \"user\", or \"assistant\"")
}
command.Args = fmt.Sprintf("%s: %s", string(bytes.ToLower(fields[0])), string(fields[1]))
default:
if !bytes.HasPrefix(fields[0], []byte("#")) {
// log a warning for unknown commands
slog.Warn(fmt.Sprintf("Unknown command: %s", fields[0]))
}
continue
}
commands = append(commands, command)
command.Reset()
}
if modelCommand.Args == "" {
return nil, errors.New("no FROM line for the model was specified")
}
return commands, scanner.Err()
}
func scanModelfile(data []byte, atEOF bool) (advance int, token []byte, err error) {
advance, token, err = scan([]byte(`"""`), []byte(`"""`), data, atEOF)
if err != nil {
return 0, nil, err
}
if advance > 0 && token != nil {
return advance, token, nil
}
advance, token, err = scan([]byte(`"`), []byte(`"`), data, atEOF)
if err != nil {
return 0, nil, err
}
if advance > 0 && token != nil {
return advance, token, nil
}
return bufio.ScanLines(data, atEOF)
}
func scan(openBytes, closeBytes, data []byte, atEOF bool) (advance int, token []byte, err error) {
newline := bytes.IndexByte(data, '\n')
if start := bytes.Index(data, openBytes); start >= 0 && start < newline {
end := bytes.Index(data[start+len(openBytes):], closeBytes)
if end < 0 {
if atEOF {
return 0, nil, fmt.Errorf("unterminated %s: expecting %s", openBytes, closeBytes)
} else {
return 0, nil, nil
}
}
n := start + len(openBytes) + end + len(closeBytes)
newData := data[:start]
newData = append(newData, data[start+len(openBytes):n-len(closeBytes)]...)
return n, newData, nil
}
return 0, nil, nil
}

View file

@ -1,98 +0,0 @@
package parser
import (
"strings"
"testing"
"github.com/stretchr/testify/assert"
)
func Test_Parser(t *testing.T) {
input := `
FROM model1
ADAPTER adapter1
LICENSE MIT
PARAMETER param1 value1
PARAMETER param2 value2
TEMPLATE template1
`
reader := strings.NewReader(input)
commands, err := Parse(reader)
assert.Nil(t, err)
expectedCommands := []Command{
{Name: "model", Args: "model1"},
{Name: "adapter", Args: "adapter1"},
{Name: "license", Args: "MIT"},
{Name: "param1", Args: "value1"},
{Name: "param2", Args: "value2"},
{Name: "template", Args: "template1"},
}
assert.Equal(t, expectedCommands, commands)
}
func Test_Parser_NoFromLine(t *testing.T) {
input := `
PARAMETER param1 value1
PARAMETER param2 value2
`
reader := strings.NewReader(input)
_, err := Parse(reader)
assert.ErrorContains(t, err, "no FROM line")
}
func Test_Parser_MissingValue(t *testing.T) {
input := `
FROM foo
PARAMETER param1
`
reader := strings.NewReader(input)
_, err := Parse(reader)
assert.ErrorContains(t, err, "missing value for [param1]")
}
func Test_Parser_Messages(t *testing.T) {
input := `
FROM foo
MESSAGE system You are a Parser. Always Parse things.
MESSAGE user Hey there!
MESSAGE assistant Hello, I want to parse all the things!
`
reader := strings.NewReader(input)
commands, err := Parse(reader)
assert.Nil(t, err)
expectedCommands := []Command{
{Name: "model", Args: "foo"},
{Name: "message", Args: "system: You are a Parser. Always Parse things."},
{Name: "message", Args: "user: Hey there!"},
{Name: "message", Args: "assistant: Hello, I want to parse all the things!"},
}
assert.Equal(t, expectedCommands, commands)
}
func Test_Parser_Messages_BadRole(t *testing.T) {
input := `
FROM foo
MESSAGE badguy I'm a bad guy!
`
reader := strings.NewReader(input)
_, err := Parse(reader)
assert.ErrorContains(t, err, "role must be one of \"system\", \"user\", or \"assistant\"")
}

View file

@ -218,7 +218,7 @@ func (i *Instance) Readline() (string, error) {
case CharCtrlZ: case CharCtrlZ:
fd := int(syscall.Stdin) fd := int(syscall.Stdin)
return handleCharCtrlZ(fd, i.Terminal.termios) return handleCharCtrlZ(fd, i.Terminal.termios)
case CharEnter: case CharEnter, CharCtrlJ:
output := buf.String() output := buf.String()
if output != "" { if output != "" {
i.History.Add([]rune(output)) i.History.Add([]rune(output))
@ -232,7 +232,7 @@ func (i *Instance) Readline() (string, error) {
metaDel = false metaDel = false
continue continue
} }
if r >= CharSpace || r == CharEnter { if r >= CharSpace || r == CharEnter || r == CharCtrlJ {
buf.Add(r) buf.Add(r)
} }
} }

Some files were not shown because too many files have changed in this diff Show more