Merge remote-tracking branch 'upstream/main' into pr3702

This commit is contained in:
Daniel Hiltgen 2024-05-08 16:44:35 -07:00
commit 920a4b0794
132 changed files with 7701 additions and 4766 deletions

View file

@ -103,6 +103,7 @@ jobs:
path: |
llm/build/**/bin/*
llm/build/**/*.a
dist/windows-amd64/**
# ROCm generation step
generate-windows-rocm:
@ -173,7 +174,9 @@ jobs:
- uses: actions/upload-artifact@v4
with:
name: generate-windows-rocm
path: llm/build/**/bin/*
path: |
llm/build/**/bin/*
dist/windows-amd64/**
- uses: actions/upload-artifact@v4
with:
name: windows-rocm-deps
@ -253,7 +256,9 @@ jobs:
- uses: actions/upload-artifact@v4
with:
name: generate-windows-cuda
path: llm/build/**/bin/*
path: |
llm/build/**/bin/*
dist/windows-amd64/**
- uses: actions/upload-artifact@v4
with:
name: windows-cuda-deps
@ -306,23 +311,18 @@ jobs:
- uses: actions/download-artifact@v4
with:
name: generate-windows-cpu
path: llm/build
- uses: actions/download-artifact@v4
with:
name: generate-windows-cuda
path: llm/build
- uses: actions/download-artifact@v4
with:
name: windows-cuda-deps
path: dist/deps
- uses: actions/download-artifact@v4
with:
name: windows-rocm-deps
path: dist/deps
- uses: actions/download-artifact@v4
with:
name: generate-windows-rocm
path: llm/build
- run: dir llm/build
- run: |
$gopath=(get-command go).source | split-path -parent
@ -331,13 +331,13 @@ jobs:
$env:CMAKE_SYSTEM_VERSION="10.0.22621.0"
$env:PATH="$gopath;$env:PATH"
$env:OLLAMA_SKIP_GENERATE="1"
$env:NVIDIA_DIR=$(resolve-path ".\dist\deps")
$env:HIP_PATH=$(resolve-path ".\dist\deps")
& .\scripts\build_windows.ps1
- uses: actions/upload-artifact@v4
with:
name: dist-windows
path: dist/*.exe
path: |
dist/OllamaSetup.exe
dist/ollama-windows-*.zip
# Linux x86 assets built using the container based build
build-linux-amd64:

View file

@ -1,5 +1,15 @@
name: test
concurrency:
# For PRs, later CI runs preempt previous ones. e.g. a force push on a PR
# cancels running CI jobs and starts all new ones.
#
# For non-PR pushes, concurrency.group needs to be unique for every distinct
# CI run we want to have happen. Use run_id, which in practice means all
# non-PR CI runs will be allowed to run without preempting each other.
group: ${{ github.workflow }}-$${{ github.pull_request.number || github.run_id }}
cancel-in-progress: true
on:
pull_request:
paths:
@ -21,7 +31,9 @@ jobs:
- id: changes
run: |
changed() {
git diff-tree -r --no-commit-id --name-only ${{ github.event.pull_request.base.sha }} ${{ github.event.pull_request.head.sha }} \
git diff-tree -r --no-commit-id --name-only \
$(git merge-base ${{ github.event.pull_request.base.sha }} ${{ github.event.pull_request.head.sha }}) \
${{ github.event.pull_request.head.sha }} \
| xargs python3 -c "import sys; print(any([x.startswith('$1') for x in sys.argv[1:]]))"
}
@ -103,7 +115,9 @@ jobs:
- uses: actions/upload-artifact@v4
with:
name: cuda-${{ matrix.cuda-version }}-libraries
path: llm/build/**/bin/*
path: |
llm/build/**/bin/*
dist/windows-amd64/**
generate-rocm:
needs: [changes]
if: ${{ needs.changes.outputs.GENERATE_ROCM == 'True' }}
@ -134,7 +148,9 @@ jobs:
- uses: actions/upload-artifact@v4
with:
name: rocm-${{ matrix.rocm-version }}-libraries
path: llm/build/**/bin/*
path: |
llm/build/**/bin/*
dist/windows-amd64/**
# ROCm generation step
generate-windows-rocm:
@ -253,14 +269,9 @@ jobs:
mkdir -p llm/build/darwin/$ARCH/stub/bin
touch llm/build/darwin/$ARCH/stub/bin/ollama_llama_server
if: ${{ startsWith(matrix.os, 'macos-') }}
- run: |
mkdir -p llm/build/windows/$ARCH/stub/bin
touch llm/build/windows/$ARCH/stub/bin/ollama_llama_server
if: ${{ startsWith(matrix.os, 'windows-') }}
shell: bash
- uses: golangci/golangci-lint-action@v4
with:
args: --timeout 8m0s
args: --timeout 8m0s -v
test:
strategy:
matrix:
@ -284,7 +295,6 @@ jobs:
with:
go-version-file: go.mod
cache: true
- run: go get
- run: |
case ${{ matrix.arch }} in
amd64) echo ARCH=x86_64 ;;
@ -299,10 +309,6 @@ jobs:
mkdir -p llm/build/darwin/$ARCH/stub/bin
touch llm/build/darwin/$ARCH/stub/bin/ollama_llama_server
if: ${{ startsWith(matrix.os, 'macos-') }}
- run: |
mkdir -p llm/build/windows/$ARCH/stub/bin
touch llm/build/windows/$ARCH/stub/bin/ollama_llama_server
if: ${{ startsWith(matrix.os, 'windows-') }}
shell: bash
- run: go generate ./...
- run: go build

3
.gitignore vendored
View file

@ -11,4 +11,5 @@ ggml-metal.metal
.idea
test_data
*.crt
llm/build
llm/build
__debug_bin*

View file

@ -18,7 +18,7 @@ ENV PATH /opt/rh/devtoolset-10/root/usr/bin:$PATH
COPY --from=llm-code / /go/src/github.com/ollama/ollama/
WORKDIR /go/src/github.com/ollama/ollama/llm/generate
ARG CGO_CFLAGS
RUN OLLAMA_SKIP_CPU_GENERATE=1 sh gen_linux.sh
RUN OLLAMA_SKIP_STATIC_GENERATE=1 OLLAMA_SKIP_CPU_GENERATE=1 sh gen_linux.sh
FROM --platform=linux/arm64 nvidia/cuda:$CUDA_VERSION-devel-rockylinux8 AS cuda-build-arm64
ARG CMAKE_VERSION
@ -28,7 +28,7 @@ ENV PATH /opt/rh/gcc-toolset-10/root/usr/bin:$PATH
COPY --from=llm-code / /go/src/github.com/ollama/ollama/
WORKDIR /go/src/github.com/ollama/ollama/llm/generate
ARG CGO_CFLAGS
RUN OLLAMA_SKIP_CPU_GENERATE=1 sh gen_linux.sh
RUN OLLAMA_SKIP_STATIC_GENERATE=1 OLLAMA_SKIP_CPU_GENERATE=1 sh gen_linux.sh
FROM --platform=linux/amd64 rocm/dev-centos-7:${ROCM_VERSION}-complete AS rocm-build-amd64
ARG CMAKE_VERSION
@ -40,7 +40,7 @@ COPY --from=llm-code / /go/src/github.com/ollama/ollama/
WORKDIR /go/src/github.com/ollama/ollama/llm/generate
ARG CGO_CFLAGS
ARG AMDGPU_TARGETS
RUN OLLAMA_SKIP_CPU_GENERATE=1 sh gen_linux.sh
RUN OLLAMA_SKIP_STATIC_GENERATE=1 OLLAMA_SKIP_CPU_GENERATE=1 sh gen_linux.sh
RUN mkdir /tmp/scratch && \
for dep in $(zcat /go/src/github.com/ollama/ollama/llm/build/linux/x86_64/rocm*/bin/deps.txt.gz) ; do \
cp ${dep} /tmp/scratch/ || exit 1 ; \
@ -64,11 +64,11 @@ WORKDIR /go/src/github.com/ollama/ollama/llm/generate
FROM --platform=linux/amd64 cpu-builder-amd64 AS static-build-amd64
RUN OLLAMA_CPU_TARGET="static" sh gen_linux.sh
FROM --platform=linux/amd64 cpu-builder-amd64 AS cpu-build-amd64
RUN OLLAMA_CPU_TARGET="cpu" sh gen_linux.sh
RUN OLLAMA_SKIP_STATIC_GENERATE=1 OLLAMA_CPU_TARGET="cpu" sh gen_linux.sh
FROM --platform=linux/amd64 cpu-builder-amd64 AS cpu_avx-build-amd64
RUN OLLAMA_CPU_TARGET="cpu_avx" sh gen_linux.sh
RUN OLLAMA_SKIP_STATIC_GENERATE=1 OLLAMA_CPU_TARGET="cpu_avx" sh gen_linux.sh
FROM --platform=linux/amd64 cpu-builder-amd64 AS cpu_avx2-build-amd64
RUN OLLAMA_CPU_TARGET="cpu_avx2" sh gen_linux.sh
RUN OLLAMA_SKIP_STATIC_GENERATE=1 OLLAMA_CPU_TARGET="cpu_avx2" sh gen_linux.sh
FROM --platform=linux/arm64 centos:7 AS cpu-builder-arm64
ARG CMAKE_VERSION
@ -84,7 +84,7 @@ WORKDIR /go/src/github.com/ollama/ollama/llm/generate
FROM --platform=linux/arm64 cpu-builder-arm64 AS static-build-arm64
RUN OLLAMA_CPU_TARGET="static" sh gen_linux.sh
FROM --platform=linux/arm64 cpu-builder-arm64 AS cpu-build-arm64
RUN OLLAMA_CPU_TARGET="cpu" sh gen_linux.sh
RUN OLLAMA_SKIP_STATIC_GENERATE=1 OLLAMA_CPU_TARGET="cpu" sh gen_linux.sh
# Intermediate stage used for ./scripts/build_linux.sh

View file

@ -1,5 +1,5 @@
<div align="center">
<img alt="ollama" height="200px" src="https://github.com/ollama/ollama/assets/3325447/0d0b44e2-8f4a-4e99-9b52-a5c1c741c8f7">
 <img alt="ollama" height="200px" src="https://github.com/ollama/ollama/assets/3325447/0d0b44e2-8f4a-4e99-9b52-a5c1c741c8f7">
</div>
# Ollama
@ -35,10 +35,10 @@ The official [Ollama Docker image](https://hub.docker.com/r/ollama/ollama) `olla
## Quickstart
To run and chat with [Llama 2](https://ollama.com/library/llama2):
To run and chat with [Llama 3](https://ollama.com/library/llama3):
```
ollama run llama2
ollama run llama3
```
## Model library
@ -49,17 +49,14 @@ Here are some example models that can be downloaded:
| Model | Parameters | Size | Download |
| ------------------ | ---------- | ----- | ------------------------------ |
| Llama 2 | 7B | 3.8GB | `ollama run llama2` |
| Llama 3 | 8B | 4.7GB | `ollama run llama3` |
| Llama 3 | 70B | 40GB | `ollama run llama3:70b` |
| Phi-3 | 3.8B | 2.3GB | `ollama run phi3` |
| Mistral | 7B | 4.1GB | `ollama run mistral` |
| Dolphin Phi | 2.7B | 1.6GB | `ollama run dolphin-phi` |
| Phi-2 | 2.7B | 1.7GB | `ollama run phi` |
| Neural Chat | 7B | 4.1GB | `ollama run neural-chat` |
| Starling | 7B | 4.1GB | `ollama run starling-lm` |
| Code Llama | 7B | 3.8GB | `ollama run codellama` |
| Llama 2 Uncensored | 7B | 3.8GB | `ollama run llama2-uncensored` |
| Llama 2 13B | 13B | 7.3GB | `ollama run llama2:13b` |
| Llama 2 70B | 70B | 39GB | `ollama run llama2:70b` |
| Orca Mini | 3B | 1.9GB | `ollama run orca-mini` |
| LLaVA | 7B | 4.5GB | `ollama run llava` |
| Gemma | 2B | 1.4GB | `ollama run gemma:2b` |
| Gemma | 7B | 4.8GB | `ollama run gemma:7b` |
@ -97,16 +94,16 @@ See the [guide](docs/import.md) on importing models for more information.
### Customize a prompt
Models from the Ollama library can be customized with a prompt. For example, to customize the `llama2` model:
Models from the Ollama library can be customized with a prompt. For example, to customize the `llama3` model:
```
ollama pull llama2
ollama pull llama3
```
Create a `Modelfile`:
```
FROM llama2
FROM llama3
# set the temperature to 1 [higher is more creative, lower is more coherent]
PARAMETER temperature 1
@ -141,7 +138,7 @@ ollama create mymodel -f ./Modelfile
### Pull a model
```
ollama pull llama2
ollama pull llama3
```
> This command can also be used to update a local model. Only the diff will be pulled.
@ -149,13 +146,13 @@ ollama pull llama2
### Remove a model
```
ollama rm llama2
ollama rm llama3
```
### Copy a model
```
ollama cp llama2 my-llama2
ollama cp llama3 my-model
```
### Multiline input
@ -176,10 +173,10 @@ I'm a basic program that prints the famous "Hello, world!" message to the consol
The image features a yellow smiley face, which is likely the central focus of the picture.
```
### Pass in prompt as arguments
### Pass the prompt as an argument
```
$ ollama run llama2 "Summarize this file: $(cat README.md)"
$ ollama run llama3 "Summarize this file: $(cat README.md)"
Ollama is a lightweight, extensible framework for building and running language models on the local machine. It provides a simple API for creating, running, and managing models, as well as a library of pre-built models that can be easily used in a variety of applications.
```
@ -226,7 +223,7 @@ Next, start the server:
Finally, in a separate shell, run a model:
```
./ollama run llama2
./ollama run llama3
```
## REST API
@ -237,7 +234,7 @@ Ollama has a REST API for running and managing models.
```
curl http://localhost:11434/api/generate -d '{
"model": "llama2",
"model": "llama3",
"prompt":"Why is the sky blue?"
}'
```
@ -246,7 +243,7 @@ curl http://localhost:11434/api/generate -d '{
```
curl http://localhost:11434/api/chat -d '{
"model": "mistral",
"model": "llama3",
"messages": [
{ "role": "user", "content": "why is the sky blue?" }
]
@ -259,16 +256,18 @@ See the [API documentation](./docs/api.md) for all endpoints.
### Web & Desktop
- [Open WebUI](https://github.com/open-webui/open-webui)
- [Enchanted (macOS native)](https://github.com/AugustDev/enchanted)
- [Hollama](https://github.com/fmaclen/hollama)
- [Lollms-Webui](https://github.com/ParisNeo/lollms-webui)
- [LibreChat](https://github.com/danny-avila/LibreChat)
- [Bionic GPT](https://github.com/bionic-gpt/bionic-gpt)
- [Enchanted (macOS native)](https://github.com/AugustDev/enchanted)
- [HTML UI](https://github.com/rtcfirefly/ollama-ui)
- [Saddle](https://github.com/jikkuatwork/saddle)
- [Chatbot UI](https://github.com/ivanfioravanti/chatbot-ollama)
- [Chatbot UI v2](https://github.com/mckaywrigley/chatbot-ui)
- [Typescript UI](https://github.com/ollama-interface/Ollama-Gui?tab=readme-ov-file)
- [Minimalistic React UI for Ollama Models](https://github.com/richawo/minimal-llm-ui)
- [Open WebUI](https://github.com/open-webui/open-webui)
- [Ollamac](https://github.com/kevinhermawan/Ollamac)
- [big-AGI](https://github.com/enricoros/big-AGI/blob/main/docs/config-local-ollama.md)
- [Cheshire Cat assistant framework](https://github.com/cheshire-cat-ai/core)
@ -286,13 +285,20 @@ See the [API documentation](./docs/api.md) for all endpoints.
- [OllamaGUI](https://github.com/enoch1118/ollamaGUI)
- [OpenAOE](https://github.com/InternLM/OpenAOE)
- [Odin Runes](https://github.com/leonid20000/OdinRunes)
- [LLM-X: Progressive Web App](https://github.com/mrdjohnson/llm-x)
- [LLM-X](https://github.com/mrdjohnson/llm-x) (Progressive Web App)
- [AnythingLLM (Docker + MacOs/Windows/Linux native app)](https://github.com/Mintplex-Labs/anything-llm)
- [Ollama Basic Chat: Uses HyperDiv Reactive UI](https://github.com/rapidarchitect/ollama_basic_chat)
- [Ollama-chats RPG](https://github.com/drazdra/ollama-chats)
- [ChatOllama: Open Source Chatbot based on Ollama with Knowledge Bases](https://github.com/sugarforever/chat-ollama)
- [CRAG Ollama Chat: Simple Web Search with Corrective RAG](https://github.com/Nagi-ovo/CRAG-Ollama-Chat)
- [RAGFlow: Open-source Retrieval-Augmented Generation engine based on deep document understanding](https://github.com/infiniflow/ragflow)
- [QA-Pilot](https://github.com/reid41/QA-Pilot) (Chat with Code Repository)
- [ChatOllama](https://github.com/sugarforever/chat-ollama) (Open Source Chatbot based on Ollama with Knowledge Bases)
- [CRAG Ollama Chat](https://github.com/Nagi-ovo/CRAG-Ollama-Chat) (Simple Web Search with Corrective RAG)
- [RAGFlow](https://github.com/infiniflow/ragflow) (Open-source Retrieval-Augmented Generation engine based on deep document understanding)
- [StreamDeploy](https://github.com/StreamDeploy-DevRel/streamdeploy-llm-app-scaffold) (LLM Application Scaffold)
- [chat](https://github.com/swuecho/chat) (chat web app for teams)
- [Lobe Chat](https://github.com/lobehub/lobe-chat) with [Integrating Doc](https://lobehub.com/docs/self-hosting/examples/ollama)
- [Ollama RAG Chatbot](https://github.com/datvodinh/rag-chatbot.git) (Local Chat with multiple PDFs using Ollama and RAG)
- [BrainSoup](https://www.nurgo-software.com/products/brainsoup) (Flexible native client with RAG & multi-agent automation)
- [macai](https://github.com/Renset/macai) (macOS client for Ollama, ChatGPT, and other compatible API back-ends)
### Terminal
@ -308,11 +314,13 @@ See the [API documentation](./docs/api.md) for all endpoints.
- [Oatmeal](https://github.com/dustinblackman/oatmeal)
- [cmdh](https://github.com/pgibler/cmdh)
- [ooo](https://github.com/npahlfer/ooo)
- [shell-pilot](https://github.com/reid41/shell-pilot)
- [tenere](https://github.com/pythops/tenere)
- [llm-ollama](https://github.com/taketwo/llm-ollama) for [Datasette's LLM CLI](https://llm.datasette.io/en/stable/).
- [typechat-cli](https://github.com/anaisbetts/typechat-cli)
- [ShellOracle](https://github.com/djcopley/ShellOracle)
- [tlm](https://github.com/yusufcanb/tlm)
- [podman-ollama](https://github.com/ericcurtin/podman-ollama)
### Database
@ -344,9 +352,11 @@ See the [API documentation](./docs/api.md) for all endpoints.
- [Haystack](https://github.com/deepset-ai/haystack-integrations/blob/main/integrations/ollama.md)
- [Elixir LangChain](https://github.com/brainlid/langchain)
- [Ollama for R - rollama](https://github.com/JBGruber/rollama)
- [Ollama for R - ollama-r](https://github.com/hauselin/ollama-r)
- [Ollama-ex for Elixir](https://github.com/lebrunel/ollama-ex)
- [Ollama Connector for SAP ABAP](https://github.com/b-tocs/abap_btocs_ollama)
- [Testcontainers](https://testcontainers.com/modules/ollama/)
- [Portkey](https://portkey.ai/docs/welcome/integration-guides/ollama)
### Mobile
@ -366,17 +376,20 @@ See the [API documentation](./docs/api.md) for all endpoints.
- [Ollama Telegram Bot](https://github.com/ruecat/ollama-telegram)
- [Hass Ollama Conversation](https://github.com/ej52/hass-ollama-conversation)
- [Rivet plugin](https://github.com/abrenneke/rivet-plugin-ollama)
- [Llama Coder](https://github.com/ex3ndr/llama-coder) (Copilot alternative using Ollama)
- [Obsidian BMO Chatbot plugin](https://github.com/longy2k/obsidian-bmo-chatbot)
- [Cliobot](https://github.com/herval/cliobot) (Telegram bot with Ollama support)
- [Copilot for Obsidian plugin](https://github.com/logancyang/obsidian-copilot)
- [Obsidian Local GPT plugin](https://github.com/pfrankov/obsidian-local-gpt)
- [Open Interpreter](https://docs.openinterpreter.com/language-model-setup/local-models/ollama)
- [Llama Coder](https://github.com/ex3ndr/llama-coder) (Copilot alternative using Ollama)
- [Ollama Copilot](https://github.com/bernardo-bruning/ollama-copilot) (Proxy that allows you to use ollama as a copilot like Github copilot)
- [twinny](https://github.com/rjmacarthy/twinny) (Copilot and Copilot chat alternative using Ollama)
- [Wingman-AI](https://github.com/RussellCanfield/wingman-ai) (Copilot code and chat alternative using Ollama and HuggingFace)
- [Page Assist](https://github.com/n4ze3m/page-assist) (Chrome Extension)
- [AI Telegram Bot](https://github.com/tusharhero/aitelegrambot) (Telegram bot using Ollama in backend)
- [AI ST Completion](https://github.com/yaroslavyaroslav/OpenAI-sublime-text) (Sublime Text 4 AI assistant plugin with Ollama support)
- [Discord-Ollama Chat Bot](https://github.com/kevinthedang/discord-ollama) (Generalized TypeScript Discord Bot w/ Tuning Documentation)
### Supported backends
- [llama.cpp](https://github.com/ggerganov/llama.cpp) project founded by Georgi Gerganov.
- [llama.cpp](https://github.com/ggerganov/llama.cpp) project founded by Georgi Gerganov.

View file

@ -1,9 +1,16 @@
// Package api implements the client-side API for code wishing to interact
// with the ollama service. The methods of the [Client] type correspond to
// the ollama REST API as described in https://github.com/ollama/ollama/blob/main/docs/api.md
//
// the ollama REST API as described in [the API documentation].
// The ollama command-line client itself uses this package to interact with
// the backend service.
//
// # Examples
//
// Several examples of using this package are available [in the GitHub
// repository].
//
// [the API documentation]: https://github.com/ollama/ollama/blob/main/docs/api.md
// [in the GitHub repository]: https://github.com/ollama/ollama/tree/main/examples
package api
import (
@ -18,6 +25,7 @@ import (
"net/url"
"os"
"runtime"
"strconv"
"strings"
"github.com/ollama/ollama/format"
@ -57,12 +65,36 @@ func checkError(resp *http.Response, body []byte) error {
// If the variable is not specified, a default ollama host and port will be
// used.
func ClientFromEnvironment() (*Client, error) {
ollamaHost, err := GetOllamaHost()
if err != nil {
return nil, err
}
return &Client{
base: &url.URL{
Scheme: ollamaHost.Scheme,
Host: net.JoinHostPort(ollamaHost.Host, ollamaHost.Port),
},
http: http.DefaultClient,
}, nil
}
type OllamaHost struct {
Scheme string
Host string
Port string
}
func GetOllamaHost() (OllamaHost, error) {
defaultPort := "11434"
scheme, hostport, ok := strings.Cut(os.Getenv("OLLAMA_HOST"), "://")
hostVar := os.Getenv("OLLAMA_HOST")
hostVar = strings.TrimSpace(strings.Trim(strings.TrimSpace(hostVar), "\"'"))
scheme, hostport, ok := strings.Cut(hostVar, "://")
switch {
case !ok:
scheme, hostport = "http", os.Getenv("OLLAMA_HOST")
scheme, hostport = "http", hostVar
case scheme == "http":
defaultPort = "80"
case scheme == "https":
@ -82,15 +114,24 @@ func ClientFromEnvironment() (*Client, error) {
}
}
return &Client{
base: &url.URL{
Scheme: scheme,
Host: net.JoinHostPort(host, port),
},
http: http.DefaultClient,
if portNum, err := strconv.ParseInt(port, 10, 32); err != nil || portNum > 65535 || portNum < 0 {
return OllamaHost{}, ErrInvalidHostPort
}
return OllamaHost{
Scheme: scheme,
Host: host,
Port: port,
}, nil
}
func NewClient(base *url.URL, http *http.Client) *Client {
return &Client{
base: base,
http: http,
}
}
func (c *Client) do(ctx context.Context, method, path string, reqData, respData any) error {
var reqBody io.Reader
var data []byte
@ -265,8 +306,14 @@ func (c *Client) Pull(ctx context.Context, req *PullRequest, fn PullProgressFunc
})
}
// PushProgressFunc is a function that [Client.Push] invokes when progress is
// made.
// It's similar to other progress function types like [PullProgressFunc].
type PushProgressFunc func(ProgressResponse) error
// Push uploads a model to the model library; requires registering for ollama.ai
// and adding a public key first. fn is called each time progress is made on
// the request and can be used to display a progress bar, etc.
func (c *Client) Push(ctx context.Context, req *PushRequest, fn PushProgressFunc) error {
return c.stream(ctx, http.MethodPost, "/api/push", req, func(bts []byte) error {
var resp ProgressResponse
@ -278,8 +325,15 @@ func (c *Client) Push(ctx context.Context, req *PushRequest, fn PushProgressFunc
})
}
// CreateProgressFunc is a function that [Client.Create] invokes when progress
// is made.
// It's similar to other progress function types like [PullProgressFunc].
type CreateProgressFunc func(ProgressResponse) error
// Create creates a model from a [Modelfile]. fn is a progress function that
// behaves similarly to other methods (see [Client.Pull]).
//
// [Modelfile]: https://github.com/ollama/ollama/blob/main/docs/modelfile.md
func (c *Client) Create(ctx context.Context, req *CreateRequest, fn CreateProgressFunc) error {
return c.stream(ctx, http.MethodPost, "/api/create", req, func(bts []byte) error {
var resp ProgressResponse
@ -291,6 +345,7 @@ func (c *Client) Create(ctx context.Context, req *CreateRequest, fn CreateProgre
})
}
// List lists models that are available locally.
func (c *Client) List(ctx context.Context) (*ListResponse, error) {
var lr ListResponse
if err := c.do(ctx, http.MethodGet, "/api/tags", nil, &lr); err != nil {
@ -299,6 +354,8 @@ func (c *Client) List(ctx context.Context) (*ListResponse, error) {
return &lr, nil
}
// Copy copies a model - creating a model with another name from an existing
// model.
func (c *Client) Copy(ctx context.Context, req *CopyRequest) error {
if err := c.do(ctx, http.MethodPost, "/api/copy", req, nil); err != nil {
return err
@ -306,6 +363,7 @@ func (c *Client) Copy(ctx context.Context, req *CopyRequest) error {
return nil
}
// Delete deletes a model and its data.
func (c *Client) Delete(ctx context.Context, req *DeleteRequest) error {
if err := c.do(ctx, http.MethodDelete, "/api/delete", req, nil); err != nil {
return err
@ -313,6 +371,7 @@ func (c *Client) Delete(ctx context.Context, req *DeleteRequest) error {
return nil
}
// Show obtains model information, including details, modelfile, license etc.
func (c *Client) Show(ctx context.Context, req *ShowRequest) (*ShowResponse, error) {
var resp ShowResponse
if err := c.do(ctx, http.MethodPost, "/api/show", req, &resp); err != nil {
@ -321,12 +380,16 @@ func (c *Client) Show(ctx context.Context, req *ShowRequest) (*ShowResponse, err
return &resp, nil
}
// Hearbeat checks if the server has started and is responsive; if yes, it
// returns nil, otherwise an error.
func (c *Client) Heartbeat(ctx context.Context) error {
if err := c.do(ctx, http.MethodHead, "/", nil, nil); err != nil {
return err
}
return nil
}
// Embeddings generates embeddings from a model.
func (c *Client) Embeddings(ctx context.Context, req *EmbeddingRequest) (*EmbeddingResponse, error) {
var resp EmbeddingResponse
if err := c.do(ctx, http.MethodPost, "/api/embeddings", req, &resp); err != nil {
@ -335,10 +398,13 @@ func (c *Client) Embeddings(ctx context.Context, req *EmbeddingRequest) (*Embedd
return &resp, nil
}
// CreateBlob creates a blob from a file on the server. digest is the
// expected SHA256 digest of the file, and r represents the file.
func (c *Client) CreateBlob(ctx context.Context, digest string, r io.Reader) error {
return c.do(ctx, http.MethodPost, fmt.Sprintf("/api/blobs/%s", digest), r, nil)
}
// Version returns the Ollama server version as a string.
func (c *Client) Version(ctx context.Context) (string, error) {
var version struct {
Version string `json:"version"`

View file

@ -1,6 +1,12 @@
package api
import "testing"
import (
"fmt"
"net"
"testing"
"github.com/stretchr/testify/assert"
)
func TestClientFromEnvironment(t *testing.T) {
type testCase struct {
@ -40,4 +46,40 @@ func TestClientFromEnvironment(t *testing.T) {
}
})
}
hostTestCases := map[string]*testCase{
"empty": {value: "", expect: "127.0.0.1:11434"},
"only address": {value: "1.2.3.4", expect: "1.2.3.4:11434"},
"only port": {value: ":1234", expect: ":1234"},
"address and port": {value: "1.2.3.4:1234", expect: "1.2.3.4:1234"},
"hostname": {value: "example.com", expect: "example.com:11434"},
"hostname and port": {value: "example.com:1234", expect: "example.com:1234"},
"zero port": {value: ":0", expect: ":0"},
"too large port": {value: ":66000", err: ErrInvalidHostPort},
"too small port": {value: ":-1", err: ErrInvalidHostPort},
"ipv6 localhost": {value: "[::1]", expect: "[::1]:11434"},
"ipv6 world open": {value: "[::]", expect: "[::]:11434"},
"ipv6 no brackets": {value: "::1", expect: "[::1]:11434"},
"ipv6 + port": {value: "[::1]:1337", expect: "[::1]:1337"},
"extra space": {value: " 1.2.3.4 ", expect: "1.2.3.4:11434"},
"extra quotes": {value: "\"1.2.3.4\"", expect: "1.2.3.4:11434"},
"extra space+quotes": {value: " \" 1.2.3.4 \" ", expect: "1.2.3.4:11434"},
"extra single quotes": {value: "'1.2.3.4'", expect: "1.2.3.4:11434"},
}
for k, v := range hostTestCases {
t.Run(k, func(t *testing.T) {
t.Setenv("OLLAMA_HOST", v.value)
oh, err := GetOllamaHost()
if err != v.err {
t.Fatalf("expected %s, got %s", v.err, err)
}
if err == nil {
host := net.JoinHostPort(oh.Host, oh.Port)
assert.Equal(t, v.expect, host, fmt.Sprintf("%s: expected %s, got %s", k, v.expect, host))
}
})
}
}

View file

@ -2,6 +2,7 @@ package api
import (
"encoding/json"
"errors"
"fmt"
"math"
"os"
@ -11,6 +12,7 @@ import (
"time"
)
// StatusError is an error with and HTTP status code.
type StatusError struct {
StatusCode int
Status string
@ -31,6 +33,7 @@ func (e StatusError) Error() string {
}
}
// ImageData represents the raw binary data of an image file.
type ImageData []byte
// GenerateRequest describes a request sent by [Client.Generate]. While you
@ -76,22 +79,39 @@ type GenerateRequest struct {
Options map[string]interface{} `json:"options"`
}
// ChatRequest describes a request sent by [Client.Chat].
type ChatRequest struct {
Model string `json:"model"`
Messages []Message `json:"messages"`
Stream *bool `json:"stream,omitempty"`
Format string `json:"format"`
// Model is the model name, as in [GenerateRequest].
Model string `json:"model"`
// Messages is the messages of the chat - can be used to keep a chat memory.
Messages []Message `json:"messages"`
// Stream enable streaming of returned response; true by default.
Stream *bool `json:"stream,omitempty"`
// Format is the format to return the response in (e.g. "json").
Format string `json:"format"`
// KeepAlive controls how long the model will stay loaded into memory
// followin the request.
KeepAlive *Duration `json:"keep_alive,omitempty"`
// Options lists model-specific options.
Options map[string]interface{} `json:"options"`
}
// Message is a single message in a chat sequence. The message contains the
// role ("system", "user", or "assistant"), the content and an optional list
// of images.
type Message struct {
Role string `json:"role"` // one of ["system", "user", "assistant"]
Role string `json:"role"`
Content string `json:"content"`
Images []ImageData `json:"images,omitempty"`
}
// ChatResponse is the response returned by [Client.Chat]. Its fields are
// similar to [GenerateResponse].
type ChatResponse struct {
Model string `json:"model"`
CreatedAt time.Time `json:"created_at"`
@ -111,7 +131,8 @@ type Metrics struct {
EvalDuration time.Duration `json:"eval_duration,omitempty"`
}
// Options specified in GenerateRequest, if you add a new option here add it to the API docs also
// Options specified in [GenerateRequest], if you add a new option here add it
// to the API docs also.
type Options struct {
Runner
@ -157,18 +178,28 @@ type Runner struct {
RopeFrequencyScale float32 `json:"rope_frequency_scale,omitempty"`
}
// EmbeddingRequest is the request passed to [Client.Embeddings].
type EmbeddingRequest struct {
Model string `json:"model"`
Prompt string `json:"prompt"`
// Model is the model name.
Model string `json:"model"`
// Prompt is the textual prompt to embed.
Prompt string `json:"prompt"`
// KeepAlive controls how long the model will stay loaded in memory following
// this request.
KeepAlive *Duration `json:"keep_alive,omitempty"`
// Options lists model-specific options.
Options map[string]interface{} `json:"options"`
}
// EmbeddingResponse is the response from [Client.Embeddings].
type EmbeddingResponse struct {
Embedding []float64 `json:"embedding"`
}
// CreateRequest is the request passed to [Client.Create].
type CreateRequest struct {
Model string `json:"model"`
Path string `json:"path"`
@ -180,6 +211,7 @@ type CreateRequest struct {
Name string `json:"name"`
}
// DeleteRequest is the request passed to [Client.Delete].
type DeleteRequest struct {
Model string `json:"model"`
@ -187,6 +219,7 @@ type DeleteRequest struct {
Name string `json:"name"`
}
// ShowRequest is the request passed to [Client.Show].
type ShowRequest struct {
Model string `json:"model"`
System string `json:"system"`
@ -198,6 +231,7 @@ type ShowRequest struct {
Name string `json:"name"`
}
// ShowResponse is the response returned from [Client.Show].
type ShowResponse struct {
License string `json:"license,omitempty"`
Modelfile string `json:"modelfile,omitempty"`
@ -208,11 +242,13 @@ type ShowResponse struct {
Messages []Message `json:"messages,omitempty"`
}
// CopyRequest is the request passed to [Client.Copy].
type CopyRequest struct {
Source string `json:"source"`
Destination string `json:"destination"`
}
// PullRequest is the request passed to [Client.Pull].
type PullRequest struct {
Model string `json:"model"`
Insecure bool `json:"insecure,omitempty"`
@ -224,6 +260,8 @@ type PullRequest struct {
Name string `json:"name"`
}
// ProgressResponse is the response passed to progress functions like
// [PullProgressFunc] and [PushProgressFunc].
type ProgressResponse struct {
Status string `json:"status"`
Digest string `json:"digest,omitempty"`
@ -231,6 +269,7 @@ type ProgressResponse struct {
Completed int64 `json:"completed,omitempty"`
}
// PushRequest is the request passed to [Client.Push].
type PushRequest struct {
Model string `json:"model"`
Insecure bool `json:"insecure,omitempty"`
@ -242,10 +281,12 @@ type PushRequest struct {
Name string `json:"name"`
}
// ListResponse is the response from [Client.List].
type ListResponse struct {
Models []ModelResponse `json:"models"`
}
// ModelResponse is a single model description in [ListResponse].
type ModelResponse struct {
Name string `json:"name"`
Model string `json:"model"`
@ -259,17 +300,28 @@ type TokenResponse struct {
Token string `json:"token"`
}
// GenerateResponse is the response passed into [GenerateResponseFunc].
type GenerateResponse struct {
Model string `json:"model"`
CreatedAt time.Time `json:"created_at"`
Response string `json:"response"`
// Model is the model name that generated the response.
Model string `json:"model"`
Done bool `json:"done"`
//CreatedAt is the timestamp of the response.
CreatedAt time.Time `json:"created_at"`
// Response is the textual response itself.
Response string `json:"response"`
// Done specifies if the response is complete.
Done bool `json:"done"`
// Context is an encoding of the conversation used in this response; this
// can be sent in the next request to keep a conversational memory.
Context []int `json:"context,omitempty"`
Metrics
}
// ModelDetails provides details about a model.
type ModelDetails struct {
ParentModel string `json:"parent_model"`
Format string `json:"format"`
@ -307,7 +359,9 @@ func (m *Metrics) Summary() {
}
}
var ErrInvalidOpts = fmt.Errorf("invalid options")
// ErrInvalidOpts is returned when invalid options are passed to the client.
var ErrInvalidOpts = errors.New("invalid options")
var ErrInvalidHostPort = errors.New("invalid port specified in OLLAMA_HOST")
func (opts *Options) FromMap(m map[string]interface{}) error {
valueOpts := reflect.ValueOf(opts).Elem() // names of the fields in the options struct
@ -392,11 +446,15 @@ func (opts *Options) FromMap(m map[string]interface{}) error {
return nil
}
// DefaultOptions is the default set of options for [GenerateRequest]; these
// values are used unless the user specifies other values explicitly.
func DefaultOptions() Options {
return Options{
// options set on request to runner
NumPredict: -1,
NumKeep: 0,
NumPredict: -1,
// set a minimal num_keep to avoid issues on context shifts
NumKeep: 4,
Temperature: 0.8,
TopK: 40,
TopP: 0.9,
@ -432,6 +490,13 @@ type Duration struct {
time.Duration
}
func (d Duration) MarshalJSON() ([]byte, error) {
if d.Duration < 0 {
return []byte("-1"), nil
}
return []byte("\"" + d.Duration.String() + "\""), nil
}
func (d *Duration) UnmarshalJSON(b []byte) (err error) {
var v any
if err := json.Unmarshal(b, &v); err != nil {
@ -445,7 +510,7 @@ func (d *Duration) UnmarshalJSON(b []byte) (err error) {
if t < 0 {
d.Duration = time.Duration(math.MaxInt64)
} else {
d.Duration = time.Duration(t * float64(time.Second))
d.Duration = time.Duration(int(t) * int(time.Second))
}
case string:
d.Duration, err = time.ParseDuration(t)
@ -455,6 +520,8 @@ func (d *Duration) UnmarshalJSON(b []byte) (err error) {
if d.Duration < 0 {
d.Duration = time.Duration(math.MaxInt64)
}
default:
return fmt.Errorf("Unsupported type: '%s'", reflect.TypeOf(v))
}
return nil

View file

@ -21,6 +21,11 @@ func TestKeepAliveParsingFromJSON(t *testing.T) {
req: `{ "keep_alive": 42 }`,
exp: &Duration{42 * time.Second},
},
{
name: "Positive Float",
req: `{ "keep_alive": 42.5 }`,
exp: &Duration{42 * time.Second},
},
{
name: "Positive Integer String",
req: `{ "keep_alive": "42m" }`,
@ -31,6 +36,11 @@ func TestKeepAliveParsingFromJSON(t *testing.T) {
req: `{ "keep_alive": -1 }`,
exp: &Duration{math.MaxInt64},
},
{
name: "Negative Float",
req: `{ "keep_alive": -3.14 }`,
exp: &Duration{math.MaxInt64},
},
{
name: "Negative Integer String",
req: `{ "keep_alive": "-1m" }`,
@ -48,3 +58,50 @@ func TestKeepAliveParsingFromJSON(t *testing.T) {
})
}
}
func TestDurationMarshalUnmarshal(t *testing.T) {
tests := []struct {
name string
input time.Duration
expected time.Duration
}{
{
"negative duration",
time.Duration(-1),
time.Duration(math.MaxInt64),
},
{
"positive duration",
time.Duration(42 * time.Second),
time.Duration(42 * time.Second),
},
{
"another positive duration",
time.Duration(42 * time.Minute),
time.Duration(42 * time.Minute),
},
{
"zero duration",
time.Duration(0),
time.Duration(0),
},
{
"max duration",
time.Duration(math.MaxInt64),
time.Duration(math.MaxInt64),
},
}
for _, test := range tests {
t.Run(test.name, func(t *testing.T) {
b, err := json.Marshal(Duration{test.input})
require.NoError(t, err)
var d Duration
err = json.Unmarshal(b, &d)
require.NoError(t, err)
assert.Equal(t, test.expected, d.Duration, "input %v, marshalled %v, got %v", test.input, string(b), d.Duration)
})
}
}

View file

@ -5,12 +5,14 @@ import (
"log/slog"
"os"
"path/filepath"
"github.com/ollama/ollama/server/envconfig"
)
func InitLogging() {
level := slog.LevelInfo
if debug := os.Getenv("OLLAMA_DEBUG"); debug != "" {
if envconfig.Debug {
level = slog.LevelDebug
}

View file

@ -43,37 +43,36 @@ func getCLIFullPath(command string) string {
return command
}
func SpawnServer(ctx context.Context, command string) (chan int, error) {
done := make(chan int)
logDir := filepath.Dir(ServerLogFile)
_, err := os.Stat(logDir)
if errors.Is(err, os.ErrNotExist) {
if err := os.MkdirAll(logDir, 0o755); err != nil {
return done, fmt.Errorf("create ollama server log dir %s: %v", logDir, err)
}
}
func start(ctx context.Context, command string) (*exec.Cmd, error) {
cmd := getCmd(ctx, getCLIFullPath(command))
// send stdout and stderr to a file
stdout, err := cmd.StdoutPipe()
if err != nil {
return done, fmt.Errorf("failed to spawn server stdout pipe %s", err)
return nil, fmt.Errorf("failed to spawn server stdout pipe: %w", err)
}
stderr, err := cmd.StderrPipe()
if err != nil {
return done, fmt.Errorf("failed to spawn server stderr pipe %s", err)
}
stdin, err := cmd.StdinPipe()
if err != nil {
return done, fmt.Errorf("failed to spawn server stdin pipe %s", err)
return nil, fmt.Errorf("failed to spawn server stderr pipe: %w", err)
}
// TODO - rotation
logFile, err := os.OpenFile(ServerLogFile, os.O_APPEND|os.O_WRONLY|os.O_CREATE, 0755)
if err != nil {
return done, fmt.Errorf("failed to create server log %w", err)
return nil, fmt.Errorf("failed to create server log: %w", err)
}
logDir := filepath.Dir(ServerLogFile)
_, err = os.Stat(logDir)
if err != nil {
if !errors.Is(err, os.ErrNotExist) {
return nil, fmt.Errorf("stat ollama server log dir %s: %v", logDir, err)
}
if err := os.MkdirAll(logDir, 0o755); err != nil {
return nil, fmt.Errorf("create ollama server log dir %s: %v", logDir, err)
}
}
go func() {
defer logFile.Close()
io.Copy(logFile, stdout) //nolint:errcheck
@ -117,19 +116,33 @@ func SpawnServer(ctx context.Context, command string) (chan int, error) {
// run the command and wait for it to finish
if err := cmd.Start(); err != nil {
return done, fmt.Errorf("failed to start server %w", err)
return nil, fmt.Errorf("failed to start server %w", err)
}
if cmd.Process != nil {
slog.Info(fmt.Sprintf("started ollama server with pid %d", cmd.Process.Pid))
}
slog.Info(fmt.Sprintf("ollama server logs %s", ServerLogFile))
return cmd, nil
}
func SpawnServer(ctx context.Context, command string) (chan int, error) {
done := make(chan int)
go func() {
// Keep the server running unless we're shuttind down the app
crashCount := 0
for {
slog.Info("starting server...")
cmd, err := start(ctx, command)
if err != nil {
crashCount++
slog.Error(fmt.Sprintf("failed to start server %s", err))
time.Sleep(500 * time.Millisecond * time.Duration(crashCount))
continue
}
cmd.Wait() //nolint:errcheck
stdin.Close()
var code int
if cmd.ProcessState != nil {
code = cmd.ProcessState.ExitCode()
@ -143,15 +156,12 @@ func SpawnServer(ctx context.Context, command string) (chan int, error) {
default:
crashCount++
slog.Warn(fmt.Sprintf("server crash %d - exit code %d - respawning", crashCount, code))
time.Sleep(500 * time.Millisecond)
if err := cmd.Start(); err != nil {
slog.Error(fmt.Sprintf("failed to restart server %s", err))
// Keep trying, but back off if we keep failing
time.Sleep(time.Duration(crashCount) * time.Second)
}
time.Sleep(500 * time.Millisecond * time.Duration(crashCount))
break
}
}
}()
return done, nil
}

View file

@ -31,16 +31,13 @@ func DoUpgrade(cancel context.CancelFunc, done chan int) error {
"/LOG=" + filepath.Base(UpgradeLogFile), // Only relative seems reliable, so set pwd
"/FORCECLOSEAPPLICATIONS", // Force close the tray app - might be needed
}
// When we're not in debug mode, make the upgrade as quiet as possible (no GUI, no prompts)
// TODO - temporarily disable since we're pinning in debug mode for the preview
// if debug := os.Getenv("OLLAMA_DEBUG"); debug == "" {
// make the upgrade as quiet as possible (no GUI, no prompts)
installArgs = append(installArgs,
"/SP", // Skip the "This will install... Do you wish to continue" prompt
"/SUPPRESSMSGBOXES",
"/SILENT",
"/VERYSILENT",
)
// }
// Safeguard in case we have requests in flight that need to drain...
slog.Info("Waiting for server to shutdown")

View file

@ -88,15 +88,12 @@ DialogFontSize=12
[Files]
Source: ".\app.exe"; DestDir: "{app}"; DestName: "{#MyAppExeName}" ; Flags: ignoreversion 64bit
Source: "..\ollama.exe"; DestDir: "{app}"; Flags: ignoreversion 64bit
Source: "..\dist\windeps\*.dll"; DestDir: "{app}"; Flags: ignoreversion 64bit
Source: "..\dist\windows-{#ARCH}\*.dll"; DestDir: "{app}"; Flags: ignoreversion 64bit
Source: "..\dist\windows-{#ARCH}\ollama_runners\*"; DestDir: "{app}\ollama_runners"; Flags: ignoreversion 64bit recursesubdirs
Source: "..\dist\ollama_welcome.ps1"; DestDir: "{app}"; Flags: ignoreversion
Source: ".\assets\app.ico"; DestDir: "{app}"; Flags: ignoreversion
; Assumes v5.7, may need adjustments for v6
#if GetEnv("HIP_PATH") != ""
Source: "{#GetEnv('HIP_PATH')}\bin\hipblas.dll"; DestDir: "{app}\rocm\"; Flags: ignoreversion
Source: "{#GetEnv('HIP_PATH')}\bin\rocblas.dll"; DestDir: "{app}\rocm\"; Flags: ignoreversion
; amdhip64.dll dependency comes from the driver and must be installed already
Source: "{#GetEnv('HIP_PATH')}\bin\rocblas\library\*"; DestDir: "{app}\rocm\rocblas\library\"; Flags: ignoreversion
#if DirExists("..\dist\windows-amd64\rocm")
Source: "..\dist\windows-amd64\rocm\*"; DestDir: "{app}\rocm\"; Flags: ignoreversion recursesubdirs
#endif
@ -132,7 +129,7 @@ SetupAppRunningError=Another Ollama installer is running.%n%nPlease cancel or fi
;FinishedHeadingLabel=Run your first model
;FinishedLabel=%nRun this command in a PowerShell or cmd terminal.%n%n%n ollama run llama2
;FinishedLabel=%nRun this command in a PowerShell or cmd terminal.%n%n%n ollama run llama3
;ClickFinish=%n
[Registry]

View file

@ -1,71 +1,71 @@
//go:build windows
package wintray
import (
"fmt"
"log/slog"
"unsafe"
"golang.org/x/sys/windows"
)
const (
updatAvailableMenuID = 1
updateMenuID = updatAvailableMenuID + 1
separatorMenuID = updateMenuID + 1
diagLogsMenuID = separatorMenuID + 1
diagSeparatorMenuID = diagLogsMenuID + 1
quitMenuID = diagSeparatorMenuID + 1
)
func (t *winTray) initMenus() error {
if err := t.addOrUpdateMenuItem(diagLogsMenuID, 0, diagLogsMenuTitle, false); err != nil {
return fmt.Errorf("unable to create menu entries %w\n", err)
}
if err := t.addSeparatorMenuItem(diagSeparatorMenuID, 0); err != nil {
return fmt.Errorf("unable to create menu entries %w", err)
}
if err := t.addOrUpdateMenuItem(quitMenuID, 0, quitMenuTitle, false); err != nil {
return fmt.Errorf("unable to create menu entries %w\n", err)
}
return nil
}
func (t *winTray) UpdateAvailable(ver string) error {
if !t.updateNotified {
slog.Debug("updating menu and sending notification for new update")
if err := t.addOrUpdateMenuItem(updatAvailableMenuID, 0, updateAvailableMenuTitle, true); err != nil {
return fmt.Errorf("unable to create menu entries %w", err)
}
if err := t.addOrUpdateMenuItem(updateMenuID, 0, updateMenutTitle, false); err != nil {
return fmt.Errorf("unable to create menu entries %w", err)
}
if err := t.addSeparatorMenuItem(separatorMenuID, 0); err != nil {
return fmt.Errorf("unable to create menu entries %w", err)
}
iconFilePath, err := iconBytesToFilePath(wt.updateIcon)
if err != nil {
return fmt.Errorf("unable to write icon data to temp file: %w", err)
}
if err := wt.setIcon(iconFilePath); err != nil {
return fmt.Errorf("unable to set icon: %w", err)
}
t.updateNotified = true
t.pendingUpdate = true
// Now pop up the notification
t.muNID.Lock()
defer t.muNID.Unlock()
copy(t.nid.InfoTitle[:], windows.StringToUTF16(updateTitle))
copy(t.nid.Info[:], windows.StringToUTF16(fmt.Sprintf(updateMessage, ver)))
t.nid.Flags |= NIF_INFO
t.nid.Timeout = 10
t.nid.Size = uint32(unsafe.Sizeof(*wt.nid))
err = t.nid.modify()
if err != nil {
return err
}
}
return nil
}
//go:build windows
package wintray
import (
"fmt"
"log/slog"
"unsafe"
"golang.org/x/sys/windows"
)
const (
updatAvailableMenuID = 1
updateMenuID = updatAvailableMenuID + 1
separatorMenuID = updateMenuID + 1
diagLogsMenuID = separatorMenuID + 1
diagSeparatorMenuID = diagLogsMenuID + 1
quitMenuID = diagSeparatorMenuID + 1
)
func (t *winTray) initMenus() error {
if err := t.addOrUpdateMenuItem(diagLogsMenuID, 0, diagLogsMenuTitle, false); err != nil {
return fmt.Errorf("unable to create menu entries %w\n", err)
}
if err := t.addSeparatorMenuItem(diagSeparatorMenuID, 0); err != nil {
return fmt.Errorf("unable to create menu entries %w", err)
}
if err := t.addOrUpdateMenuItem(quitMenuID, 0, quitMenuTitle, false); err != nil {
return fmt.Errorf("unable to create menu entries %w\n", err)
}
return nil
}
func (t *winTray) UpdateAvailable(ver string) error {
if !t.updateNotified {
slog.Debug("updating menu and sending notification for new update")
if err := t.addOrUpdateMenuItem(updatAvailableMenuID, 0, updateAvailableMenuTitle, true); err != nil {
return fmt.Errorf("unable to create menu entries %w", err)
}
if err := t.addOrUpdateMenuItem(updateMenuID, 0, updateMenutTitle, false); err != nil {
return fmt.Errorf("unable to create menu entries %w", err)
}
if err := t.addSeparatorMenuItem(separatorMenuID, 0); err != nil {
return fmt.Errorf("unable to create menu entries %w", err)
}
iconFilePath, err := iconBytesToFilePath(wt.updateIcon)
if err != nil {
return fmt.Errorf("unable to write icon data to temp file: %w", err)
}
if err := wt.setIcon(iconFilePath); err != nil {
return fmt.Errorf("unable to set icon: %w", err)
}
t.updateNotified = true
t.pendingUpdate = true
// Now pop up the notification
t.muNID.Lock()
defer t.muNID.Unlock()
copy(t.nid.InfoTitle[:], windows.StringToUTF16(updateTitle))
copy(t.nid.Info[:], windows.StringToUTF16(fmt.Sprintf(updateMessage, ver)))
t.nid.Flags |= NIF_INFO
t.nid.Timeout = 10
t.nid.Size = uint32(unsafe.Sizeof(*wt.nid))
err = t.nid.modify()
if err != nil {
return err
}
}
return nil
}

View file

@ -10,12 +10,44 @@ import (
"log/slog"
"os"
"path/filepath"
"strings"
"golang.org/x/crypto/ssh"
)
const defaultPrivateKey = "id_ed25519"
func keyPath() (string, error) {
home, err := os.UserHomeDir()
if err != nil {
return "", err
}
return filepath.Join(home, ".ollama", defaultPrivateKey), nil
}
func GetPublicKey() (string, error) {
keyPath, err := keyPath()
if err != nil {
return "", err
}
privateKeyFile, err := os.ReadFile(keyPath)
if err != nil {
slog.Info(fmt.Sprintf("Failed to load private key: %v", err))
return "", err
}
privateKey, err := ssh.ParsePrivateKey(privateKeyFile)
if err != nil {
return "", err
}
publicKey := ssh.MarshalAuthorizedKey(privateKey.PublicKey())
return strings.TrimSpace(string(publicKey)), nil
}
func NewNonce(r io.Reader, length int) (string, error) {
nonce := make([]byte, length)
if _, err := io.ReadFull(r, nonce); err != nil {
@ -26,13 +58,11 @@ func NewNonce(r io.Reader, length int) (string, error) {
}
func Sign(ctx context.Context, bts []byte) (string, error) {
home, err := os.UserHomeDir()
keyPath, err := keyPath()
if err != nil {
return "", err
}
keyPath := filepath.Join(home, ".ollama", defaultPrivateKey)
privateKeyFile, err := os.ReadFile(keyPath)
if err != nil {
slog.Info(fmt.Sprintf("Failed to load private key: %v", err))

View file

@ -17,6 +17,7 @@ import (
"os"
"os/signal"
"path/filepath"
"regexp"
"runtime"
"strings"
"syscall"
@ -31,10 +32,12 @@ import (
"golang.org/x/term"
"github.com/ollama/ollama/api"
"github.com/ollama/ollama/auth"
"github.com/ollama/ollama/format"
"github.com/ollama/ollama/parser"
"github.com/ollama/ollama/progress"
"github.com/ollama/ollama/server"
"github.com/ollama/ollama/types/errtypes"
"github.com/ollama/ollama/types/model"
"github.com/ollama/ollama/version"
)
@ -53,14 +56,13 @@ func CreateHandler(cmd *cobra.Command, args []string) error {
p := progress.NewProgress(os.Stderr)
defer p.Stop()
bars := make(map[string]*progress.Bar)
modelfile, err := os.ReadFile(filename)
f, err := os.Open(filename)
if err != nil {
return err
}
defer f.Close()
commands, err := parser.Parse(bytes.NewReader(modelfile))
modelfile, err := model.ParseFile(f)
if err != nil {
return err
}
@ -74,10 +76,10 @@ func CreateHandler(cmd *cobra.Command, args []string) error {
spinner := progress.NewSpinner(status)
p.Add(status, spinner)
for _, c := range commands {
switch c.Name {
for i := range modelfile.Commands {
switch modelfile.Commands[i].Name {
case "model", "adapter":
path := c.Args
path := modelfile.Commands[i].Args
if path == "~" {
path = home
} else if strings.HasPrefix(path, "~/") {
@ -89,101 +91,22 @@ func CreateHandler(cmd *cobra.Command, args []string) error {
}
fi, err := os.Stat(path)
if errors.Is(err, os.ErrNotExist) && c.Name == "model" {
if errors.Is(err, os.ErrNotExist) && modelfile.Commands[i].Name == "model" {
continue
} else if err != nil {
return err
}
// TODO make this work w/ adapters
if fi.IsDir() {
tf, err := os.CreateTemp("", "ollama-tf")
// this is likely a safetensors or pytorch directory
// TODO make this work w/ adapters
tempfile, err := tempZipFiles(path)
if err != nil {
return err
}
defer os.RemoveAll(tf.Name())
defer os.RemoveAll(tempfile)
zf := zip.NewWriter(tf)
files := []string{}
tfiles, err := filepath.Glob(filepath.Join(path, "pytorch_model-*.bin"))
if err != nil {
return err
} else if len(tfiles) == 0 {
tfiles, err = filepath.Glob(filepath.Join(path, "model-*.safetensors"))
if err != nil {
return err
}
}
files = append(files, tfiles...)
if len(files) == 0 {
return fmt.Errorf("no models were found in '%s'", path)
}
// add the safetensor/torch config file + tokenizer
files = append(files, filepath.Join(path, "config.json"))
files = append(files, filepath.Join(path, "params.json"))
files = append(files, filepath.Join(path, "added_tokens.json"))
files = append(files, filepath.Join(path, "tokenizer.model"))
for _, fn := range files {
f, err := os.Open(fn)
// just skip whatever files aren't there
if os.IsNotExist(err) {
if strings.HasSuffix(fn, "tokenizer.model") {
// try the parent dir before giving up
parentDir := filepath.Dir(path)
newFn := filepath.Join(parentDir, "tokenizer.model")
f, err = os.Open(newFn)
if os.IsNotExist(err) {
continue
} else if err != nil {
return err
}
} else {
continue
}
} else if err != nil {
return err
}
fi, err := f.Stat()
if err != nil {
return err
}
h, err := zip.FileInfoHeader(fi)
if err != nil {
return err
}
h.Name = filepath.Base(fn)
h.Method = zip.Store
w, err := zf.CreateHeader(h)
if err != nil {
return err
}
_, err = io.Copy(w, f)
if err != nil {
return err
}
}
if err := zf.Close(); err != nil {
return err
}
if err := tf.Close(); err != nil {
return err
}
path = tf.Name()
path = tempfile
}
digest, err := createBlob(cmd, client, path)
@ -191,10 +114,11 @@ func CreateHandler(cmd *cobra.Command, args []string) error {
return err
}
modelfile = bytes.ReplaceAll(modelfile, []byte(c.Args), []byte("@"+digest))
modelfile.Commands[i].Args = "@" + digest
}
}
bars := make(map[string]*progress.Bar)
fn := func(resp api.ProgressResponse) error {
if resp.Digest != "" {
spinner.Stop()
@ -220,7 +144,7 @@ func CreateHandler(cmd *cobra.Command, args []string) error {
quantization, _ := cmd.Flags().GetString("quantization")
request := api.CreateRequest{Name: args[0], Modelfile: string(modelfile), Quantization: quantization}
request := api.CreateRequest{Name: args[0], Modelfile: modelfile.String(), Quantization: quantization}
if err := client.Create(cmd.Context(), &request, fn); err != nil {
return err
}
@ -228,6 +152,114 @@ func CreateHandler(cmd *cobra.Command, args []string) error {
return nil
}
func tempZipFiles(path string) (string, error) {
tempfile, err := os.CreateTemp("", "ollama-tf")
if err != nil {
return "", err
}
defer tempfile.Close()
zipfile := zip.NewWriter(tempfile)
defer zipfile.Close()
detectContentType := func(path string) (string, error) {
f, err := os.Open(path)
if err != nil {
return "", err
}
defer f.Close()
var b bytes.Buffer
b.Grow(512)
if _, err := io.CopyN(&b, f, 512); err != nil && !errors.Is(err, io.EOF) {
return "", err
}
contentType, _, _ := strings.Cut(http.DetectContentType(b.Bytes()), ";")
return contentType, nil
}
glob := func(pattern, contentType string) ([]string, error) {
matches, err := filepath.Glob(pattern)
if err != nil {
return nil, err
}
for _, safetensor := range matches {
if ct, err := detectContentType(safetensor); err != nil {
return nil, err
} else if ct != contentType {
return nil, fmt.Errorf("invalid content type: expected %s for %s", ct, safetensor)
}
}
return matches, nil
}
var files []string
if st, _ := glob(filepath.Join(path, "model*.safetensors"), "application/octet-stream"); len(st) > 0 {
// safetensors files might be unresolved git lfs references; skip if they are
// covers model-x-of-y.safetensors, model.fp32-x-of-y.safetensors, model.safetensors
files = append(files, st...)
} else if pt, _ := glob(filepath.Join(path, "pytorch_model*.bin"), "application/zip"); len(pt) > 0 {
// pytorch files might also be unresolved git lfs references; skip if they are
// covers pytorch_model-x-of-y.bin, pytorch_model.fp32-x-of-y.bin, pytorch_model.bin
files = append(files, pt...)
} else if pt, _ := glob(filepath.Join(path, "consolidated*.pth"), "application/octet-stream"); len(pt) > 0 {
// pytorch files might also be unresolved git lfs references; skip if they are
// covers consolidated.x.pth, consolidated.pth
files = append(files, pt...)
} else {
return "", errors.New("no safetensors or torch files found")
}
// add configuration files, json files are detected as text/plain
js, err := glob(filepath.Join(path, "*.json"), "text/plain")
if err != nil {
return "", err
}
files = append(files, js...)
if tks, _ := glob(filepath.Join(path, "tokenizer.model"), "application/octet-stream"); len(tks) > 0 {
// add tokenizer.model if it exists, tokenizer.json is automatically picked up by the previous glob
// tokenizer.model might be a unresolved git lfs reference; error if it is
files = append(files, tks...)
} else if tks, _ := glob(filepath.Join(path, "**/tokenizer.model"), "text/plain"); len(tks) > 0 {
// some times tokenizer.model is in a subdirectory (e.g. meta-llama/Meta-Llama-3-8B)
files = append(files, tks...)
}
for _, file := range files {
f, err := os.Open(file)
if err != nil {
return "", err
}
defer f.Close()
fi, err := f.Stat()
if err != nil {
return "", err
}
zfi, err := zip.FileInfoHeader(fi)
if err != nil {
return "", err
}
zf, err := zipfile.CreateHeader(zfi)
if err != nil {
return "", err
}
if _, err := io.Copy(zf, f); err != nil {
return "", err
}
}
return tempfile.Name(), nil
}
func createBlob(cmd *cobra.Command, client *api.Client, path string) (string, error) {
bin, err := os.Open(path)
if err != nil {
@ -322,6 +354,47 @@ func RunHandler(cmd *cobra.Command, args []string) error {
return generateInteractive(cmd, opts)
}
func errFromUnknownKey(unknownKeyErr error) error {
// find SSH public key in the error message
sshKeyPattern := `ssh-\w+ [^\s"]+`
re := regexp.MustCompile(sshKeyPattern)
matches := re.FindStringSubmatch(unknownKeyErr.Error())
if len(matches) > 0 {
serverPubKey := matches[0]
localPubKey, err := auth.GetPublicKey()
if err != nil {
return unknownKeyErr
}
if runtime.GOOS == "linux" && serverPubKey != localPubKey {
// try the ollama service public key
svcPubKey, err := os.ReadFile("/usr/share/ollama/.ollama/id_ed25519.pub")
if err != nil {
return unknownKeyErr
}
localPubKey = strings.TrimSpace(string(svcPubKey))
}
// check if the returned public key matches the local public key, this prevents adding a remote key to the user's account
if serverPubKey != localPubKey {
return unknownKeyErr
}
var msg strings.Builder
msg.WriteString(unknownKeyErr.Error())
msg.WriteString("\n\nYour ollama key is:\n")
msg.WriteString(localPubKey)
msg.WriteString("\nAdd your key at:\n")
msg.WriteString("https://ollama.com/settings/keys")
return errors.New(msg.String())
}
return unknownKeyErr
}
func PushHandler(cmd *cobra.Command, args []string) error {
client, err := api.ClientFromEnvironment()
if err != nil {
@ -369,6 +442,20 @@ func PushHandler(cmd *cobra.Command, args []string) error {
request := api.PushRequest{Name: args[0], Insecure: insecure}
if err := client.Push(cmd.Context(), &request, fn); err != nil {
if spinner != nil {
spinner.Stop()
}
if strings.Contains(err.Error(), "access denied") {
return errors.New("you are not authorized to push to this namespace, create the model under a namespace you own")
}
host := model.ParseName(args[0]).Host
isOllamaHost := strings.HasSuffix(host, ".ollama.ai") || strings.HasSuffix(host, ".ollama.com")
if strings.Contains(err.Error(), errtypes.UnknownOllamaKeyErrMsg) && isOllamaHost {
// the user has not added their ollama key to ollama.com
// re-throw an error with a more user-friendly message
return errFromUnknownKey(err)
}
return err
}
@ -796,24 +883,27 @@ func generate(cmd *cobra.Command, opts runOptions) error {
}
func RunServer(cmd *cobra.Command, _ []string) error {
host, port, err := net.SplitHostPort(strings.Trim(os.Getenv("OLLAMA_HOST"), "\"'"))
// retrieve the OLLAMA_HOST environment variable
ollamaHost, err := api.GetOllamaHost()
if err != nil {
host, port = "127.0.0.1", "11434"
if ip := net.ParseIP(strings.Trim(os.Getenv("OLLAMA_HOST"), "[]")); ip != nil {
host = ip.String()
}
return err
}
if err := initializeKeypair(); err != nil {
return err
}
ln, err := net.Listen("tcp", net.JoinHostPort(host, port))
ln, err := net.Listen("tcp", net.JoinHostPort(ollamaHost.Host, ollamaHost.Port))
if err != nil {
return err
}
return server.Serve(ln)
err = server.Serve(ln)
if errors.Is(err, http.ErrServerClosed) {
return nil
}
return err
}
func initializeKeypair() error {
@ -1034,7 +1124,7 @@ Environment Variables:
RunE: ListHandler,
}
copyCmd := &cobra.Command{
Use: "cp SOURCE TARGET",
Use: "cp SOURCE DESTINATION",
Short: "Copy a model",
Args: cobra.ExactArgs(2),
PreRunE: checkServerHeartbeat,

View file

@ -94,6 +94,7 @@ func generateInteractive(cmd *cobra.Command, opts runOptions) error {
fmt.Fprintln(os.Stderr, " /show Show model information")
fmt.Fprintln(os.Stderr, " /load <model> Load a session or model")
fmt.Fprintln(os.Stderr, " /save <model> Save your current session")
fmt.Fprintln(os.Stderr, " /clear Clear session context")
fmt.Fprintln(os.Stderr, " /bye Exit")
fmt.Fprintln(os.Stderr, " /?, /help Help for a command")
fmt.Fprintln(os.Stderr, " /? shortcuts Help for keyboard shortcuts")
@ -161,7 +162,7 @@ func generateInteractive(cmd *cobra.Command, opts runOptions) error {
fmt.Fprintln(os.Stderr, " /set parameter repeat_penalty <float> How strongly to penalize repetitions")
fmt.Fprintln(os.Stderr, " /set parameter repeat_last_n <int> Set how far back to look for repetitions")
fmt.Fprintln(os.Stderr, " /set parameter num_gpu <int> The number of layers to send to the GPU")
fmt.Fprintln(os.Stderr, " /set parameter stop \"<string>\", ... Set the stop parameters")
fmt.Fprintln(os.Stderr, " /set parameter stop <string> <string> ... Set the stop parameters")
fmt.Fprintln(os.Stderr, "")
}
@ -280,6 +281,10 @@ func generateInteractive(cmd *cobra.Command, opts runOptions) error {
}
fmt.Printf("Created new model '%s'\n", args[1])
continue
case strings.HasPrefix(line, "/clear"):
opts.Messages = []api.Message{}
fmt.Println("Cleared session context")
continue
case strings.HasPrefix(line, "/set"):
args := strings.Fields(line)
if len(args) > 1 {

View file

@ -5,6 +5,7 @@ import (
"encoding/binary"
"encoding/json"
"fmt"
"io"
"log/slog"
"os"
"path/filepath"
@ -18,19 +19,23 @@ import (
)
type Params struct {
Architectures []string `json:"architectures"`
VocabSize int `json:"vocab_size"`
HiddenSize int `json:"hidden_size"` // n_embd
HiddenLayers int `json:"num_hidden_layers"` // n_layer
ContextSize int `json:"max_position_embeddings"`
IntermediateSize int `json:"intermediate_size"`
AttentionHeads int `json:"num_attention_heads"` // n_head
KeyValHeads int `json:"num_key_value_heads"`
NormEPS float64 `json:"rms_norm_eps"`
BoSTokenID int `json:"bos_token_id"`
EoSTokenID int `json:"eos_token_id"`
HeadDimension int `json:"head_dim"`
PaddingTokenID int `json:"pad_token_id"`
Architectures []string `json:"architectures"`
VocabSize int `json:"vocab_size"`
HiddenSize int `json:"hidden_size"` // n_embd
HiddenLayers int `json:"num_hidden_layers"` // n_layer
ContextSize int `json:"max_position_embeddings"`
IntermediateSize int `json:"intermediate_size"`
AttentionHeads int `json:"num_attention_heads"` // n_head
KeyValHeads int `json:"num_key_value_heads"`
NormEPS float64 `json:"rms_norm_eps"`
BoSTokenID int `json:"bos_token_id"`
EoSTokenID int `json:"eos_token_id"`
HeadDimension int `json:"head_dim"`
PaddingTokenID int `json:"pad_token_id"`
RopeFrequencyBase float64 `json:"rope_theta"`
Experts int `json:"num_local_experts"`
ExpertsUsed int `json:"num_experts_per_tok"`
ByteOrder
}
@ -43,7 +48,7 @@ type ByteOrder interface {
type ModelArch interface {
GetTensors() error
LoadVocab() error
WriteGGUF() (string, error)
WriteGGUF(io.WriteSeeker) error
}
type ModelFormat interface {

View file

@ -94,7 +94,7 @@ func (m *GemmaModel) LoadVocab() error {
return nil
}
func (m *GemmaModel) WriteGGUF() (string, error) {
func (m *GemmaModel) WriteGGUF(ws io.WriteSeeker) error {
kv := llm.KV{
"general.architecture": "gemma",
"general.name": m.Name,
@ -122,16 +122,5 @@ func (m *GemmaModel) WriteGGUF() (string, error) {
"tokenizer.ggml.add_eos_token": false,
}
f, err := os.CreateTemp("", "ollama-gguf")
if err != nil {
return "", err
}
defer f.Close()
mod := llm.NewGGUFV3(m.Params.ByteOrder)
if err := mod.Encode(f, kv, m.Tensors); err != nil {
return "", err
}
return f.Name(), nil
return llm.NewGGUFV3(m.Params.ByteOrder).Encode(ws, kv, m.Tensors)
}

View file

@ -5,7 +5,6 @@ import (
"fmt"
"io"
"log/slog"
"os"
"regexp"
"strings"
@ -132,7 +131,7 @@ func (m *LlamaModel) LoadVocab() error {
return nil
}
func (m *LlamaModel) WriteGGUF() (string, error) {
func (m *LlamaModel) WriteGGUF(ws io.WriteSeeker) error {
kv := llm.KV{
"general.architecture": "llama",
"general.name": m.Name,
@ -159,18 +158,5 @@ func (m *LlamaModel) WriteGGUF() (string, error) {
"tokenizer.ggml.add_eos_token": false,
}
f, err := os.CreateTemp("", "ollama-gguf")
if err != nil {
return "", err
}
defer f.Close()
mod := llm.NewGGUFV3(m.Params.ByteOrder)
if err := mod.Encode(f, kv, m.Tensors); err != nil {
return "", err
}
slog.Debug(fmt.Sprintf("gguf file = %s", f.Name()))
return f.Name(), nil
return llm.NewGGUFV3(m.Params.ByteOrder).Encode(ws, kv, m.Tensors)
}

View file

@ -132,7 +132,7 @@ func (m *MistralModel) LoadVocab() error {
return nil
}
func (m *MistralModel) WriteGGUF() (string, error) {
func (m *MistralModel) WriteGGUF(ws io.WriteSeeker) error {
kv := llm.KV{
"general.architecture": "llama",
"general.name": m.Name,
@ -158,16 +158,5 @@ func (m *MistralModel) WriteGGUF() (string, error) {
"tokenizer.ggml.unknown_token_id": uint32(0),
}
f, err := os.CreateTemp("", "ollama-gguf")
if err != nil {
return "", err
}
defer f.Close()
mod := llm.NewGGUFV3(m.Params.ByteOrder)
if err := mod.Encode(f, kv, m.Tensors); err != nil {
return "", err
}
return f.Name(), nil
return llm.NewGGUFV3(m.Params.ByteOrder).Encode(ws, kv, m.Tensors)
}

85
convert/mixtral.go Normal file
View file

@ -0,0 +1,85 @@
package convert
import (
"io"
"regexp"
"github.com/ollama/ollama/llm"
)
type MixtralModel struct {
ModelData
}
func (m *MixtralModel) GetTensors() error {
t, err := m.Format.GetTensors(m.Path, m.Params)
if err != nil {
return err
}
m.Tensors = []llm.Tensor{}
pattern := `^blk\.[0-9]+\.attn_(?P<layer>q|k)\.weight$`
re, err := regexp.Compile(pattern)
if err != nil {
return err
}
for _, l := range t {
matches := re.FindAllStringSubmatch(l.Name, -1)
if len(matches) > 0 {
wt := l.WriterTo.(safetensorWriterTo)
wt.handler = mistralLayerHandler
l.WriterTo = wt
}
m.Tensors = append(m.Tensors, l)
}
return nil
}
func (m *MixtralModel) LoadVocab() error {
v, err := LoadSentencePieceTokens(m.Path, m.Params)
if err != nil {
return err
}
m.Vocab = v
return nil
}
func (m *MixtralModel) WriteGGUF(ws io.WriteSeeker) error {
kv := llm.KV{
"general.architecture": "llama",
"general.name": m.Name,
"llama.block_count": uint32(m.Params.HiddenLayers),
"llama.context_length": uint32(m.Params.ContextSize),
"llama.embedding_length": uint32(m.Params.HiddenSize),
"llama.feed_forward_length": uint32(m.Params.IntermediateSize),
"llama.attention.head_count": uint32(m.Params.AttentionHeads),
"llama.attention.head_count_kv": uint32(m.Params.KeyValHeads),
"llama.rope.freq_base": float32(m.Params.RopeFrequencyBase),
"llama.attention.layer_norm_rms_epsilon": float32(m.Params.NormEPS),
"llama.expert_count": uint32(m.Params.Experts),
"llama.expert_used_count": uint32(m.Params.ExpertsUsed),
"llama.vocab_size": uint32(len(m.Vocab.Tokens)),
"llama.rope.dimension_count": uint32(m.Params.HiddenSize / m.Params.AttentionHeads),
"general.file_type": uint32(1),
"tokenizer.ggml.model": "llama",
"tokenizer.ggml.tokens": m.Vocab.Tokens,
"tokenizer.ggml.scores": m.Vocab.Scores,
"tokenizer.ggml.token_type": m.Vocab.Types,
"tokenizer.ggml.bos_token_id": uint32(m.Params.BoSTokenID),
"tokenizer.ggml.eos_token_id": uint32(m.Params.EoSTokenID),
"tokenizer.ggml.unknown_token_id": uint32(0),
"tokenizer.ggml.add_bos_token": true,
"tokenizer.ggml.add_eos_token": false,
}
return llm.NewGGUFV3(m.Params.ByteOrder).Encode(ws, kv, m.Tensors)
}

View file

@ -53,7 +53,7 @@ func (m *SafetensorFormat) GetTensors(dirpath string, params *Params) ([]llm.Ten
var err error
t, offset, err = m.readTensors(f, offset, params)
if err != nil {
slog.Error("%v", err)
slog.Error(err.Error())
return nil, err
}
tensors = append(tensors, t...)
@ -93,7 +93,6 @@ func (m *SafetensorFormat) readTensors(fn string, offset uint64, params *Params)
}
slices.Sort(keys)
slog.Info("converting layers")
var tensors []llm.Tensor
@ -105,7 +104,6 @@ func (m *SafetensorFormat) readTensors(fn string, offset uint64, params *Params)
return nil, 0, err
}
slog.Debug(fmt.Sprintf("metadata = %#v", data))
var size uint64
var kind uint32
switch len(data.Shape) {
@ -124,7 +122,7 @@ func (m *SafetensorFormat) readTensors(fn string, offset uint64, params *Params)
ggufName, err := m.GetLayerName(k)
if err != nil {
slog.Error("%v", err)
slog.Error(err.Error())
return nil, 0, err
}
@ -150,11 +148,13 @@ func (m *SafetensorFormat) readTensors(fn string, offset uint64, params *Params)
padding: 8 + jsonSize,
}
tensors = append(tensors, t)
offset += size
tensors = append(tensors, t)
}
slog.Debug(fmt.Sprintf("total tensors for file = %d", len(tensors)))
slog.Debug(fmt.Sprintf("offset = %d", offset))
return tensors, offset, nil
}
@ -185,15 +185,19 @@ func (m *SafetensorFormat) GetLayerName(n string) (string, error) {
}
tMap := map[string]string{
"model.layers.(\\d+).input_layernorm.weight": "blk.$1.attn_norm.weight",
"model.layers.(\\d+).mlp.down_proj.weight": "blk.$1.ffn_down.weight",
"model.layers.(\\d+).mlp.gate_proj.weight": "blk.$1.ffn_gate.weight",
"model.layers.(\\d+).mlp.up_proj.weight": "blk.$1.ffn_up.weight",
"model.layers.(\\d+).post_attention_layernorm.weight": "blk.$1.ffn_norm.weight",
"model.layers.(\\d+).self_attn.k_proj.weight": "blk.$1.attn_k.weight",
"model.layers.(\\d+).self_attn.o_proj.weight": "blk.$1.attn_output.weight",
"model.layers.(\\d+).self_attn.q_proj.weight": "blk.$1.attn_q.weight",
"model.layers.(\\d+).self_attn.v_proj.weight": "blk.$1.attn_v.weight",
"model.layers.(\\d+).input_layernorm.weight": "blk.$1.attn_norm.weight",
"model.layers.(\\d+).mlp.down_proj.weight": "blk.$1.ffn_down.weight",
"model.layers.(\\d+).mlp.gate_proj.weight": "blk.$1.ffn_gate.weight",
"model.layers.(\\d+).mlp.up_proj.weight": "blk.$1.ffn_up.weight",
"model.layers.(\\d+).post_attention_layernorm.weight": "blk.$1.ffn_norm.weight",
"model.layers.(\\d+).self_attn.k_proj.weight": "blk.$1.attn_k.weight",
"model.layers.(\\d+).self_attn.o_proj.weight": "blk.$1.attn_output.weight",
"model.layers.(\\d+).self_attn.q_proj.weight": "blk.$1.attn_q.weight",
"model.layers.(\\d+).self_attn.v_proj.weight": "blk.$1.attn_v.weight",
"model.layers.(\\d+).block_sparse_moe.gate.weight": "blk.$1.ffn_gate_inp.weight",
"model.layers.(\\d+).block_sparse_moe.experts.(\\d+).w1.weight": "blk.$1.ffn_gate.$2.weight",
"model.layers.(\\d+).block_sparse_moe.experts.(\\d+).w2.weight": "blk.$1.ffn_down.$2.weight",
"model.layers.(\\d+).block_sparse_moe.experts.(\\d+).w3.weight": "blk.$1.ffn_up.$2.weight",
}
v, ok := directMap[n]
@ -286,6 +290,15 @@ func (m *SafetensorFormat) GetModelArch(name, dirPath string, params *Params) (M
Format: m,
},
}, nil
case "MixtralForCausalLM":
return &MixtralModel{
ModelData{
Name: name,
Path: dirPath,
Params: params,
Format: m,
},
}, nil
case "GemmaForCausalLM":
return &GemmaModel{
ModelData{

View file

@ -74,7 +74,7 @@ func (tf *TorchFormat) GetTensors(dirpath string, params *Params) ([]llm.Tensor,
ggufName, err := tf.GetLayerName(k.(string))
if err != nil {
slog.Error("%v", err)
slog.Error(err.Error())
return nil, err
}
slog.Debug(fmt.Sprintf("finding name for '%s' -> '%s'", k.(string), ggufName))

View file

@ -17,7 +17,7 @@
### Model names
Model names follow a `model:tag` format, where `model` can have an optional namespace such as `example/model`. Some examples are `orca-mini:3b-q4_1` and `llama2:70b`. The tag is optional and, if not provided, will default to `latest`. The tag is used to identify a specific version.
Model names follow a `model:tag` format, where `model` can have an optional namespace such as `example/model`. Some examples are `orca-mini:3b-q4_1` and `llama3:70b`. The tag is optional and, if not provided, will default to `latest`. The tag is used to identify a specific version.
### Durations
@ -66,7 +66,7 @@ Enable JSON mode by setting the `format` parameter to `json`. This will structur
```shell
curl http://localhost:11434/api/generate -d '{
"model": "llama2",
"model": "llama3",
"prompt": "Why is the sky blue?"
}'
```
@ -77,7 +77,7 @@ A stream of JSON objects is returned:
```json
{
"model": "llama2",
"model": "llama3",
"created_at": "2023-08-04T08:52:19.385406455-07:00",
"response": "The",
"done": false
@ -90,16 +90,16 @@ The final response in the stream also includes additional data about the generat
- `load_duration`: time spent in nanoseconds loading the model
- `prompt_eval_count`: number of tokens in the prompt
- `prompt_eval_duration`: time spent in nanoseconds evaluating the prompt
- `eval_count`: number of tokens the response
- `eval_count`: number of tokens in the response
- `eval_duration`: time in nanoseconds spent generating the response
- `context`: an encoding of the conversation used in this response, this can be sent in the next request to keep a conversational memory
- `response`: empty if the response was streamed, if not streamed, this will contain the full response
To calculate how fast the response is generated in tokens per second (token/s), divide `eval_count` / `eval_duration`.
To calculate how fast the response is generated in tokens per second (token/s), divide `eval_count` / `eval_duration` * `10^9`.
```json
{
"model": "llama2",
"model": "llama3",
"created_at": "2023-08-04T19:22:45.499127Z",
"response": "",
"done": true,
@ -121,7 +121,7 @@ A response can be received in one reply when streaming is off.
```shell
curl http://localhost:11434/api/generate -d '{
"model": "llama2",
"model": "llama3",
"prompt": "Why is the sky blue?",
"stream": false
}'
@ -133,7 +133,7 @@ If `stream` is set to `false`, the response will be a single JSON object:
```json
{
"model": "llama2",
"model": "llama3",
"created_at": "2023-08-04T19:22:45.499127Z",
"response": "The sky is blue because it is the color of the sky.",
"done": true,
@ -155,7 +155,7 @@ If `stream` is set to `false`, the response will be a single JSON object:
```shell
curl http://localhost:11434/api/generate -d '{
"model": "llama2",
"model": "llama3",
"prompt": "What color is the sky at different times of the day? Respond using JSON",
"format": "json",
"stream": false
@ -166,7 +166,7 @@ curl http://localhost:11434/api/generate -d '{
```json
{
"model": "llama2",
"model": "llama3",
"created_at": "2023-11-09T21:07:55.186497Z",
"response": "{\n\"morning\": {\n\"color\": \"blue\"\n},\n\"noon\": {\n\"color\": \"blue-gray\"\n},\n\"afternoon\": {\n\"color\": \"warm gray\"\n},\n\"evening\": {\n\"color\": \"orange\"\n}\n}\n",
"done": true,
@ -289,7 +289,7 @@ If you want to set custom options for the model at runtime rather than in the Mo
```shell
curl http://localhost:11434/api/generate -d '{
"model": "llama2",
"model": "llama3",
"prompt": "Why is the sky blue?",
"stream": false,
"options": {
@ -332,7 +332,7 @@ curl http://localhost:11434/api/generate -d '{
```json
{
"model": "llama2",
"model": "llama3",
"created_at": "2023-08-04T19:22:45.499127Z",
"response": "The sky is blue because it is the color of the sky.",
"done": true,
@ -354,7 +354,7 @@ If an empty prompt is provided, the model will be loaded into memory.
```shell
curl http://localhost:11434/api/generate -d '{
"model": "llama2"
"model": "llama3"
}'
```
@ -364,7 +364,7 @@ A single JSON object is returned:
```json
{
"model": "llama2",
"model": "llama3",
"created_at": "2023-12-18T19:52:07.071755Z",
"response": "",
"done": true
@ -407,7 +407,7 @@ Send a chat message with a streaming response.
```shell
curl http://localhost:11434/api/chat -d '{
"model": "llama2",
"model": "llama3",
"messages": [
{
"role": "user",
@ -423,7 +423,7 @@ A stream of JSON objects is returned:
```json
{
"model": "llama2",
"model": "llama3",
"created_at": "2023-08-04T08:52:19.385406455-07:00",
"message": {
"role": "assistant",
@ -438,7 +438,7 @@ Final response:
```json
{
"model": "llama2",
"model": "llama3",
"created_at": "2023-08-04T19:22:45.499127Z",
"done": true,
"total_duration": 4883583458,
@ -456,7 +456,7 @@ Final response:
```shell
curl http://localhost:11434/api/chat -d '{
"model": "llama2",
"model": "llama3",
"messages": [
{
"role": "user",
@ -471,7 +471,7 @@ curl http://localhost:11434/api/chat -d '{
```json
{
"model": "registry.ollama.ai/library/llama2:latest",
"model": "registry.ollama.ai/library/llama3:latest",
"created_at": "2023-12-12T14:13:43.416799Z",
"message": {
"role": "assistant",
@ -495,7 +495,7 @@ Send a chat message with a conversation history. You can use this same approach
```shell
curl http://localhost:11434/api/chat -d '{
"model": "llama2",
"model": "llama3",
"messages": [
{
"role": "user",
@ -519,7 +519,7 @@ A stream of JSON objects is returned:
```json
{
"model": "llama2",
"model": "llama3",
"created_at": "2023-08-04T08:52:19.385406455-07:00",
"message": {
"role": "assistant",
@ -533,7 +533,7 @@ Final response:
```json
{
"model": "llama2",
"model": "llama3",
"created_at": "2023-08-04T19:22:45.499127Z",
"done": true,
"total_duration": 8113331500,
@ -591,7 +591,7 @@ curl http://localhost:11434/api/chat -d '{
```shell
curl http://localhost:11434/api/chat -d '{
"model": "llama2",
"model": "llama3",
"messages": [
{
"role": "user",
@ -609,7 +609,7 @@ curl http://localhost:11434/api/chat -d '{
```json
{
"model": "registry.ollama.ai/library/llama2:latest",
"model": "registry.ollama.ai/library/llama3:latest",
"created_at": "2023-12-12T14:13:43.416799Z",
"message": {
"role": "assistant",
@ -651,7 +651,7 @@ Create a new model from a `Modelfile`.
```shell
curl http://localhost:11434/api/create -d '{
"name": "mario",
"modelfile": "FROM llama2\nSYSTEM You are mario from Super Mario Bros."
"modelfile": "FROM llama3\nSYSTEM You are mario from Super Mario Bros."
}'
```
@ -758,7 +758,7 @@ A single JSON object will be returned.
}
},
{
"name": "llama2:latest",
"name": "llama3:latest",
"modified_at": "2023-12-07T09:32:18.757212583-08:00",
"size": 3825819519,
"digest": "fe938a131f40e6f6d40083c9f0f430a515233eb2edaa6d72eb85c50d64f2300e",
@ -792,7 +792,7 @@ Show information about a model including details, modelfile, template, parameter
```shell
curl http://localhost:11434/api/show -d '{
"name": "llama2"
"name": "llama3"
}'
```
@ -827,8 +827,8 @@ Copy a model. Creates a model with another name from an existing model.
```shell
curl http://localhost:11434/api/copy -d '{
"source": "llama2",
"destination": "llama2-backup"
"source": "llama3",
"destination": "llama3-backup"
}'
```
@ -854,7 +854,7 @@ Delete a model and its data.
```shell
curl -X DELETE http://localhost:11434/api/delete -d '{
"name": "llama2:13b"
"name": "llama3:13b"
}'
```
@ -882,7 +882,7 @@ Download a model from the ollama library. Cancelled pulls are resumed from where
```shell
curl http://localhost:11434/api/pull -d '{
"name": "llama2"
"name": "llama3"
}'
```

View file

@ -51,7 +51,7 @@ Typically the build scripts will auto-detect CUDA, however, if your Linux distro
or installation approach uses unusual paths, you can specify the location by
specifying an environment variable `CUDA_LIB_DIR` to the location of the shared
libraries, and `CUDACXX` to the location of the nvcc compiler. You can customize
set set of target CUDA architectues by setting `CMAKE_CUDA_ARCHITECTURES` (e.g. "50;60;70")
a set of target CUDA architectures by setting `CMAKE_CUDA_ARCHITECTURES` (e.g. "50;60;70")
Then generate dependencies:
@ -142,4 +142,4 @@ In addition to the common Windows development tools described above, install AMD
- [AMD HIP](https://www.amd.com/en/developer/resources/rocm-hub/hip-sdk.html)
- [Strawberry Perl](https://strawberryperl.com/)
Lastly, add `ninja.exe` included with MSVC to the system path (e.g. `C:\Program Files (x86)\Microsoft Visual Studio\2019\Community\Common7\IDE\CommonExtensions\Microsoft\CMake\Ninja`).
Lastly, add `ninja.exe` included with MSVC to the system path (e.g. `C:\Program Files (x86)\Microsoft Visual Studio\2019\Community\Common7\IDE\CommonExtensions\Microsoft\CMake\Ninja`).

View file

@ -32,7 +32,7 @@ When using the API, specify the `num_ctx` parameter:
```
curl http://localhost:11434/api/generate -d '{
"model": "llama2",
"model": "llama3",
"prompt": "Why is the sky blue?",
"options": {
"num_ctx": 4096
@ -88,9 +88,9 @@ On windows, Ollama inherits your user and system environment variables.
3. Edit or create New variable(s) for your user account for `OLLAMA_HOST`, `OLLAMA_MODELS`, etc.
4. Click OK/Apply to save
4. Click OK/Apply to save
5. Run `ollama` from a new terminal window
5. Run `ollama` from a new terminal window
## How can I expose Ollama on my network?
@ -140,7 +140,7 @@ Refer to the section [above](#how-do-i-configure-ollama-server) for how to set e
- macOS: `~/.ollama/models`
- Linux: `/usr/share/ollama/.ollama/models`
- Windows: `C:\Users\<username>\.ollama\models`
- Windows: `C:\Users\%username%\.ollama\models`
### How do I set them to a different location?
@ -221,10 +221,20 @@ The `keep_alive` parameter can be set to:
For example, to preload a model and leave it in memory use:
```shell
curl http://localhost:11434/api/generate -d '{"model": "llama2", "keep_alive": -1}'
curl http://localhost:11434/api/generate -d '{"model": "llama3", "keep_alive": -1}'
```
To unload the model and free up memory use:
```shell
curl http://localhost:11434/api/generate -d '{"model": "llama2", "keep_alive": 0}'
curl http://localhost:11434/api/generate -d '{"model": "llama3", "keep_alive": 0}'
```
Alternatively, you can change the amount of time all models are loaded into memory by setting the `OLLAMA_KEEP_ALIVE` environment variable when starting the Ollama server. The `OLLAMA_KEEP_ALIVE` variable uses the same parameter types as the `keep_alive` parameter types mentioned above. Refer to section explaining [how to configure the Ollama server](#how-do-i-configure-ollama-server) to correctly set the environment variable.
If you wish to override the `OLLAMA_KEEP_ALIVE` setting, use the `keep_alive` API parameter with the `/api/generate` or `/api/chat` API endpoints.
## How do I manage the maximum number of requests the server can queue
If too many requests are sent to the server, it will respond with a 503 error
indicating the server is overloaded. You can adjust how many requests may be
queue by setting `OLLAMA_MAX_QUEUE`

View file

@ -125,7 +125,7 @@ Publishing models is in early alpha. If you'd like to publish your model to shar
1. Create [an account](https://ollama.com/signup)
2. Copy your Ollama public key:
- macOS: `cat ~/.ollama/id_ed25519.pub`
- macOS: `cat ~/.ollama/id_ed25519.pub | pbcopy`
- Windows: `type %USERPROFILE%\.ollama\id_ed25519.pub`
- Linux: `cat /usr/share/ollama/.ollama/id_ed25519.pub`
3. Add your public key to your [Ollama account](https://ollama.com/settings/keys)
@ -136,6 +136,8 @@ Next, copy your model to your username's namespace:
ollama cp example <your username>/example
```
> Note: model names may only contain lowercase letters, digits, and the characters `.`, `-`, and `_`.
Then push the model:
```

View file

@ -105,7 +105,7 @@ sudo chmod +x /usr/bin/ollama
To view logs of Ollama running as a startup service, run:
```bash
journalctl -u ollama
journalctl -e -u ollama
```
## Uninstall

View file

@ -10,7 +10,7 @@ A model file is the blueprint to create and share models with Ollama.
- [Examples](#examples)
- [Instructions](#instructions)
- [FROM (Required)](#from-required)
- [Build from llama2](#build-from-llama2)
- [Build from llama3](#build-from-llama3)
- [Build from a bin file](#build-from-a-bin-file)
- [PARAMETER](#parameter)
- [Valid Parameters and Values](#valid-parameters-and-values)
@ -48,7 +48,7 @@ INSTRUCTION arguments
An example of a `Modelfile` creating a mario blueprint:
```modelfile
FROM llama2
FROM llama3
# sets the temperature to 1 [higher is more creative, lower is more coherent]
PARAMETER temperature 1
# sets the context window size to 4096, this controls how many tokens the LLM can use as context to generate the next token
@ -67,33 +67,25 @@ To use this:
More examples are available in the [examples directory](../examples).
### `Modelfile`s in [ollama.com/library][1]
There are two ways to view `Modelfile`s underlying the models in [ollama.com/library][1]:
- Option 1: view a details page from a model's tags page:
1. Go to a particular model's tags (e.g. https://ollama.com/library/llama2/tags)
2. Click on a tag (e.g. https://ollama.com/library/llama2:13b)
3. Scroll down to "Layers"
- Note: if the [`FROM` instruction](#from-required) is not present,
it means the model was created from a local file
- Option 2: use `ollama show` to print the `Modelfile` for any local models like so:
To view the Modelfile of a given model, use the `ollama show --modelfile` command.
```bash
> ollama show --modelfile llama2:13b
> ollama show --modelfile llama3
# Modelfile generated by "ollama show"
# To build a new Modelfile based on this one, replace the FROM line with:
# FROM llama2:13b
# FROM llama3:latest
FROM /Users/pdevine/.ollama/models/blobs/sha256-00e1317cbf74d901080d7100f57580ba8dd8de57203072dc6f668324ba545f29
TEMPLATE """{{ if .System }}<|start_header_id|>system<|end_header_id|>
FROM /root/.ollama/models/blobs/sha256:123abc
TEMPLATE """[INST] {{ if .System }}<<SYS>>{{ .System }}<</SYS>>
{{ .System }}<|eot_id|>{{ end }}{{ if .Prompt }}<|start_header_id|>user<|end_header_id|>
{{ end }}{{ .Prompt }} [/INST] """
SYSTEM """"""
PARAMETER stop [INST]
PARAMETER stop [/INST]
PARAMETER stop <<SYS>>
PARAMETER stop <</SYS>>
{{ .Prompt }}<|eot_id|>{{ end }}<|start_header_id|>assistant<|end_header_id|>
{{ .Response }}<|eot_id|>"""
PARAMETER stop "<|start_header_id|>"
PARAMETER stop "<|end_header_id|>"
PARAMETER stop "<|eot_id|>"
PARAMETER stop "<|reserved_special_token"
```
## Instructions
@ -106,10 +98,10 @@ The `FROM` instruction defines the base model to use when creating a model.
FROM <model name>:<tag>
```
#### Build from llama2
#### Build from llama3
```modelfile
FROM llama2
FROM llama3
```
A list of available base models:

View file

@ -25,7 +25,7 @@ chat_completion = client.chat.completions.create(
'content': 'Say this is a test',
}
],
model='llama2',
model='llama3',
)
```
@ -43,7 +43,7 @@ const openai = new OpenAI({
const chatCompletion = await openai.chat.completions.create({
messages: [{ role: 'user', content: 'Say this is a test' }],
model: 'llama2',
model: 'llama3',
})
```
@ -53,7 +53,7 @@ const chatCompletion = await openai.chat.completions.create({
curl http://localhost:11434/v1/chat/completions \
-H "Content-Type: application/json" \
-d '{
"model": "llama2",
"model": "llama3",
"messages": [
{
"role": "system",
@ -113,7 +113,7 @@ curl http://localhost:11434/v1/chat/completions \
Before using a model, pull it locally `ollama pull`:
```shell
ollama pull llama2
ollama pull llama3
```
### Default model names
@ -121,7 +121,7 @@ ollama pull llama2
For tooling that relies on default OpenAI model names such as `gpt-3.5-turbo`, use `ollama cp` to copy an existing model name to a temporary name:
```
ollama cp llama2 gpt-3.5-turbo
ollama cp llama3 gpt-3.5-turbo
```
Afterwards, this new model name can be specified the `model` field:

View file

@ -15,7 +15,7 @@ import { Ollama } from "langchain/llms/ollama";
const ollama = new Ollama({
baseUrl: "http://localhost:11434",
model: "llama2",
model: "llama3",
});
const answer = await ollama.invoke(`why is the sky blue?`);
@ -23,10 +23,10 @@ const answer = await ollama.invoke(`why is the sky blue?`);
console.log(answer);
```
That will get us the same thing as if we ran `ollama run llama2 "why is the sky blue"` in the terminal. But we want to load a document from the web to ask a question against. **Cheerio** is a great library for ingesting a webpage, and **LangChain** uses it in their **CheerioWebBaseLoader**. So let's install **Cheerio** and build that part of the app.
That will get us the same thing as if we ran `ollama run llama3 "why is the sky blue"` in the terminal. But we want to load a document from the web to ask a question against. **Cheerio** is a great library for ingesting a webpage, and **LangChain** uses it in their **CheerioWebBaseLoader**. So let's install **Cheerio** and build that part of the app.
```bash
npm install cheerio
npm install cheerio
```
```javascript

View file

@ -12,15 +12,17 @@ So let's figure out how we can use **LangChain** with Ollama to ask our question
Let's start by asking a simple question that we can get an answer to from the **Llama2** model using **Ollama**. First, we need to install the **LangChain** package:
`pip install langchain`
`pip install langchain_community`
Then we can create a model and ask the question:
```python
from langchain.llms import Ollama
ollama = Ollama(base_url='http://localhost:11434',
model="llama2")
print(ollama("why is the sky blue"))
from langchain_community.llms import Ollama
ollama = Ollama(
base_url='http://localhost:11434',
model="llama3"
)
print(ollama.invoke("why is the sky blue"))
```
Notice that we are defining the model and the base URL for Ollama.

View file

@ -1,38 +1,15 @@
# Running Ollama on NVIDIA Jetson Devices
With some minor configuration, Ollama runs well on [NVIDIA Jetson Devices](https://www.nvidia.com/en-us/autonomous-machines/embedded-systems/). The following has been tested on [JetPack 5.1.2](https://developer.nvidia.com/embedded/jetpack).
Ollama runs well on [NVIDIA Jetson Devices](https://www.nvidia.com/en-us/autonomous-machines/embedded-systems/) and should run out of the box with the standard installation instructions.
NVIDIA Jetson devices are Linux-based embedded AI computers that are purpose-built for AI applications.
Jetsons have an integrated GPU that is wired directly to the memory controller of the machine. For this reason, the `nvidia-smi` command is unrecognized, and Ollama proceeds to operate in "CPU only"
mode. This can be verified by using a monitoring tool like jtop.
In order to address this, we simply pass the path to the Jetson's pre-installed CUDA libraries into `ollama serve` (while in a tmux session). We then hardcode the num_gpu parameters into a cloned
version of our target model.
Prerequisites:
- curl
- tmux
Here are the steps:
The following has been tested on [JetPack 5.1.2](https://developer.nvidia.com/embedded/jetpack), but should also work on JetPack 6.0.
- Install Ollama via standard Linux command (ignore the 404 error): `curl https://ollama.com/install.sh | sh`
- Stop the Ollama service: `sudo systemctl stop ollama`
- Start Ollama serve in a tmux session called ollama_jetson and reference the CUDA libraries path: `tmux has-session -t ollama_jetson 2>/dev/null || tmux new-session -d -s ollama_jetson
'LD_LIBRARY_PATH=/usr/local/cuda/lib64 ollama serve'`
- Pull the model you want to use (e.g. mistral): `ollama pull mistral`
- Create a new Modelfile specifically for enabling GPU support on the Jetson: `touch ModelfileMistralJetson`
- In the ModelfileMistralJetson file, specify the FROM model and the num_gpu PARAMETER as shown below:
```
FROM mistral
PARAMETER num_gpu 999
```
- Create a new model from your Modelfile: `ollama create mistral-jetson -f ./ModelfileMistralJetson`
- Run the new model: `ollama run mistral-jetson`
If you run a monitoring tool like jtop you should now see that Ollama is using the Jetson's integrated GPU.
- Start an interactive session: `ollama run mistral`
And that's it!
# Running Ollama in Docker
When running GPU accelerated applications in Docker, it is highly recommended to use [dusty-nv jetson-containers repo](https://github.com/dusty-nv/jetson-containers).

View file

@ -1,47 +1,61 @@
# Ollama Windows Preview
Welcome to the Ollama Windows preview.
No more WSL required!
Ollama now runs as a native Windows application, including NVIDIA and AMD Radeon GPU support.
After installing Ollama Windows Preview, Ollama will run in the background and
the `ollama` command line is available in `cmd`, `powershell` or your favorite
terminal application. As usual the Ollama [api](./api.md) will be served on
`http://localhost:11434`.
As this is a preview release, you should expect a few bugs here and there. If
you run into a problem you can reach out on
[Discord](https://discord.gg/ollama), or file an
[issue](https://github.com/ollama/ollama/issues).
Logs will often be helpful in dianosing the problem (see
[Troubleshooting](#troubleshooting) below)
## System Requirements
* Windows 10 or newer, Home or Pro
* NVIDIA 452.39 or newer Drivers if you have an NVIDIA card
* AMD Radeon Driver https://www.amd.com/en/support if you have a Radeon card
## API Access
Here's a quick example showing API access from `powershell`
```powershell
(Invoke-WebRequest -method POST -Body '{"model":"llama2", "prompt":"Why is the sky blue?", "stream": false}' -uri http://localhost:11434/api/generate ).Content | ConvertFrom-json
```
## Troubleshooting
While we're in preview, `OLLAMA_DEBUG` is always enabled, which adds
a "view logs" menu item to the app, and increses logging for the GUI app and
server.
Ollama on Windows stores files in a few different locations. You can view them in
the explorer window by hitting `<cmd>+R` and type in:
- `explorer %LOCALAPPDATA%\Ollama` contains logs, and downloaded updates
- *app.log* contains logs from the GUI application
- *server.log* contains the server logs
- *upgrade.log* contains log output for upgrades
- `explorer %LOCALAPPDATA%\Programs\Ollama` contains the binaries (The installer adds this to your user PATH)
- `explorer %HOMEPATH%\.ollama` contains models and configuration
- `explorer %TEMP%` contains temporary executable files in one or more `ollama*` directories
# Ollama Windows Preview
Welcome to the Ollama Windows preview.
No more WSL required!
Ollama now runs as a native Windows application, including NVIDIA and AMD Radeon GPU support.
After installing Ollama Windows Preview, Ollama will run in the background and
the `ollama` command line is available in `cmd`, `powershell` or your favorite
terminal application. As usual the Ollama [api](./api.md) will be served on
`http://localhost:11434`.
As this is a preview release, you should expect a few bugs here and there. If
you run into a problem you can reach out on
[Discord](https://discord.gg/ollama), or file an
[issue](https://github.com/ollama/ollama/issues).
Logs will often be helpful in diagnosing the problem (see
[Troubleshooting](#troubleshooting) below)
## System Requirements
* Windows 10 or newer, Home or Pro
* NVIDIA 452.39 or newer Drivers if you have an NVIDIA card
* AMD Radeon Driver https://www.amd.com/en/support if you have a Radeon card
## API Access
Here's a quick example showing API access from `powershell`
```powershell
(Invoke-WebRequest -method POST -Body '{"model":"llama3", "prompt":"Why is the sky blue?", "stream": false}' -uri http://localhost:11434/api/generate ).Content | ConvertFrom-json
```
## Troubleshooting
While we're in preview, `OLLAMA_DEBUG` is always enabled, which adds
a "view logs" menu item to the app, and increses logging for the GUI app and
server.
Ollama on Windows stores files in a few different locations. You can view them in
the explorer window by hitting `<cmd>+R` and type in:
- `explorer %LOCALAPPDATA%\Ollama` contains logs, and downloaded updates
- *app.log* contains logs from the GUI application
- *server.log* contains the server logs
- *upgrade.log* contains log output for upgrades
- `explorer %LOCALAPPDATA%\Programs\Ollama` contains the binaries (The installer adds this to your user PATH)
- `explorer %HOMEPATH%\.ollama` contains models and configuration
- `explorer %TEMP%` contains temporary executable files in one or more `ollama*` directories
## Standalone CLI
The easiest way to install Ollama on Windows is to use the `OllamaSetup.exe`
installer. It installs in your account without requiring Administrator rights.
We update Ollama regularly to support the latest models, and this installer will
help you keep up to date.
If you'd like to install or integrate Ollama as a service, a standalone
`ollama-windows-amd64.zip` zip file is available containing only the Ollama CLI
and GPU library dependencies for Nvidia and AMD. This allows for embedding
Ollama in existing applications, or running it as a system service via `ollama
serve` with tools such as [NSSM](https://nssm.cc/).

View file

@ -2,7 +2,7 @@
When calling `ollama`, you can pass it a file to run all the prompts in the file, one after the other:
`ollama run llama2 < sourcequestions.txt`
`ollama run llama3 < sourcequestions.txt`
This concept is used in the following example.

1
examples/flyio/.gitignore vendored Normal file
View file

@ -0,0 +1 @@
fly.toml

67
examples/flyio/README.md Normal file
View file

@ -0,0 +1,67 @@
# Deploy Ollama to Fly.io
> Note: this example exposes a public endpoint and does not configure authentication. Use with care.
## Prerequisites
- Ollama: https://ollama.com/download
- Fly.io account. Sign up for a free account: https://fly.io/app/sign-up
## Steps
1. Login to Fly.io
```bash
fly auth login
```
1. Create a new Fly app
```bash
fly launch --name <name> --image ollama/ollama --internal-port 11434 --vm-size shared-cpu-8x --now
```
1. Pull and run `orca-mini:3b`
```bash
OLLAMA_HOST=https://<name>.fly.dev ollama run orca-mini:3b
```
`shared-cpu-8x` is a free-tier eligible machine type. For better performance, switch to a `performance` or `dedicated` machine type or attach a GPU for hardware acceleration (see below).
## (Optional) Persistent Volume
By default Fly Machines use ephemeral storage which is problematic if you want to use the same model across restarts without pulling it again. Create and attach a persistent volume to store the downloaded models:
1. Create the Fly Volume
```bash
fly volume create ollama
```
1. Update `fly.toml` and add `[mounts]`
```toml
[mounts]
source = "ollama"
destination = "/mnt/ollama/models"
```
1. Update `fly.toml` and add `[env]`
```toml
[env]
OLLAMA_MODELS = "/mnt/ollama/models"
```
1. Deploy your app
```bash
fly deploy
```
## (Optional) Hardware Acceleration
Fly.io GPU is currently in waitlist. Sign up for the waitlist: https://fly.io/gpu
Once you've been accepted, create the app with the additional flags `--vm-gpu-kind a100-pcie-40gb` or `--vm-gpu-kind a100-pcie-80gb`.

View file

@ -35,7 +35,7 @@ func main() {
ctx := context.Background()
req := &api.ChatRequest{
Model: "llama2",
Model: "llama3",
Messages: messages,
}

View file

@ -19,7 +19,7 @@ func main() {
}
defer resp.Body.Close()
responseData, err := io.ReadAll(resp.Body)
if err != nil {
log.Fatal(err)

View file

@ -7,12 +7,24 @@
## Steps
1. Create the Ollama namespace, daemon set, and service
1. Create the Ollama namespace, deployment, and service
```bash
kubectl apply -f cpu.yaml
```
## (Optional) Hardware Acceleration
Hardware acceleration in Kubernetes requires NVIDIA's [`k8s-device-plugin`](https://github.com/NVIDIA/k8s-device-plugin) which is deployed in Kubernetes in form of daemonset. Follow the link for more details.
Once configured, create a GPU enabled Ollama deployment.
```bash
kubectl apply -f gpu.yaml
```
## Test
1. Port forward the Ollama service to connect and use it locally
```bash
@ -23,14 +35,4 @@
```bash
ollama run orca-mini:3b
```
## (Optional) Hardware Acceleration
Hardware acceleration in Kubernetes requires NVIDIA's [`k8s-device-plugin`](https://github.com/NVIDIA/k8s-device-plugin). Follow the link for more details.
Once configured, create a GPU enabled Ollama deployment.
```bash
kubectl apply -f gpu.yaml
```
```

View file

@ -40,9 +40,9 @@ while True:
continue
# Prompt
template = """Use the following pieces of context to answer the question at the end.
If you don't know the answer, just say that you don't know, don't try to make up an answer.
Use three sentences maximum and keep the answer as concise as possible.
template = """Use the following pieces of context to answer the question at the end.
If you don't know the answer, just say that you don't know, don't try to make up an answer.
Use three sentences maximum and keep the answer as concise as possible.
{context}
Question: {question}
Helpful Answer:"""
@ -51,11 +51,11 @@ while True:
template=template,
)
llm = Ollama(model="llama2:13b", callback_manager=CallbackManager([StreamingStdOutCallbackHandler()]))
llm = Ollama(model="llama3:8b", callback_manager=CallbackManager([StreamingStdOutCallbackHandler()]))
qa_chain = RetrievalQA.from_chain_type(
llm,
retriever=vectorstore.as_retriever(),
chain_type_kwargs={"prompt": QA_CHAIN_PROMPT},
)
result = qa_chain({"query": query})
result = qa_chain({"query": query})

View file

@ -1,12 +1,12 @@
from langchain.llms import Ollama
from langchain.document_loaders import WebBaseLoader
from langchain_community.llms import Ollama
from langchain_community.document_loaders import WebBaseLoader
from langchain.chains.summarize import load_summarize_chain
loader = WebBaseLoader("https://ollama.com/blog/run-llama2-uncensored-locally")
docs = loader.load()
llm = Ollama(model="llama2")
llm = Ollama(model="llama3")
chain = load_summarize_chain(llm, chain_type="stuff")
result = chain.run(docs)
result = chain.invoke(docs)
print(result)

View file

@ -4,10 +4,10 @@ This example is a basic "hello world" of using LangChain with Ollama.
## Running the Example
1. Ensure you have the `llama2` model installed:
1. Ensure you have the `llama3` model installed:
```bash
ollama pull llama2
ollama pull llama3
```
2. Install the Python Requirements.
@ -21,4 +21,3 @@ This example is a basic "hello world" of using LangChain with Ollama.
```bash
python main.py
```

View file

@ -1,6 +1,6 @@
from langchain.llms import Ollama
input = input("What is your question?")
llm = Ollama(model="llama2")
llm = Ollama(model="llama3")
res = llm.predict(input)
print (res)

View file

@ -1,4 +1,4 @@
FROM llama2
FROM llama3
PARAMETER temperature 1
SYSTEM """
You are Mario from super mario bros, acting as an assistant.

View file

@ -2,12 +2,12 @@
# Example character: Mario
This example shows how to create a basic character using Llama2 as the base model.
This example shows how to create a basic character using Llama3 as the base model.
To run this example:
1. Download the Modelfile
2. `ollama pull llama2` to get the base model used in the model file.
2. `ollama pull llama3` to get the base model used in the model file.
3. `ollama create NAME -f ./Modelfile`
4. `ollama run NAME`
@ -18,7 +18,7 @@ Ask it some questions like "Who are you?" or "Is Peach in trouble again?"
What the model file looks like:
```
FROM llama2
FROM llama3
PARAMETER temperature 1
SYSTEM """
You are Mario from Super Mario Bros, acting as an assistant.

View file

@ -2,16 +2,16 @@ import requests
import json
import random
model = "llama2"
model = "llama3"
template = {
"firstName": "",
"lastName": "",
"firstName": "",
"lastName": "",
"address": {
"street": "",
"city": "",
"state": "",
"street": "",
"city": "",
"state": "",
"zipCode": ""
},
},
"phoneNumber": ""
}

View file

@ -12,7 +12,7 @@ countries = [
"France",
]
country = random.choice(countries)
model = "llama2"
model = "llama3"
prompt = f"generate one realistically believable sample data set of a persons first name, last name, address in {country}, and phone number. Do not use common names. Respond using JSON. Key names should have no backslashes, values should use plain ascii with no special characters."

View file

@ -6,10 +6,10 @@ There are two python scripts in this example. `randomaddresses.py` generates ran
## Running the Example
1. Ensure you have the `llama2` model installed:
1. Ensure you have the `llama3` model installed:
```bash
ollama pull llama2
ollama pull llama3
```
2. Install the Python Requirements.

View file

@ -2,7 +2,7 @@ import json
import requests
# NOTE: ollama must be running for this to work, start the ollama app or run `ollama serve`
model = "llama2" # TODO: update this for whatever model you wish to use
model = "llama3" # TODO: update this for whatever model you wish to use
def chat(messages):

View file

@ -4,10 +4,10 @@ The **chat** endpoint is one of two ways to generate text from an LLM with Ollam
## Running the Example
1. Ensure you have the `llama2` model installed:
1. Ensure you have the `llama3` model installed:
```bash
ollama pull llama2
ollama pull llama3
```
2. Install the Python Requirements.

View file

@ -4,10 +4,10 @@ This example demonstrates how one would create a set of 'mentors' you can have a
## Usage
1. Add llama2 to have the mentors ask your questions:
1. Add llama3 to have the mentors ask your questions:
```bash
ollama pull llama2
ollama pull llama3
```
2. Install prerequisites:

View file

@ -15,7 +15,7 @@ async function characterGenerator() {
ollama.setModel("stablebeluga2:70b-q4_K_M");
const bio = await ollama.generate(`create a bio of ${character} in a single long paragraph. Instead of saying '${character} is...' or '${character} was...' use language like 'You are...' or 'You were...'. Then create a paragraph describing the speaking mannerisms and style of ${character}. Don't include anything about how ${character} looked or what they sounded like, just focus on the words they said. Instead of saying '${character} would say...' use language like 'You should say...'. If you use quotes, always use single quotes instead of double quotes. If there are any specific words or phrases you used a lot, show how you used them. `);
const thecontents = `FROM llama2\nSYSTEM """\n${bio.response.replace(/(\r\n|\n|\r)/gm, " ").replace('would', 'should')} All answers to questions should be related back to what you are most known for.\n"""`;
const thecontents = `FROM llama3\nSYSTEM """\n${bio.response.replace(/(\r\n|\n|\r)/gm, " ").replace('would', 'should')} All answers to questions should be related back to what you are most known for.\n"""`;
fs.writeFile(path.join(directory, 'Modelfile'), thecontents, (err: any) => {
if (err) throw err;
@ -23,4 +23,4 @@ async function characterGenerator() {
});
}
characterGenerator();
characterGenerator();

View file

@ -1,6 +1,6 @@
import * as readline from "readline";
const model = "llama2";
const model = "llama3";
type Message = {
role: "assistant" | "user" | "system";
content: string;
@ -74,4 +74,4 @@ async function main() {
}
main();
main();

View file

@ -15,6 +15,7 @@ const (
KibiByte = Byte * 1024
MebiByte = KibiByte * 1024
GibiByte = MebiByte * 1024
)
func HumanBytes(b int64) string {
@ -52,6 +53,8 @@ func HumanBytes(b int64) string {
func HumanBytes2(b uint64) string {
switch {
case b >= GibiByte:
return fmt.Sprintf("%.1f GiB", float64(b)/GibiByte)
case b >= MebiByte:
return fmt.Sprintf("%.1f MiB", float64(b)/MebiByte)
case b >= KibiByte:

View file

@ -13,12 +13,20 @@ const (
func HumanNumber(b uint64) string {
switch {
case b > Billion:
return fmt.Sprintf("%.0fB", math.Round(float64(b)/Billion))
case b > Million:
return fmt.Sprintf("%.0fM", math.Round(float64(b)/Million))
case b > Thousand:
return fmt.Sprintf("%.0fK", math.Round(float64(b)/Thousand))
case b >= Billion:
number := float64(b) / Billion
if number == math.Floor(number) {
return fmt.Sprintf("%.0fB", number) // no decimals if whole number
}
return fmt.Sprintf("%.1fB", number) // one decimal if not a whole number
case b >= Million:
number := float64(b) / Million
if number == math.Floor(number) {
return fmt.Sprintf("%.0fM", number) // no decimals if whole number
}
return fmt.Sprintf("%.2fM", number) // two decimals if not a whole number
case b >= Thousand:
return fmt.Sprintf("%.0fK", float64(b)/Thousand)
default:
return fmt.Sprintf("%d", b)
}

34
format/format_test.go Normal file
View file

@ -0,0 +1,34 @@
package format
import (
"testing"
)
func TestHumanNumber(t *testing.T) {
type testCase struct {
input uint64
expected string
}
testCases := []testCase{
{0, "0"},
{1000000, "1M"},
{125000000, "125M"},
{500500000, "500.50M"},
{500550000, "500.55M"},
{1000000000, "1B"},
{2800000000, "2.8B"},
{2850000000, "2.9B"},
{1000000000000, "1000B"},
}
for _, tc := range testCases {
t.Run(tc.expected, func(t *testing.T) {
result := HumanNumber(tc.input)
if result != tc.expected {
t.Errorf("Expected %s, got %s", tc.expected, result)
}
})
}
}

View file

@ -7,7 +7,7 @@ import (
"log/slog"
"os"
"path/filepath"
"strconv"
"runtime"
"strings"
)
@ -35,22 +35,66 @@ func GetSupportedGFX(libDir string) ([]string, error) {
return ret, nil
}
func amdSetVisibleDevices(ids []int, skip map[int]interface{}) {
// Set the visible devices if not already set
// TODO - does sort order matter?
devices := []string{}
for i := range ids {
if _, skipped := skip[i]; skipped {
func rocmGetVisibleDevicesEnv(gpuInfo []GpuInfo) (string, string) {
ids := []string{}
for _, info := range gpuInfo {
if info.Library != "rocm" {
// TODO shouldn't happen if things are wired correctly...
slog.Debug("rocmGetVisibleDevicesEnv skipping over non-rocm device", "library", info.Library)
continue
}
devices = append(devices, strconv.Itoa(i))
ids = append(ids, info.ID)
}
return "HIP_VISIBLE_DEVICES", strings.Join(ids, ",")
}
func commonAMDValidateLibDir() (string, error) {
// We try to favor system paths first, so that we can wire up the subprocess to use
// the system version. Only use our bundled version if the system version doesn't work
// This gives users a more recovery options if versions have subtle problems at runtime
// Prefer explicit HIP env var
hipPath := os.Getenv("HIP_PATH")
if hipPath != "" {
hipLibDir := filepath.Join(hipPath, "bin")
if rocmLibUsable(hipLibDir) {
slog.Debug("detected ROCM via HIP_PATH=" + hipPath)
return hipLibDir, nil
}
}
val := strings.Join(devices, ",")
err := os.Setenv("HIP_VISIBLE_DEVICES", val)
if err != nil {
slog.Warn(fmt.Sprintf("failed to set env: %s", err))
} else {
slog.Info("Setting HIP_VISIBLE_DEVICES=" + val)
// Scan the LD_LIBRARY_PATH or PATH
pathEnv := "LD_LIBRARY_PATH"
if runtime.GOOS == "windows" {
pathEnv = "PATH"
}
paths := os.Getenv(pathEnv)
for _, path := range filepath.SplitList(paths) {
d, err := filepath.Abs(path)
if err != nil {
continue
}
if rocmLibUsable(d) {
return d, nil
}
}
// Well known location(s)
for _, path := range RocmStandardLocations {
if rocmLibUsable(path) {
return path, nil
}
}
// Installer payload location if we're running the installed binary
exe, err := os.Executable()
if err == nil {
rocmTargetDir := filepath.Join(filepath.Dir(exe), "rocm")
if rocmLibUsable(rocmTargetDir) {
slog.Debug("detected ROCM next to ollama executable " + rocmTargetDir)
return rocmTargetDir, nil
}
}
return "", fmt.Errorf("no suitable rocm found, falling back to CPU")
}

View file

@ -69,7 +69,7 @@ func NewHipLib() (*HipLib, error) {
func (hl *HipLib) Release() {
err := windows.FreeLibrary(hl.dll)
if err != nil {
slog.Warn(fmt.Sprintf("failed to unload amdhip64.dll: %s", err))
slog.Warn("failed to unload amdhip64.dll", "error", err)
}
hl.dll = 0
}
@ -98,7 +98,7 @@ func (hl *HipLib) HipGetDeviceCount() int {
return 0
}
if status != hipSuccess {
slog.Warn(fmt.Sprintf("failed call to hipGetDeviceCount: %d %s", status, err))
slog.Warn("failed call to hipGetDeviceCount", "status", status, "error", err)
}
return count
}

View file

@ -11,6 +11,8 @@ import (
"slices"
"strconv"
"strings"
"github.com/ollama/ollama/format"
)
// Discovery logic for AMD/ROCm GPUs
@ -23,26 +25,20 @@ const (
// Prefix with the node dir
GPUTotalMemoryFileGlob = "mem_banks/*/properties" // size_in_bytes line
GPUUsedMemoryFileGlob = "mem_banks/*/used_memory"
RocmStandardLocation = "/opt/rocm/lib"
// TODO find a better way to detect iGPU instead of minimum memory
IGPUMemLimit = 1024 * 1024 * 1024 // 512G is what they typically report, so anything less than 1G must be iGPU
)
var (
// Used to validate if the given ROCm lib is usable
ROCmLibGlobs = []string{"libhipblas.so.2*", "rocblas"} // TODO - probably include more coverage of files here...
ROCmLibGlobs = []string{"libhipblas.so.2*", "rocblas"} // TODO - probably include more coverage of files here...
RocmStandardLocations = []string{"/opt/rocm/lib", "/usr/lib64"}
)
// Gather GPU information from the amdgpu driver if any supported GPUs are detected
// HIP_VISIBLE_DEVICES will be set if we detect a mix of unsupported and supported devices
// and the user hasn't already set this variable
func AMDGetGPUInfo(resp *GpuInfo) {
// TODO - DRY this out with windows
func AMDGetGPUInfo() []GpuInfo {
resp := []GpuInfo{}
if !AMDDetected() {
return
return resp
}
skip := map[int]interface{}{}
// Opportunistic logging of driver version to aid in troubleshooting
ver, err := AMDDriverVersion()
@ -50,160 +46,117 @@ func AMDGetGPUInfo(resp *GpuInfo) {
slog.Info("AMD Driver: " + ver)
} else {
// TODO - if we see users crash and burn with the upstreamed kernel this can be adjusted to hard-fail rocm support and fallback to CPU
slog.Warn(fmt.Sprintf("ollama recommends running the https://www.amd.com/en/support/linux-drivers: %s", err))
slog.Warn("ollama recommends running the https://www.amd.com/en/support/linux-drivers", "error", err)
}
// If the user has specified exactly which GPUs to use, look up their memory
visibleDevices := os.Getenv("HIP_VISIBLE_DEVICES")
if visibleDevices != "" {
ids := []int{}
for _, idStr := range strings.Split(visibleDevices, ",") {
id, err := strconv.Atoi(idStr)
if err != nil {
slog.Warn(fmt.Sprintf("malformed HIP_VISIBLE_DEVICES=%s %s", visibleDevices, err))
} else {
ids = append(ids, id)
}
}
amdProcMemLookup(resp, nil, ids)
return
// Determine if the user has already pre-selected which GPUs to look at, then ignore the others
var visibleDevices []string
hipVD := os.Getenv("HIP_VISIBLE_DEVICES") // zero based index only
rocrVD := os.Getenv("ROCR_VISIBLE_DEVICES") // zero based index or UUID, but consumer cards seem to not support UUID
gpuDO := os.Getenv("GPU_DEVICE_ORDINAL") // zero based index
switch {
// TODO is this priorty order right?
case hipVD != "":
visibleDevices = strings.Split(hipVD, ",")
case rocrVD != "":
visibleDevices = strings.Split(rocrVD, ",")
// TODO - since we don't yet support UUIDs, consider detecting and reporting here
// all our test systems show GPU-XX indicating UUID is not supported
case gpuDO != "":
visibleDevices = strings.Split(gpuDO, ",")
}
// Gather GFX version information from all detected cards
gfx := AMDGFXVersions()
verStrings := []string{}
for i, v := range gfx {
verStrings = append(verStrings, v.ToGFXString())
if v.Major == 0 {
// Silently skip CPUs
skip[i] = struct{}{}
continue
}
if v.Major < 9 {
// TODO consider this a build-time setting if we can support 8xx family GPUs
slog.Warn(fmt.Sprintf("amdgpu [%d] too old %s", i, v.ToGFXString()))
skip[i] = struct{}{}
}
}
slog.Info(fmt.Sprintf("detected amdgpu versions %v", verStrings))
// Abort if all GPUs are skipped
if len(skip) >= len(gfx) {
slog.Info("all detected amdgpus are skipped, falling back to CPU")
return
}
// If we got this far, then we have at least 1 GPU that's a ROCm candidate, so make sure we have a lib
libDir, err := AMDValidateLibDir()
if err != nil {
slog.Warn(fmt.Sprintf("unable to verify rocm library, will use cpu: %s", err))
return
}
updateLibPath(libDir)
gfxOverride := os.Getenv("HSA_OVERRIDE_GFX_VERSION")
if gfxOverride == "" {
supported, err := GetSupportedGFX(libDir)
var supported []string
libDir := ""
// The amdgpu driver always exposes the host CPU(s) first, but we have to skip them and subtract
// from the other IDs to get alignment with the HIP libraries expectations (zero is the first GPU, not the CPU)
matches, _ := filepath.Glob(GPUPropertiesFileGlob)
cpuCount := 0
for _, match := range matches {
slog.Debug("evaluating amdgpu node " + match)
fp, err := os.Open(match)
if err != nil {
slog.Warn(fmt.Sprintf("failed to lookup supported GFX types, falling back to CPU mode: %s", err))
return
}
slog.Debug(fmt.Sprintf("rocm supported GPU types %v", supported))
for i, v := range gfx {
if !slices.Contains[[]string, string](supported, v.ToGFXString()) {
slog.Warn(fmt.Sprintf("amdgpu [%d] %s is not supported by %s %v", i, v.ToGFXString(), libDir, supported))
// TODO - consider discrete markdown just for ROCM troubleshooting?
slog.Warn("See https://github.com/ollama/ollama/blob/main/docs/gpu.md#overrides for HSA_OVERRIDE_GFX_VERSION usage")
skip[i] = struct{}{}
} else {
slog.Info(fmt.Sprintf("amdgpu [%d] %s is supported", i, v.ToGFXString()))
}
}
} else {
slog.Debug("skipping rocm gfx compatibility check with HSA_OVERRIDE_GFX_VERSION=" + gfxOverride)
}
if len(skip) >= len(gfx) {
slog.Info("all detected amdgpus are skipped, falling back to CPU")
return
}
ids := make([]int, len(gfx))
i := 0
for k := range gfx {
ids[i] = k
i++
}
amdProcMemLookup(resp, skip, ids)
if resp.memInfo.DeviceCount == 0 {
return
}
if len(skip) > 0 {
amdSetVisibleDevices(ids, skip)
}
}
func updateLibPath(libDir string) {
ldPaths := []string{}
if val, ok := os.LookupEnv("LD_LIBRARY_PATH"); ok {
ldPaths = strings.Split(val, ":")
}
for _, d := range ldPaths {
if d == libDir {
return
}
}
val := strings.Join(append(ldPaths, libDir), ":")
slog.Debug("updated lib path", "LD_LIBRARY_PATH", val)
os.Setenv("LD_LIBRARY_PATH", val)
}
// Walk the sysfs nodes for the available GPUs and gather information from them
// skipping over any devices in the skip map
func amdProcMemLookup(resp *GpuInfo, skip map[int]interface{}, ids []int) {
resp.memInfo.DeviceCount = 0
resp.memInfo.TotalMemory = 0
resp.memInfo.FreeMemory = 0
slog.Debug("discovering VRAM for amdgpu devices")
if len(ids) == 0 {
entries, err := os.ReadDir(AMDNodesSysfsDir)
if err != nil {
slog.Warn(fmt.Sprintf("failed to read amdgpu sysfs %s - %s", AMDNodesSysfsDir, err))
return
}
for _, node := range entries {
if !node.IsDir() {
continue
}
id, err := strconv.Atoi(node.Name())
if err != nil {
slog.Warn("malformed amdgpu sysfs node id " + node.Name())
continue
}
ids = append(ids, id)
}
}
slog.Debug(fmt.Sprintf("amdgpu devices %v", ids))
for _, id := range ids {
if _, skipped := skip[id]; skipped {
slog.Debug("failed to open sysfs node", "file", match, "error", err)
continue
}
defer fp.Close()
nodeID, err := strconv.Atoi(filepath.Base(filepath.Dir(match)))
if err != nil {
slog.Debug("failed to parse node ID", "error", err)
continue
}
scanner := bufio.NewScanner(fp)
isCPU := false
var major, minor, patch uint64
for scanner.Scan() {
line := strings.TrimSpace(scanner.Text())
// Note: we could also use "cpu_cores_count X" where X is greater than zero to detect CPUs
if strings.HasPrefix(line, "gfx_target_version") {
ver := strings.Fields(line)
// Detect CPUs
if len(ver) == 2 && ver[1] == "0" {
slog.Debug("detected CPU " + match)
isCPU = true
break
}
if len(ver) != 2 || len(ver[1]) < 5 {
slog.Warn("malformed "+match, "gfx_target_version", line)
// If this winds up being a CPU, our offsets may be wrong
continue
}
l := len(ver[1])
var err1, err2, err3 error
patch, err1 = strconv.ParseUint(ver[1][l-2:l], 10, 32)
minor, err2 = strconv.ParseUint(ver[1][l-4:l-2], 10, 32)
major, err3 = strconv.ParseUint(ver[1][:l-4], 10, 32)
if err1 != nil || err2 != nil || err3 != nil {
slog.Debug("malformed int " + line)
continue
}
}
// TODO - any other properties we want to extract and record?
// vendor_id + device_id -> pci lookup for "Name"
// Other metrics that may help us understand relative performance between multiple GPUs
}
if isCPU {
cpuCount++
continue
}
// CPUs are always first in the list
gpuID := nodeID - cpuCount
// Shouldn't happen, but just in case...
if gpuID < 0 {
slog.Error("unexpected amdgpu sysfs data resulted in negative GPU ID, please set OLLAMA_DEBUG=1 and report an issue")
return []GpuInfo{}
}
if int(major) < RocmComputeMin {
slog.Warn(fmt.Sprintf("amdgpu too old gfx%d%d%x", major, minor, patch), "gpu", gpuID)
continue
}
// Look up the memory for the current node
totalMemory := uint64(0)
usedMemory := uint64(0)
// Adjust for sysfs vs HIP ids
propGlob := filepath.Join(AMDNodesSysfsDir, strconv.Itoa(id+1), GPUTotalMemoryFileGlob)
propGlob := filepath.Join(AMDNodesSysfsDir, strconv.Itoa(nodeID), GPUTotalMemoryFileGlob)
propFiles, err := filepath.Glob(propGlob)
if err != nil {
slog.Warn(fmt.Sprintf("error looking up total GPU memory: %s %s", propGlob, err))
slog.Warn("error looking up total GPU memory", "glob", propGlob, "error", err)
}
// 1 or more memory banks - sum the values of all of them
for _, propFile := range propFiles {
fp, err := os.Open(propFile)
if err != nil {
slog.Warn(fmt.Sprintf("failed to open sysfs node file %s: %s", propFile, err))
slog.Warn("failed to open sysfs node", "file", propFile, "erroir", err)
continue
}
defer fp.Close()
@ -226,49 +179,113 @@ func amdProcMemLookup(resp *GpuInfo, skip map[int]interface{}, ids []int) {
}
}
if totalMemory == 0 {
slog.Warn(fmt.Sprintf("amdgpu [%d] reports zero total memory, skipping", id))
skip[id] = struct{}{}
slog.Warn("amdgpu reports zero total memory", "gpu", gpuID)
continue
}
if totalMemory < IGPUMemLimit {
slog.Info(fmt.Sprintf("amdgpu [%d] appears to be an iGPU with %dM reported total memory, skipping", id, totalMemory/1024/1024))
skip[id] = struct{}{}
continue
}
usedGlob := filepath.Join(AMDNodesSysfsDir, strconv.Itoa(id), GPUUsedMemoryFileGlob)
usedGlob := filepath.Join(AMDNodesSysfsDir, strconv.Itoa(nodeID), GPUUsedMemoryFileGlob)
usedFiles, err := filepath.Glob(usedGlob)
if err != nil {
slog.Warn(fmt.Sprintf("error looking up used GPU memory: %s %s", usedGlob, err))
slog.Warn("error looking up used GPU memory", "glob", usedGlob, "error", err)
continue
}
for _, usedFile := range usedFiles {
fp, err := os.Open(usedFile)
if err != nil {
slog.Warn(fmt.Sprintf("failed to open sysfs node file %s: %s", usedFile, err))
slog.Warn("failed to open sysfs node", "file", usedFile, "error", err)
continue
}
defer fp.Close()
data, err := io.ReadAll(fp)
if err != nil {
slog.Warn(fmt.Sprintf("failed to read sysfs node file %s: %s", usedFile, err))
slog.Warn("failed to read sysfs node", "file", usedFile, "error", err)
continue
}
used, err := strconv.ParseUint(strings.TrimSpace(string(data)), 10, 64)
if err != nil {
slog.Warn(fmt.Sprintf("malformed used memory %s: %s", string(data), err))
slog.Warn("malformed used memory", "data", string(data), "error", err)
continue
}
usedMemory += used
}
slog.Info(fmt.Sprintf("[%d] amdgpu totalMemory %dM", id, totalMemory/1024/1024))
slog.Info(fmt.Sprintf("[%d] amdgpu freeMemory %dM", id, (totalMemory-usedMemory)/1024/1024))
resp.memInfo.DeviceCount++
resp.memInfo.TotalMemory += totalMemory
resp.memInfo.FreeMemory += (totalMemory - usedMemory)
// iGPU detection, remove this check once we can support an iGPU variant of the rocm library
if totalMemory < IGPUMemLimit {
slog.Info("amdgpu appears to be an iGPU, skipping", "gpu", gpuID, "total", format.HumanBytes2(totalMemory))
continue
}
slog.Info("amdgpu memory", "gpu", gpuID, "total", format.HumanBytes2(totalMemory))
slog.Info("amdgpu memory", "gpu", gpuID, "available", format.HumanBytes2(totalMemory-usedMemory))
gpuInfo := GpuInfo{
Library: "rocm",
memInfo: memInfo{
TotalMemory: totalMemory,
FreeMemory: (totalMemory - usedMemory),
},
ID: fmt.Sprintf("%d", gpuID),
// Name: not exposed in sysfs directly, would require pci device id lookup
Major: int(major),
Minor: int(minor),
Patch: int(patch),
MinimumMemory: rocmMinimumMemory,
}
// If the user wants to filter to a subset of devices, filter out if we aren't a match
if len(visibleDevices) > 0 {
include := false
for _, visible := range visibleDevices {
if visible == gpuInfo.ID {
include = true
break
}
}
if !include {
slog.Info("filtering out device per user request", "id", gpuInfo.ID, "visible_devices", visibleDevices)
continue
}
}
// Final validation is gfx compatibility - load the library if we haven't already loaded it
// even if the user overrides, we still need to validate the library
if libDir == "" {
libDir, err = AMDValidateLibDir()
if err != nil {
slog.Warn("unable to verify rocm library, will use cpu", "error", err)
return []GpuInfo{}
}
}
gpuInfo.DependencyPath = libDir
if gfxOverride == "" {
// Only load supported list once
if len(supported) == 0 {
supported, err = GetSupportedGFX(libDir)
if err != nil {
slog.Warn("failed to lookup supported GFX types, falling back to CPU mode", "error", err)
return []GpuInfo{}
}
slog.Debug("rocm supported GPUs", "types", supported)
}
gfx := fmt.Sprintf("gfx%d%d%x", gpuInfo.Major, gpuInfo.Minor, gpuInfo.Patch)
if !slices.Contains[[]string, string](supported, gfx) {
slog.Warn("amdgpu is not supported", "gpu", gpuInfo.ID, "gpu_type", gfx, "library", libDir, "supported_types", supported)
// TODO - consider discrete markdown just for ROCM troubleshooting?
slog.Warn("See https://github.com/ollama/ollama/blob/main/docs/gpu.md#overrides for HSA_OVERRIDE_GFX_VERSION usage")
continue
} else {
slog.Info("amdgpu is supported", "gpu", gpuInfo.ID, "gpu_type", gfx)
}
} else {
slog.Debug("skipping rocm gfx compatibility check with HSA_OVERRIDE_GFX_VERSION=" + gfxOverride)
}
// The GPU has passed all the verification steps and is supported
resp = append(resp, gpuInfo)
}
if resp.memInfo.DeviceCount > 0 {
resp.Library = "rocm"
if len(resp) == 0 {
slog.Info("no compatible amdgpu devices detected")
}
return resp
}
// Quick check for AMD driver so we can skip amdgpu discovery if not present
@ -280,87 +297,24 @@ func AMDDetected() bool {
slog.Debug("amdgpu driver not detected " + sysfsDir)
return false
} else if err != nil {
slog.Debug(fmt.Sprintf("error looking up amd driver %s %s", sysfsDir, err))
slog.Debug("error looking up amd driver", "path", sysfsDir, "error", err)
return false
}
return true
}
func setupLink(source, target string) error {
if err := os.RemoveAll(target); err != nil {
return fmt.Errorf("failed to remove old rocm directory %s %w", target, err)
}
if err := os.Symlink(source, target); err != nil {
return fmt.Errorf("failed to create link %s => %s %w", source, target, err)
}
slog.Debug(fmt.Sprintf("host rocm linked %s => %s", source, target))
return nil
}
// Ensure the AMD rocm lib dir is wired up
// Prefer to use host installed ROCm, as long as it meets our minimum requirements
// failing that, tell the user how to download it on their own
func AMDValidateLibDir() (string, error) {
// We rely on the rpath compiled into our library to find rocm
// so we establish a symlink to wherever we find it on the system
// to <payloads>/rocm
payloadsDir, err := PayloadsDir()
if err != nil {
return "", err
}
// If we already have a rocm dependency wired, nothing more to do
rocmTargetDir := filepath.Clean(filepath.Join(payloadsDir, "..", "rocm"))
if rocmLibUsable(rocmTargetDir) {
return rocmTargetDir, nil
}
// next to the running binary
exe, err := os.Executable()
libDir, err := commonAMDValidateLibDir()
if err == nil {
peerDir := filepath.Dir(exe)
if rocmLibUsable(peerDir) {
slog.Debug("detected ROCM next to ollama executable " + peerDir)
return rocmTargetDir, setupLink(peerDir, rocmTargetDir)
}
peerDir = filepath.Join(filepath.Dir(exe), "rocm")
if rocmLibUsable(peerDir) {
slog.Debug("detected ROCM next to ollama executable " + peerDir)
return rocmTargetDir, setupLink(peerDir, rocmTargetDir)
}
return libDir, nil
}
// Well known ollama installer path
installedRocmDir := "/usr/share/ollama/lib/rocm"
if rocmLibUsable(installedRocmDir) {
return rocmTargetDir, setupLink(installedRocmDir, rocmTargetDir)
}
// Prefer explicit HIP env var
hipPath := os.Getenv("HIP_PATH")
if hipPath != "" {
hipLibDir := filepath.Join(hipPath, "lib")
if rocmLibUsable(hipLibDir) {
slog.Debug("detected ROCM via HIP_PATH=" + hipPath)
return rocmTargetDir, setupLink(hipLibDir, rocmTargetDir)
}
}
// Scan the library path for potential matches
ldPaths := strings.Split(os.Getenv("LD_LIBRARY_PATH"), ":")
for _, ldPath := range ldPaths {
d, err := filepath.Abs(ldPath)
if err != nil {
continue
}
if rocmLibUsable(d) {
return rocmTargetDir, setupLink(d, rocmTargetDir)
}
}
// Well known location(s)
if rocmLibUsable("/opt/rocm/lib") {
return rocmTargetDir, setupLink("/opt/rocm/lib", rocmTargetDir)
return installedRocmDir, nil
}
// If we still haven't found a usable rocm, the user will have to install it on their own
@ -384,68 +338,3 @@ func AMDDriverVersion() (string, error) {
}
return strings.TrimSpace(string(verString)), nil
}
func AMDGFXVersions() map[int]Version {
// The amdgpu driver always exposes the host CPU as node 0, but we have to skip that and subtract one
// from the other IDs to get alignment with the HIP libraries expectations (zero is the first GPU, not the CPU)
res := map[int]Version{}
matches, _ := filepath.Glob(GPUPropertiesFileGlob)
for _, match := range matches {
fp, err := os.Open(match)
if err != nil {
slog.Debug(fmt.Sprintf("failed to open sysfs node file %s: %s", match, err))
continue
}
defer fp.Close()
i, err := strconv.Atoi(filepath.Base(filepath.Dir(match)))
if err != nil {
slog.Debug(fmt.Sprintf("failed to parse node ID %s", err))
continue
}
if i == 0 {
// Skipping the CPU
continue
}
// Align with HIP IDs (zero is first GPU, not CPU)
i -= 1
scanner := bufio.NewScanner(fp)
for scanner.Scan() {
line := strings.TrimSpace(scanner.Text())
if strings.HasPrefix(line, "gfx_target_version") {
ver := strings.Fields(line)
if len(ver) != 2 || len(ver[1]) < 5 {
if ver[1] != "0" {
slog.Debug("malformed " + line)
}
res[i] = Version{
Major: 0,
Minor: 0,
Patch: 0,
}
continue
}
l := len(ver[1])
patch, err1 := strconv.ParseUint(ver[1][l-2:l], 10, 32)
minor, err2 := strconv.ParseUint(ver[1][l-4:l-2], 10, 32)
major, err3 := strconv.ParseUint(ver[1][:l-4], 10, 32)
if err1 != nil || err2 != nil || err3 != nil {
slog.Debug("malformed int " + line)
continue
}
res[i] = Version{
Major: uint(major),
Minor: uint(minor),
Patch: uint(patch),
}
}
}
}
return res
}
func (v Version) ToGFXString() string {
return fmt.Sprintf("gfx%d%d%d", v.Major, v.Minor, v.Patch)
}

View file

@ -7,11 +7,13 @@ import (
"os"
"path/filepath"
"slices"
"strconv"
"strings"
"github.com/ollama/ollama/format"
)
const (
RocmStandardLocation = "C:\\Program Files\\AMD\\ROCm\\5.7\\bin" // TODO glob?
// TODO We're lookinng for this exact name to detect iGPUs since hipGetDeviceProperties never reports integrated==true
iGPUName = "AMD Radeon(TM) Graphics"
@ -19,39 +21,36 @@ const (
var (
// Used to validate if the given ROCm lib is usable
ROCmLibGlobs = []string{"hipblas.dll", "rocblas"} // TODO - probably include more coverage of files here...
ROCmLibGlobs = []string{"hipblas.dll", "rocblas"} // TODO - probably include more coverage of files here...
RocmStandardLocations = []string{"C:\\Program Files\\AMD\\ROCm\\5.7\\bin"} // TODO glob?
)
func AMDGetGPUInfo(resp *GpuInfo) {
func AMDGetGPUInfo() []GpuInfo {
resp := []GpuInfo{}
hl, err := NewHipLib()
if err != nil {
slog.Debug(err.Error())
return
return nil
}
defer hl.Release()
skip := map[int]interface{}{}
ids := []int{}
resp.memInfo.DeviceCount = 0
resp.memInfo.TotalMemory = 0
resp.memInfo.FreeMemory = 0
ver, err := hl.AMDDriverVersion()
if err == nil {
slog.Info("AMD Driver: " + ver)
} else {
// For now this is benign, but we may eventually need to fail compatibility checks
slog.Debug(fmt.Sprintf("error looking up amd driver version: %s", err))
slog.Debug("error looking up amd driver version", "error", err)
}
// Note: the HIP library automatically handles HIP_VISIBLE_DEVICES
// Note: the HIP library automatically handles subsetting to any HIP_VISIBLE_DEVICES the user specified
count := hl.HipGetDeviceCount()
if count == 0 {
return
return nil
}
libDir, err := AMDValidateLibDir()
if err != nil {
slog.Warn(fmt.Sprintf("unable to verify rocm library, will use cpu: %s", err))
return
slog.Warn("unable to verify rocm library, will use cpu", "error", err)
return nil
}
var supported []string
@ -59,95 +58,120 @@ func AMDGetGPUInfo(resp *GpuInfo) {
if gfxOverride == "" {
supported, err = GetSupportedGFX(libDir)
if err != nil {
slog.Warn(fmt.Sprintf("failed to lookup supported GFX types, falling back to CPU mode: %s", err))
return
slog.Warn("failed to lookup supported GFX types, falling back to CPU mode", "error", err)
return nil
}
} else {
slog.Debug("skipping rocm gfx compatibility check with HSA_OVERRIDE_GFX_VERSION=" + gfxOverride)
}
slog.Info(fmt.Sprintf("detected %d hip devices", count))
slog.Info("detected hip devices", "count", count)
// TODO how to determine the underlying device ID when visible devices is causing this to subset?
for i := 0; i < count; i++ {
ids = append(ids, i)
err = hl.HipSetDevice(i)
if err != nil {
slog.Warn(fmt.Sprintf("[%d] %s", i, err))
skip[i] = struct{}{}
slog.Warn("set device", "id", i, "error", err)
continue
}
props, err := hl.HipGetDeviceProperties(i)
if err != nil {
slog.Warn(fmt.Sprintf("[%d] %s", i, err))
skip[i] = struct{}{}
slog.Warn("get properties", "id", i, "error", err)
continue
}
n := bytes.IndexByte(props.Name[:], 0)
name := string(props.Name[:n])
slog.Info(fmt.Sprintf("[%d] Name: %s", i, name))
// TODO is UUID actually populated on windows?
// Can luid be used on windows for setting visible devices (and is it actually set?)
n = bytes.IndexByte(props.GcnArchName[:], 0)
gfx := string(props.GcnArchName[:n])
slog.Info(fmt.Sprintf("[%d] GcnArchName: %s", i, gfx))
slog.Info("hip device", "id", i, "name", name, "gfx", gfx)
var major, minor, patch string
switch len(gfx) {
case 6:
major, minor, patch = gfx[3:4], gfx[4:5], gfx[5:]
case 7:
major, minor, patch = gfx[3:5], gfx[5:6], gfx[6:]
}
//slog.Info(fmt.Sprintf("[%d] Integrated: %d", i, props.iGPU)) // DOESN'T REPORT CORRECTLY! Always 0
// TODO Why isn't props.iGPU accurate!?
if strings.EqualFold(name, iGPUName) {
slog.Info(fmt.Sprintf("iGPU detected [%d] skipping", i))
skip[i] = struct{}{}
slog.Info("iGPU detected skipping", "id", i)
continue
}
if gfxOverride == "" {
if !slices.Contains[[]string, string](supported, gfx) {
slog.Warn(fmt.Sprintf("amdgpu [%d] %s is not supported by %s %v", i, gfx, libDir, supported))
slog.Warn("amdgpu is not supported", "gpu", i, "gpu_type", gfx, "library", libDir, "supported_types", supported)
// TODO - consider discrete markdown just for ROCM troubleshooting?
slog.Warn("See https://github.com/ollama/ollama/blob/main/docs/troubleshooting.md for HSA_OVERRIDE_GFX_VERSION usage")
skip[i] = struct{}{}
continue
} else {
slog.Info(fmt.Sprintf("amdgpu [%d] %s is supported", i, gfx))
slog.Info("amdgpu is supported", "gpu", i, "gpu_type", gfx)
}
}
totalMemory, freeMemory, err := hl.HipMemGetInfo()
freeMemory, totalMemory, err := hl.HipMemGetInfo()
if err != nil {
slog.Warn(fmt.Sprintf("[%d] %s", i, err))
slog.Warn("get mem info", "id", i, "error", err)
continue
}
// TODO according to docs, freeMem may lie on windows!
slog.Info(fmt.Sprintf("[%d] Total Mem: %d", i, totalMemory))
slog.Info(fmt.Sprintf("[%d] Free Mem: %d", i, freeMemory))
resp.memInfo.DeviceCount++
resp.memInfo.TotalMemory += totalMemory
resp.memInfo.FreeMemory += freeMemory
// iGPU detection, remove this check once we can support an iGPU variant of the rocm library
if totalMemory < IGPUMemLimit {
slog.Info("amdgpu appears to be an iGPU, skipping", "gpu", i, "total", format.HumanBytes2(totalMemory))
continue
}
// TODO revisit this once ROCm v6 is available on windows.
// v5.7 only reports VRAM used by this process, so it's completely wrong and unusable
slog.Info("amdgpu memory", "gpu", i, "total", format.HumanBytes2(totalMemory))
slog.Info("amdgpu memory", "gpu", i, "available", format.HumanBytes2(freeMemory))
gpuInfo := GpuInfo{
Library: "rocm",
memInfo: memInfo{
TotalMemory: totalMemory,
FreeMemory: freeMemory,
},
ID: fmt.Sprintf("%d", i), // TODO this is probably wrong if we specify visible devices
DependencyPath: libDir,
MinimumMemory: rocmMinimumMemory,
}
if major != "" {
gpuInfo.Major, err = strconv.Atoi(major)
if err != nil {
slog.Info("failed to parse version", "version", gfx, "error", err)
}
}
if minor != "" {
gpuInfo.Minor, err = strconv.Atoi(minor)
if err != nil {
slog.Info("failed to parse version", "version", gfx, "error", err)
}
}
if patch != "" {
// Patch rev is hex; e.g. gfx90a
p, err := strconv.ParseInt(patch, 16, 0)
if err != nil {
slog.Info("failed to parse version", "version", gfx, "error", err)
} else {
gpuInfo.Patch = int(p)
}
}
if gpuInfo.Major < RocmComputeMin {
slog.Warn(fmt.Sprintf("amdgpu [%s] too old gfx%d%d%x", gpuInfo.ID, gpuInfo.Major, gpuInfo.Minor, gpuInfo.Patch))
continue
}
resp = append(resp, gpuInfo)
}
if resp.memInfo.DeviceCount > 0 {
resp.Library = "rocm"
}
// Abort if all GPUs are skipped
if len(skip) >= count {
slog.Info("all detected amdgpus are skipped, falling back to CPU")
return
}
if len(skip) > 0 {
amdSetVisibleDevices(ids, skip)
}
UpdatePath(libDir)
return resp
}
func AMDValidateLibDir() (string, error) {
// On windows non-admins typically can't create links
// so instead of trying to rely on rpath and a link in
// $LibDir/rocm, we instead rely on setting PATH to point
// to the location of the ROCm library
// Installer payload location if we're running the installed binary
exe, err := os.Executable()
libDir, err := commonAMDValidateLibDir()
if err == nil {
rocmTargetDir := filepath.Join(filepath.Dir(exe), "rocm")
if rocmLibUsable(rocmTargetDir) {
slog.Debug("detected ROCM next to ollama executable " + rocmTargetDir)
return rocmTargetDir, nil
}
return libDir, nil
}
// Installer payload (if we're running from some other location)
@ -159,21 +183,6 @@ func AMDValidateLibDir() (string, error) {
return rocmTargetDir, nil
}
// Prefer explicit HIP env var
hipPath := os.Getenv("HIP_PATH")
if hipPath != "" {
hipLibDir := filepath.Join(hipPath, "bin")
if rocmLibUsable(hipLibDir) {
slog.Debug("detected ROCM via HIP_PATH=" + hipPath)
return hipLibDir, nil
}
}
// Well known location(s)
if rocmLibUsable(RocmStandardLocation) {
return RocmStandardLocation, nil
}
// Should not happen on windows since we include it in the installer, but stand-alone binary might hit this
slog.Warn("amdgpu detected, but no compatible rocm library found. Please install ROCm")
return "", fmt.Errorf("no suitable rocm found, falling back to CPU")

View file

@ -12,6 +12,8 @@ import (
"sync"
"syscall"
"time"
"github.com/ollama/ollama/server/envconfig"
)
var (
@ -24,8 +26,16 @@ func PayloadsDir() (string, error) {
defer lock.Unlock()
var err error
if payloadsDir == "" {
runnersDir := envconfig.RunnersDir
if runnersDir != "" {
payloadsDir = runnersDir
return payloadsDir, nil
}
// The remainder only applies on non-windows where we still carry payloads in the main executable
cleanupTmpDirs()
tmpDir := os.Getenv("OLLAMA_TMPDIR")
tmpDir := envconfig.TmpDir
if tmpDir == "" {
tmpDir, err = os.MkdirTemp("", "ollama")
if err != nil {
@ -80,7 +90,7 @@ func cleanupTmpDirs() {
}
err = os.RemoveAll(d)
if err != nil {
slog.Debug(fmt.Sprintf("unable to cleanup stale tmpdir %s: %s", d, err))
slog.Debug("unable to cleanup stale tmpdir", "path", d, "error", err)
}
}
}
@ -88,7 +98,8 @@ func cleanupTmpDirs() {
func Cleanup() {
lock.Lock()
defer lock.Unlock()
if payloadsDir != "" {
runnersDir := envconfig.RunnersDir
if payloadsDir != "" && runnersDir == "" && runtime.GOOS != "windows" {
// We want to fully clean up the tmpdir parent of the payloads dir
tmpDir := filepath.Clean(filepath.Join(payloadsDir, ".."))
slog.Debug("cleaning up", "dir", tmpDir)
@ -120,7 +131,7 @@ func UpdatePath(dir string) {
}
}
newPath := strings.Join(append([]string{dir}, pathComponents...), ";")
slog.Info(fmt.Sprintf("Updating PATH to %s", newPath))
slog.Info("updating", "PATH", newPath)
os.Setenv("PATH", newPath)
}
// linux and darwin rely on rpath

22
gpu/cuda_common.go Normal file
View file

@ -0,0 +1,22 @@
//go:build linux || windows
package gpu
import (
"log/slog"
"strings"
)
func cudaGetVisibleDevicesEnv(gpuInfo []GpuInfo) (string, string) {
ids := []string{}
for _, info := range gpuInfo {
if info.Library != "cuda" {
// TODO shouldn't happen if things are wired correctly...
slog.Debug("cudaGetVisibleDevicesEnv skipping over non-cuda device", "library", info.Library)
continue
}
ids = append(ids, info.ID)
}
return "CUDA_VISIBLE_DEVICES", strings.Join(ids, ",")
}

View file

@ -16,22 +16,23 @@ import (
"os"
"path/filepath"
"runtime"
"strconv"
"strings"
"sync"
"unsafe"
"github.com/ollama/ollama/format"
"github.com/ollama/ollama/server/envconfig"
)
type handles struct {
nvml *C.nvml_handle_t
cudart *C.cudart_handle_t
deviceCount int
cudart *C.cudart_handle_t
nvcuda *C.nvcuda_handle_t
}
const (
cudaMinimumMemory = 457 * format.MebiByte
rocmMinimumMemory = 457 * format.MebiByte
cudaMinimumMemory = 256 * format.MebiByte
rocmMinimumMemory = 256 * format.MebiByte
)
var gpuMutex sync.Mutex
@ -39,26 +40,10 @@ var gpuMutex sync.Mutex
// With our current CUDA compile flags, older than 5.0 will not work properly
var CudaComputeMin = [2]C.int{5, 0}
// Possible locations for the nvidia-ml library
var NvmlLinuxGlobs = []string{
"/usr/local/cuda/lib64/libnvidia-ml.so*",
"/usr/lib/x86_64-linux-gnu/nvidia/current/libnvidia-ml.so*",
"/usr/lib/x86_64-linux-gnu/libnvidia-ml.so*",
"/usr/lib/wsl/lib/libnvidia-ml.so*",
"/usr/lib/wsl/drivers/*/libnvidia-ml.so*",
"/opt/cuda/lib64/libnvidia-ml.so*",
"/usr/lib*/libnvidia-ml.so*",
"/usr/lib/aarch64-linux-gnu/nvidia/current/libnvidia-ml.so*",
"/usr/lib/aarch64-linux-gnu/libnvidia-ml.so*",
"/usr/local/lib*/libnvidia-ml.so*",
var RocmComputeMin = 9
// TODO: are these stubs ever valid?
"/opt/cuda/targets/x86_64-linux/lib/stubs/libnvidia-ml.so*",
}
var NvmlWindowsGlobs = []string{
"c:\\Windows\\System32\\nvml.dll",
}
// TODO find a better way to detect iGPU instead of minimum memory
const IGPUMemLimit = 1 * format.GibiByte // 512G is what they typically report, so anything less than 1G must be iGPU
var CudartLinuxGlobs = []string{
"/usr/local/cuda/lib64/libcudart.so*",
@ -79,6 +64,22 @@ var CudartWindowsGlobs = []string{
"c:\\Program Files\\NVIDIA GPU Computing Toolkit\\CUDA\\v*\\bin\\cudart64_*.dll",
}
var NvcudaLinuxGlobs = []string{
"/usr/local/cuda*/targets/*/lib/libcuda.so*",
"/usr/lib/*-linux-gnu/nvidia/current/libcuda.so*",
"/usr/lib/*-linux-gnu/libcuda.so*",
"/usr/lib/wsl/lib/libcuda.so*",
"/usr/lib/wsl/drivers/*/libcuda.so*",
"/opt/cuda/lib*/libcuda.so*",
"/usr/local/cuda/lib*/libcuda.so*",
"/usr/lib*/libcuda.so*",
"/usr/local/lib*/libcuda.so*",
}
var NvcudaWindowsGlobs = []string{
"c:\\windows\\system*\\nvcuda.dll",
}
// Jetson devices have JETSON_JETPACK="x.y.z" factory set to the Jetpack version installed.
// Included to drive logic for reducing Ollama-allocated overhead on L4T/Jetson devices.
var CudaTegra string = os.Getenv("JETSON_JETPACK")
@ -88,61 +89,62 @@ func initGPUHandles() *handles {
// TODO - if the ollama build is CPU only, don't do these checks as they're irrelevant and confusing
gpuHandles := &handles{nil, nil}
var nvmlMgmtName string
var nvmlMgmtPatterns []string
gpuHandles := &handles{}
var cudartMgmtName string
var cudartMgmtPatterns []string
var nvcudaMgmtName string
var nvcudaMgmtPatterns []string
tmpDir, _ := PayloadsDir()
switch runtime.GOOS {
case "windows":
nvmlMgmtName = "nvml.dll"
nvmlMgmtPatterns = make([]string, len(NvmlWindowsGlobs))
copy(nvmlMgmtPatterns, NvmlWindowsGlobs)
cudartMgmtName = "cudart64_*.dll"
localAppData := os.Getenv("LOCALAPPDATA")
cudartMgmtPatterns = []string{filepath.Join(localAppData, "Programs", "Ollama", cudartMgmtName)}
cudartMgmtPatterns = append(cudartMgmtPatterns, CudartWindowsGlobs...)
// Aligned with driver, we can't carry as payloads
nvcudaMgmtName = "nvcuda.dll"
nvcudaMgmtPatterns = NvcudaWindowsGlobs
case "linux":
nvmlMgmtName = "libnvidia-ml.so"
nvmlMgmtPatterns = make([]string, len(NvmlLinuxGlobs))
copy(nvmlMgmtPatterns, NvmlLinuxGlobs)
cudartMgmtName = "libcudart.so*"
if tmpDir != "" {
// TODO - add "payloads" for subprocess
cudartMgmtPatterns = []string{filepath.Join(tmpDir, "cuda*", cudartMgmtName)}
}
cudartMgmtPatterns = append(cudartMgmtPatterns, CudartLinuxGlobs...)
// Aligned with driver, we can't carry as payloads
nvcudaMgmtName = "libcuda.so*"
nvcudaMgmtPatterns = NvcudaLinuxGlobs
default:
return gpuHandles
}
slog.Info("Detecting GPU type")
cudartLibPaths := FindGPULibs(cudartMgmtName, cudartMgmtPatterns)
if len(cudartLibPaths) > 0 {
cudart := LoadCUDARTMgmt(cudartLibPaths)
if cudart != nil {
slog.Info("Nvidia GPU detected via cudart")
gpuHandles.cudart = cudart
slog.Info("Detecting GPUs")
nvcudaLibPaths := FindGPULibs(nvcudaMgmtName, nvcudaMgmtPatterns)
if len(nvcudaLibPaths) > 0 {
deviceCount, nvcuda, libPath := LoadNVCUDAMgmt(nvcudaLibPaths)
if nvcuda != nil {
slog.Info("detected GPUs", "count", deviceCount, "library", libPath)
gpuHandles.nvcuda = nvcuda
gpuHandles.deviceCount = deviceCount
return gpuHandles
}
}
// TODO once we build confidence, remove this and the gpu_info_nvml.[ch] files
nvmlLibPaths := FindGPULibs(nvmlMgmtName, nvmlMgmtPatterns)
if len(nvmlLibPaths) > 0 {
nvml := LoadNVMLMgmt(nvmlLibPaths)
if nvml != nil {
slog.Info("Nvidia GPU detected via nvidia-ml")
gpuHandles.nvml = nvml
cudartLibPaths := FindGPULibs(cudartMgmtName, cudartMgmtPatterns)
if len(cudartLibPaths) > 0 {
deviceCount, cudart, libPath := LoadCUDARTMgmt(cudartLibPaths)
if cudart != nil {
slog.Info("detected GPUs", "library", libPath, "count", deviceCount)
gpuHandles.cudart = cudart
gpuHandles.deviceCount = deviceCount
return gpuHandles
}
}
return gpuHandles
}
func GetGPUInfo() GpuInfo {
func GetGPUInfo() GpuInfoList {
// TODO - consider exploring lspci (and equivalent on windows) to check for
// GPUs so we can report warnings if we see Nvidia/AMD but fail to load the libraries
gpuMutex.Lock()
@ -150,12 +152,12 @@ func GetGPUInfo() GpuInfo {
gpuHandles := initGPUHandles()
defer func() {
if gpuHandles.nvml != nil {
C.nvml_release(*gpuHandles.nvml)
}
if gpuHandles.cudart != nil {
C.cudart_release(*gpuHandles.cudart)
}
if gpuHandles.nvcuda != nil {
C.nvcuda_release(*gpuHandles.nvcuda)
}
}()
// All our GPU builds on x86 have AVX enabled, so fallback to CPU if we don't detect at least AVX
@ -164,73 +166,75 @@ func GetGPUInfo() GpuInfo {
slog.Warn("CPU does not have AVX or AVX2, disabling GPU support.")
}
var memInfo C.mem_info_t
resp := GpuInfo{}
if gpuHandles.nvml != nil && (cpuVariant != "" || runtime.GOARCH != "amd64") {
C.nvml_check_vram(*gpuHandles.nvml, &memInfo)
if memInfo.err != nil {
slog.Info(fmt.Sprintf("[nvidia-ml] error looking up NVML GPU memory: %s", C.GoString(memInfo.err)))
C.free(unsafe.Pointer(memInfo.err))
} else if memInfo.count > 0 {
// Verify minimum compute capability
var cc C.nvml_compute_capability_t
C.nvml_compute_capability(*gpuHandles.nvml, &cc)
if cc.err != nil {
slog.Info(fmt.Sprintf("[nvidia-ml] error looking up NVML GPU compute capability: %s", C.GoString(cc.err)))
C.free(unsafe.Pointer(cc.err))
} else if cc.major > CudaComputeMin[0] || (cc.major == CudaComputeMin[0] && cc.minor >= CudaComputeMin[1]) {
slog.Info(fmt.Sprintf("[nvidia-ml] NVML CUDA Compute Capability detected: %d.%d", cc.major, cc.minor))
resp.Library = "cuda"
resp.MinimumMemory = cudaMinimumMemory
} else {
slog.Info(fmt.Sprintf("[nvidia-ml] CUDA GPU is too old. Falling back to CPU mode. Compute Capability detected: %d.%d", cc.major, cc.minor))
}
}
} else if gpuHandles.cudart != nil && (cpuVariant != "" || runtime.GOARCH != "amd64") {
C.cudart_check_vram(*gpuHandles.cudart, &memInfo)
if memInfo.err != nil {
slog.Info(fmt.Sprintf("[cudart] error looking up CUDART GPU memory: %s", C.GoString(memInfo.err)))
C.free(unsafe.Pointer(memInfo.err))
} else if memInfo.count > 0 {
// Verify minimum compute capability
var cc C.cudart_compute_capability_t
C.cudart_compute_capability(*gpuHandles.cudart, &cc)
if cc.err != nil {
slog.Info(fmt.Sprintf("[cudart] error looking up CUDA compute capability: %s", C.GoString(cc.err)))
C.free(unsafe.Pointer(cc.err))
} else if cc.major > CudaComputeMin[0] || (cc.major == CudaComputeMin[0] && cc.minor >= CudaComputeMin[1]) {
slog.Info(fmt.Sprintf("[cudart] CUDART CUDA Compute Capability detected: %d.%d", cc.major, cc.minor))
resp.Library = "cuda"
resp.MinimumMemory = cudaMinimumMemory
} else {
slog.Info(fmt.Sprintf("[cudart] CUDA GPU is too old. Falling back to CPU mode. Compute Capability detected: %d.%d", cc.major, cc.minor))
}
}
} else {
AMDGetGPUInfo(&resp)
if resp.Library != "" {
resp.MinimumMemory = rocmMinimumMemory
return resp
}
}
if resp.Library == "" {
C.cpu_check_ram(&memInfo)
resp.Library = "cpu"
resp.Variant = cpuVariant
}
if memInfo.err != nil {
slog.Info(fmt.Sprintf("error looking up CPU memory: %s", C.GoString(memInfo.err)))
C.free(unsafe.Pointer(memInfo.err))
return resp
// On windows we bundle the nvidia library one level above the runner dir
depPath := ""
if runtime.GOOS == "windows" && envconfig.RunnersDir != "" {
depPath = filepath.Dir(envconfig.RunnersDir)
}
var memInfo C.mem_info_t
resp := []GpuInfo{}
// NVIDIA first
for i := 0; i < gpuHandles.deviceCount; i++ {
// TODO once we support CPU compilation variants of GPU libraries refine this...
if cpuVariant == "" && runtime.GOARCH == "amd64" {
continue
}
gpuInfo := GpuInfo{
Library: "cuda",
}
if gpuHandles.cudart != nil {
C.cudart_check_vram(*gpuHandles.cudart, C.int(i), &memInfo)
} else {
C.nvcuda_check_vram(*gpuHandles.nvcuda, C.int(i), &memInfo)
}
if memInfo.err != nil {
slog.Info("error looking up nvidia GPU memory", "error", C.GoString(memInfo.err))
C.free(unsafe.Pointer(memInfo.err))
continue
}
if memInfo.major < CudaComputeMin[0] || (memInfo.major == CudaComputeMin[0] && memInfo.minor < CudaComputeMin[1]) {
slog.Info(fmt.Sprintf("[%d] CUDA GPU is too old. Compute Capability detected: %d.%d", i, memInfo.major, memInfo.minor))
continue
}
gpuInfo.TotalMemory = uint64(memInfo.total)
gpuInfo.FreeMemory = uint64(memInfo.free)
gpuInfo.ID = C.GoString(&memInfo.gpu_id[0])
gpuInfo.Major = int(memInfo.major)
gpuInfo.Minor = int(memInfo.minor)
gpuInfo.MinimumMemory = cudaMinimumMemory
gpuInfo.DependencyPath = depPath
// TODO potentially sort on our own algorithm instead of what the underlying GPU library does...
resp = append(resp, gpuInfo)
}
// Then AMD
resp = append(resp, AMDGetGPUInfo()...)
if len(resp) == 0 {
C.cpu_check_ram(&memInfo)
if memInfo.err != nil {
slog.Info("error looking up CPU memory", "error", C.GoString(memInfo.err))
C.free(unsafe.Pointer(memInfo.err))
return resp
}
gpuInfo := GpuInfo{
Library: "cpu",
Variant: cpuVariant,
}
gpuInfo.TotalMemory = uint64(memInfo.total)
gpuInfo.FreeMemory = uint64(memInfo.free)
gpuInfo.ID = C.GoString(&memInfo.gpu_id[0])
resp = append(resp, gpuInfo)
}
resp.DeviceCount = uint32(memInfo.count)
resp.FreeMemory = uint64(memInfo.free)
resp.TotalMemory = uint64(memInfo.total)
return resp
}
func getCPUMem() (memInfo, error) {
func GetCPUMem() (memInfo, error) {
var ret memInfo
var info C.mem_info_t
C.cpu_check_ram(&info)
@ -243,29 +247,12 @@ func getCPUMem() (memInfo, error) {
return ret, nil
}
func CheckVRAM() (uint64, error) {
userLimit := os.Getenv("OLLAMA_MAX_VRAM")
if userLimit != "" {
avail, err := strconv.ParseInt(userLimit, 10, 64)
if err != nil {
return 0, fmt.Errorf("Invalid OLLAMA_MAX_VRAM setting %s: %s", userLimit, err)
}
slog.Info(fmt.Sprintf("user override OLLAMA_MAX_VRAM=%d", avail))
return uint64(avail), nil
}
gpuInfo := GetGPUInfo()
if gpuInfo.FreeMemory > 0 && (gpuInfo.Library == "cuda" || gpuInfo.Library == "rocm") {
return gpuInfo.FreeMemory, nil
}
return 0, fmt.Errorf("no GPU detected") // TODO - better handling of CPU based memory determiniation
}
func FindGPULibs(baseLibName string, patterns []string) []string {
func FindGPULibs(baseLibName string, defaultPatterns []string) []string {
// Multiple GPU libraries may exist, and some may not work, so keep trying until we exhaust them
var ldPaths []string
var patterns []string
gpuLibPaths := []string{}
slog.Info(fmt.Sprintf("Searching for GPU management library %s", baseLibName))
slog.Debug("Searching for GPU library", "name", baseLibName)
switch runtime.GOOS {
case "windows":
@ -283,8 +270,14 @@ func FindGPULibs(baseLibName string, patterns []string) []string {
}
patterns = append(patterns, filepath.Join(d, baseLibName+"*"))
}
slog.Debug(fmt.Sprintf("gpu management search paths: %v", patterns))
patterns = append(patterns, defaultPatterns...)
slog.Debug("gpu library search", "globs", patterns)
for _, pattern := range patterns {
// Nvidia PhysX known to return bogus results
if strings.Contains(pattern, "PhysX") {
slog.Debug("skipping PhysX cuda library path", "path", pattern)
}
// Ignore glob discovery errors
matches, _ := filepath.Glob(pattern)
for _, match := range matches {
@ -311,28 +304,11 @@ func FindGPULibs(baseLibName string, patterns []string) []string {
}
}
}
slog.Info(fmt.Sprintf("Discovered GPU libraries: %v", gpuLibPaths))
slog.Debug("discovered GPU libraries", "paths", gpuLibPaths)
return gpuLibPaths
}
func LoadNVMLMgmt(nvmlLibPaths []string) *C.nvml_handle_t {
var resp C.nvml_init_resp_t
resp.ch.verbose = getVerboseState()
for _, libPath := range nvmlLibPaths {
lib := C.CString(libPath)
defer C.free(unsafe.Pointer(lib))
C.nvml_init(lib, &resp)
if resp.err != nil {
slog.Info(fmt.Sprintf("Unable to load NVML management library %s: %s", libPath, C.GoString(resp.err)))
C.free(unsafe.Pointer(resp.err))
} else {
return &resp.ch
}
}
return nil
}
func LoadCUDARTMgmt(cudartLibPaths []string) *C.cudart_handle_t {
func LoadCUDARTMgmt(cudartLibPaths []string) (int, *C.cudart_handle_t, string) {
var resp C.cudart_init_resp_t
resp.ch.verbose = getVerboseState()
for _, libPath := range cudartLibPaths {
@ -340,18 +316,54 @@ func LoadCUDARTMgmt(cudartLibPaths []string) *C.cudart_handle_t {
defer C.free(unsafe.Pointer(lib))
C.cudart_init(lib, &resp)
if resp.err != nil {
slog.Info(fmt.Sprintf("Unable to load cudart CUDA management library %s: %s", libPath, C.GoString(resp.err)))
slog.Debug("Unable to load cudart", "library", libPath, "error", C.GoString(resp.err))
C.free(unsafe.Pointer(resp.err))
} else {
return &resp.ch
return int(resp.num_devices), &resp.ch, libPath
}
}
return nil
return 0, nil, ""
}
func LoadNVCUDAMgmt(nvcudaLibPaths []string) (int, *C.nvcuda_handle_t, string) {
var resp C.nvcuda_init_resp_t
resp.ch.verbose = getVerboseState()
for _, libPath := range nvcudaLibPaths {
lib := C.CString(libPath)
defer C.free(unsafe.Pointer(lib))
C.nvcuda_init(lib, &resp)
if resp.err != nil {
slog.Debug("Unable to load nvcuda", "library", libPath, "error", C.GoString(resp.err))
C.free(unsafe.Pointer(resp.err))
} else {
return int(resp.num_devices), &resp.ch, libPath
}
}
return 0, nil, ""
}
func getVerboseState() C.uint16_t {
if debug := os.Getenv("OLLAMA_DEBUG"); debug != "" {
if envconfig.Debug {
return C.uint16_t(1)
}
return C.uint16_t(0)
}
// Given the list of GPUs this instantiation is targeted for,
// figure out the visible devices environment variable
//
// If different libraries are detected, the first one is what we use
func (l GpuInfoList) GetVisibleDevicesEnv() (string, string) {
if len(l) == 0 {
return "", ""
}
switch l[0].Library {
case "cuda":
return cudaGetVisibleDevicesEnv(l)
case "rocm":
return rocmGetVisibleDevicesEnv(l)
default:
slog.Debug("no filter required for library " + l[0].Library)
return "", ""
}
}

View file

@ -9,52 +9,47 @@ package gpu
*/
import "C"
import (
"fmt"
"log/slog"
"os"
"runtime"
"strconv"
"github.com/ollama/ollama/format"
)
// CheckVRAM returns the free VRAM in bytes on Linux machines with NVIDIA GPUs
func CheckVRAM() (uint64, error) {
userLimit := os.Getenv("OLLAMA_MAX_VRAM")
if userLimit != "" {
avail, err := strconv.ParseInt(userLimit, 10, 64)
if err != nil {
return 0, fmt.Errorf("Invalid OLLAMA_MAX_VRAM setting %s: %s", userLimit, err)
}
slog.Info(fmt.Sprintf("user override OLLAMA_MAX_VRAM=%d", avail))
return uint64(avail), nil
}
const (
metalMinimumMemory = 384 * format.MebiByte
)
func GetGPUInfo() GpuInfoList {
mem, _ := GetCPUMem()
if runtime.GOARCH == "amd64" {
// gpu not supported, this may not be metal
return 0, nil
}
return uint64(C.getRecommendedMaxVRAM()), nil
}
func GetGPUInfo() GpuInfo {
mem, _ := getCPUMem()
if runtime.GOARCH == "amd64" {
return GpuInfo{
Library: "cpu",
Variant: GetCPUVariant(),
memInfo: mem,
return []GpuInfo{
{
Library: "cpu",
Variant: GetCPUVariant(),
memInfo: mem,
},
}
}
return GpuInfo{
info := GpuInfo{
Library: "metal",
memInfo: mem,
ID: "0",
}
info.TotalMemory = uint64(C.getRecommendedMaxVRAM())
// TODO is there a way to gather actual allocated video memory? (currentAllocatedSize doesn't work)
info.FreeMemory = info.TotalMemory
info.MinimumMemory = metalMinimumMemory
return []GpuInfo{info}
}
func getCPUMem() (memInfo, error) {
func GetCPUMem() (memInfo, error) {
return memInfo{
TotalMemory: uint64(C.getPhysicalMemory()),
FreeMemory: 0,
DeviceCount: 1,
}, nil
}
func (l GpuInfoList) GetVisibleDevicesEnv() (string, string) {
// No-op on darwin
return "", ""
}

View file

@ -38,12 +38,17 @@
extern "C" {
#endif
#define GPU_ID_LEN 64
typedef struct mem_info {
char *err; // If non-nill, caller responsible for freeing
char gpu_id[GPU_ID_LEN];
uint64_t total;
uint64_t free;
unsigned int count;
int igpu_index; // If >= 0, we detected an integrated GPU to ignore
char *err; // If non-nill, caller responsible for freeing
// Compute Capability
int major;
int minor;
} mem_info_t;
void cpu_check_ram(mem_info_t *resp);
@ -52,8 +57,8 @@ void cpu_check_ram(mem_info_t *resp);
}
#endif
#include "gpu_info_nvml.h"
#include "gpu_info_cudart.h"
#include "gpu_info_nvcuda.h"
#endif // __GPU_INFO_H__
#endif // __APPLE__

View file

@ -8,9 +8,11 @@ void cpu_check_ram(mem_info_t *resp) {
MEMORYSTATUSEX info;
info.dwLength = sizeof(info);
if (GlobalMemoryStatusEx(&info) != 0) {
resp->count = 1;
resp->total = info.ullTotalPhys;
resp->free = info.ullAvailPhys;
resp->major = 0;
resp->minor = 0;
snprintf(&resp->gpu_id[0], GPU_ID_LEN, "0");
} else {
resp->err = LOAD_ERR();
}
@ -27,9 +29,11 @@ void cpu_check_ram(mem_info_t *resp) {
if (sysinfo(&info) != 0) {
resp->err = strdup(strerror(errno));
} else {
resp->count = 1;
resp->total = info.totalram * info.mem_unit;
resp->free = info.freeram * info.mem_unit;
resp->major = 0;
resp->minor = 0;
snprintf(&resp->gpu_id[0], GPU_ID_LEN, "0");
}
return;
}

View file

@ -6,6 +6,7 @@
void cudart_init(char *cudart_lib_path, cudart_init_resp_t *resp) {
cudartReturn_t ret;
resp->err = NULL;
resp->num_devices = 0;
const int buflen = 256;
char buf[buflen + 1];
int i;
@ -21,6 +22,7 @@ void cudart_init(char *cudart_lib_path, cudart_init_resp_t *resp) {
{"cudaGetDeviceCount", (void *)&resp->ch.cudaGetDeviceCount},
{"cudaDeviceGetAttribute", (void *)&resp->ch.cudaDeviceGetAttribute},
{"cudaDriverGetVersion", (void *)&resp->ch.cudaDriverGetVersion},
{"cudaGetDeviceProperties", (void *)&resp->ch.cudaGetDeviceProperties},
{NULL, NULL},
};
@ -36,13 +38,7 @@ void cudart_init(char *cudart_lib_path, cudart_init_resp_t *resp) {
return;
}
// TODO once we've squashed the remaining corner cases remove this log
LOG(resp->ch.verbose, "wiring cudart library functions in %s\n", cudart_lib_path);
for (i = 0; l[i].s != NULL; i++) {
// TODO once we've squashed the remaining corner cases remove this log
LOG(resp->ch.verbose, "dlsym: %s\n", l[i].s);
*l[i].p = LOAD_SYMBOL(resp->ch.handle, l[i].s);
if (!l[i].p) {
char *msg = LOAD_ERR();
@ -63,7 +59,7 @@ void cudart_init(char *cudart_lib_path, cudart_init_resp_t *resp) {
UNLOAD_LIBRARY(resp->ch.handle);
resp->ch.handle = NULL;
if (ret == CUDA_ERROR_INSUFFICIENT_DRIVER) {
resp->err = strdup("your nvidia driver is too old or missing, please upgrade to run ollama");
resp->err = strdup("your nvidia driver is too old or missing. If you have a CUDA GPU please upgrade to run ollama");
return;
}
snprintf(buf, buflen, "cudart init failure: %d", ret);
@ -85,110 +81,95 @@ void cudart_init(char *cudart_lib_path, cudart_init_resp_t *resp) {
driverVersion.minor = (version - (driverVersion.major * 1000)) / 10;
LOG(resp->ch.verbose, "CUDA driver version: %d-%d\n", driverVersion.major, driverVersion.minor);
}
ret = (*resp->ch.cudaGetDeviceCount)(&resp->num_devices);
if (ret != CUDART_SUCCESS) {
LOG(resp->ch.verbose, "cudaGetDeviceCount err: %d\n", ret);
UNLOAD_LIBRARY(resp->ch.handle);
resp->ch.handle = NULL;
snprintf(buf, buflen, "unable to get device count: %d", ret);
resp->err = strdup(buf);
return;
}
}
void cudart_check_vram(cudart_handle_t h, mem_info_t *resp) {
void cudart_check_vram(cudart_handle_t h, int i, mem_info_t *resp) {
resp->err = NULL;
cudartMemory_t memInfo = {0,0,0};
cudartReturn_t ret;
const int buflen = 256;
char buf[buflen + 1];
int i;
if (h.handle == NULL) {
resp->err = strdup("cudart handle isn't initialized");
return;
}
// cudaGetDeviceCount takes int type, resp-> count is uint
int deviceCount;
ret = (*h.cudaGetDeviceCount)(&deviceCount);
ret = (*h.cudaSetDevice)(i);
if (ret != CUDART_SUCCESS) {
snprintf(buf, buflen, "unable to get device count: %d", ret);
snprintf(buf, buflen, "cudart device failed to initialize");
resp->err = strdup(buf);
return;
}
cudaDeviceProp_t props;
ret = (*h.cudaGetDeviceProperties)(&props, i);
if (ret != CUDART_SUCCESS) {
LOG(h.verbose, "[%d] device properties lookup failure: %d\n", i, ret);
snprintf(&resp->gpu_id[0], GPU_ID_LEN, "%d", i);
resp->major = 0;
resp->minor = 0;
} else {
resp->count = (unsigned int)deviceCount;
}
resp->total = 0;
resp->free = 0;
for (i = 0; i < resp-> count; i++) {
ret = (*h.cudaSetDevice)(i);
if (ret != CUDART_SUCCESS) {
snprintf(buf, buflen, "cudart device failed to initialize");
resp->err = strdup(buf);
return;
int allNull = 1;
for (int j = 0; j < 16; j++) {
if (props.uuid.bytes[j] != 0) {
allNull = 0;
break;
}
}
ret = (*h.cudaMemGetInfo)(&memInfo.free, &memInfo.total);
if (ret != CUDART_SUCCESS) {
snprintf(buf, buflen, "cudart device memory info lookup failure %d", ret);
resp->err = strdup(buf);
return;
if (allNull != 0) {
snprintf(&resp->gpu_id[0], GPU_ID_LEN, "%d", i);
} else {
// GPU-d110a105-ac29-1d54-7b49-9c90440f215b
snprintf(&resp->gpu_id[0], GPU_ID_LEN,
"GPU-%02x%02x%02x%02x-%02x%02x-%02x%02x-%02x%02x-%02x%02x%02x%02x%02x%02x",
props.uuid.bytes[0],
props.uuid.bytes[1],
props.uuid.bytes[2],
props.uuid.bytes[3],
props.uuid.bytes[4],
props.uuid.bytes[5],
props.uuid.bytes[6],
props.uuid.bytes[7],
props.uuid.bytes[8],
props.uuid.bytes[9],
props.uuid.bytes[10],
props.uuid.bytes[11],
props.uuid.bytes[12],
props.uuid.bytes[13],
props.uuid.bytes[14],
props.uuid.bytes[15]
);
}
resp->major = props.major;
resp->minor = props.minor;
LOG(h.verbose, "[%d] CUDA totalMem %lu\n", i, memInfo.total);
LOG(h.verbose, "[%d] CUDA freeMem %lu\n", i, memInfo.free);
resp->total += memInfo.total;
resp->free += memInfo.free;
// TODO add other useful properties from props
}
}
void cudart_compute_capability(cudart_handle_t h, cudart_compute_capability_t *resp) {
resp->err = NULL;
resp->major = 0;
resp->minor = 0;
int major = 0;
int minor = 0;
cudartReturn_t ret;
const int buflen = 256;
char buf[buflen + 1];
int i;
if (h.handle == NULL) {
resp->err = strdup("cudart handle not initialized");
return;
}
int devices;
ret = (*h.cudaGetDeviceCount)(&devices);
ret = (*h.cudaMemGetInfo)(&memInfo.free, &memInfo.total);
if (ret != CUDART_SUCCESS) {
snprintf(buf, buflen, "unable to get cudart device count: %d", ret);
snprintf(buf, buflen, "cudart device memory info lookup failure %d", ret);
resp->err = strdup(buf);
return;
}
for (i = 0; i < devices; i++) {
ret = (*h.cudaSetDevice)(i);
if (ret != CUDART_SUCCESS) {
snprintf(buf, buflen, "cudart device failed to initialize");
resp->err = strdup(buf);
return;
}
resp->total = memInfo.total;
resp->free = memInfo.free;
ret = (*h.cudaDeviceGetAttribute)(&major, cudartDevAttrComputeCapabilityMajor, i);
if (ret != CUDART_SUCCESS) {
snprintf(buf, buflen, "device compute capability lookup failure %d: %d", i, ret);
resp->err = strdup(buf);
return;
}
ret = (*h.cudaDeviceGetAttribute)(&minor, cudartDevAttrComputeCapabilityMinor, i);
if (ret != CUDART_SUCCESS) {
snprintf(buf, buflen, "device compute capability lookup failure %d: %d", i, ret);
resp->err = strdup(buf);
return;
}
// Report the lowest major.minor we detect as that limits our compatibility
if (resp->major == 0 || resp->major > major ) {
resp->major = major;
resp->minor = minor;
} else if ( resp->major == major && resp->minor > minor ) {
resp->minor = minor;
}
}
LOG(h.verbose, "[%s] CUDA totalMem %lu\n", resp->gpu_id, resp->total);
LOG(h.verbose, "[%s] CUDA freeMem %lu\n", resp->gpu_id, resp->free);
LOG(h.verbose, "[%s] Compute Capability %d.%d\n", resp->gpu_id, resp->major, resp->minor);
}
void cudart_release(cudart_handle_t h) {

View file

@ -6,14 +6,20 @@
// Just enough typedef's to dlopen/dlsym for memory information
typedef enum cudartReturn_enum {
CUDART_SUCCESS = 0,
CUDART_UNSUPPORTED = 1,
CUDA_ERROR_INSUFFICIENT_DRIVER = 35,
CUDART_ERROR_INVALID_VALUE = 1,
CUDART_ERROR_MEMORY_ALLOCATION = 2,
CUDART_ERROR_INSUFFICIENT_DRIVER = 35,
// Other values omitted for now...
} cudartReturn_t;
typedef enum cudartDeviceAttr_enum {
cudartDevAttrComputeCapabilityMajor = 75,
cudartDevAttrComputeCapabilityMinor = 76,
// TODO - not yet wired up but may be useful for Jetson or other
// integrated GPU scenarios with shared memory
cudaDevAttrIntegrated = 18
} cudartDeviceAttr_t;
typedef void *cudartDevice_t; // Opaque is sufficient
@ -28,6 +34,92 @@ typedef struct cudartDriverVersion {
int minor;
} cudartDriverVersion_t;
typedef struct cudaUUID {
unsigned char bytes[16];
} cudaUUID_t;
typedef struct cudaDeviceProp {
char name[256]; /**< ASCII string identifying device */
cudaUUID_t uuid; /**< 16-byte unique identifier */
char luid[8]; /**< 8-byte locally unique identifier. Value is undefined on TCC and non-Windows platforms */
unsigned int luidDeviceNodeMask; /**< LUID device node mask. Value is undefined on TCC and non-Windows platforms */
size_t totalGlobalMem; /**< Global memory available on device in bytes */
size_t sharedMemPerBlock; /**< Shared memory available per block in bytes */
int regsPerBlock; /**< 32-bit registers available per block */
int warpSize; /**< Warp size in threads */
size_t memPitch; /**< Maximum pitch in bytes allowed by memory copies */
int maxThreadsPerBlock; /**< Maximum number of threads per block */
int maxThreadsDim[3]; /**< Maximum size of each dimension of a block */
int maxGridSize[3]; /**< Maximum size of each dimension of a grid */
int clockRate; /**< Clock frequency in kilohertz */
size_t totalConstMem; /**< Constant memory available on device in bytes */
int major; /**< Major compute capability */
int minor; /**< Minor compute capability */
size_t textureAlignment; /**< Alignment requirement for textures */
size_t texturePitchAlignment; /**< Pitch alignment requirement for texture references bound to pitched memory */
int deviceOverlap; /**< Device can concurrently copy memory and execute a kernel. Deprecated. Use instead asyncEngineCount. */
int multiProcessorCount; /**< Number of multiprocessors on device */
int kernelExecTimeoutEnabled; /**< Specified whether there is a run time limit on kernels */
int integrated; /**< Device is integrated as opposed to discrete */
int canMapHostMemory; /**< Device can map host memory with cudaHostAlloc/cudaHostGetDevicePointer */
int computeMode; /**< Compute mode (See ::cudaComputeMode) */
int maxTexture1D; /**< Maximum 1D texture size */
int maxTexture1DMipmap; /**< Maximum 1D mipmapped texture size */
int maxTexture1DLinear; /**< Deprecated, do not use. Use cudaDeviceGetTexture1DLinearMaxWidth() or cuDeviceGetTexture1DLinearMaxWidth() instead. */
int maxTexture2D[2]; /**< Maximum 2D texture dimensions */
int maxTexture2DMipmap[2]; /**< Maximum 2D mipmapped texture dimensions */
int maxTexture2DLinear[3]; /**< Maximum dimensions (width, height, pitch) for 2D textures bound to pitched memory */
int maxTexture2DGather[2]; /**< Maximum 2D texture dimensions if texture gather operations have to be performed */
int maxTexture3D[3]; /**< Maximum 3D texture dimensions */
int maxTexture3DAlt[3]; /**< Maximum alternate 3D texture dimensions */
int maxTextureCubemap; /**< Maximum Cubemap texture dimensions */
int maxTexture1DLayered[2]; /**< Maximum 1D layered texture dimensions */
int maxTexture2DLayered[3]; /**< Maximum 2D layered texture dimensions */
int maxTextureCubemapLayered[2];/**< Maximum Cubemap layered texture dimensions */
int maxSurface1D; /**< Maximum 1D surface size */
int maxSurface2D[2]; /**< Maximum 2D surface dimensions */
int maxSurface3D[3]; /**< Maximum 3D surface dimensions */
int maxSurface1DLayered[2]; /**< Maximum 1D layered surface dimensions */
int maxSurface2DLayered[3]; /**< Maximum 2D layered surface dimensions */
int maxSurfaceCubemap; /**< Maximum Cubemap surface dimensions */
int maxSurfaceCubemapLayered[2];/**< Maximum Cubemap layered surface dimensions */
size_t surfaceAlignment; /**< Alignment requirements for surfaces */
int concurrentKernels; /**< Device can possibly execute multiple kernels concurrently */
int ECCEnabled; /**< Device has ECC support enabled */
int pciBusID; /**< PCI bus ID of the device */
int pciDeviceID; /**< PCI device ID of the device */
int pciDomainID; /**< PCI domain ID of the device */
int tccDriver; /**< 1 if device is a Tesla device using TCC driver, 0 otherwise */
int asyncEngineCount; /**< Number of asynchronous engines */
int unifiedAddressing; /**< Device shares a unified address space with the host */
int memoryClockRate; /**< Peak memory clock frequency in kilohertz */
int memoryBusWidth; /**< Global memory bus width in bits */
int l2CacheSize; /**< Size of L2 cache in bytes */
int persistingL2CacheMaxSize; /**< Device's maximum l2 persisting lines capacity setting in bytes */
int maxThreadsPerMultiProcessor;/**< Maximum resident threads per multiprocessor */
int streamPrioritiesSupported; /**< Device supports stream priorities */
int globalL1CacheSupported; /**< Device supports caching globals in L1 */
int localL1CacheSupported; /**< Device supports caching locals in L1 */
size_t sharedMemPerMultiprocessor; /**< Shared memory available per multiprocessor in bytes */
int regsPerMultiprocessor; /**< 32-bit registers available per multiprocessor */
int managedMemory; /**< Device supports allocating managed memory on this system */
int isMultiGpuBoard; /**< Device is on a multi-GPU board */
int multiGpuBoardGroupID; /**< Unique identifier for a group of devices on the same multi-GPU board */
int hostNativeAtomicSupported; /**< Link between the device and the host supports native atomic operations */
int singleToDoublePrecisionPerfRatio; /**< Ratio of single precision performance (in floating-point operations per second) to double precision performance */
int pageableMemoryAccess; /**< Device supports coherently accessing pageable memory without calling cudaHostRegister on it */
int concurrentManagedAccess; /**< Device can coherently access managed memory concurrently with the CPU */
int computePreemptionSupported; /**< Device supports Compute Preemption */
int canUseHostPointerForRegisteredMem; /**< Device can access host registered memory at the same virtual address as the CPU */
int cooperativeLaunch; /**< Device supports launching cooperative kernels via ::cudaLaunchCooperativeKernel */
int cooperativeMultiDeviceLaunch; /**< Deprecated, cudaLaunchCooperativeKernelMultiDevice is deprecated. */
size_t sharedMemPerBlockOptin; /**< Per device maximum shared memory per block usable by special opt in */
int pageableMemoryAccessUsesHostPageTables; /**< Device accesses pageable memory via the host's page tables */
int directManagedMemAccessFromHost; /**< Host can directly access managed memory on the device without migration. */
int maxBlocksPerMultiProcessor; /**< Maximum number of resident blocks per multiprocessor */
int accessPolicyMaxWindowSize; /**< The maximum value of ::cudaAccessPolicyWindow::num_bytes. */
size_t reservedSharedMemPerBlock; /**< Shared memory reserved by CUDA driver per block in bytes */
} cudaDeviceProp_t;
typedef struct cudart_handle {
void *handle;
uint16_t verbose;
@ -38,23 +130,17 @@ typedef struct cudart_handle {
cudartReturn_t (*cudaGetDeviceCount)(int *);
cudartReturn_t (*cudaDeviceGetAttribute)(int* value, cudartDeviceAttr_t attr, int device);
cudartReturn_t (*cudaDriverGetVersion) (int *driverVersion);
cudartReturn_t (*cudaGetDeviceProperties) (cudaDeviceProp_t* prop, int device);
} cudart_handle_t;
typedef struct cudart_init_resp {
char *err; // If err is non-null handle is invalid
cudart_handle_t ch;
int num_devices;
} cudart_init_resp_t;
typedef struct cudart_compute_capability {
char *err;
int major;
int minor;
} cudart_compute_capability_t;
void cudart_init(char *cudart_lib_path, cudart_init_resp_t *resp);
void cudart_check_vram(cudart_handle_t ch, mem_info_t *resp);
void cudart_compute_capability(cudart_handle_t th, cudart_compute_capability_t *cc);
void cudart_check_vram(cudart_handle_t ch, int device_id, mem_info_t *resp);
void cudart_release(cudart_handle_t ch);
#endif // __GPU_INFO_CUDART_H__

203
gpu/gpu_info_nvcuda.c Normal file
View file

@ -0,0 +1,203 @@
#ifndef __APPLE__ // TODO - maybe consider nvidia support on intel macs?
#include <string.h>
#include "gpu_info_nvcuda.h"
void nvcuda_init(char *nvcuda_lib_path, nvcuda_init_resp_t *resp) {
CUresult ret;
resp->err = NULL;
resp->num_devices = 0;
const int buflen = 256;
char buf[buflen + 1];
int i;
struct lookup {
char *s;
void **p;
} l[] = {
{"cuInit", (void *)&resp->ch.cuInit},
{"cuDriverGetVersion", (void *)&resp->ch.cuDriverGetVersion},
{"cuDeviceGetCount", (void *)&resp->ch.cuDeviceGetCount},
{"cuDeviceGet", (void *)&resp->ch.cuDeviceGet},
{"cuDeviceGetAttribute", (void *)&resp->ch.cuDeviceGetAttribute},
{"cuDeviceGetUuid", (void *)&resp->ch.cuDeviceGetUuid},
{"cuCtxCreate_v3", (void *)&resp->ch.cuCtxCreate_v3},
{"cuMemGetInfo_v2", (void *)&resp->ch.cuMemGetInfo_v2},
{"cuCtxDestroy", (void *)&resp->ch.cuCtxDestroy},
{NULL, NULL},
};
resp->ch.handle = LOAD_LIBRARY(nvcuda_lib_path, RTLD_LAZY);
if (!resp->ch.handle) {
char *msg = LOAD_ERR();
LOG(resp->ch.verbose, "library %s load err: %s\n", nvcuda_lib_path, msg);
snprintf(buf, buflen,
"Unable to load %s library to query for Nvidia GPUs: %s",
nvcuda_lib_path, msg);
free(msg);
resp->err = strdup(buf);
return;
}
for (i = 0; l[i].s != NULL; i++) {
*l[i].p = LOAD_SYMBOL(resp->ch.handle, l[i].s);
if (!*l[i].p) {
char *msg = LOAD_ERR();
LOG(resp->ch.verbose, "dlerr: %s\n", msg);
UNLOAD_LIBRARY(resp->ch.handle);
resp->ch.handle = NULL;
snprintf(buf, buflen, "symbol lookup for %s failed: %s", l[i].s,
msg);
free(msg);
resp->err = strdup(buf);
return;
}
}
ret = (*resp->ch.cuInit)(0);
if (ret != CUDA_SUCCESS) {
LOG(resp->ch.verbose, "cuInit err: %d\n", ret);
UNLOAD_LIBRARY(resp->ch.handle);
resp->ch.handle = NULL;
if (ret == CUDA_ERROR_INSUFFICIENT_DRIVER) {
resp->err = strdup("your nvidia driver is too old or missing. If you have a CUDA GPU please upgrade to run ollama");
return;
}
snprintf(buf, buflen, "nvcuda init failure: %d", ret);
resp->err = strdup(buf);
return;
}
int version = 0;
nvcudaDriverVersion_t driverVersion;
driverVersion.major = 0;
driverVersion.minor = 0;
// Report driver version if we're in verbose mode, ignore errors
ret = (*resp->ch.cuDriverGetVersion)(&version);
if (ret != CUDA_SUCCESS) {
LOG(resp->ch.verbose, "cuDriverGetVersion failed: %d\n", ret);
} else {
driverVersion.major = version / 1000;
driverVersion.minor = (version - (driverVersion.major * 1000)) / 10;
LOG(resp->ch.verbose, "CUDA driver version: %d-%d\n", driverVersion.major, driverVersion.minor);
}
ret = (*resp->ch.cuDeviceGetCount)(&resp->num_devices);
if (ret != CUDA_SUCCESS) {
LOG(resp->ch.verbose, "cuDeviceGetCount err: %d\n", ret);
UNLOAD_LIBRARY(resp->ch.handle);
resp->ch.handle = NULL;
snprintf(buf, buflen, "unable to get device count: %d", ret);
resp->err = strdup(buf);
return;
}
}
const int buflen = 256;
void nvcuda_check_vram(nvcuda_handle_t h, int i, mem_info_t *resp) {
resp->err = NULL;
nvcudaMemory_t memInfo = {0,0};
CUresult ret;
CUdevice device = -1;
CUcontext ctx = NULL;
char buf[buflen + 1];
CUuuid uuid = {0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0};
if (h.handle == NULL) {
resp->err = strdup("nvcuda handle isn't initialized");
return;
}
ret = (*h.cuDeviceGet)(&device, i);
if (ret != CUDA_SUCCESS) {
snprintf(buf, buflen, "nvcuda device failed to initialize");
resp->err = strdup(buf);
return;
}
resp->major = 0;
resp->minor = 0;
int major = 0;
int minor = 0;
ret = (*h.cuDeviceGetAttribute)(&major, CU_DEVICE_ATTRIBUTE_COMPUTE_CAPABILITY_MAJOR, device);
if (ret != CUDA_SUCCESS) {
LOG(h.verbose, "[%d] device major lookup failure: %d\n", i, ret);
} else {
ret = (*h.cuDeviceGetAttribute)(&minor, CU_DEVICE_ATTRIBUTE_COMPUTE_CAPABILITY_MINOR, device);
if (ret != CUDA_SUCCESS) {
LOG(h.verbose, "[%d] device minor lookup failure: %d\n", i, ret);
} else {
resp->minor = minor;
resp->major = major;
}
}
ret = (*h.cuDeviceGetUuid)(&uuid, device);
if (ret != CUDA_SUCCESS) {
LOG(h.verbose, "[%d] device uuid lookup failure: %d\n", i, ret);
snprintf(&resp->gpu_id[0], GPU_ID_LEN, "%d", i);
} else {
// GPU-d110a105-ac29-1d54-7b49-9c90440f215b
snprintf(&resp->gpu_id[0], GPU_ID_LEN,
"GPU-%02x%02x%02x%02x-%02x%02x-%02x%02x-%02x%02x-%02x%02x%02x%02x%02x%02x",
uuid.bytes[0],
uuid.bytes[1],
uuid.bytes[2],
uuid.bytes[3],
uuid.bytes[4],
uuid.bytes[5],
uuid.bytes[6],
uuid.bytes[7],
uuid.bytes[8],
uuid.bytes[9],
uuid.bytes[10],
uuid.bytes[11],
uuid.bytes[12],
uuid.bytes[13],
uuid.bytes[14],
uuid.bytes[15]
);
}
// To get memory we have to set (and release) a context
ret = (*h.cuCtxCreate_v3)(&ctx, NULL, 0, 0, device);
if (ret != CUDA_SUCCESS) {
snprintf(buf, buflen, "nvcuda failed to get primary device context %d", ret);
resp->err = strdup(buf);
return;
}
ret = (*h.cuMemGetInfo_v2)(&memInfo.free, &memInfo.total);
if (ret != CUDA_SUCCESS) {
snprintf(buf, buflen, "nvcuda device memory info lookup failure %d", ret);
resp->err = strdup(buf);
// Best effort on failure...
(*h.cuCtxDestroy)(ctx);
return;
}
resp->total = memInfo.total;
resp->free = memInfo.free;
LOG(h.verbose, "[%s] CUDA totalMem %lu mb\n", resp->gpu_id, resp->total / 1024 / 1024);
LOG(h.verbose, "[%s] CUDA freeMem %lu mb\n", resp->gpu_id, resp->free / 1024 / 1024);
LOG(h.verbose, "[%s] Compute Capability %d.%d\n", resp->gpu_id, resp->major, resp->minor);
ret = (*h.cuCtxDestroy)(ctx);
if (ret != CUDA_SUCCESS) {
LOG(1, "nvcuda failed to release primary device context %d", ret);
}
}
void nvcuda_release(nvcuda_handle_t h) {
LOG(h.verbose, "releasing nvcuda library\n");
UNLOAD_LIBRARY(h.handle);
// TODO and other context release logic?
h.handle = NULL;
}
#endif // __APPLE__

71
gpu/gpu_info_nvcuda.h Normal file
View file

@ -0,0 +1,71 @@
#ifndef __APPLE__
#ifndef __GPU_INFO_NVCUDA_H__
#define __GPU_INFO_NVCUDA_H__
#include "gpu_info.h"
// Just enough typedef's to dlopen/dlsym for memory information
typedef enum cudaError_enum {
CUDA_SUCCESS = 0,
CUDA_ERROR_INVALID_VALUE = 1,
CUDA_ERROR_MEMORY_ALLOCATION = 2,
CUDA_ERROR_NOT_INITIALIZED = 3,
CUDA_ERROR_INSUFFICIENT_DRIVER = 35,
// Other values omitted for now...
} CUresult;
typedef enum CUdevice_attribute_enum {
CU_DEVICE_ATTRIBUTE_COMPUTE_CAPABILITY_MAJOR = 75,
CU_DEVICE_ATTRIBUTE_COMPUTE_CAPABILITY_MINOR = 76,
// TODO - not yet wired up but may be useful for Jetson or other
// integrated GPU scenarios with shared memory
CU_DEVICE_ATTRIBUTE_INTEGRATED = 18
} CUdevice_attribute;
typedef void *nvcudaDevice_t; // Opaque is sufficient
typedef struct nvcudaMemory_st {
uint64_t total;
uint64_t free;
} nvcudaMemory_t;
typedef struct nvcudaDriverVersion {
int major;
int minor;
} nvcudaDriverVersion_t;
typedef struct CUuuid_st {
unsigned char bytes[16];
} CUuuid;
typedef int CUdevice;
typedef void* CUcontext;
typedef struct nvcuda_handle {
void *handle;
uint16_t verbose;
CUresult (*cuInit)(unsigned int Flags);
CUresult (*cuDriverGetVersion)(int *driverVersion);
CUresult (*cuDeviceGetCount)(int *);
CUresult (*cuDeviceGet)(CUdevice* device, int ordinal);
CUresult (*cuDeviceGetAttribute)(int* pi, CUdevice_attribute attrib, CUdevice dev);
CUresult (*cuDeviceGetUuid)(CUuuid* uuid, CUdevice dev); // signature compatible with cuDeviceGetUuid_v2
// Context specific aspects
CUresult (*cuCtxCreate_v3)(CUcontext* pctx, void *params, int len, unsigned int flags, CUdevice dev);
CUresult (*cuMemGetInfo_v2)(uint64_t* free, uint64_t* total);
CUresult (*cuCtxDestroy)(CUcontext ctx);
} nvcuda_handle_t;
typedef struct nvcuda_init_resp {
char *err; // If err is non-null handle is invalid
nvcuda_handle_t ch;
int num_devices;
} nvcuda_init_resp_t;
void nvcuda_init(char *nvcuda_lib_path, nvcuda_init_resp_t *resp);
void nvcuda_check_vram(nvcuda_handle_t ch, int device_id, mem_info_t *resp);
void nvcuda_release(nvcuda_handle_t ch);
#endif // __GPU_INFO_NVCUDA_H__
#endif // __APPLE__

View file

@ -1,221 +0,0 @@
#ifndef __APPLE__ // TODO - maybe consider nvidia support on intel macs?
#include <string.h>
#include "gpu_info_nvml.h"
void nvml_init(char *nvml_lib_path, nvml_init_resp_t *resp) {
nvmlReturn_t ret;
resp->err = NULL;
const int buflen = 256;
char buf[buflen + 1];
int i;
struct lookup {
char *s;
void **p;
} l[] = {
{"nvmlInit_v2", (void *)&resp->ch.nvmlInit_v2},
{"nvmlShutdown", (void *)&resp->ch.nvmlShutdown},
{"nvmlDeviceGetHandleByIndex", (void *)&resp->ch.nvmlDeviceGetHandleByIndex},
{"nvmlDeviceGetMemoryInfo", (void *)&resp->ch.nvmlDeviceGetMemoryInfo},
{"nvmlDeviceGetCount_v2", (void *)&resp->ch.nvmlDeviceGetCount_v2},
{"nvmlDeviceGetCudaComputeCapability", (void *)&resp->ch.nvmlDeviceGetCudaComputeCapability},
{"nvmlSystemGetDriverVersion", (void *)&resp->ch.nvmlSystemGetDriverVersion},
{"nvmlDeviceGetName", (void *)&resp->ch.nvmlDeviceGetName},
{"nvmlDeviceGetSerial", (void *)&resp->ch.nvmlDeviceGetSerial},
{"nvmlDeviceGetVbiosVersion", (void *)&resp->ch.nvmlDeviceGetVbiosVersion},
{"nvmlDeviceGetBoardPartNumber", (void *)&resp->ch.nvmlDeviceGetBoardPartNumber},
{"nvmlDeviceGetBrand", (void *)&resp->ch.nvmlDeviceGetBrand},
{NULL, NULL},
};
resp->ch.handle = LOAD_LIBRARY(nvml_lib_path, RTLD_LAZY);
if (!resp->ch.handle) {
char *msg = LOAD_ERR();
LOG(resp->ch.verbose, "library %s load err: %s\n", nvml_lib_path, msg);
snprintf(buf, buflen,
"Unable to load %s library to query for Nvidia GPUs: %s",
nvml_lib_path, msg);
free(msg);
resp->err = strdup(buf);
return;
}
// TODO once we've squashed the remaining corner cases remove this log
LOG(resp->ch.verbose, "wiring nvidia management library functions in %s\n", nvml_lib_path);
for (i = 0; l[i].s != NULL; i++) {
// TODO once we've squashed the remaining corner cases remove this log
LOG(resp->ch.verbose, "dlsym: %s\n", l[i].s);
*l[i].p = LOAD_SYMBOL(resp->ch.handle, l[i].s);
if (!l[i].p) {
resp->ch.handle = NULL;
char *msg = LOAD_ERR();
LOG(resp->ch.verbose, "dlerr: %s\n", msg);
UNLOAD_LIBRARY(resp->ch.handle);
snprintf(buf, buflen, "symbol lookup for %s failed: %s", l[i].s,
msg);
free(msg);
resp->err = strdup(buf);
return;
}
}
ret = (*resp->ch.nvmlInit_v2)();
if (ret != NVML_SUCCESS) {
LOG(resp->ch.verbose, "nvmlInit_v2 err: %d\n", ret);
UNLOAD_LIBRARY(resp->ch.handle);
resp->ch.handle = NULL;
snprintf(buf, buflen, "nvml vram init failure: %d", ret);
resp->err = strdup(buf);
return;
}
// Report driver version if we're in verbose mode, ignore errors
ret = (*resp->ch.nvmlSystemGetDriverVersion)(buf, buflen);
if (ret != NVML_SUCCESS) {
LOG(resp->ch.verbose, "nvmlSystemGetDriverVersion failed: %d\n", ret);
} else {
LOG(resp->ch.verbose, "CUDA driver version: %s\n", buf);
}
}
void nvml_check_vram(nvml_handle_t h, mem_info_t *resp) {
resp->err = NULL;
nvmlDevice_t device;
nvmlMemory_t memInfo = {0};
nvmlReturn_t ret;
const int buflen = 256;
char buf[buflen + 1];
int i;
if (h.handle == NULL) {
resp->err = strdup("nvml handle isn't initialized");
return;
}
ret = (*h.nvmlDeviceGetCount_v2)(&resp->count);
if (ret != NVML_SUCCESS) {
snprintf(buf, buflen, "unable to get device count: %d", ret);
resp->err = strdup(buf);
return;
}
resp->total = 0;
resp->free = 0;
for (i = 0; i < resp->count; i++) {
ret = (*h.nvmlDeviceGetHandleByIndex)(i, &device);
if (ret != NVML_SUCCESS) {
snprintf(buf, buflen, "unable to get device handle %d: %d", i, ret);
resp->err = strdup(buf);
return;
}
ret = (*h.nvmlDeviceGetMemoryInfo)(device, &memInfo);
if (ret != NVML_SUCCESS) {
snprintf(buf, buflen, "device memory info lookup failure %d: %d", i, ret);
resp->err = strdup(buf);
return;
}
if (h.verbose) {
nvmlBrandType_t brand = 0;
// When in verbose mode, report more information about
// the card we discover, but don't fail on error
ret = (*h.nvmlDeviceGetName)(device, buf, buflen);
if (ret != NVML_SUCCESS) {
LOG(h.verbose, "nvmlDeviceGetName failed: %d\n", ret);
} else {
LOG(h.verbose, "[%d] CUDA device name: %s\n", i, buf);
}
ret = (*h.nvmlDeviceGetBoardPartNumber)(device, buf, buflen);
if (ret != NVML_SUCCESS) {
LOG(h.verbose, "nvmlDeviceGetBoardPartNumber failed: %d\n", ret);
} else {
LOG(h.verbose, "[%d] CUDA part number: %s\n", i, buf);
}
ret = (*h.nvmlDeviceGetSerial)(device, buf, buflen);
if (ret != NVML_SUCCESS) {
LOG(h.verbose, "nvmlDeviceGetSerial failed: %d\n", ret);
} else {
LOG(h.verbose, "[%d] CUDA S/N: %s\n", i, buf);
}
ret = (*h.nvmlDeviceGetVbiosVersion)(device, buf, buflen);
if (ret != NVML_SUCCESS) {
LOG(h.verbose, "nvmlDeviceGetVbiosVersion failed: %d\n", ret);
} else {
LOG(h.verbose, "[%d] CUDA vbios version: %s\n", i, buf);
}
ret = (*h.nvmlDeviceGetBrand)(device, &brand);
if (ret != NVML_SUCCESS) {
LOG(h.verbose, "nvmlDeviceGetBrand failed: %d\n", ret);
} else {
LOG(h.verbose, "[%d] CUDA brand: %d\n", i, brand);
}
}
LOG(h.verbose, "[%d] CUDA totalMem %ld\n", i, memInfo.total);
LOG(h.verbose, "[%d] CUDA freeMem %ld\n", i, memInfo.free);
resp->total += memInfo.total;
resp->free += memInfo.free;
}
}
void nvml_compute_capability(nvml_handle_t h, nvml_compute_capability_t *resp) {
resp->err = NULL;
resp->major = 0;
resp->minor = 0;
nvmlDevice_t device;
int major = 0;
int minor = 0;
nvmlReturn_t ret;
const int buflen = 256;
char buf[buflen + 1];
int i;
if (h.handle == NULL) {
resp->err = strdup("nvml handle not initialized");
return;
}
unsigned int devices;
ret = (*h.nvmlDeviceGetCount_v2)(&devices);
if (ret != NVML_SUCCESS) {
snprintf(buf, buflen, "unable to get device count: %d", ret);
resp->err = strdup(buf);
return;
}
for (i = 0; i < devices; i++) {
ret = (*h.nvmlDeviceGetHandleByIndex)(i, &device);
if (ret != NVML_SUCCESS) {
snprintf(buf, buflen, "unable to get device handle %d: %d", i, ret);
resp->err = strdup(buf);
return;
}
ret = (*h.nvmlDeviceGetCudaComputeCapability)(device, &major, &minor);
if (ret != NVML_SUCCESS) {
snprintf(buf, buflen, "device compute capability lookup failure %d: %d", i, ret);
resp->err = strdup(buf);
return;
}
// Report the lowest major.minor we detect as that limits our compatibility
if (resp->major == 0 || resp->major > major ) {
resp->major = major;
resp->minor = minor;
} else if ( resp->major == major && resp->minor > minor ) {
resp->minor = minor;
}
}
}
void nvml_release(nvml_handle_t h) {
LOG(h.verbose, "releasing nvml library\n");
UNLOAD_LIBRARY(h.handle);
h.handle = NULL;
}
#endif // __APPLE__

View file

@ -1,57 +0,0 @@
#ifndef __APPLE__
#ifndef __GPU_INFO_NVML_H__
#define __GPU_INFO_NVML_H__
#include "gpu_info.h"
// Just enough typedef's to dlopen/dlsym for memory information
typedef enum nvmlReturn_enum {
NVML_SUCCESS = 0,
// Other values omitted for now...
} nvmlReturn_t;
typedef void *nvmlDevice_t; // Opaque is sufficient
typedef struct nvmlMemory_st {
unsigned long long total;
unsigned long long free;
unsigned long long used;
} nvmlMemory_t;
typedef enum nvmlBrandType_enum
{
NVML_BRAND_UNKNOWN = 0,
} nvmlBrandType_t;
typedef struct nvml_handle {
void *handle;
uint16_t verbose;
nvmlReturn_t (*nvmlInit_v2)(void);
nvmlReturn_t (*nvmlShutdown)(void);
nvmlReturn_t (*nvmlDeviceGetHandleByIndex)(unsigned int, nvmlDevice_t *);
nvmlReturn_t (*nvmlDeviceGetMemoryInfo)(nvmlDevice_t, nvmlMemory_t *);
nvmlReturn_t (*nvmlDeviceGetCount_v2)(unsigned int *);
nvmlReturn_t (*nvmlDeviceGetCudaComputeCapability)(nvmlDevice_t, int* major, int* minor);
nvmlReturn_t (*nvmlSystemGetDriverVersion) (char* version, unsigned int length);
nvmlReturn_t (*nvmlDeviceGetName) (nvmlDevice_t device, char* name, unsigned int length);
nvmlReturn_t (*nvmlDeviceGetSerial) (nvmlDevice_t device, char* serial, unsigned int length);
nvmlReturn_t (*nvmlDeviceGetVbiosVersion) (nvmlDevice_t device, char* version, unsigned int length);
nvmlReturn_t (*nvmlDeviceGetBoardPartNumber) (nvmlDevice_t device, char* partNumber, unsigned int length);
nvmlReturn_t (*nvmlDeviceGetBrand) (nvmlDevice_t device, nvmlBrandType_t* type);
} nvml_handle_t;
typedef struct nvml_init_resp {
char *err; // If err is non-null handle is invalid
nvml_handle_t ch;
} nvml_init_resp_t;
typedef struct nvml_compute_capability {
char *err;
int major;
int minor;
} nvml_compute_capability_t;
void nvml_init(char *nvml_lib_path, nvml_init_resp_t *resp);
void nvml_check_vram(nvml_handle_t ch, mem_info_t *resp);
void nvml_compute_capability(nvml_handle_t ch, nvml_compute_capability_t *cc);
void nvml_release(nvml_handle_t ch);
#endif // __GPU_INFO_NVML_H__
#endif // __APPLE__

View file

@ -9,23 +9,16 @@ import (
func TestBasicGetGPUInfo(t *testing.T) {
info := GetGPUInfo()
assert.Contains(t, "cuda rocm cpu metal", info.Library)
switch runtime.GOOS {
case "darwin":
// TODO - remove this once MacOS returns some size for CPU
return
case "linux", "windows":
assert.Greater(t, info.TotalMemory, uint64(0))
assert.Greater(t, info.FreeMemory, uint64(0))
assert.Greater(t, info.DeviceCount, uint32(0))
default:
return
assert.Greater(t, len(info), 0)
assert.Contains(t, "cuda rocm cpu metal", info[0].Library)
if info[0].Library != "cpu" {
assert.Greater(t, info[0].TotalMemory, uint64(0))
assert.Greater(t, info[0].FreeMemory, uint64(0))
}
}
func TestCPUMemInfo(t *testing.T) {
info, err := getCPUMem()
info, err := GetCPUMem()
assert.NoError(t, err)
switch runtime.GOOS {
case "darwin":

View file

@ -3,7 +3,6 @@ package gpu
type memInfo struct {
TotalMemory uint64 `json:"total_memory,omitempty"`
FreeMemory uint64 `json:"free_memory,omitempty"`
DeviceCount uint32 `json:"device_count,omitempty"`
}
// Beginning of an `ollama info` command
@ -17,11 +16,49 @@ type GpuInfo struct {
// MinimumMemory represents the minimum memory required to use the GPU
MinimumMemory uint64 `json:"-"`
// TODO add other useful attributes about the card here for discovery information
// Any extra PATH/LD_LIBRARY_PATH dependencies required for the Library to operate properly
DependencyPath string `json:"lib_path,omitempty"`
// GPU information
ID string `json:"gpu_id"` // string to use for selection of this specific GPU
Name string `json:"name"` // user friendly name if available
Major int `json:"major,omitempty"` // Major compatibility version (CC or gfx)
Minor int `json:"minor,omitempty"` // Minor compatibility version (CC or gfx)
Patch int `json:"patch,omitempty"` // Patch compatibility only matters on AMD
// TODO other performance capability info to help in scheduling decisions
}
type Version struct {
Major uint
Minor uint
Patch uint
type GpuInfoList []GpuInfo
// Split up the set of gpu info's by Library and variant
func (l GpuInfoList) ByLibrary() []GpuInfoList {
resp := []GpuInfoList{}
libs := []string{}
for _, info := range l {
found := false
requested := info.Library
if info.Variant != "" {
requested += "_" + info.Variant
}
for i, lib := range libs {
if lib == requested {
resp[i] = append(resp[i], info)
found = true
break
}
}
if !found {
libs = append(libs, info.Library)
resp = append(resp, []GpuInfo{info})
}
}
return resp
}
// Sort by Free Space
type ByFreeMemory []GpuInfo
func (a ByFreeMemory) Len() int { return len(a) }
func (a ByFreeMemory) Swap(i, j int) { a[i], a[j] = a[j], a[i] }
func (a ByFreeMemory) Less(i, j int) bool { return a[i].FreeMemory < a[j].FreeMemory }

View file

@ -4,11 +4,14 @@ package integration
import (
"context"
"net/http"
"log/slog"
"os"
"runtime"
"testing"
"time"
"github.com/ollama/ollama/api"
"github.com/stretchr/testify/require"
)
func TestOrcaMiniBlueSky(t *testing.T) {
@ -24,5 +27,44 @@ func TestOrcaMiniBlueSky(t *testing.T) {
"seed": 123,
},
}
GenerateTestHelper(ctx, t, &http.Client{}, req, []string{"rayleigh", "scattering"})
GenerateTestHelper(ctx, t, req, []string{"rayleigh", "scattering"})
}
func TestUnicodeModelDir(t *testing.T) {
// This is only useful for Windows with utf-16 characters, so skip this test for other platforms
if runtime.GOOS != "windows" {
t.Skip("Unicode test only applicable to windows")
}
// Only works for local testing
if os.Getenv("OLLAMA_TEST_EXISTING") != "" {
t.Skip("TestUnicodeModelDir only works for local testing, skipping")
}
modelDir, err := os.MkdirTemp("", "ollama_埃")
require.NoError(t, err)
defer os.RemoveAll(modelDir)
slog.Info("unicode", "OLLAMA_MODELS", modelDir)
oldModelsDir := os.Getenv("OLLAMA_MODELS")
if oldModelsDir == "" {
defer os.Unsetenv("OLLAMA_MODELS")
} else {
defer os.Setenv("OLLAMA_MODELS", oldModelsDir)
}
err = os.Setenv("OLLAMA_MODELS", modelDir)
require.NoError(t, err)
ctx, cancel := context.WithTimeout(context.Background(), 2*time.Minute)
defer cancel()
req := api.GenerateRequest{
Model: "orca-mini",
Prompt: "why is the sky blue?",
Stream: &stream,
Options: map[string]interface{}{
"temperature": 0,
"seed": 123,
},
}
GenerateTestHelper(ctx, t, req, []string{"rayleigh", "scattering"})
}

View file

@ -0,0 +1,225 @@
//go:build integration
package integration
import (
"context"
"log/slog"
"os"
"strconv"
"sync"
"testing"
"time"
"github.com/ollama/ollama/api"
"github.com/stretchr/testify/require"
)
func TestMultiModelConcurrency(t *testing.T) {
var (
req = [2]api.GenerateRequest{
{
Model: "orca-mini",
Prompt: "why is the ocean blue?",
Stream: &stream,
Options: map[string]interface{}{
"seed": 42,
"temperature": 0.0,
},
}, {
Model: "tinydolphin",
Prompt: "what is the origin of the us thanksgiving holiday?",
Stream: &stream,
Options: map[string]interface{}{
"seed": 42,
"temperature": 0.0,
},
},
}
resp = [2][]string{
[]string{"sunlight"},
[]string{"england", "english", "massachusetts", "pilgrims"},
}
)
var wg sync.WaitGroup
wg.Add(len(req))
ctx, cancel := context.WithTimeout(context.Background(), time.Second*120)
defer cancel()
for i := 0; i < len(req); i++ {
go func(i int) {
defer wg.Done()
GenerateTestHelper(ctx, t, req[i], resp[i])
}(i)
}
wg.Wait()
}
func TestIntegrationConcurrentPredictOrcaMini(t *testing.T) {
ctx, cancel := context.WithTimeout(context.Background(), 10*time.Minute) // GTX 750 2G card takes ~9 minutes
defer cancel()
client, _, cleanup := InitServerConnection(ctx, t)
defer cleanup()
req, resp := GenerateRequests()
// Get the server running (if applicable) warm the model up with a single initial request
DoGenerate(ctx, t, client, req[0], resp[0], 60*time.Second, 5*time.Second)
var wg sync.WaitGroup
wg.Add(len(req))
for i := 0; i < len(req); i++ {
go func(i int) {
defer wg.Done()
for j := 0; j < 5; j++ {
slog.Info("Starting", "req", i, "iter", j)
// On slower GPUs it can take a while to process the 4 concurrent requests
// so we allow a much longer initial timeout
DoGenerate(ctx, t, client, req[i], resp[i], 90*time.Second, 5*time.Second)
}
}(i)
}
wg.Wait()
}
// Stress the system if we know how much VRAM it has, and attempt to load more models than will fit
func TestMultiModelStress(t *testing.T) {
vram := os.Getenv("OLLAMA_MAX_VRAM")
if vram == "" {
t.Skip("OLLAMA_MAX_VRAM not specified, can't pick the right models for the stress test")
}
max, err := strconv.ParseUint(vram, 10, 64)
require.NoError(t, err)
const MB = uint64(1024 * 1024)
type model struct {
name string
size uint64 // Approximate amount of VRAM they typically use when fully loaded in VRAM
}
smallModels := []model{
{
name: "orca-mini",
size: 2992 * MB,
},
{
name: "phi",
size: 2616 * MB,
},
{
name: "gemma:2b",
size: 2364 * MB,
},
{
name: "stable-code:3b",
size: 2608 * MB,
},
{
name: "starcoder2:3b",
size: 2166 * MB,
},
}
mediumModels := []model{
{
name: "llama2",
size: 5118 * MB,
},
{
name: "mistral",
size: 4620 * MB,
},
{
name: "orca-mini:7b",
size: 5118 * MB,
},
{
name: "dolphin-mistral",
size: 4620 * MB,
},
{
name: "gemma:7b",
size: 5000 * MB,
},
// TODO - uncomment this once #3565 is merged and this is rebased on it
// {
// name: "codellama:7b",
// size: 5118 * MB,
// },
}
// These seem to be too slow to be useful...
// largeModels := []model{
// {
// name: "llama2:13b",
// size: 7400 * MB,
// },
// {
// name: "codellama:13b",
// size: 7400 * MB,
// },
// {
// name: "orca-mini:13b",
// size: 7400 * MB,
// },
// {
// name: "gemma:7b",
// size: 5000 * MB,
// },
// {
// name: "starcoder2:15b",
// size: 9100 * MB,
// },
// }
var chosenModels []model
switch {
case max < 10000*MB:
slog.Info("selecting small models")
chosenModels = smallModels
// case max < 30000*MB:
default:
slog.Info("selecting medium models")
chosenModels = mediumModels
// default:
// slog.Info("selecting large models")
// chosenModels = largModels
}
req, resp := GenerateRequests()
for i := range req {
if i > len(chosenModels) {
break
}
req[i].Model = chosenModels[i].name
}
ctx, cancel := context.WithTimeout(context.Background(), 15*time.Minute) // TODO baseline -- 10m too short
defer cancel()
client, _, cleanup := InitServerConnection(ctx, t)
defer cleanup()
// Make sure all the models are pulled before we get started
for _, r := range req {
require.NoError(t, PullIfMissing(ctx, client, r.Model))
}
var wg sync.WaitGroup
consumed := uint64(256 * MB) // Assume some baseline usage
for i := 0; i < len(req); i++ {
// Always get at least 2 models, but dont' overshoot VRAM too much or we'll take too long
if i > 1 && consumed > max {
slog.Info("achieved target vram exhaustion", "count", i, "vramMB", max/1024/1024, "modelsMB", consumed/1024/1024)
break
}
consumed += chosenModels[i].size
slog.Info("target vram", "count", i, "vramMB", max/1024/1024, "modelsMB", consumed/1024/1024)
wg.Add(1)
go func(i int) {
defer wg.Done()
for j := 0; j < 3; j++ {
slog.Info("Starting", "req", i, "iter", j, "model", req[i].Model)
DoGenerate(ctx, t, client, req[i], resp[i], 90*time.Second, 5*time.Second)
}
}(i)
}
wg.Wait()
}

View file

@ -4,7 +4,6 @@ package integration
import (
"context"
"net/http"
"testing"
"time"
@ -25,5 +24,5 @@ func TestContextExhaustion(t *testing.T) {
"num_ctx": 128,
},
}
GenerateTestHelper(ctx, t, &http.Client{}, req, []string{"once", "upon", "lived"})
GenerateTestHelper(ctx, t, req, []string{"once", "upon", "lived"})
}

View file

@ -5,7 +5,6 @@ package integration
import (
"context"
"encoding/base64"
"net/http"
"testing"
"time"
@ -29,10 +28,11 @@ func TestIntegrationMultimodal(t *testing.T) {
},
}
resp := "the ollamas"
// Note: sometimes it returns "the ollamas" sometimes "the ollams"
resp := "the ollam"
ctx, cancel := context.WithTimeout(context.Background(), 3*time.Minute)
defer cancel()
GenerateTestHelper(ctx, t, &http.Client{}, req, []string{resp})
GenerateTestHelper(ctx, t, req, []string{resp})
}
const imageEncoding = `iVBORw0KGgoAAAANSUhEUgAAANIAAAB4CAYAAACHHqzKAAAAAXNSR0IArs4c6QAAAIRlWElmTU0AKgAAAAgABQESAAMAAAABAAEAAAEaAAUAAAABAAAASgEb

View file

@ -4,8 +4,6 @@ package integration
import (
"context"
"net/http"
"sync"
"testing"
"time"
@ -45,25 +43,5 @@ var (
func TestIntegrationSimpleOrcaMini(t *testing.T) {
ctx, cancel := context.WithTimeout(context.Background(), time.Second*120)
defer cancel()
GenerateTestHelper(ctx, t, &http.Client{}, req[0], resp[0])
GenerateTestHelper(ctx, t, req[0], resp[0])
}
// TODO
// The server always loads a new runner and closes the old one, which forces serial execution
// At present this test case fails with concurrency problems. Eventually we should try to
// get true concurrency working with n_parallel support in the backend
func TestIntegrationConcurrentPredictOrcaMini(t *testing.T) {
var wg sync.WaitGroup
wg.Add(len(req))
ctx, cancel := context.WithTimeout(context.Background(), time.Second*120)
defer cancel()
for i := 0; i < len(req); i++ {
go func(i int) {
defer wg.Done()
GenerateTestHelper(ctx, t, &http.Client{}, req[i], resp[i])
}(i)
}
wg.Wait()
}
// TODO - create a parallel test with 2 different models once we support concurrency

View file

@ -0,0 +1,117 @@
//go:build integration
package integration
import (
"context"
"errors"
"fmt"
"log/slog"
"os"
"strconv"
"strings"
"sync"
"testing"
"time"
"github.com/ollama/ollama/api"
"github.com/stretchr/testify/require"
)
func TestMaxQueue(t *testing.T) {
// Note: This test can be quite slow when running in CPU mode, so keep the threadCount low unless your on GPU
// Also note that by default Darwin can't sustain > ~128 connections without adjusting limits
threadCount := 32
mq := os.Getenv("OLLAMA_MAX_QUEUE")
if mq != "" {
var err error
threadCount, err = strconv.Atoi(mq)
require.NoError(t, err)
} else {
os.Setenv("OLLAMA_MAX_QUEUE", fmt.Sprintf("%d", threadCount))
}
req := api.GenerateRequest{
Model: "orca-mini",
Prompt: "write a long historical fiction story about christopher columbus. use at least 10 facts from his actual journey",
Options: map[string]interface{}{
"seed": 42,
"temperature": 0.0,
},
}
resp := []string{"explore", "discover", "ocean"}
// CPU mode takes much longer at the limit with a large queue setting
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Minute)
defer cancel()
client, _, cleanup := InitServerConnection(ctx, t)
defer cleanup()
require.NoError(t, PullIfMissing(ctx, client, req.Model))
// Context for the worker threads so we can shut them down
// embedCtx, embedCancel := context.WithCancel(ctx)
embedCtx := ctx
var genwg sync.WaitGroup
go func() {
genwg.Add(1)
defer genwg.Done()
slog.Info("Starting generate request")
DoGenerate(ctx, t, client, req, resp, 45*time.Second, 5*time.Second)
slog.Info("generate completed")
}()
// Give the generate a chance to get started before we start hammering on embed requests
time.Sleep(5 * time.Millisecond)
threadCount += 10 // Add a few extra to ensure we push the queue past its limit
busyCount := 0
resetByPeerCount := 0
canceledCount := 0
succesCount := 0
counterMu := sync.Mutex{}
var embedwg sync.WaitGroup
for i := 0; i < threadCount; i++ {
go func(i int) {
embedwg.Add(1)
defer embedwg.Done()
slog.Info("embed started", "id", i)
embedReq := api.EmbeddingRequest{
Model: req.Model,
Prompt: req.Prompt,
Options: req.Options,
}
// Fresh client for every request
client, _ = GetTestEndpoint()
resp, genErr := client.Embeddings(embedCtx, &embedReq)
counterMu.Lock()
defer counterMu.Unlock()
switch {
case genErr == nil:
succesCount++
require.Greater(t, len(resp.Embedding), 5) // somewhat arbitrary, but sufficient to be reasonable
case errors.Is(genErr, context.Canceled):
canceledCount++
case strings.Contains(genErr.Error(), "busy"):
busyCount++
case strings.Contains(genErr.Error(), "connection reset by peer"):
resetByPeerCount++
default:
require.NoError(t, genErr, "%d request failed", i)
}
slog.Info("embed finished", "id", i)
}(i)
}
genwg.Wait()
slog.Info("generate done, waiting for embeds")
embedwg.Wait()
require.Equal(t, resetByPeerCount, 0, "Connections reset by peer, have you updated your fd and socket limits?")
require.True(t, busyCount > 0, "no requests hit busy error but some should have")
require.True(t, canceledCount == 0, "no requests should have been canceled due to timeout")
slog.Info("embeds completed", "success", succesCount, "busy", busyCount, "reset", resetByPeerCount, "canceled", canceledCount)
}

View file

@ -5,13 +5,14 @@ package integration
import (
"bytes"
"context"
"encoding/json"
"errors"
"fmt"
"io"
"log/slog"
"math/rand"
"net"
"net/http"
"net/url"
"os"
"path/filepath"
"runtime"
@ -23,9 +24,13 @@ import (
"github.com/ollama/ollama/api"
"github.com/ollama/ollama/app/lifecycle"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)
func Init() {
lifecycle.InitLogging()
}
func FindPort() string {
port := 0
if a, err := net.ResolveTCPAddr("tcp", "localhost:0"); err == nil {
@ -41,7 +46,7 @@ func FindPort() string {
return strconv.Itoa(port)
}
func GetTestEndpoint() (string, string) {
func GetTestEndpoint() (*api.Client, string) {
defaultPort := "11434"
ollamaHost := os.Getenv("OLLAMA_HOST")
@ -67,16 +72,20 @@ func GetTestEndpoint() (string, string) {
port = FindPort()
}
url := fmt.Sprintf("%s:%s", host, port)
slog.Info("server connection", "url", url)
return scheme, url
slog.Info("server connection", "host", host, "port", port)
return api.NewClient(
&url.URL{
Scheme: scheme,
Host: net.JoinHostPort(host, port),
},
http.DefaultClient), fmt.Sprintf("%s:%s", host, port)
}
// TODO make fanicier, grab logs, etc.
var serverMutex sync.Mutex
var serverReady bool
func StartServer(ctx context.Context, ollamaHost string) error {
func startServer(ctx context.Context, ollamaHost string) error {
// Make sure the server has been built
CLIName, err := filepath.Abs("../ollama")
if err != nil {
@ -98,7 +107,7 @@ func StartServer(ctx context.Context, ollamaHost string) error {
if tmp := os.Getenv("OLLAMA_HOST"); tmp != ollamaHost {
slog.Info("setting env", "OLLAMA_HOST", ollamaHost)
os.Setenv("OLLAMA_HOST", ollamaHost)
t.Setenv("OLLAMA_HOST", ollamaHost)
}
slog.Info("starting server", "url", ollamaHost)
@ -125,67 +134,76 @@ func StartServer(ctx context.Context, ollamaHost string) error {
return nil
}
func PullIfMissing(ctx context.Context, client *http.Client, scheme, testEndpoint, modelName string) error {
func PullIfMissing(ctx context.Context, client *api.Client, modelName string) error {
slog.Info("checking status of model", "model", modelName)
showReq := &api.ShowRequest{Name: modelName}
requestJSON, err := json.Marshal(showReq)
if err != nil {
return err
}
req, err := http.NewRequest("POST", scheme+"://"+testEndpoint+"/api/show", bytes.NewReader(requestJSON))
if err != nil {
showCtx, cancel := context.WithDeadlineCause(
ctx,
time.Now().Add(5*time.Second),
fmt.Errorf("show for existing model %s took too long", modelName),
)
defer cancel()
_, err := client.Show(showCtx, showReq)
var statusError api.StatusError
switch {
case errors.As(err, &statusError) && statusError.StatusCode == http.StatusNotFound:
break
case err != nil:
return err
}
// Make the request with the HTTP client
response, err := client.Do(req.WithContext(ctx))
if err != nil {
return err
}
defer response.Body.Close()
if response.StatusCode == 200 {
default:
slog.Info("model already present", "model", modelName)
return nil
}
slog.Info("model missing", "status", response.StatusCode)
slog.Info("model missing", "model", modelName)
stallDuration := 30 * time.Second // This includes checksum verification, which can take a while on larger models
stallTimer := time.NewTimer(stallDuration)
fn := func(resp api.ProgressResponse) error {
// fmt.Print(".")
if !stallTimer.Reset(stallDuration) {
return fmt.Errorf("stall was detected, aborting status reporting")
}
return nil
}
stream := true
pullReq := &api.PullRequest{Name: modelName, Stream: &stream}
requestJSON, err = json.Marshal(pullReq)
if err != nil {
return err
}
req, err = http.NewRequest("POST", scheme+"://"+testEndpoint+"/api/pull", bytes.NewReader(requestJSON))
if err != nil {
return err
}
slog.Info("pulling", "model", modelName)
var pullError error
response, err = client.Do(req.WithContext(ctx))
if err != nil {
return err
done := make(chan int)
go func() {
pullError = client.Pull(ctx, pullReq, fn)
done <- 0
}()
select {
case <-stallTimer.C:
return fmt.Errorf("download stalled")
case <-done:
return pullError
}
defer response.Body.Close()
if response.StatusCode != 200 {
return fmt.Errorf("failed to pull model") // TODO more details perhaps
}
slog.Info("model pulled", "model", modelName)
return nil
}
var serverProcMutex sync.Mutex
func GenerateTestHelper(ctx context.Context, t *testing.T, client *http.Client, genReq api.GenerateRequest, anyResp []string) {
// TODO maybe stuff in an init routine?
lifecycle.InitLogging()
requestJSON, err := json.Marshal(genReq)
if err != nil {
t.Fatalf("Error serializing request: %v", err)
// Returns an Client, the testEndpoint, and a cleanup function, fails the test on errors
// Starts the server if needed
func InitServerConnection(ctx context.Context, t *testing.T) (*api.Client, string, func()) {
client, testEndpoint := GetTestEndpoint()
if os.Getenv("OLLAMA_TEST_EXISTING") == "" {
serverProcMutex.Lock()
fp, err := os.CreateTemp("", "ollama-server-*.log")
if err != nil {
t.Fatalf("failed to generate log file: %s", err)
}
lifecycle.ServerLogFile = fp.Name()
fp.Close()
require.NoError(t, startServer(ctx, testEndpoint))
}
defer func() {
return client, testEndpoint, func() {
if os.Getenv("OLLAMA_TEST_EXISTING") == "" {
defer serverProcMutex.Unlock()
if t.Failed() {
@ -203,63 +221,118 @@ func GenerateTestHelper(ctx context.Context, t *testing.T, client *http.Client,
os.Stderr.Write(data)
slog.Warn("END OF SERVER")
}
err = os.Remove(lifecycle.ServerLogFile)
err := os.Remove(lifecycle.ServerLogFile)
if err != nil && !os.IsNotExist(err) {
slog.Warn("failed to cleanup", "logfile", lifecycle.ServerLogFile, "error", err)
}
}
}()
scheme, testEndpoint := GetTestEndpoint()
if os.Getenv("OLLAMA_TEST_EXISTING") == "" {
serverProcMutex.Lock()
fp, err := os.CreateTemp("", "ollama-server-*.log")
if err != nil {
t.Fatalf("failed to generate log file: %s", err)
}
lifecycle.ServerLogFile = fp.Name()
fp.Close()
assert.NoError(t, StartServer(ctx, testEndpoint))
}
err = PullIfMissing(ctx, client, scheme, testEndpoint, genReq.Model)
if err != nil {
t.Fatalf("Error pulling model: %v", err)
}
// Make the request and get the response
req, err := http.NewRequest("POST", scheme+"://"+testEndpoint+"/api/generate", bytes.NewReader(requestJSON))
if err != nil {
t.Fatalf("Error creating request: %v", err)
}
// Set the content type for the request
req.Header.Set("Content-Type", "application/json")
// Make the request with the HTTP client
response, err := client.Do(req.WithContext(ctx))
if err != nil {
t.Fatalf("Error making request: %v", err)
}
defer response.Body.Close()
body, err := io.ReadAll(response.Body)
assert.NoError(t, err)
assert.Equal(t, response.StatusCode, 200, string(body))
// Verify the response is valid JSON
var payload api.GenerateResponse
err = json.Unmarshal(body, &payload)
if err != nil {
assert.NoError(t, err, body)
}
// Verify the response contains the expected data
atLeastOne := false
for _, resp := range anyResp {
if strings.Contains(strings.ToLower(payload.Response), resp) {
atLeastOne = true
break
}
}
assert.True(t, atLeastOne, "none of %v found in %s", anyResp, payload.Response)
}
func GenerateTestHelper(ctx context.Context, t *testing.T, genReq api.GenerateRequest, anyResp []string) {
client, _, cleanup := InitServerConnection(ctx, t)
defer cleanup()
require.NoError(t, PullIfMissing(ctx, client, genReq.Model))
DoGenerate(ctx, t, client, genReq, anyResp, 30*time.Second, 10*time.Second)
}
func DoGenerate(ctx context.Context, t *testing.T, client *api.Client, genReq api.GenerateRequest, anyResp []string, initialTimeout, streamTimeout time.Duration) {
stallTimer := time.NewTimer(initialTimeout)
var buf bytes.Buffer
fn := func(response api.GenerateResponse) error {
// fmt.Print(".")
buf.Write([]byte(response.Response))
if !stallTimer.Reset(streamTimeout) {
return fmt.Errorf("stall was detected while streaming response, aborting")
}
return nil
}
stream := true
genReq.Stream = &stream
done := make(chan int)
var genErr error
go func() {
genErr = client.Generate(ctx, &genReq, fn)
done <- 0
}()
select {
case <-stallTimer.C:
if buf.Len() == 0 {
t.Errorf("generate never started. Timed out after :%s", initialTimeout.String())
} else {
t.Errorf("generate stalled. Response so far:%s", buf.String())
}
case <-done:
require.NoError(t, genErr, "failed with %s request prompt %s ", genReq.Model, genReq.Prompt)
// Verify the response contains the expected data
response := buf.String()
atLeastOne := false
for _, resp := range anyResp {
if strings.Contains(strings.ToLower(response), resp) {
atLeastOne = true
break
}
}
require.True(t, atLeastOne, "none of %v found in %s", anyResp, response)
slog.Info("test pass", "model", genReq.Model, "prompt", genReq.Prompt, "contains", anyResp, "response", response)
case <-ctx.Done():
t.Error("outer test context done while waiting for generate")
}
}
// Generate a set of requests
// By default each request uses orca-mini as the model
func GenerateRequests() ([]api.GenerateRequest, [][]string) {
return []api.GenerateRequest{
{
Model: "orca-mini",
Prompt: "why is the ocean blue?",
Stream: &stream,
Options: map[string]interface{}{
"seed": 42,
"temperature": 0.0,
},
}, {
Model: "orca-mini",
Prompt: "why is the color of dirt brown?",
Stream: &stream,
Options: map[string]interface{}{
"seed": 42,
"temperature": 0.0,
},
}, {
Model: "orca-mini",
Prompt: "what is the origin of the us thanksgiving holiday?",
Stream: &stream,
Options: map[string]interface{}{
"seed": 42,
"temperature": 0.0,
},
}, {
Model: "orca-mini",
Prompt: "what is the origin of independence day?",
Stream: &stream,
Options: map[string]interface{}{
"seed": 42,
"temperature": 0.0,
},
}, {
Model: "orca-mini",
Prompt: "what is the composition of air?",
Stream: &stream,
Options: map[string]interface{}{
"seed": 42,
"temperature": 0.0,
},
},
},
[][]string{
[]string{"sunlight"},
[]string{"soil", "organic", "earth", "black", "tan"},
[]string{"england", "english", "massachusetts", "pilgrims"},
[]string{"fourth", "july", "declaration", "independence"},
[]string{"nitrogen", "oxygen", "carbon", "dioxide"},
}
}

View file

@ -1032,7 +1032,7 @@ struct llama_server_context
slot.has_next_token = false;
}
if (!slot.cache_tokens.empty() && result.tok == llama_token_eos(model))
if (!slot.cache_tokens.empty() && llama_token_is_eog(model, result.tok))
{
slot.stopped_eos = true;
slot.has_next_token = false;
@ -1144,12 +1144,15 @@ struct llama_server_context
res.result_json = json
{
{"content", tkn.text_to_send},
{"stop", false},
{"slot_id", slot.id},
{"multimodal", multimodal}
};
if (!llama_token_is_eog(model, tkn.tok)) {
res.result_json["content"] = tkn.text_to_send;
}
if (slot.sparams.n_probs > 0)
{
std::vector<completion_token_output> probs_output = {};
@ -1183,8 +1186,6 @@ struct llama_server_context
{"model", params.model_alias},
{"tokens_predicted", slot.n_decoded},
{"tokens_evaluated", slot.n_prompt_tokens},
{"generation_settings", get_formated_generation(slot)},
{"prompt", slot.prompt},
{"truncated", slot.truncated},
{"stopped_eos", slot.stopped_eos},
{"stopped_word", slot.stopped_word},
@ -2644,18 +2645,18 @@ static void server_params_parse(int argc, char **argv, server_params &sparams,
if (strncmp(sep, "int:", 4) == 0) {
sep += 4;
kvo.tag = LLAMA_KV_OVERRIDE_TYPE_INT;
kvo.int_value = std::atol(sep);
kvo.val_i64 = std::atol(sep);
} else if (strncmp(sep, "float:", 6) == 0) {
sep += 6;
kvo.tag = LLAMA_KV_OVERRIDE_TYPE_FLOAT;
kvo.float_value = std::atof(sep);
kvo.val_f64 = std::atof(sep);
} else if (strncmp(sep, "bool:", 5) == 0) {
sep += 5;
kvo.tag = LLAMA_KV_OVERRIDE_TYPE_BOOL;
if (std::strcmp(sep, "true") == 0) {
kvo.bool_value = true;
kvo.val_bool = true;
} else if (std::strcmp(sep, "false") == 0) {
kvo.bool_value = false;
kvo.val_bool = false;
} else {
fprintf(stderr, "error: Invalid boolean value for KV override: %s\n", argv[i]);
invalid_param = true;

140
llm/filetype.go Normal file
View file

@ -0,0 +1,140 @@
package llm
import "fmt"
type fileType uint32
const (
fileTypeF32 fileType = iota
fileTypeF16
fileTypeQ4_0
fileTypeQ4_1
fileTypeQ4_1_F16
fileTypeQ4_2 // unused
fileTypeQ4_3 // unused
fileTypeQ8_0
fileTypeQ5_0
fileTypeQ5_1
fileTypeQ2_K
fileTypeQ3_K_S
fileTypeQ3_K_M
fileTypeQ3_K_L
fileTypeQ4_K_S
fileTypeQ4_K_M
fileTypeQ5_K_S
fileTypeQ5_K_M
fileTypeQ6_K
fileTypeIQ2_XXS
fileTypeIQ2_XS
fileTypeQ2_K_S
fileTypeQ3_K_XS
fileTypeIQ3_XXS
fileTypeUnknown
)
func ParseFileType(s string) (fileType, error) {
switch s {
case "F32":
return fileTypeF32, nil
case "F16":
return fileTypeF16, nil
case "Q4_0":
return fileTypeQ4_0, nil
case "Q4_1":
return fileTypeQ4_1, nil
case "Q4_1_F16":
return fileTypeQ4_1_F16, nil
case "Q8_0":
return fileTypeQ8_0, nil
case "Q5_0":
return fileTypeQ5_0, nil
case "Q5_1":
return fileTypeQ5_1, nil
case "Q2_K":
return fileTypeQ2_K, nil
case "Q3_K_S":
return fileTypeQ3_K_S, nil
case "Q3_K_M":
return fileTypeQ3_K_M, nil
case "Q3_K_L":
return fileTypeQ3_K_L, nil
case "Q4_K_S":
return fileTypeQ4_K_S, nil
case "Q4_K_M":
return fileTypeQ4_K_M, nil
case "Q5_K_S":
return fileTypeQ5_K_S, nil
case "Q5_K_M":
return fileTypeQ5_K_M, nil
case "Q6_K":
return fileTypeQ6_K, nil
case "IQ2_XXS":
return fileTypeIQ2_XXS, nil
case "IQ2_XS":
return fileTypeIQ2_XS, nil
case "Q2_K_S":
return fileTypeQ2_K_S, nil
case "Q3_K_XS":
return fileTypeQ3_K_XS, nil
case "IQ3_XXS":
return fileTypeIQ3_XXS, nil
default:
return fileTypeUnknown, fmt.Errorf("unknown fileType: %s", s)
}
}
func (t fileType) String() string {
switch t {
case fileTypeF32:
return "F32"
case fileTypeF16:
return "F16"
case fileTypeQ4_0:
return "Q4_0"
case fileTypeQ4_1:
return "Q4_1"
case fileTypeQ4_1_F16:
return "Q4_1_F16"
case fileTypeQ8_0:
return "Q8_0"
case fileTypeQ5_0:
return "Q5_0"
case fileTypeQ5_1:
return "Q5_1"
case fileTypeQ2_K:
return "Q2_K"
case fileTypeQ3_K_S:
return "Q3_K_S"
case fileTypeQ3_K_M:
return "Q3_K_M"
case fileTypeQ3_K_L:
return "Q3_K_L"
case fileTypeQ4_K_S:
return "Q4_K_S"
case fileTypeQ4_K_M:
return "Q4_K_M"
case fileTypeQ5_K_S:
return "Q5_K_S"
case fileTypeQ5_K_M:
return "Q5_K_M"
case fileTypeQ6_K:
return "Q6_K"
case fileTypeIQ2_XXS:
return "IQ2_XXS"
case fileTypeIQ2_XS:
return "IQ2_XS"
case fileTypeQ2_K_S:
return "Q2_K_S"
case fileTypeQ3_K_XS:
return "Q3_K_XS"
case fileTypeIQ3_XXS:
return "IQ3_XXS"
default:
return "unknown"
}
}
func (t fileType) Value() uint32 {
return uint32(t)
}

View file

@ -21,7 +21,7 @@ init_vars() {
# TODO - add additional optimization flags...
CMAKE_DEFS="-DCMAKE_BUILD_TYPE=Release -DLLAMA_SERVER_VERBOSE=off ${CMAKE_DEFS}"
fi
case $(uname -s) in
case $(uname -s) in
"Darwin")
LIB_EXT="dylib"
WHOLE_ARCHIVE="-Wl,-force_load"

View file

@ -57,21 +57,21 @@ init_vars
git_module_setup
apply_patches
init_vars
if [ -z "${OLLAMA_SKIP_STATIC_GENERATE}" -o "${OLLAMA_CPU_TARGET}" = "static" ]; then
# Builds by default, allows skipping, forces build if OLLAMA_CPU_TARGET="static"
# Enables optimized Dockerfile builds using a blanket skip and targeted overrides
# Static build for linking into the Go binary
init_vars
CMAKE_TARGETS="--target llama --target ggml"
CMAKE_DEFS="-DBUILD_SHARED_LIBS=off -DLLAMA_NATIVE=off -DLLAMA_AVX=off -DLLAMA_AVX2=off -DLLAMA_AVX512=off -DLLAMA_FMA=off -DLLAMA_F16C=off ${CMAKE_DEFS}"
BUILD_DIR="../build/linux/${ARCH}_static"
echo "Building static library"
build
fi
init_vars
if [ -z "${OLLAMA_SKIP_CPU_GENERATE}" ]; then
if [ -z "${OLLAMA_CPU_TARGET}" -o "${OLLAMA_CPU_TARGET}" = "static" ]; then
# Static build for linking into the Go binary
init_vars
CMAKE_TARGETS="--target llama --target ggml"
CMAKE_DEFS="-DBUILD_SHARED_LIBS=off -DLLAMA_NATIVE=off -DLLAMA_AVX=off -DLLAMA_AVX2=off -DLLAMA_AVX512=off -DLLAMA_FMA=off -DLLAMA_F16C=off ${CMAKE_DEFS}"
BUILD_DIR="../build/linux/${ARCH}_static"
echo "Building static library"
build
fi
# Users building from source can tune the exact flags we pass to cmake for configuring
# llama.cpp, and we'll build only 1 CPU variant in that case as the default.
if [ -n "${OLLAMA_CUSTOM_CPU_DEFS}" ]; then
@ -165,14 +165,22 @@ if [ -d "${CUDA_LIB_DIR}" ]; then
fi
if [ "${ARCH}" == "arm64" ]; then
echo "ARM CPU detected - disabling unsupported AVX instructions"
# ARM-based CPUs such as M1 and Tegra do not support AVX extensions.
#
# CUDA compute < 6.0 lacks proper FP16 support on ARM.
# Disabling has minimal performance effect while maintaining compatibility.
# CUDA compute < 6.0 lacks proper FP16 support on ARM.
# Disabling has minimal performance effect while maintaining compatibility.
ARM64_DEFS="-DLLAMA_AVX=off -DLLAMA_AVX2=off -DLLAMA_AVX512=off -DLLAMA_CUDA_F16=off"
fi
CMAKE_DEFS="-DLLAMA_CUDA=on -DLLAMA_CUDA_FORCE_MMQ=on -DCMAKE_CUDA_ARCHITECTURES=${CMAKE_CUDA_ARCHITECTURES} ${COMMON_CMAKE_DEFS} ${CMAKE_DEFS} ${ARM64_DEFS}"
# Users building from source can tune the exact flags we pass to cmake for configuring llama.cpp
if [ -n "${OLLAMA_CUSTOM_CUDA_DEFS}" ]; then
echo "OLLAMA_CUSTOM_CUDA_DEFS=\"${OLLAMA_CUSTOM_CUDA_DEFS}\""
CMAKE_CUDA_DEFS="-DLLAMA_CUDA=on -DCMAKE_CUDA_ARCHITECTURES=${CMAKE_CUDA_ARCHITECTURES} ${OLLAMA_CUSTOM_CUDA_DEFS}"
echo "Building custom CUDA GPU"
else
CMAKE_CUDA_DEFS="-DLLAMA_CUDA=on -DLLAMA_CUDA_FORCE_MMQ=on -DCMAKE_CUDA_ARCHITECTURES=${CMAKE_CUDA_ARCHITECTURES}"
fi
CMAKE_DEFS="${COMMON_CMAKE_DEFS} ${CMAKE_DEFS} ${ARM64_DEFS} ${CMAKE_CUDA_DEFS}"
BUILD_DIR="../build/linux/${ARCH}/cuda${CUDA_VARIANT}"
EXTRA_LIBS="-L${CUDA_LIB_DIR} -lcudart -lcublas -lcublasLt -lcuda"
build
@ -217,6 +225,12 @@ if [ -d "${ROCM_PATH}" ]; then
fi
init_vars
CMAKE_DEFS="${COMMON_CMAKE_DEFS} ${CMAKE_DEFS} -DLLAMA_HIPBLAS=on -DCMAKE_C_COMPILER=$ROCM_PATH/llvm/bin/clang -DCMAKE_CXX_COMPILER=$ROCM_PATH/llvm/bin/clang++ -DAMDGPU_TARGETS=$(amdGPUs) -DGPU_TARGETS=$(amdGPUs)"
# Users building from source can tune the exact flags we pass to cmake for configuring llama.cpp
if [ -n "${OLLAMA_CUSTOM_ROCM_DEFS}" ]; then
echo "OLLAMA_CUSTOM_ROCM_DEFS=\"${OLLAMA_CUSTOM_ROCM_DEFS}\""
CMAKE_DEFS="${CMAKE_DEFS} ${OLLAMA_CUSTOM_ROCM_DEFS}"
echo "Building custom ROCM GPU"
fi
BUILD_DIR="../build/linux/${ARCH}/rocm${ROCM_VARIANT}"
EXTRA_LIBS="-L${ROCM_PATH}/lib -L/opt/amdgpu/lib/x86_64-linux-gnu/ -Wl,-rpath,\$ORIGIN/../../rocm/ -lhipblas -lrocblas -lamdhip64 -lrocsolver -lamd_comgr -lhsa-runtime64 -lrocsparse -ldrm -ldrm_amdgpu"
build

View file

@ -26,15 +26,25 @@ function amdGPUs {
$GPU_LIST -join ';'
}
function init_vars {
$script:SRC_DIR = $(resolve-path "..\..\")
$script:llamacppDir = "../llama.cpp"
if (!$script:SRC_DIR) {
$script:SRC_DIR = $(resolve-path "..\..\")
}
if (!$script:llamacppDir) {
$script:llamacppDir = "../llama.cpp"
}
if (!$script:cmakeTargets) {
$script:cmakeTargets = @("ollama_llama_server")
}
$script:cmakeDefs = @(
"-DBUILD_SHARED_LIBS=on",
"-DLLAMA_NATIVE=off"
)
$script:cmakeTargets = @("ollama_llama_server")
$script:ARCH = "amd64" # arm not yet supported.
$script:commonCpuDefs = @("-DCMAKE_POSITION_INDEPENDENT_CODE=on")
$script:ARCH = $Env:PROCESSOR_ARCHITECTURE.ToLower()
$script:DIST_BASE = "${script:SRC_DIR}\dist\windows-${script:ARCH}\ollama_runners"
md "$script:DIST_BASE" -ea 0 > $null
if ($env:CGO_CFLAGS -contains "-g") {
$script:cmakeDefs += @("-DCMAKE_VERBOSE_MAKEFILE=on", "-DLLAMA_SERVER_VERBOSE=on", "-DCMAKE_BUILD_TYPE=RelWithDebInfo")
$script:config = "RelWithDebInfo"
@ -55,7 +65,6 @@ function init_vars {
} else {
$script:CUDA_LIB_DIR=$env:CUDA_LIB_DIR
}
$script:GZIP=(get-command -ea 'silentlycontinue' gzip).path
$script:DUMPBIN=(get-command -ea 'silentlycontinue' dumpbin).path
if ($null -eq $env:CMAKE_CUDA_ARCHITECTURES) {
$script:CMAKE_CUDA_ARCHITECTURES="50;52;61;70;75;80"
@ -134,21 +143,18 @@ function sign {
}
}
function compress {
if ($script:GZIP -eq $null) {
write-host "gzip not installed, not compressing files"
return
}
write-host "Compressing binaries..."
function install {
write-host "Installing binaries to dist dir ${script:distDir}"
mkdir ${script:distDir} -ErrorAction SilentlyContinue
$binaries = dir "${script:buildDir}/bin/*.exe"
foreach ($file in $binaries) {
& "$script:GZIP" --best -f $file
copy-item -Path $file -Destination ${script:distDir} -Force
}
write-host "Compressing dlls..."
write-host "Installing dlls to dist dir ${script:distDir}"
$dlls = dir "${script:buildDir}/bin/*.dll"
foreach ($file in $dlls) {
& "$script:GZIP" --best -f $file
copy-item -Path $file -Destination ${script:distDir} -Force
}
}
@ -169,123 +175,195 @@ function cleanup {
}
}
init_vars
git_module_setup
apply_patches
# -DLLAMA_AVX -- 2011 Intel Sandy Bridge & AMD Bulldozer
# -DLLAMA_AVX2 -- 2013 Intel Haswell & 2015 AMD Excavator / 2017 AMD Zen
# -DLLAMA_FMA (FMA3) -- 2013 Intel Haswell & 2012 AMD Piledriver
$script:commonCpuDefs = @("-DCMAKE_POSITION_INDEPENDENT_CODE=on")
if ($null -eq ${env:OLLAMA_SKIP_CPU_GENERATE}) {
function build_static() {
if ((-not "${env:OLLAMA_SKIP_STATIC_GENERATE}") -and ((-not "${env:OLLAMA_CPU_TARGET}") -or ("${env:OLLAMA_CPU_TARGET}" -eq "static"))) {
# GCC build for direct linking into the Go binary
init_vars
# cmake will silently fallback to msvc compilers if mingw isn't in the path, so detect and fail fast
# as we need this to be compiled by gcc for golang to be able to link with itx
write-host "Checking for MinGW..."
# error action ensures we exit on failure
get-command gcc
get-command mingw32-make
$oldTargets = $script:cmakeTargets
$script:cmakeTargets = @("llama", "ggml")
$script:cmakeDefs = @(
"-G", "MinGW Makefiles"
"-DCMAKE_C_COMPILER=gcc.exe",
"-DCMAKE_CXX_COMPILER=g++.exe",
"-DBUILD_SHARED_LIBS=off",
"-DLLAMA_NATIVE=off",
"-DLLAMA_AVX=off",
"-DLLAMA_AVX2=off",
"-DLLAMA_AVX512=off",
"-DLLAMA_F16C=off",
"-DLLAMA_FMA=off")
$script:buildDir="../build/windows/${script:ARCH}_static"
write-host "Building static library"
build
$script:cmakeTargets = $oldTargets
} else {
write-host "Skipping CPU generation step as requested"
}
}
function build_cpu($gen_arch) {
if ((-not "${env:OLLAMA_SKIP_CPU_GENERATE}" ) -and ((-not "${env:OLLAMA_CPU_TARGET}") -or ("${env:OLLAMA_CPU_TARGET}" -eq "cpu"))) {
# remaining llama.cpp builds use MSVC
init_vars
$script:cmakeDefs = $script:commonCpuDefs + @("-A", $gen_arch, "-DLLAMA_AVX=off", "-DLLAMA_AVX2=off", "-DLLAMA_AVX512=off", "-DLLAMA_FMA=off", "-DLLAMA_F16C=off") + $script:cmakeDefs
$script:buildDir="../build/windows/${script:ARCH}/cpu"
$script:distDir="$script:DIST_BASE\cpu"
write-host "Building LCD CPU"
build
sign
install
} else {
write-host "Skipping CPU generation step as requested"
}
}
function build_cpu_avx() {
if ((-not "${env:OLLAMA_SKIP_CPU_GENERATE}" ) -and ((-not "${env:OLLAMA_CPU_TARGET}") -or ("${env:OLLAMA_CPU_TARGET}" -eq "cpu_avx"))) {
init_vars
$script:cmakeDefs = $script:commonCpuDefs + @("-A", "x64", "-DLLAMA_AVX=on", "-DLLAMA_AVX2=off", "-DLLAMA_AVX512=off", "-DLLAMA_FMA=off", "-DLLAMA_F16C=off") + $script:cmakeDefs
$script:buildDir="../build/windows/${script:ARCH}/cpu_avx"
$script:distDir="$script:DIST_BASE\cpu_avx"
write-host "Building AVX CPU"
build
sign
install
} else {
write-host "Skipping CPU AVX generation step as requested"
}
}
function build_cpu_avx2() {
if ((-not "${env:OLLAMA_SKIP_CPU_GENERATE}" ) -and ((-not "${env:OLLAMA_CPU_TARGET}") -or ("${env:OLLAMA_CPU_TARGET}" -eq "cpu_avx2"))) {
init_vars
$script:cmakeDefs = $script:commonCpuDefs + @("-A", "x64", "-DLLAMA_AVX=on", "-DLLAMA_AVX2=on", "-DLLAMA_AVX512=off", "-DLLAMA_FMA=on", "-DLLAMA_F16C=on") + $script:cmakeDefs
$script:buildDir="../build/windows/${script:ARCH}/cpu_avx2"
$script:distDir="$script:DIST_BASE\cpu_avx2"
write-host "Building AVX2 CPU"
build
sign
install
} else {
write-host "Skipping CPU AVX2 generation step as requested"
}
}
function build_cuda() {
if ((-not "${env:OLLAMA_SKIP_CUDA_GENERATE}") -and ("${script:CUDA_LIB_DIR}")) {
# Then build cuda as a dynamically loaded library
$nvcc = "$script:CUDA_LIB_DIR\nvcc.exe"
$script:CUDA_VERSION=(get-item ($nvcc | split-path | split-path)).Basename
if ($null -ne $script:CUDA_VERSION) {
$script:CUDA_VARIANT="_"+$script:CUDA_VERSION
}
init_vars
$script:buildDir="../build/windows/${script:ARCH}/cuda$script:CUDA_VARIANT"
$script:distDir="$script:DIST_BASE\cuda$script:CUDA_VARIANT"
$script:cmakeDefs += @("-A", "x64", "-DLLAMA_CUDA=ON", "-DLLAMA_AVX=on", "-DLLAMA_AVX2=off", "-DCUDAToolkit_INCLUDE_DIR=$script:CUDA_INCLUDE_DIR", "-DCMAKE_CUDA_ARCHITECTURES=${script:CMAKE_CUDA_ARCHITECTURES}")
if ($null -ne $env:OLLAMA_CUSTOM_CUDA_DEFS) {
write-host "OLLAMA_CUSTOM_CUDA_DEFS=`"${env:OLLAMA_CUSTOM_CUDA_DEFS}`""
$script:cmakeDefs +=@("${env:OLLAMA_CUSTOM_CUDA_DEFS}")
write-host "building custom CUDA GPU"
}
build
sign
install
write-host "copying CUDA dependencies to ${script:SRC_DIR}\dist\windows-${script:ARCH}\"
cp "${script:CUDA_LIB_DIR}\cudart64_*.dll" "${script:SRC_DIR}\dist\windows-${script:ARCH}\"
cp "${script:CUDA_LIB_DIR}\cublas64_*.dll" "${script:SRC_DIR}\dist\windows-${script:ARCH}\"
cp "${script:CUDA_LIB_DIR}\cublasLt64_*.dll" "${script:SRC_DIR}\dist\windows-${script:ARCH}\"
} else {
write-host "Skipping CUDA generation step"
}
}
function build_rocm() {
if ((-not "${env:OLLAMA_SKIP_ROCM_GENERATE}") -and ("${env:HIP_PATH}")) {
$script:ROCM_VERSION=(get-item $env:HIP_PATH).Basename
if ($null -ne $script:ROCM_VERSION) {
$script:ROCM_VARIANT="_v"+$script:ROCM_VERSION
}
init_vars
$script:buildDir="../build/windows/${script:ARCH}/rocm$script:ROCM_VARIANT"
$script:distDir="$script:DIST_BASE\rocm$script:ROCM_VARIANT"
$script:cmakeDefs += @(
"-G", "Ninja",
"-DCMAKE_C_COMPILER=clang.exe",
"-DCMAKE_CXX_COMPILER=clang++.exe",
"-DLLAMA_HIPBLAS=on",
"-DHIP_PLATFORM=amd",
"-DLLAMA_AVX=on",
"-DLLAMA_AVX2=off",
"-DCMAKE_POSITION_INDEPENDENT_CODE=on",
"-DAMDGPU_TARGETS=$(amdGPUs)",
"-DGPU_TARGETS=$(amdGPUs)"
)
# Make sure the ROCm binary dir is first in the path
$env:PATH="$env:HIP_PATH\bin;$env:PATH"
# We have to clobber the LIB var from the developer shell for clang to work properly
$env:LIB=""
if ($null -ne $env:OLLAMA_CUSTOM_ROCM_DEFS) {
write-host "OLLAMA_CUSTOM_ROCM_DEFS=`"${env:OLLAMA_CUSTOM_ROCM_DEFS}`""
$script:cmakeDefs += @("${env:OLLAMA_CUSTOM_ROCM_DEFS}")
write-host "building custom ROCM GPU"
}
write-host "Building ROCm"
build
# Ninja doesn't prefix with config name
${script:config}=""
if ($null -ne $script:DUMPBIN) {
& "$script:DUMPBIN" /dependents "${script:buildDir}/bin/ollama_llama_server.exe" | select-string ".dll"
}
sign
install
# Assumes v5.7, may need adjustments for v6
rm -ea 0 -recurse -force -path "${script:SRC_DIR}\dist\windows-${script:ARCH}\rocm\"
md "${script:SRC_DIR}\dist\windows-${script:ARCH}\rocm\rocblas\library\" -ea 0 > $null
cp "${env:HIP_PATH}\bin\hipblas.dll" "${script:SRC_DIR}\dist\windows-${script:ARCH}\rocm\"
cp "${env:HIP_PATH}\bin\rocblas.dll" "${script:SRC_DIR}\dist\windows-${script:ARCH}\rocm\"
# amdhip64.dll dependency comes from the driver and must be installed on the host to use AMD GPUs
cp "${env:HIP_PATH}\bin\rocblas\library\*" "${script:SRC_DIR}\dist\windows-${script:ARCH}\rocm\rocblas\library\"
} else {
write-host "Skipping ROCm generation step"
}
}
# GCC build for direct linking into the Go binary
init_vars
# cmake will silently fallback to msvc compilers if mingw isn't in the path, so detect and fail fast
# as we need this to be compiled by gcc for golang to be able to link with itx
write-host "Checking for MinGW..."
# error action ensures we exit on failure
get-command gcc
get-command mingw32-make
$script:cmakeTargets = @("llama", "ggml")
$script:cmakeDefs = @(
"-G", "MinGW Makefiles"
"-DCMAKE_C_COMPILER=gcc.exe",
"-DCMAKE_CXX_COMPILER=g++.exe",
"-DBUILD_SHARED_LIBS=off",
"-DLLAMA_NATIVE=off",
"-DLLAMA_AVX=off",
"-DLLAMA_AVX2=off",
"-DLLAMA_AVX512=off",
"-DLLAMA_F16C=off",
"-DLLAMA_FMA=off")
$script:buildDir="../build/windows/${script:ARCH}_static"
write-host "Building static library"
build
if ($($args.count) -eq 0) {
git_module_setup
apply_patches
build_static
if ($script:ARCH -eq "arm64") {
build_cpu("ARM64")
} else { # amd64
build_cpu("x64")
build_cpu_avx
build_cpu_avx2
build_cuda
build_rocm
}
# remaining llama.cpp builds use MSVC
init_vars
$script:cmakeDefs = $script:commonCpuDefs + @("-A", "x64", "-DLLAMA_AVX=off", "-DLLAMA_AVX2=off", "-DLLAMA_AVX512=off", "-DLLAMA_FMA=off", "-DLLAMA_F16C=off") + $script:cmakeDefs
$script:buildDir="../build/windows/${script:ARCH}/cpu"
write-host "Building LCD CPU"
build
sign
compress
init_vars
$script:cmakeDefs = $script:commonCpuDefs + @("-A", "x64", "-DLLAMA_AVX=on", "-DLLAMA_AVX2=off", "-DLLAMA_AVX512=off", "-DLLAMA_FMA=off", "-DLLAMA_F16C=off") + $script:cmakeDefs
$script:buildDir="../build/windows/${script:ARCH}/cpu_avx"
write-host "Building AVX CPU"
build
sign
compress
init_vars
$script:cmakeDefs = $script:commonCpuDefs + @("-A", "x64", "-DLLAMA_AVX=on", "-DLLAMA_AVX2=on", "-DLLAMA_AVX512=off", "-DLLAMA_FMA=on", "-DLLAMA_F16C=on") + $script:cmakeDefs
$script:buildDir="../build/windows/${script:ARCH}/cpu_avx2"
write-host "Building AVX2 CPU"
build
sign
compress
cleanup
write-host "`ngo generate completed. LLM runners: $(get-childitem -path $script:DIST_BASE)"
} else {
write-host "Skipping CPU generation step as requested"
}
if ($null -ne $script:CUDA_LIB_DIR) {
# Then build cuda as a dynamically loaded library
$nvcc = "$script:CUDA_LIB_DIR\nvcc.exe"
$script:CUDA_VERSION=(get-item ($nvcc | split-path | split-path)).Basename
if ($null -ne $script:CUDA_VERSION) {
$script:CUDA_VARIANT="_"+$script:CUDA_VERSION
}
init_vars
$script:buildDir="../build/windows/${script:ARCH}/cuda$script:CUDA_VARIANT"
$script:cmakeDefs += @("-A", "x64", "-DLLAMA_CUDA=ON", "-DLLAMA_AVX=on", "-DLLAMA_AVX2=off", "-DCUDAToolkit_INCLUDE_DIR=$script:CUDA_INCLUDE_DIR", "-DCMAKE_CUDA_ARCHITECTURES=${script:CMAKE_CUDA_ARCHITECTURES}")
build
sign
compress
}
if ($null -ne $env:HIP_PATH) {
$script:ROCM_VERSION=(get-item $env:HIP_PATH).Basename
if ($null -ne $script:ROCM_VERSION) {
$script:ROCM_VARIANT="_v"+$script:ROCM_VERSION
}
init_vars
$script:buildDir="../build/windows/${script:ARCH}/rocm$script:ROCM_VARIANT"
$script:cmakeDefs += @(
"-G", "Ninja",
"-DCMAKE_C_COMPILER=clang.exe",
"-DCMAKE_CXX_COMPILER=clang++.exe",
"-DLLAMA_HIPBLAS=on",
"-DHIP_PLATFORM=amd",
"-DLLAMA_AVX=on",
"-DLLAMA_AVX2=off",
"-DCMAKE_POSITION_INDEPENDENT_CODE=on",
"-DAMDGPU_TARGETS=$(amdGPUs)",
"-DGPU_TARGETS=$(amdGPUs)"
)
# Make sure the ROCm binary dir is first in the path
$env:PATH="$env:HIP_PATH\bin;$env:PATH"
# We have to clobber the LIB var from the developer shell for clang to work properly
$env:LIB=""
write-host "Building ROCm"
build
# Ninja doesn't prefix with config name
${script:config}=""
if ($null -ne $script:DUMPBIN) {
& "$script:DUMPBIN" /dependents "${script:buildDir}/bin/ollama_llama_server.exe" | select-string ".dll"
}
sign
compress
}
cleanup
write-host "`ngo generate completed. LLM runners: $(get-childitem -path ${script:SRC_DIR}\llm\build\windows\${script:ARCH})"
for ( $i = 0; $i -lt $args.count; $i++ ) {
write-host "performing $($args[$i])"
& $($args[$i])
}
}

View file

@ -13,82 +13,6 @@ type GGML struct {
model
}
const (
fileTypeF32 uint32 = iota
fileTypeF16
fileTypeQ4_0
fileTypeQ4_1
fileTypeQ4_1_F16
fileTypeQ8_0 uint32 = iota + 2
fileTypeQ5_0
fileTypeQ5_1
fileTypeQ2_K
fileTypeQ3_K_S
fileTypeQ3_K_M
fileTypeQ3_K_L
fileTypeQ4_K_S
fileTypeQ4_K_M
fileTypeQ5_K_S
fileTypeQ5_K_M
fileTypeQ6_K
fileTypeIQ2_XXS
fileTypeIQ2_XS
fileTypeQ2_K_S
fileTypeQ3_K_XS
fileTypeIQ3_XXS
)
func fileType(fileType uint32) string {
switch fileType {
case fileTypeF32:
return "F32"
case fileTypeF16:
return "F16"
case fileTypeQ4_0:
return "Q4_0"
case fileTypeQ4_1:
return "Q4_1"
case fileTypeQ4_1_F16:
return "Q4_1_F16"
case fileTypeQ8_0:
return "Q8_0"
case fileTypeQ5_0:
return "Q5_0"
case fileTypeQ5_1:
return "Q5_1"
case fileTypeQ2_K:
return "Q2_K"
case fileTypeQ3_K_S:
return "Q3_K_S"
case fileTypeQ3_K_M:
return "Q3_K_M"
case fileTypeQ3_K_L:
return "Q3_K_L"
case fileTypeQ4_K_S:
return "Q4_K_S"
case fileTypeQ4_K_M:
return "Q4_K_M"
case fileTypeQ5_K_S:
return "Q5_K_S"
case fileTypeQ5_K_M:
return "Q5_K_M"
case fileTypeQ6_K:
return "Q6_K"
case fileTypeIQ2_XXS:
return "IQ2_XXS"
case fileTypeIQ2_XS:
return "IQ2_XS"
case fileTypeQ2_K_S:
return "Q2_K_S"
case fileTypeQ3_K_XS:
return "Q3_K_XS"
case fileTypeIQ3_XXS:
return "IQ3_XXS"
default:
return "unknown"
}
}
type model interface {
KV() KV
Tensors() Tensors
@ -121,12 +45,12 @@ func (kv KV) ParameterCount() uint64 {
return kv.u64("general.parameter_count")
}
func (kv KV) FileType() string {
func (kv KV) FileType() fileType {
if u64 := kv.u64("general.file_type"); u64 > 0 {
return fileType(uint32(u64))
}
return "unknown"
return fileTypeUnknown
}
func (kv KV) BlockCount() uint64 {
@ -286,6 +210,23 @@ const (
var ErrUnsupportedFormat = errors.New("unsupported model format")
func DetectGGMLType(b []byte) string {
switch binary.LittleEndian.Uint32(b[:4]) {
case FILE_MAGIC_GGML:
return "ggml"
case FILE_MAGIC_GGMF:
return "ggmf"
case FILE_MAGIC_GGJT:
return "ggjt"
case FILE_MAGIC_GGLA:
return "ggla"
case FILE_MAGIC_GGUF_LE, FILE_MAGIC_GGUF_BE:
return "gguf"
default:
return ""
}
}
func DecodeGGML(rs io.ReadSeeker) (*GGML, int64, error) {
var magic uint32
if err := binary.Read(rs, binary.LittleEndian, &magic); err != nil {
@ -343,7 +284,15 @@ func (llm GGML) GraphSize(context, batch uint64) (partialOffload, fullOffload ui
4*batch*(embedding+vocab)+embedding*vocab*105/128,
)
if ffnGateWeight, ok := layers["0"]["ffn_gate.0.weight"]; ok {
if ffnGateExpsWeight, ok := layers["blk.0"]["ffn_gate_exps.weight"]; ok {
// mixtral 8x22b
ff := uint64(llm.KV()["llama.feed_forward_length"].(uint32))
partialOffload = max(
3*ffnGateExpsWeight.size()+4*batch*(2*ff+headsKV+embedding+context+embedding/heads*headsKV),
4*(context*batch*heads+context*embedding/heads*headsKV+batch*1024+embedding/heads*headsKV*batch),
)
} else if ffnGateWeight, ok := layers["blk.0"]["ffn_gate.0.weight"]; ok {
// mixtral 8x7b
ffnGateWeight1 := ffnGateWeight.Shape[1]
fullOffload = 4 * batch * (2 + 3*embedding + context*(1+heads) + 2*headsKV + ffnGateWeight1)
partialOffload = max(

View file

@ -190,8 +190,6 @@ func (llm *gguf) Decode(rs io.ReadSeeker) error {
llm.kv[k] = v
}
slog.Debug(fmt.Sprintf("general.architecture = %s", llm.kv["general.architecture"]))
// decode tensors
for i := 0; uint64(i) < llm.numTensor(); i++ {
name, err := readGGUFString(llm, rs)
@ -465,11 +463,13 @@ var ggufKVOrder = map[string][]string{
"llama.embedding_length",
"llama.block_count",
"llama.feed_forward_length",
"llama.rope.dimension_count",
"llama.attention.head_count",
"llama.attention.head_count_kv",
"llama.attention.layer_norm_rms_epsilon",
"llama.rope.freq_base",
"llama.rope.dimension_count",
"llama.expert_count",
"llama.expert_used_count",
"gemma.context_length",
"gemma.embedding_length",
"gemma.block_count",
@ -577,6 +577,8 @@ func (llm *gguf) Encode(ws io.WriteSeeker, kv KV, tensors []Tensor) error {
return err
}
}
default:
return fmt.Errorf("improper type for '%s'", k)
}
if err != nil {
return err
@ -598,9 +600,11 @@ func (llm *gguf) Encode(ws io.WriteSeeker, kv KV, tensors []Tensor) error {
return err
}
dims := 1
if tensor.Shape[1] > 0 {
dims = 2
dims := 0
for cnt := 0; cnt < len(tensor.Shape); cnt++ {
if tensor.Shape[cnt] > 0 {
dims++
}
}
if err := binary.Write(ws, llm.ByteOrder, uint32(dims)); err != nil {

@ -1 +1 @@
Subproject commit 7593639ce335e8d7f89aa9a54d616951f273af60
Subproject commit 952d03dbead16e4dbdd1d3458486340673cc2465

View file

@ -4,6 +4,7 @@ package llm
// #cgo darwin,arm64 LDFLAGS: ${SRCDIR}/build/darwin/arm64_static/libllama.a -lstdc++
// #cgo darwin,amd64 LDFLAGS: ${SRCDIR}/build/darwin/x86_64_static/libllama.a -lstdc++
// #cgo windows,amd64 LDFLAGS: ${SRCDIR}/build/windows/amd64_static/libllama.a -static -lstdc++
// #cgo windows,arm64 LDFLAGS: ${SRCDIR}/build/windows/arm64_static/libllama.a -static -lstdc++
// #cgo linux,amd64 LDFLAGS: ${SRCDIR}/build/linux/x86_64_static/libllama.a -lstdc++
// #cgo linux,arm64 LDFLAGS: ${SRCDIR}/build/linux/arm64_static/libllama.a -lstdc++
// #include <stdlib.h>
@ -19,7 +20,7 @@ func SystemInfo() string {
return C.GoString(C.llama_print_system_info())
}
func Quantize(infile, outfile, filetype string) error {
func Quantize(infile, outfile string, ftype fileType) error {
cinfile := C.CString(infile)
defer C.free(unsafe.Pointer(cinfile))
@ -28,58 +29,10 @@ func Quantize(infile, outfile, filetype string) error {
params := C.llama_model_quantize_default_params()
params.nthread = -1
params.ftype = ftype.Value()
switch filetype {
case "F32":
params.ftype = fileTypeF32
case "F16":
params.ftype = fileTypeF16
case "Q4_0":
params.ftype = fileTypeQ4_0
case "Q4_1":
params.ftype = fileTypeQ4_1
case "Q4_1_F16":
params.ftype = fileTypeQ4_1_F16
case "Q8_0":
params.ftype = fileTypeQ8_0
case "Q5_0":
params.ftype = fileTypeQ5_0
case "Q5_1":
params.ftype = fileTypeQ5_1
case "Q2_K":
params.ftype = fileTypeQ2_K
case "Q3_K_S":
params.ftype = fileTypeQ3_K_S
case "Q3_K_M":
params.ftype = fileTypeQ3_K_M
case "Q3_K_L":
params.ftype = fileTypeQ3_K_L
case "Q4_K_S":
params.ftype = fileTypeQ4_K_S
case "Q4_K_M":
params.ftype = fileTypeQ4_K_M
case "Q5_K_S":
params.ftype = fileTypeQ5_K_S
case "Q5_K_M":
params.ftype = fileTypeQ5_K_M
case "Q6_K":
params.ftype = fileTypeQ6_K
case "IQ2_XXS":
params.ftype = fileTypeIQ2_XXS
case "IQ2_XS":
params.ftype = fileTypeIQ2_XS
case "Q2_K_S":
params.ftype = fileTypeQ2_K_S
case "Q3_K_XS":
params.ftype = fileTypeQ3_K_XS
case "IQ3_XXS":
params.ftype = fileTypeIQ3_XXS
default:
return fmt.Errorf("unknown filetype: %s", filetype)
}
if retval := C.llama_model_quantize(cinfile, coutfile, &params); retval != 0 {
return fmt.Errorf("llama_model_quantize: %d", retval)
if rc := C.llama_model_quantize(cinfile, coutfile, &params); rc != 0 {
return fmt.Errorf("llama_model_quantize: %d", rc)
}
return nil

View file

@ -2,5 +2,5 @@ package llm
import "embed"
//go:embed build/windows/*/*/bin/*
// unused on windows
var libEmbed embed.FS

185
llm/memory.go Normal file
View file

@ -0,0 +1,185 @@
package llm
import (
"fmt"
"log/slog"
"github.com/ollama/ollama/api"
"github.com/ollama/ollama/format"
"github.com/ollama/ollama/gpu"
"github.com/ollama/ollama/server/envconfig"
)
// This algorithm looks for a complete fit to determine if we need to unload other models
func PredictServerFit(allGpus gpu.GpuInfoList, ggml *GGML, adapters, projectors []string, opts api.Options) (bool, uint64) {
var estimatedVRAM uint64
if opts.NumCtx > int(ggml.KV().ContextLength()) {
slog.Warn("requested context length is greater than model max context length", "requested", opts.NumCtx, "model", ggml.KV().ContextLength())
opts.NumCtx = int(ggml.KV().ContextLength())
}
if opts.NumCtx < 4 {
opts.NumCtx = 4
}
// Split up the GPUs by type and try them
for _, gpus := range allGpus.ByLibrary() {
var layerCount int
layerCount, estimatedVRAM, _ = EstimateGPULayers(gpus, ggml, projectors, opts)
if opts.NumGPU < 0 {
if layerCount > 0 && layerCount >= int(ggml.KV().BlockCount()+1) {
return true, estimatedVRAM
}
} else {
if layerCount > 0 && layerCount >= opts.NumGPU {
return true, estimatedVRAM
}
}
}
return false, estimatedVRAM
}
// Given a model and one or more GPU targets, predict how many layers and bytes we can load, and the total size
// The GPUs provided must all be the same Library
func EstimateGPULayers(gpus []gpu.GpuInfo, ggml *GGML, projectors []string, opts api.Options) (int, uint64, uint64) {
var memoryAvailable uint64
for _, info := range gpus {
memoryAvailable += info.FreeMemory
}
if envconfig.MaxVRAM > 0 {
memoryAvailable = envconfig.MaxVRAM
}
slog.Debug("evaluating", "library", gpus[0].Library, "gpu_count", len(gpus), "available", format.HumanBytes2(memoryAvailable))
// TODO - this is probably wrong, first GPU vs secondaries will have different overheads
memoryMinimum := gpus[0].MinimumMemory
for _, projector := range projectors {
memoryMinimum += projectorMemoryRequirements(projector)
// multimodal models require at least 2048 context
opts.NumCtx = max(opts.NumCtx, 2048)
}
// fp16 k,v = (1 (k) + 1 (v)) * sizeof(float16) * n_ctx * n_layer * n_embd / n_head * n_head_kv
var kv uint64 = 2 * 2 * uint64(opts.NumCtx) * ggml.KV().BlockCount() * ggml.KV().EmbeddingLength() / ggml.KV().HeadCount() * ggml.KV().HeadCountKV()
graphPartialOffload, graphFullOffload := ggml.GraphSize(uint64(opts.NumCtx), uint64(min(opts.NumCtx, opts.NumBatch)))
if graphPartialOffload == 0 {
graphPartialOffload = ggml.KV().GQA() * kv / 6
}
if graphFullOffload == 0 {
graphFullOffload = graphPartialOffload
}
graphFullOffload *= uint64(len(gpus))
graphPartialOffload *= uint64(len(gpus))
// on metal there's no partial offload overhead
if gpus[0].Library == "metal" {
graphPartialOffload = graphFullOffload
}
layers := ggml.Tensors().Layers()
// memoryRequiredTotal represents the memory required for full GPU offloading (all layers)
memoryRequiredTotal := memoryMinimum + graphFullOffload + layers["blk.0"].size()
// memoryRequiredPartial represents the memory required for partial GPU offloading (n > 0, n < layers)
memoryRequiredPartial := memoryMinimum + graphPartialOffload + layers["blk.0"].size()
var memoryLayerOutput uint64
if layer, ok := layers["output_norm"]; ok {
memoryLayerOutput += layer.size()
}
if layer, ok := layers["output"]; ok {
memoryLayerOutput += layer.size()
} else if layer, ok := layers["token_embd"]; ok {
memoryLayerOutput += layer.size()
}
if gpus[0].Library == "metal" && opts.UseMMap {
// memory is preallocated for output tensors
memoryRequiredTotal += memoryLayerOutput
memoryRequiredPartial += memoryLayerOutput
}
var layerCount int
for i := 0; i < int(ggml.KV().BlockCount()); i++ {
memoryLayer := layers[fmt.Sprintf("blk.%d", i)].size()
// KV is proportional to the number of layers
memoryLayer += kv / ggml.KV().BlockCount()
memoryRequiredTotal += memoryLayer
if memoryAvailable > memoryRequiredPartial+memoryLayer {
memoryRequiredPartial += memoryLayer
layerCount++
}
}
if gpus[0].Library != "metal" || !opts.UseMMap {
// memory was not preallocated for output tensors
memoryRequiredTotal += memoryLayerOutput
}
if memoryAvailable > memoryRequiredTotal {
layerCount = int(ggml.KV().BlockCount()) + 1
memoryRequiredPartial = memoryRequiredTotal
}
memoryWeights := memoryRequiredTotal - memoryMinimum - graphFullOffload - kv
slog.Info(
"offload to gpu",
slog.Group(
"layers",
// actual number of layers offloaded
"real", opts.NumGPU,
// estimated number of layers that can be offloaded
"estimate", layerCount,
),
slog.Group(
"memory",
// memory available for offloading
"available", format.HumanBytes2(memoryAvailable),
slog.Group(
"required",
// memory required for full offloading
"full", format.HumanBytes2(memoryRequiredTotal),
// memory required to offload layers.estimate layers
"partial", format.HumanBytes2(memoryRequiredPartial),
// memory of KV cache
"kv", format.HumanBytes2(kv),
),
slog.Group(
"weights",
// memory of the weights
"total", format.HumanBytes2(memoryWeights),
// memory of repeating layers
"repeating", format.HumanBytes2(memoryWeights-memoryLayerOutput),
// memory of non-repeating layers
"nonrepeating", format.HumanBytes2(memoryLayerOutput),
),
slog.Group(
"graph",
// memory of graph when fully offloaded
"full", format.HumanBytes2(graphFullOffload),
// memory of graph when not fully offloaded
"partial", format.HumanBytes2(graphPartialOffload),
),
),
)
if gpus[0].Library == "cpu" {
return 0, 0, memoryRequiredTotal
}
if memoryRequiredPartial > memoryAvailable {
slog.Debug("insufficient VRAM to load any model layers")
return 0, 0, memoryRequiredTotal
}
return layerCount, memoryRequiredPartial, memoryRequiredTotal
}

View file

@ -0,0 +1,12 @@
diff --git a/examples/llava/clip.cpp b/examples/llava/clip.cpp
index e431c7f7..f077e688 100644
--- a/examples/llava/clip.cpp
+++ b/examples/llava/clip.cpp
@@ -3,6 +3,7 @@
// I'll gradually clean and extend it
// Note: Even when using identical normalized image inputs (see normalize_image_u8_to_f32()) we have a significant difference in resulting embeddings compared to pytorch
#include "clip.h"
+#include "common.h"
#include "log.h"
#include "ggml.h"
#include "ggml-alloc.h"

45
llm/patches/04-metal.diff Normal file
View file

@ -0,0 +1,45 @@
diff --git a/ggml-metal.m b/ggml-metal.m
index 0207b787..b5e9884b 100644
--- a/ggml-metal.m
+++ b/ggml-metal.m
@@ -1396,27 +1396,23 @@ static enum ggml_status ggml_metal_graph_compute(
// to the matrix-vector kernel
int ne11_mm_min = 1;
-#if 0
// the numbers below are measured on M2 Ultra for 7B and 13B models
// these numbers do not translate to other devices or model sizes
// TODO: need to find a better approach
- if ([ctx->device.name isEqualToString:@"Apple M2 Ultra"]) {
- switch (src0t) {
- case GGML_TYPE_F16: ne11_mm_min = 2; break;
- case GGML_TYPE_Q8_0: ne11_mm_min = 7; break;
- case GGML_TYPE_Q2_K: ne11_mm_min = 15; break;
- case GGML_TYPE_Q3_K: ne11_mm_min = 7; break;
- case GGML_TYPE_Q4_0:
- case GGML_TYPE_Q4_1: ne11_mm_min = 15; break;
- case GGML_TYPE_Q4_K: ne11_mm_min = 11; break;
- case GGML_TYPE_Q5_0: // not tested yet
- case GGML_TYPE_Q5_1: ne11_mm_min = 13; break; // not tested yet
- case GGML_TYPE_Q5_K: ne11_mm_min = 7; break;
- case GGML_TYPE_Q6_K: ne11_mm_min = 7; break;
- default: ne11_mm_min = 1; break;
- }
+ switch (src0t) {
+ case GGML_TYPE_F16: ne11_mm_min = 2; break;
+ case GGML_TYPE_Q8_0: ne11_mm_min = 7; break;
+ case GGML_TYPE_Q2_K: ne11_mm_min = 15; break;
+ case GGML_TYPE_Q3_K: ne11_mm_min = 7; break;
+ case GGML_TYPE_Q4_0:
+ case GGML_TYPE_Q4_1: ne11_mm_min = 15; break;
+ case GGML_TYPE_Q4_K: ne11_mm_min = 11; break;
+ case GGML_TYPE_Q5_0: // not tested yet
+ case GGML_TYPE_Q5_1: ne11_mm_min = 13; break; // not tested yet
+ case GGML_TYPE_Q5_K: ne11_mm_min = 7; break;
+ case GGML_TYPE_Q6_K: ne11_mm_min = 7; break;
+ default: ne11_mm_min = 1; break;
}
-#endif
// for now the matrix-matrix multiplication kernel only works on A14+/M1+ SoCs
// AMD GPU and older A-chips will reuse matrix-vector multiplication kernel

View file

@ -0,0 +1,24 @@
diff --git a/examples/llava/clip.cpp b/examples/llava/clip.cpp
index e3c9bcd4..b43f892d 100644
--- a/examples/llava/clip.cpp
+++ b/examples/llava/clip.cpp
@@ -573,14 +573,16 @@ static ggml_cgraph * clip_image_build_graph(clip_ctx * ctx, const clip_image_f32
struct ggml_tensor * embeddings = inp;
if (ctx->has_class_embedding) {
embeddings = ggml_new_tensor_3d(ctx0, GGML_TYPE_F32, hidden_size, num_positions, batch_size);
+ }
+ ggml_set_name(embeddings, "embeddings");
+ ggml_set_input(embeddings);
+
+ if (ctx->has_class_embedding) {
embeddings = ggml_acc(ctx0, embeddings, model.class_embedding,
embeddings->nb[1], embeddings->nb[2], embeddings->nb[3], 0);
embeddings = ggml_acc(ctx0, embeddings, inp,
embeddings->nb[1], embeddings->nb[2], embeddings->nb[3], model.class_embedding->nb[1]);
}
- ggml_set_name(embeddings, "embeddings");
- ggml_set_input(embeddings);
-
struct ggml_tensor * positions = ggml_new_tensor_1d(ctx0, GGML_TYPE_I32, num_positions);
ggml_set_name(positions, "positions");

View file

@ -9,6 +9,7 @@ import (
"log/slog"
"os"
"path/filepath"
"runtime"
"strings"
"golang.org/x/exp/slices"
@ -17,7 +18,7 @@ import (
"github.com/ollama/ollama/gpu"
)
var errPayloadMissing = fmt.Errorf("expected payloads not included in this build of ollama")
var errPayloadMissing = errors.New("expected payloads not included in this build of ollama")
func Init() error {
payloadsDir, err := gpu.PayloadsDir()
@ -25,13 +26,15 @@ func Init() error {
return err
}
slog.Info("extracting embedded files", "dir", payloadsDir)
binGlob := "build/*/*/*/bin/*"
if runtime.GOOS != "windows" {
slog.Info("extracting embedded files", "dir", payloadsDir)
binGlob := "build/*/*/*/bin/*"
// extract server libraries
err = extractFiles(payloadsDir, binGlob)
if err != nil {
return fmt.Errorf("extract binaries: %v", err)
// extract server libraries
err = extractFiles(payloadsDir, binGlob)
if err != nil {
return fmt.Errorf("extract binaries: %v", err)
}
}
var variants []string
@ -138,6 +141,23 @@ func serversForGpu(info gpu.GpuInfo) []string {
return servers
}
// Return the optimal server for this CPU architecture
func serverForCpu() string {
if runtime.GOOS == "darwin" && runtime.GOARCH == "arm64" {
return "metal"
}
variant := gpu.GetCPUVariant()
availableServers := availableServers()
if variant != "" {
for cmp := range availableServers {
if cmp == "cpu_"+variant {
return cmp
}
}
}
return "cpu"
}
// extract extracts the embedded files to the target directory
func extractFiles(targetDir string, glob string) error {
files, err := fs.Glob(libEmbed, glob)

View file

@ -21,21 +21,47 @@ import (
"strings"
"time"
"golang.org/x/sync/semaphore"
"github.com/ollama/ollama/api"
"github.com/ollama/ollama/format"
"github.com/ollama/ollama/gpu"
"github.com/ollama/ollama/server/envconfig"
)
// LlamaServer is an instance of the llama.cpp server
type LlamaServer struct {
type LlamaServer interface {
Ping(ctx context.Context) error
WaitUntilRunning(ctx context.Context) error
Completion(ctx context.Context, req CompletionRequest, fn func(CompletionResponse)) error
Embedding(ctx context.Context, prompt string) ([]float64, error)
Tokenize(ctx context.Context, content string) ([]int, error)
Detokenize(ctx context.Context, tokens []int) (string, error)
Close() error
EstimatedVRAM() uint64
}
// llmServer is an instance of the llama.cpp server
type llmServer struct {
port int
cmd *exec.Cmd
done chan error // Channel to signal when the process exits
status *StatusWriter
options api.Options
// TODO - this should be broken down by GPU
estimatedVRAM uint64 // Estimated usage of VRAM by the loaded model
estimatedTotal uint64 // Total size of model
totalLayers uint64
gpuCount int
sem *semaphore.Weighted
}
func NewLlamaServer(model string, adapters, projectors []string, opts api.Options) (*LlamaServer, error) {
func LoadModel(model string) (*GGML, error) {
if _, err := os.Stat(model); err != nil {
return nil, err
}
f, err := os.Open(model)
if err != nil {
return nil, err
@ -43,144 +69,69 @@ func NewLlamaServer(model string, adapters, projectors []string, opts api.Option
defer f.Close()
ggml, _, err := DecodeGGML(f)
if err != nil {
return nil, err
}
return ggml, err
}
// NewLlamaServer will run a server for the given GPUs
// The gpu list must be a single family.
func NewLlamaServer(gpus gpu.GpuInfoList, model string, ggml *GGML, adapters, projectors []string, opts api.Options) (LlamaServer, error) {
var err error
if opts.NumCtx > int(ggml.KV().ContextLength()) {
slog.Warn("requested context length is greater than model max context length", "requested", opts.NumCtx, "model", ggml.KV().ContextLength())
opts.NumCtx = int(ggml.KV().ContextLength())
slog.Warn("requested context length is greater than the model's training context window size", "requested", opts.NumCtx, "training size", ggml.KV().ContextLength())
}
if opts.NumCtx < 4 {
opts.NumCtx = 4
}
memoryAvailable, _ := gpu.CheckVRAM()
info := gpu.GetGPUInfo()
cpuRunner := ""
var estimatedVRAM uint64
var estimatedTotal uint64
var systemMemory uint64
gpuCount := len(gpus)
if (len(gpus) == 1 && gpus[0].Library == "cpu") || opts.NumGPU == 0 {
memoryMinimum := info.MinimumMemory
for _, projector := range projectors {
memoryMinimum += projectorMemoryRequirements(projector)
// TODO evaluate system memory to see if we should block the load, or force an unload of another CPU runner
// multimodal models require at least 2048 context
opts.NumCtx = max(opts.NumCtx, 2048)
}
cpuRunner = serverForCpu()
gpuCount = 0
} else {
if gpus[0].Library == "metal" {
memInfo, err := gpu.GetCPUMem()
if err != nil {
slog.Error("failed to lookup system memory", "error", err)
} else {
systemMemory = memInfo.TotalMemory
slog.Debug("system memory", "total", format.HumanBytes2(systemMemory))
}
}
var layers int
layers, estimatedVRAM, estimatedTotal = EstimateGPULayers(gpus, ggml, projectors, opts)
// fp16 k,v = (1 (k) + 1 (v)) * sizeof(float16) * n_ctx * n_layer * n_embd / n_head * n_head_kv
var kv uint64 = 2 * 2 * uint64(opts.NumCtx) * ggml.KV().BlockCount() * ggml.KV().EmbeddingLength() / ggml.KV().HeadCount() * ggml.KV().HeadCountKV()
graphPartialOffload, graphFullOffload := ggml.GraphSize(uint64(opts.NumCtx), uint64(min(opts.NumCtx, opts.NumBatch)))
if graphPartialOffload == 0 {
graphPartialOffload = ggml.KV().GQA() * kv / 6
}
if graphFullOffload == 0 {
graphFullOffload = graphPartialOffload
}
graphFullOffload *= uint64(info.DeviceCount)
graphPartialOffload *= uint64(info.DeviceCount)
// memoryRequiredTotal represents the memory required for full GPU offloading (all layers)
memoryRequiredTotal := memoryMinimum + graphFullOffload
// memoryRequiredPartial represents the memory required for partial GPU offloading (n > 0, n < layers)
memoryRequiredPartial := memoryMinimum + graphPartialOffload
if info.Library != "metal" {
if memoryRequiredPartial > memoryAvailable {
info.Library = "cpu"
if gpus[0].Library == "metal" && estimatedVRAM > systemMemory {
// disable partial offloading when model is greater than total system memory as this
// can lead to locking up the system
opts.NumGPU = 0
} else if opts.NumGPU < 0 && layers > 0 && gpus[0].Library != "cpu" {
opts.NumGPU = layers
}
}
var layerCount int
layers := ggml.Tensors().Layers()
for i := 0; i < int(ggml.KV().BlockCount()); i++ {
memoryLayer := layers[fmt.Sprintf("blk.%d", i)].size()
// KV is proportional to the number of layers
memoryLayer += kv / ggml.KV().BlockCount()
memoryRequiredTotal += memoryLayer
if memoryAvailable > memoryRequiredPartial+memoryLayer {
memoryRequiredPartial += memoryLayer
layerCount++
}
}
var memoryLayerOutput uint64
for k, v := range layers {
if !strings.HasPrefix(k, "blk.") {
memoryLayerOutput += v.size()
}
}
memoryRequiredTotal += memoryLayerOutput
if info.Library == "metal" && memoryRequiredTotal > info.TotalMemory {
// disable partial offloading when model is greater than total system memory
opts.NumGPU = 0
} else if memoryAvailable > memoryRequiredTotal {
layerCount = int(ggml.KV().BlockCount()) + 1
memoryRequiredPartial = memoryRequiredTotal
}
if opts.NumGPU < 0 {
opts.NumGPU = layerCount
}
memoryWeights := memoryRequiredTotal - memoryMinimum - graphFullOffload - kv
slog.Info(
"offload to gpu",
slog.Group(
"layers",
// actual number of layers offloaded
"real", opts.NumGPU,
// estimated number of layers that can be offloaded
"estimate", layerCount,
),
slog.Group(
"memory",
// memory available for offloading
"available", format.HumanBytes2(memoryAvailable),
slog.Group(
"required",
// memory required for full offloading
"full", format.HumanBytes2(memoryRequiredTotal),
// memory required to offload layers.estimate layers
"partial", format.HumanBytes2(memoryRequiredPartial),
// memory of KV cache
"kv", format.HumanBytes2(kv),
),
slog.Group(
"weights",
// memory of the weights
"total", format.HumanBytes2(memoryWeights),
// memory of repeating layers
"repeating", format.HumanBytes2(memoryWeights-memoryLayerOutput),
// memory of non-repeating layers
"nonrepeating", format.HumanBytes2(memoryLayerOutput),
),
slog.Group(
"graph",
// memory of graph when fully offloaded
"full", format.HumanBytes2(graphFullOffload),
// memory of graph when not fully offloaded
"partial", format.HumanBytes2(graphPartialOffload),
),
),
)
// Loop through potential servers
finalErr := fmt.Errorf("no suitable llama servers found")
if len(adapters) > 1 {
return nil, errors.New("ollama supports only one lora adapter, but multiple were provided")
}
availableServers := availableServers()
servers := serversForGpu(info)
demandLib := os.Getenv("OLLAMA_LLM_LIBRARY")
var servers []string
if cpuRunner != "" {
servers = []string{cpuRunner}
} else {
servers = serversForGpu(gpus[0]) // All GPUs in the list are matching Library and Variant
}
demandLib := envconfig.LLMLibrary
if demandLib != "" {
serverPath := availableServers[demandLib]
if serverPath == "" {
@ -188,11 +139,15 @@ func NewLlamaServer(model string, adapters, projectors []string, opts api.Option
} else {
slog.Info("user override", "OLLAMA_LLM_LIBRARY", demandLib, "path", serverPath)
servers = []string{demandLib}
if strings.HasPrefix(demandLib, "cpu") {
// Omit the GPU flag to silence the warning
opts.NumGPU = -1
}
}
}
if len(servers) == 0 {
return nil, fmt.Errorf("no servers found for %v", info)
return nil, fmt.Errorf("no servers found for %v", gpus)
}
params := []string{
@ -201,7 +156,7 @@ func NewLlamaServer(model string, adapters, projectors []string, opts api.Option
"--batch-size", fmt.Sprintf("%d", opts.NumBatch),
"--embedding",
}
if debug := os.Getenv("OLLAMA_DEBUG"); debug != "" {
if envconfig.Debug {
params = append(params, "--log-format", "json")
} else {
params = append(params, "--log-disable")
@ -211,7 +166,7 @@ func NewLlamaServer(model string, adapters, projectors []string, opts api.Option
params = append(params, "--n-gpu-layers", fmt.Sprintf("%d", opts.NumGPU))
}
if debug := os.Getenv("OLLAMA_DEBUG"); debug != "" {
if envconfig.Debug {
params = append(params, "--verbose")
}
@ -249,10 +204,30 @@ func NewLlamaServer(model string, adapters, projectors []string, opts api.Option
params = append(params, "--numa")
}
// Loop through potential servers
var finalErr error
numParallel := envconfig.NumParallel
// TODO (jmorganca): multimodal models don't support parallel yet
// see https://github.com/ollama/ollama/issues/4165
if len(projectors) > 0 {
numParallel = 1
slog.Warn("multimodal models don't support parallel requests yet")
}
params = append(params, "--parallel", fmt.Sprintf("%d", numParallel))
for i := 0; i < len(servers); i++ {
dir := availableServers[servers[i]]
if dir == "" {
// Shouldn't happen
finalErr = fmt.Errorf("[%d] server %s not listed in available servers %v", i, servers[i], availableServers)
slog.Error("sever list inconsistent", "error", finalErr)
continue
}
if strings.HasPrefix(servers[i], "cpu") {
// TODO if we tried a gpu runner first, and it failed, record the error and bubble that back up
gpuCount = 0
}
// Find an availableServers port, retry on each iterration in case the failure was a port conflict race
port := 0
@ -273,12 +248,21 @@ func NewLlamaServer(model string, adapters, projectors []string, opts api.Option
if runtime.GOOS == "windows" {
pathEnv = "PATH"
}
// append the server directory to LD_LIBRARY_PATH/PATH
// prepend the server directory to LD_LIBRARY_PATH/PATH
libraryPaths := []string{dir}
if libraryPath, ok := os.LookupEnv(pathEnv); ok {
// Append our runner directory to the path
// This will favor system libraries over our bundled library dependencies
libraryPaths = append(filepath.SplitList(libraryPath), libraryPaths...)
libraryPaths = append(libraryPaths, filepath.SplitList(libraryPath)...)
}
// Note: we always put the dependency path first
// since this was the exact version we verified for AMD GPUs
// and we favor what the user had in their path
if gpus[0].DependencyPath != "" {
// TODO refine for multi-gpu support
libraryPaths = append([]string{gpus[0].DependencyPath}, libraryPaths...)
}
server := filepath.Join(dir, "ollama_llama_server")
@ -286,21 +270,66 @@ func NewLlamaServer(model string, adapters, projectors []string, opts api.Option
server = server + ".exe"
}
s := &LlamaServer{
port: port,
cmd: exec.Command(server, finalParams...),
status: NewStatusWriter(os.Stderr),
options: opts,
// Detect tmp cleaners wiping out the file
_, err := os.Stat(server)
if errors.Is(err, os.ErrNotExist) {
slog.Warn("llama server disappeared, reinitializing payloads", "path", server, "error", err)
err = Init()
if err != nil {
slog.Warn("failed to reinitialize payloads", "error", err)
return nil, err
}
}
libEnv := fmt.Sprintf("%s=%s", pathEnv, strings.Join(libraryPaths, string(filepath.ListSeparator)))
slog.Debug(libEnv)
s.cmd.Env = append(os.Environ(), libEnv)
s := &llmServer{
port: port,
cmd: exec.Command(server, finalParams...),
status: NewStatusWriter(os.Stderr),
options: opts,
estimatedVRAM: estimatedVRAM,
estimatedTotal: estimatedTotal,
sem: semaphore.NewWeighted(int64(numParallel)),
totalLayers: ggml.KV().BlockCount() + 1,
gpuCount: gpuCount,
}
s.cmd.Env = os.Environ()
s.cmd.Stdout = os.Stdout
s.cmd.Stderr = s.status
visibleDevicesEnv, visibleDevicesEnvVal := gpu.GpuInfoList(gpus).GetVisibleDevicesEnv()
pathEnvVal := strings.Join(libraryPaths, string(filepath.ListSeparator))
// Update or add the path and visible devices variable with our adjusted version
pathNeeded := true
devicesNeeded := visibleDevicesEnv != ""
for i := range s.cmd.Env {
cmp := strings.SplitN(s.cmd.Env[i], "=", 2)
if strings.EqualFold(cmp[0], pathEnv) {
s.cmd.Env[i] = pathEnv + "=" + pathEnvVal
pathNeeded = false
} else if devicesNeeded && strings.EqualFold(cmp[0], visibleDevicesEnv) {
s.cmd.Env[i] = visibleDevicesEnv + "=" + visibleDevicesEnvVal
devicesNeeded = false
}
}
if pathNeeded {
s.cmd.Env = append(s.cmd.Env, pathEnv+"="+pathEnvVal)
}
if devicesNeeded {
s.cmd.Env = append(s.cmd.Env, visibleDevicesEnv+"="+visibleDevicesEnvVal)
}
slog.Info("starting llama server", "cmd", s.cmd.String())
// Log at debug as the environment is inherited and might contain sensitive information
slog.Debug("subprocess", "environment", s.cmd.Env)
if err = s.cmd.Start(); err != nil {
// Detect permission denied and augment them essage about noexec
if errors.Is(err, os.ErrPermission) {
finalErr = fmt.Errorf("unable to start server %w. %s may have noexec set. Set OLLAMA_TMPDIR for server to a writable executable directory", err, dir)
continue
}
msg := ""
if s.status != nil && s.status.LastErrMsg != "" {
msg = s.status.LastErrMsg
@ -310,12 +339,6 @@ func NewLlamaServer(model string, adapters, projectors []string, opts api.Option
continue
}
// reap subprocess when it exits
go func() {
// Exit status managed via getServerStatus
_ = s.cmd.Wait()
}()
return s, nil
}
@ -347,12 +370,27 @@ type ServerStatus int
const ( // iota is reset to 0
ServerStatusReady ServerStatus = iota
ServerStatusNoSlotsAvaialble
ServerStatusNoSlotsAvailable
ServerStatusLoadingModel
ServerStatusNotResponding
ServerStatusError
)
func (s ServerStatus) ToString() string {
switch s {
case ServerStatusReady:
return "llm server ready"
case ServerStatusNoSlotsAvailable:
return "llm busy - no slots available"
case ServerStatusLoadingModel:
return "llm server loading model"
case ServerStatusNotResponding:
return "llm server not responding"
default:
return "llm server error"
}
}
type ServerStatusResp struct {
Status string `json:"status"`
SlotsIdle int `json:"slots_idle"`
@ -360,13 +398,17 @@ type ServerStatusResp struct {
Error string `json:"error"`
}
func (s *LlamaServer) getServerStatus(ctx context.Context) (ServerStatus, error) {
func (s *llmServer) getServerStatus(ctx context.Context) (ServerStatus, error) {
// Fail fast if its exited
if s.cmd.ProcessState != nil {
msg := ""
if s.status != nil && s.status.LastErrMsg != "" {
msg = s.status.LastErrMsg
}
if s.cmd.ProcessState.ExitCode() == -1 {
// Most likely a signal killed it, log some more details to try to help troubleshoot
slog.Warn("llama runner process no longer running", "sys", s.cmd.ProcessState.Sys(), "string", s.cmd.ProcessState.String())
}
return ServerStatusError, fmt.Errorf("llama runner process no longer running: %d %s", s.cmd.ProcessState.ExitCode(), msg)
}
@ -399,7 +441,7 @@ func (s *LlamaServer) getServerStatus(ctx context.Context) (ServerStatus, error)
case "ok":
return ServerStatusReady, nil
case "no slot available":
return ServerStatusNoSlotsAvaialble, nil
return ServerStatusNoSlotsAvailable, nil
case "loading model":
return ServerStatusLoadingModel, nil
default:
@ -407,7 +449,30 @@ func (s *LlamaServer) getServerStatus(ctx context.Context) (ServerStatus, error)
}
}
func (s *LlamaServer) Ping(ctx context.Context) error {
// getServerStatusRetry will retry if ServerStatusNoSlotsAvailable is received
func (s *llmServer) getServerStatusRetry(ctx context.Context) (ServerStatus, error) {
var retries int
for {
status, err := s.getServerStatus(ctx)
if err != nil {
return status, err
}
if status == ServerStatusNoSlotsAvailable {
if retries >= 10 {
return status, fmt.Errorf("no slots available after %d retries", retries)
}
time.Sleep(5 * time.Millisecond)
retries++
continue
}
return status, nil
}
}
func (s *llmServer) Ping(ctx context.Context) error {
_, err := s.getServerStatus(ctx)
if err != nil {
slog.Debug("server unhealthy", "error", err)
@ -416,13 +481,25 @@ func (s *LlamaServer) Ping(ctx context.Context) error {
return nil
}
func (s *LlamaServer) WaitUntilRunning() error {
func (s *llmServer) WaitUntilRunning(ctx context.Context) error {
start := time.Now()
expiresAt := time.Now().Add(10 * time.Minute) // be generous with timeout, large models can take a while to load
slog.Info("waiting for llama runner to start responding")
for {
select {
case <-ctx.Done():
slog.Info("context expired before server started")
return fmt.Errorf("timed out waiting for llama runner to start: %w", ctx.Err())
case err := <-s.done:
msg := ""
if s.status != nil && s.status.LastErrMsg != "" {
msg = s.status.LastErrMsg
}
return fmt.Errorf("llama runner process has terminated: %v %s", err, msg)
default:
}
ctx, cancel := context.WithTimeout(context.Background(), 200*time.Millisecond)
defer cancel()
status, err := s.getServerStatus(ctx)
@ -487,7 +564,6 @@ ws ::= ([ \t\n] ws)?
`
const maxBufferSize = 512 * format.KiloByte
const maxRetries = 3
type ImageData struct {
Data []byte `json:"data"`
@ -524,7 +600,19 @@ type CompletionResponse struct {
EvalDuration time.Duration
}
func (s *LlamaServer) Completion(ctx context.Context, req CompletionRequest, fn func(CompletionResponse)) error {
func (s *llmServer) Completion(ctx context.Context, req CompletionRequest, fn func(CompletionResponse)) error {
if err := s.sem.Acquire(ctx, 1); err != nil {
slog.Error("Failed to acquire semaphore", "error", err)
return err
}
defer s.sem.Release(1)
// only allow maximum 10 "context shifts" to avoid infinite generation
if req.Options.NumPredict < 0 || req.Options.NumPredict > 10*s.options.NumCtx {
req.Options.NumPredict = 10 * s.options.NumCtx
slog.Debug("setting token limit to 10x num_ctx", "num_ctx", s.options.NumCtx, "num_predict", req.Options.NumPredict)
}
request := map[string]any{
"prompt": req.Prompt,
"stream": true,
@ -551,11 +639,11 @@ func (s *LlamaServer) Completion(ctx context.Context, req CompletionRequest, fn
}
// Make sure the server is ready
status, err := s.getServerStatus(ctx)
status, err := s.getServerStatusRetry(ctx)
if err != nil {
return err
} else if status != ServerStatusReady {
return fmt.Errorf("unexpected server status: %d", status)
return fmt.Errorf("unexpected server status: %s", status.ToString())
}
if req.Format == "json" {
@ -565,133 +653,113 @@ func (s *LlamaServer) Completion(ctx context.Context, req CompletionRequest, fn
}
}
retryDelay := 100 * time.Microsecond
for retries := 0; retries < maxRetries; retries++ {
if retries > 0 {
time.Sleep(retryDelay) // wait before retrying
retryDelay *= 2 // exponential backoff
}
// Handling JSON marshaling with special characters unescaped.
buffer := &bytes.Buffer{}
enc := json.NewEncoder(buffer)
enc.SetEscapeHTML(false)
// Handling JSON marshaling with special characters unescaped.
buffer := &bytes.Buffer{}
enc := json.NewEncoder(buffer)
enc.SetEscapeHTML(false)
if err := enc.Encode(request); err != nil {
return fmt.Errorf("failed to marshal data: %v", err)
}
if err := enc.Encode(request); err != nil {
return fmt.Errorf("failed to marshal data: %v", err)
}
endpoint := fmt.Sprintf("http://127.0.0.1:%d/completion", s.port)
serverReq, err := http.NewRequestWithContext(ctx, http.MethodPost, endpoint, buffer)
if err != nil {
return fmt.Errorf("error creating POST request: %v", err)
}
serverReq.Header.Set("Content-Type", "application/json")
endpoint := fmt.Sprintf("http://127.0.0.1:%d/completion", s.port)
req, err := http.NewRequestWithContext(ctx, http.MethodPost, endpoint, buffer)
res, err := http.DefaultClient.Do(serverReq)
if err != nil {
return fmt.Errorf("POST predict: %v", err)
}
defer res.Body.Close()
if res.StatusCode >= 400 {
bodyBytes, err := io.ReadAll(res.Body)
if err != nil {
return fmt.Errorf("error creating POST request: %v", err)
return fmt.Errorf("failed reading llm error response: %w", err)
}
req.Header.Set("Content-Type", "application/json")
log.Printf("llm predict error: %s", bodyBytes)
return fmt.Errorf("%s", bodyBytes)
}
resp, err := http.DefaultClient.Do(req)
if err != nil {
return fmt.Errorf("POST predict: %v", err)
}
defer resp.Body.Close()
scanner := bufio.NewScanner(res.Body)
buf := make([]byte, 0, maxBufferSize)
scanner.Buffer(buf, maxBufferSize)
if resp.StatusCode >= 400 {
bodyBytes, err := io.ReadAll(resp.Body)
if err != nil {
return fmt.Errorf("failed reading llm error response: %w", err)
// keep track of the last token generated, this is used to abort if the model starts looping
var lastToken string
var tokenRepeat int
for scanner.Scan() {
select {
case <-ctx.Done():
// This handles the request cancellation
return ctx.Err()
default:
line := scanner.Bytes()
if len(line) == 0 {
continue
}
log.Printf("llm predict error: %s", bodyBytes)
return fmt.Errorf("%s", bodyBytes)
}
scanner := bufio.NewScanner(resp.Body)
buf := make([]byte, 0, maxBufferSize)
scanner.Buffer(buf, maxBufferSize)
evt, ok := bytes.CutPrefix(line, []byte("data: "))
if !ok {
return fmt.Errorf("error parsing llm response stream: %s", line)
}
retryNeeded := false
// keep track of the last token generated, this is used to abort if the model starts looping
var lastToken string
var tokenRepeat int
var c completion
if err := json.Unmarshal(evt, &c); err != nil {
return fmt.Errorf("error unmarshaling llm prediction response: %v", err)
}
for scanner.Scan() {
select {
case <-ctx.Done():
// This handles the request cancellation
return ctx.Err()
switch {
case strings.TrimSpace(c.Content) == lastToken:
tokenRepeat++
default:
line := scanner.Bytes()
if len(line) == 0 {
continue
}
// try again on slot unavailable
if bytes.Contains(line, []byte("slot unavailable")) {
retryNeeded = true
break
}
evt, ok := bytes.CutPrefix(line, []byte("data: "))
if !ok {
return fmt.Errorf("error parsing llm response stream: %s", line)
}
var c completion
if err := json.Unmarshal(evt, &c); err != nil {
return fmt.Errorf("error unmarshaling llm prediction response: %v", err)
}
switch {
case strings.TrimSpace(c.Content) == lastToken:
tokenRepeat++
default:
lastToken = strings.TrimSpace(c.Content)
tokenRepeat = 0
}
// 30 picked as an arbitrary max token repeat limit, modify as needed
if tokenRepeat > 30 {
slog.Debug("prediction aborted, token repeat limit reached")
return ctx.Err()
}
if c.Content != "" {
fn(CompletionResponse{
Content: c.Content,
})
}
if c.Stop {
fn(CompletionResponse{
Done: true,
PromptEvalCount: c.Timings.PromptN,
PromptEvalDuration: parseDurationMs(c.Timings.PromptMS),
EvalCount: c.Timings.PredictedN,
EvalDuration: parseDurationMs(c.Timings.PredictedMS),
})
return nil
}
lastToken = strings.TrimSpace(c.Content)
tokenRepeat = 0
}
}
if err := scanner.Err(); err != nil {
if strings.Contains(err.Error(), "unexpected EOF") {
s.Close()
msg := ""
if s.status != nil && s.status.LastErrMsg != "" {
msg = s.status.LastErrMsg
}
return fmt.Errorf("an unknown error was encountered while running the model %s", msg)
// 30 picked as an arbitrary max token repeat limit, modify as needed
if tokenRepeat > 30 {
slog.Debug("prediction aborted, token repeat limit reached")
return ctx.Err()
}
return fmt.Errorf("error reading llm response: %v", err)
}
if !retryNeeded {
return nil // success
if c.Content != "" {
fn(CompletionResponse{
Content: c.Content,
})
}
if c.Stop {
fn(CompletionResponse{
Done: true,
PromptEvalCount: c.Timings.PromptN,
PromptEvalDuration: parseDurationMs(c.Timings.PromptMS),
EvalCount: c.Timings.PredictedN,
EvalDuration: parseDurationMs(c.Timings.PredictedMS),
})
return nil
}
}
}
// should never reach here ideally
return fmt.Errorf("max retries exceeded")
if err := scanner.Err(); err != nil {
if strings.Contains(err.Error(), "unexpected EOF") {
s.Close()
msg := ""
if s.status != nil && s.status.LastErrMsg != "" {
msg = s.status.LastErrMsg
}
return fmt.Errorf("an unknown error was encountered while running the model %s", msg)
}
return fmt.Errorf("error reading llm response: %v", err)
}
return nil
}
type EmbeddingRequest struct {
@ -702,13 +770,19 @@ type EmbeddingResponse struct {
Embedding []float64 `json:"embedding"`
}
func (s *LlamaServer) Embedding(ctx context.Context, prompt string) ([]float64, error) {
func (s *llmServer) Embedding(ctx context.Context, prompt string) ([]float64, error) {
if err := s.sem.Acquire(ctx, 1); err != nil {
slog.Error("Failed to acquire semaphore", "error", err)
return nil, err
}
defer s.sem.Release(1)
// Make sure the server is ready
status, err := s.getServerStatus(ctx)
status, err := s.getServerStatusRetry(ctx)
if err != nil {
return nil, err
} else if status != ServerStatusReady {
return nil, fmt.Errorf("unexpected server status: %d", status)
return nil, fmt.Errorf("unexpected server status: %s", status.ToString())
}
data, err := json.Marshal(TokenizeRequest{Content: prompt})
@ -754,13 +828,13 @@ type TokenizeResponse struct {
Tokens []int `json:"tokens"`
}
func (s *LlamaServer) Tokenize(ctx context.Context, content string) ([]int, error) {
func (s *llmServer) Tokenize(ctx context.Context, content string) ([]int, error) {
// Make sure the server is ready
status, err := s.getServerStatus(ctx)
if err != nil {
return nil, err
} else if status != ServerStatusReady {
return nil, fmt.Errorf("unexpected server status: %d", status)
} else if status != ServerStatusReady && status != ServerStatusNoSlotsAvailable {
return nil, fmt.Errorf("unexpected server status: %s", status.ToString())
}
data, err := json.Marshal(TokenizeRequest{Content: content})
@ -806,13 +880,13 @@ type DetokenizeResponse struct {
Content string `json:"content"`
}
func (s *LlamaServer) Detokenize(ctx context.Context, tokens []int) (string, error) {
func (s *llmServer) Detokenize(ctx context.Context, tokens []int) (string, error) {
// Make sure the server is ready
status, err := s.getServerStatus(ctx)
if err != nil {
return "", err
} else if status != ServerStatusReady {
return "", fmt.Errorf("unexpected server status: %d", status)
} else if status != ServerStatusReady && status != ServerStatusNoSlotsAvailable {
return "", fmt.Errorf("unexpected server status: %s", status.ToString())
}
data, err := json.Marshal(DetokenizeRequest{Tokens: tokens})
@ -850,15 +924,25 @@ func (s *LlamaServer) Detokenize(ctx context.Context, tokens []int) (string, err
return decoded.Content, nil
}
func (s *LlamaServer) Close() error {
func (s *llmServer) Close() error {
if s.cmd != nil {
slog.Debug("stopping llama server")
return s.cmd.Process.Kill()
if err := s.cmd.Process.Kill(); err != nil {
return err
}
_ = s.cmd.Wait()
slog.Debug("llama server stopped")
}
return nil
}
func (s *llmServer) EstimatedVRAM() uint64 {
return s.estimatedVRAM
}
func parseDurationMs(ms float64) time.Duration {
dur, err := time.ParseDuration(fmt.Sprintf("%fms", ms))
if err != nil {

View file

@ -19,7 +19,7 @@ export default function () {
const [step, setStep] = useState<Step>(Step.WELCOME)
const [commandCopied, setCommandCopied] = useState<boolean>(false)
const command = 'ollama run llama2'
const command = 'ollama run llama3'
return (
<div className='drag'>

Some files were not shown because too many files have changed in this diff Show more