diff --git a/.github/workflows/release.yaml b/.github/workflows/release.yaml index ffb2cf9d..4adab4f8 100644 --- a/.github/workflows/release.yaml +++ b/.github/workflows/release.yaml @@ -103,6 +103,7 @@ jobs: path: | llm/build/**/bin/* llm/build/**/*.a + dist/windows-amd64/** # ROCm generation step generate-windows-rocm: @@ -173,7 +174,9 @@ jobs: - uses: actions/upload-artifact@v4 with: name: generate-windows-rocm - path: llm/build/**/bin/* + path: | + llm/build/**/bin/* + dist/windows-amd64/** - uses: actions/upload-artifact@v4 with: name: windows-rocm-deps @@ -253,7 +256,9 @@ jobs: - uses: actions/upload-artifact@v4 with: name: generate-windows-cuda - path: llm/build/**/bin/* + path: | + llm/build/**/bin/* + dist/windows-amd64/** - uses: actions/upload-artifact@v4 with: name: windows-cuda-deps @@ -306,23 +311,18 @@ jobs: - uses: actions/download-artifact@v4 with: name: generate-windows-cpu - path: llm/build - uses: actions/download-artifact@v4 with: name: generate-windows-cuda - path: llm/build - uses: actions/download-artifact@v4 with: name: windows-cuda-deps - path: dist/deps - uses: actions/download-artifact@v4 with: name: windows-rocm-deps - path: dist/deps - uses: actions/download-artifact@v4 with: name: generate-windows-rocm - path: llm/build - run: dir llm/build - run: | $gopath=(get-command go).source | split-path -parent @@ -331,13 +331,13 @@ jobs: $env:CMAKE_SYSTEM_VERSION="10.0.22621.0" $env:PATH="$gopath;$env:PATH" $env:OLLAMA_SKIP_GENERATE="1" - $env:NVIDIA_DIR=$(resolve-path ".\dist\deps") - $env:HIP_PATH=$(resolve-path ".\dist\deps") & .\scripts\build_windows.ps1 - uses: actions/upload-artifact@v4 with: name: dist-windows - path: dist/*.exe + path: | + dist/OllamaSetup.exe + dist/ollama-windows-*.zip # Linux x86 assets built using the container based build build-linux-amd64: diff --git a/.github/workflows/test.yaml b/.github/workflows/test.yaml index 26d754a9..9a2544b8 100644 --- a/.github/workflows/test.yaml +++ b/.github/workflows/test.yaml @@ -1,5 +1,15 @@ name: test +concurrency: + # For PRs, later CI runs preempt previous ones. e.g. a force push on a PR + # cancels running CI jobs and starts all new ones. + # + # For non-PR pushes, concurrency.group needs to be unique for every distinct + # CI run we want to have happen. Use run_id, which in practice means all + # non-PR CI runs will be allowed to run without preempting each other. + group: ${{ github.workflow }}-$${{ github.pull_request.number || github.run_id }} + cancel-in-progress: true + on: pull_request: paths: @@ -21,7 +31,9 @@ jobs: - id: changes run: | changed() { - git diff-tree -r --no-commit-id --name-only ${{ github.event.pull_request.base.sha }} ${{ github.event.pull_request.head.sha }} \ + git diff-tree -r --no-commit-id --name-only \ + $(git merge-base ${{ github.event.pull_request.base.sha }} ${{ github.event.pull_request.head.sha }}) \ + ${{ github.event.pull_request.head.sha }} \ | xargs python3 -c "import sys; print(any([x.startswith('$1') for x in sys.argv[1:]]))" } @@ -103,7 +115,9 @@ jobs: - uses: actions/upload-artifact@v4 with: name: cuda-${{ matrix.cuda-version }}-libraries - path: llm/build/**/bin/* + path: | + llm/build/**/bin/* + dist/windows-amd64/** generate-rocm: needs: [changes] if: ${{ needs.changes.outputs.GENERATE_ROCM == 'True' }} @@ -134,7 +148,9 @@ jobs: - uses: actions/upload-artifact@v4 with: name: rocm-${{ matrix.rocm-version }}-libraries - path: llm/build/**/bin/* + path: | + llm/build/**/bin/* + dist/windows-amd64/** # ROCm generation step generate-windows-rocm: @@ -253,14 +269,9 @@ jobs: mkdir -p llm/build/darwin/$ARCH/stub/bin touch llm/build/darwin/$ARCH/stub/bin/ollama_llama_server if: ${{ startsWith(matrix.os, 'macos-') }} - - run: | - mkdir -p llm/build/windows/$ARCH/stub/bin - touch llm/build/windows/$ARCH/stub/bin/ollama_llama_server - if: ${{ startsWith(matrix.os, 'windows-') }} - shell: bash - uses: golangci/golangci-lint-action@v4 with: - args: --timeout 8m0s + args: --timeout 8m0s -v test: strategy: matrix: @@ -284,7 +295,6 @@ jobs: with: go-version-file: go.mod cache: true - - run: go get - run: | case ${{ matrix.arch }} in amd64) echo ARCH=x86_64 ;; @@ -299,10 +309,6 @@ jobs: mkdir -p llm/build/darwin/$ARCH/stub/bin touch llm/build/darwin/$ARCH/stub/bin/ollama_llama_server if: ${{ startsWith(matrix.os, 'macos-') }} - - run: | - mkdir -p llm/build/windows/$ARCH/stub/bin - touch llm/build/windows/$ARCH/stub/bin/ollama_llama_server - if: ${{ startsWith(matrix.os, 'windows-') }} shell: bash - run: go generate ./... - run: go build diff --git a/.gitignore b/.gitignore index e0362a19..0d826ab6 100644 --- a/.gitignore +++ b/.gitignore @@ -11,4 +11,5 @@ ggml-metal.metal .idea test_data *.crt -llm/build \ No newline at end of file +llm/build +__debug_bin* \ No newline at end of file diff --git a/Dockerfile b/Dockerfile index 0698f749..72edef2a 100644 --- a/Dockerfile +++ b/Dockerfile @@ -18,7 +18,7 @@ ENV PATH /opt/rh/devtoolset-10/root/usr/bin:$PATH COPY --from=llm-code / /go/src/github.com/ollama/ollama/ WORKDIR /go/src/github.com/ollama/ollama/llm/generate ARG CGO_CFLAGS -RUN OLLAMA_SKIP_CPU_GENERATE=1 sh gen_linux.sh +RUN OLLAMA_SKIP_STATIC_GENERATE=1 OLLAMA_SKIP_CPU_GENERATE=1 sh gen_linux.sh FROM --platform=linux/arm64 nvidia/cuda:$CUDA_VERSION-devel-rockylinux8 AS cuda-build-arm64 ARG CMAKE_VERSION @@ -28,7 +28,7 @@ ENV PATH /opt/rh/gcc-toolset-10/root/usr/bin:$PATH COPY --from=llm-code / /go/src/github.com/ollama/ollama/ WORKDIR /go/src/github.com/ollama/ollama/llm/generate ARG CGO_CFLAGS -RUN OLLAMA_SKIP_CPU_GENERATE=1 sh gen_linux.sh +RUN OLLAMA_SKIP_STATIC_GENERATE=1 OLLAMA_SKIP_CPU_GENERATE=1 sh gen_linux.sh FROM --platform=linux/amd64 rocm/dev-centos-7:${ROCM_VERSION}-complete AS rocm-build-amd64 ARG CMAKE_VERSION @@ -40,7 +40,7 @@ COPY --from=llm-code / /go/src/github.com/ollama/ollama/ WORKDIR /go/src/github.com/ollama/ollama/llm/generate ARG CGO_CFLAGS ARG AMDGPU_TARGETS -RUN OLLAMA_SKIP_CPU_GENERATE=1 sh gen_linux.sh +RUN OLLAMA_SKIP_STATIC_GENERATE=1 OLLAMA_SKIP_CPU_GENERATE=1 sh gen_linux.sh RUN mkdir /tmp/scratch && \ for dep in $(zcat /go/src/github.com/ollama/ollama/llm/build/linux/x86_64/rocm*/bin/deps.txt.gz) ; do \ cp ${dep} /tmp/scratch/ || exit 1 ; \ @@ -64,11 +64,11 @@ WORKDIR /go/src/github.com/ollama/ollama/llm/generate FROM --platform=linux/amd64 cpu-builder-amd64 AS static-build-amd64 RUN OLLAMA_CPU_TARGET="static" sh gen_linux.sh FROM --platform=linux/amd64 cpu-builder-amd64 AS cpu-build-amd64 -RUN OLLAMA_CPU_TARGET="cpu" sh gen_linux.sh +RUN OLLAMA_SKIP_STATIC_GENERATE=1 OLLAMA_CPU_TARGET="cpu" sh gen_linux.sh FROM --platform=linux/amd64 cpu-builder-amd64 AS cpu_avx-build-amd64 -RUN OLLAMA_CPU_TARGET="cpu_avx" sh gen_linux.sh +RUN OLLAMA_SKIP_STATIC_GENERATE=1 OLLAMA_CPU_TARGET="cpu_avx" sh gen_linux.sh FROM --platform=linux/amd64 cpu-builder-amd64 AS cpu_avx2-build-amd64 -RUN OLLAMA_CPU_TARGET="cpu_avx2" sh gen_linux.sh +RUN OLLAMA_SKIP_STATIC_GENERATE=1 OLLAMA_CPU_TARGET="cpu_avx2" sh gen_linux.sh FROM --platform=linux/arm64 centos:7 AS cpu-builder-arm64 ARG CMAKE_VERSION @@ -84,7 +84,7 @@ WORKDIR /go/src/github.com/ollama/ollama/llm/generate FROM --platform=linux/arm64 cpu-builder-arm64 AS static-build-arm64 RUN OLLAMA_CPU_TARGET="static" sh gen_linux.sh FROM --platform=linux/arm64 cpu-builder-arm64 AS cpu-build-arm64 -RUN OLLAMA_CPU_TARGET="cpu" sh gen_linux.sh +RUN OLLAMA_SKIP_STATIC_GENERATE=1 OLLAMA_CPU_TARGET="cpu" sh gen_linux.sh # Intermediate stage used for ./scripts/build_linux.sh diff --git a/README.md b/README.md index 7eece163..4f980375 100644 --- a/README.md +++ b/README.md @@ -1,5 +1,5 @@
- ollamaollama
# Ollama @@ -35,10 +35,10 @@ The official [Ollama Docker image](https://hub.docker.com/r/ollama/ollama) `olla ## Quickstart -To run and chat with [Llama 2](https://ollama.com/library/llama2): +To run and chat with [Llama 3](https://ollama.com/library/llama3): ``` -ollama run llama2 +ollama run llama3 ``` ## Model library @@ -49,17 +49,14 @@ Here are some example models that can be downloaded: | Model | Parameters | Size | Download | | ------------------ | ---------- | ----- | ------------------------------ | -| Llama 2 | 7B | 3.8GB | `ollama run llama2` | +| Llama 3 | 8B | 4.7GB | `ollama run llama3` | +| Llama 3 | 70B | 40GB | `ollama run llama3:70b` | +| Phi-3 | 3.8B | 2.3GB | `ollama run phi3` | | Mistral | 7B | 4.1GB | `ollama run mistral` | -| Dolphin Phi | 2.7B | 1.6GB | `ollama run dolphin-phi` | -| Phi-2 | 2.7B | 1.7GB | `ollama run phi` | | Neural Chat | 7B | 4.1GB | `ollama run neural-chat` | | Starling | 7B | 4.1GB | `ollama run starling-lm` | | Code Llama | 7B | 3.8GB | `ollama run codellama` | | Llama 2 Uncensored | 7B | 3.8GB | `ollama run llama2-uncensored` | -| Llama 2 13B | 13B | 7.3GB | `ollama run llama2:13b` | -| Llama 2 70B | 70B | 39GB | `ollama run llama2:70b` | -| Orca Mini | 3B | 1.9GB | `ollama run orca-mini` | | LLaVA | 7B | 4.5GB | `ollama run llava` | | Gemma | 2B | 1.4GB | `ollama run gemma:2b` | | Gemma | 7B | 4.8GB | `ollama run gemma:7b` | @@ -97,16 +94,16 @@ See the [guide](docs/import.md) on importing models for more information. ### Customize a prompt -Models from the Ollama library can be customized with a prompt. For example, to customize the `llama2` model: +Models from the Ollama library can be customized with a prompt. For example, to customize the `llama3` model: ``` -ollama pull llama2 +ollama pull llama3 ``` Create a `Modelfile`: ``` -FROM llama2 +FROM llama3 # set the temperature to 1 [higher is more creative, lower is more coherent] PARAMETER temperature 1 @@ -141,7 +138,7 @@ ollama create mymodel -f ./Modelfile ### Pull a model ``` -ollama pull llama2 +ollama pull llama3 ``` > This command can also be used to update a local model. Only the diff will be pulled. @@ -149,13 +146,13 @@ ollama pull llama2 ### Remove a model ``` -ollama rm llama2 +ollama rm llama3 ``` ### Copy a model ``` -ollama cp llama2 my-llama2 +ollama cp llama3 my-model ``` ### Multiline input @@ -176,10 +173,10 @@ I'm a basic program that prints the famous "Hello, world!" message to the consol The image features a yellow smiley face, which is likely the central focus of the picture. ``` -### Pass in prompt as arguments +### Pass the prompt as an argument ``` -$ ollama run llama2 "Summarize this file: $(cat README.md)" +$ ollama run llama3 "Summarize this file: $(cat README.md)" Ollama is a lightweight, extensible framework for building and running language models on the local machine. It provides a simple API for creating, running, and managing models, as well as a library of pre-built models that can be easily used in a variety of applications. ``` @@ -226,7 +223,7 @@ Next, start the server: Finally, in a separate shell, run a model: ``` -./ollama run llama2 +./ollama run llama3 ``` ## REST API @@ -237,7 +234,7 @@ Ollama has a REST API for running and managing models. ``` curl http://localhost:11434/api/generate -d '{ - "model": "llama2", + "model": "llama3", "prompt":"Why is the sky blue?" }' ``` @@ -246,7 +243,7 @@ curl http://localhost:11434/api/generate -d '{ ``` curl http://localhost:11434/api/chat -d '{ - "model": "mistral", + "model": "llama3", "messages": [ { "role": "user", "content": "why is the sky blue?" } ] @@ -259,16 +256,18 @@ See the [API documentation](./docs/api.md) for all endpoints. ### Web & Desktop +- [Open WebUI](https://github.com/open-webui/open-webui) +- [Enchanted (macOS native)](https://github.com/AugustDev/enchanted) +- [Hollama](https://github.com/fmaclen/hollama) - [Lollms-Webui](https://github.com/ParisNeo/lollms-webui) - [LibreChat](https://github.com/danny-avila/LibreChat) - [Bionic GPT](https://github.com/bionic-gpt/bionic-gpt) -- [Enchanted (macOS native)](https://github.com/AugustDev/enchanted) - [HTML UI](https://github.com/rtcfirefly/ollama-ui) - [Saddle](https://github.com/jikkuatwork/saddle) - [Chatbot UI](https://github.com/ivanfioravanti/chatbot-ollama) +- [Chatbot UI v2](https://github.com/mckaywrigley/chatbot-ui) - [Typescript UI](https://github.com/ollama-interface/Ollama-Gui?tab=readme-ov-file) - [Minimalistic React UI for Ollama Models](https://github.com/richawo/minimal-llm-ui) -- [Open WebUI](https://github.com/open-webui/open-webui) - [Ollamac](https://github.com/kevinhermawan/Ollamac) - [big-AGI](https://github.com/enricoros/big-AGI/blob/main/docs/config-local-ollama.md) - [Cheshire Cat assistant framework](https://github.com/cheshire-cat-ai/core) @@ -286,13 +285,20 @@ See the [API documentation](./docs/api.md) for all endpoints. - [OllamaGUI](https://github.com/enoch1118/ollamaGUI) - [OpenAOE](https://github.com/InternLM/OpenAOE) - [Odin Runes](https://github.com/leonid20000/OdinRunes) -- [LLM-X: Progressive Web App](https://github.com/mrdjohnson/llm-x) +- [LLM-X](https://github.com/mrdjohnson/llm-x) (Progressive Web App) - [AnythingLLM (Docker + MacOs/Windows/Linux native app)](https://github.com/Mintplex-Labs/anything-llm) - [Ollama Basic Chat: Uses HyperDiv Reactive UI](https://github.com/rapidarchitect/ollama_basic_chat) - [Ollama-chats RPG](https://github.com/drazdra/ollama-chats) -- [ChatOllama: Open Source Chatbot based on Ollama with Knowledge Bases](https://github.com/sugarforever/chat-ollama) -- [CRAG Ollama Chat: Simple Web Search with Corrective RAG](https://github.com/Nagi-ovo/CRAG-Ollama-Chat) -- [RAGFlow: Open-source Retrieval-Augmented Generation engine based on deep document understanding](https://github.com/infiniflow/ragflow) +- [QA-Pilot](https://github.com/reid41/QA-Pilot) (Chat with Code Repository) +- [ChatOllama](https://github.com/sugarforever/chat-ollama) (Open Source Chatbot based on Ollama with Knowledge Bases) +- [CRAG Ollama Chat](https://github.com/Nagi-ovo/CRAG-Ollama-Chat) (Simple Web Search with Corrective RAG) +- [RAGFlow](https://github.com/infiniflow/ragflow) (Open-source Retrieval-Augmented Generation engine based on deep document understanding) +- [StreamDeploy](https://github.com/StreamDeploy-DevRel/streamdeploy-llm-app-scaffold) (LLM Application Scaffold) +- [chat](https://github.com/swuecho/chat) (chat web app for teams) +- [Lobe Chat](https://github.com/lobehub/lobe-chat) with [Integrating Doc](https://lobehub.com/docs/self-hosting/examples/ollama) +- [Ollama RAG Chatbot](https://github.com/datvodinh/rag-chatbot.git) (Local Chat with multiple PDFs using Ollama and RAG) +- [BrainSoup](https://www.nurgo-software.com/products/brainsoup) (Flexible native client with RAG & multi-agent automation) +- [macai](https://github.com/Renset/macai) (macOS client for Ollama, ChatGPT, and other compatible API back-ends) ### Terminal @@ -308,11 +314,13 @@ See the [API documentation](./docs/api.md) for all endpoints. - [Oatmeal](https://github.com/dustinblackman/oatmeal) - [cmdh](https://github.com/pgibler/cmdh) - [ooo](https://github.com/npahlfer/ooo) +- [shell-pilot](https://github.com/reid41/shell-pilot) - [tenere](https://github.com/pythops/tenere) - [llm-ollama](https://github.com/taketwo/llm-ollama) for [Datasette's LLM CLI](https://llm.datasette.io/en/stable/). - [typechat-cli](https://github.com/anaisbetts/typechat-cli) - [ShellOracle](https://github.com/djcopley/ShellOracle) - [tlm](https://github.com/yusufcanb/tlm) +- [podman-ollama](https://github.com/ericcurtin/podman-ollama) ### Database @@ -344,9 +352,11 @@ See the [API documentation](./docs/api.md) for all endpoints. - [Haystack](https://github.com/deepset-ai/haystack-integrations/blob/main/integrations/ollama.md) - [Elixir LangChain](https://github.com/brainlid/langchain) - [Ollama for R - rollama](https://github.com/JBGruber/rollama) +- [Ollama for R - ollama-r](https://github.com/hauselin/ollama-r) - [Ollama-ex for Elixir](https://github.com/lebrunel/ollama-ex) - [Ollama Connector for SAP ABAP](https://github.com/b-tocs/abap_btocs_ollama) - [Testcontainers](https://testcontainers.com/modules/ollama/) +- [Portkey](https://portkey.ai/docs/welcome/integration-guides/ollama) ### Mobile @@ -366,17 +376,20 @@ See the [API documentation](./docs/api.md) for all endpoints. - [Ollama Telegram Bot](https://github.com/ruecat/ollama-telegram) - [Hass Ollama Conversation](https://github.com/ej52/hass-ollama-conversation) - [Rivet plugin](https://github.com/abrenneke/rivet-plugin-ollama) -- [Llama Coder](https://github.com/ex3ndr/llama-coder) (Copilot alternative using Ollama) - [Obsidian BMO Chatbot plugin](https://github.com/longy2k/obsidian-bmo-chatbot) - [Cliobot](https://github.com/herval/cliobot) (Telegram bot with Ollama support) - [Copilot for Obsidian plugin](https://github.com/logancyang/obsidian-copilot) - [Obsidian Local GPT plugin](https://github.com/pfrankov/obsidian-local-gpt) - [Open Interpreter](https://docs.openinterpreter.com/language-model-setup/local-models/ollama) +- [Llama Coder](https://github.com/ex3ndr/llama-coder) (Copilot alternative using Ollama) +- [Ollama Copilot](https://github.com/bernardo-bruning/ollama-copilot) (Proxy that allows you to use ollama as a copilot like Github copilot) - [twinny](https://github.com/rjmacarthy/twinny) (Copilot and Copilot chat alternative using Ollama) - [Wingman-AI](https://github.com/RussellCanfield/wingman-ai) (Copilot code and chat alternative using Ollama and HuggingFace) - [Page Assist](https://github.com/n4ze3m/page-assist) (Chrome Extension) - [AI Telegram Bot](https://github.com/tusharhero/aitelegrambot) (Telegram bot using Ollama in backend) - [AI ST Completion](https://github.com/yaroslavyaroslav/OpenAI-sublime-text) (Sublime Text 4 AI assistant plugin with Ollama support) +- [Discord-Ollama Chat Bot](https://github.com/kevinthedang/discord-ollama) (Generalized TypeScript Discord Bot w/ Tuning Documentation) ### Supported backends -- [llama.cpp](https://github.com/ggerganov/llama.cpp) project founded by Georgi Gerganov. +- [llama.cpp](https://github.com/ggerganov/llama.cpp) project founded by Georgi Gerganov. + diff --git a/api/client.go b/api/client.go index a1ebdcd4..5b1fc796 100644 --- a/api/client.go +++ b/api/client.go @@ -1,9 +1,16 @@ // Package api implements the client-side API for code wishing to interact // with the ollama service. The methods of the [Client] type correspond to -// the ollama REST API as described in https://github.com/ollama/ollama/blob/main/docs/api.md -// +// the ollama REST API as described in [the API documentation]. // The ollama command-line client itself uses this package to interact with // the backend service. +// +// # Examples +// +// Several examples of using this package are available [in the GitHub +// repository]. +// +// [the API documentation]: https://github.com/ollama/ollama/blob/main/docs/api.md +// [in the GitHub repository]: https://github.com/ollama/ollama/tree/main/examples package api import ( @@ -18,6 +25,7 @@ import ( "net/url" "os" "runtime" + "strconv" "strings" "github.com/ollama/ollama/format" @@ -57,12 +65,36 @@ func checkError(resp *http.Response, body []byte) error { // If the variable is not specified, a default ollama host and port will be // used. func ClientFromEnvironment() (*Client, error) { + ollamaHost, err := GetOllamaHost() + if err != nil { + return nil, err + } + + return &Client{ + base: &url.URL{ + Scheme: ollamaHost.Scheme, + Host: net.JoinHostPort(ollamaHost.Host, ollamaHost.Port), + }, + http: http.DefaultClient, + }, nil +} + +type OllamaHost struct { + Scheme string + Host string + Port string +} + +func GetOllamaHost() (OllamaHost, error) { defaultPort := "11434" - scheme, hostport, ok := strings.Cut(os.Getenv("OLLAMA_HOST"), "://") + hostVar := os.Getenv("OLLAMA_HOST") + hostVar = strings.TrimSpace(strings.Trim(strings.TrimSpace(hostVar), "\"'")) + + scheme, hostport, ok := strings.Cut(hostVar, "://") switch { case !ok: - scheme, hostport = "http", os.Getenv("OLLAMA_HOST") + scheme, hostport = "http", hostVar case scheme == "http": defaultPort = "80" case scheme == "https": @@ -82,15 +114,24 @@ func ClientFromEnvironment() (*Client, error) { } } - return &Client{ - base: &url.URL{ - Scheme: scheme, - Host: net.JoinHostPort(host, port), - }, - http: http.DefaultClient, + if portNum, err := strconv.ParseInt(port, 10, 32); err != nil || portNum > 65535 || portNum < 0 { + return OllamaHost{}, ErrInvalidHostPort + } + + return OllamaHost{ + Scheme: scheme, + Host: host, + Port: port, }, nil } +func NewClient(base *url.URL, http *http.Client) *Client { + return &Client{ + base: base, + http: http, + } +} + func (c *Client) do(ctx context.Context, method, path string, reqData, respData any) error { var reqBody io.Reader var data []byte @@ -265,8 +306,14 @@ func (c *Client) Pull(ctx context.Context, req *PullRequest, fn PullProgressFunc }) } +// PushProgressFunc is a function that [Client.Push] invokes when progress is +// made. +// It's similar to other progress function types like [PullProgressFunc]. type PushProgressFunc func(ProgressResponse) error +// Push uploads a model to the model library; requires registering for ollama.ai +// and adding a public key first. fn is called each time progress is made on +// the request and can be used to display a progress bar, etc. func (c *Client) Push(ctx context.Context, req *PushRequest, fn PushProgressFunc) error { return c.stream(ctx, http.MethodPost, "/api/push", req, func(bts []byte) error { var resp ProgressResponse @@ -278,8 +325,15 @@ func (c *Client) Push(ctx context.Context, req *PushRequest, fn PushProgressFunc }) } +// CreateProgressFunc is a function that [Client.Create] invokes when progress +// is made. +// It's similar to other progress function types like [PullProgressFunc]. type CreateProgressFunc func(ProgressResponse) error +// Create creates a model from a [Modelfile]. fn is a progress function that +// behaves similarly to other methods (see [Client.Pull]). +// +// [Modelfile]: https://github.com/ollama/ollama/blob/main/docs/modelfile.md func (c *Client) Create(ctx context.Context, req *CreateRequest, fn CreateProgressFunc) error { return c.stream(ctx, http.MethodPost, "/api/create", req, func(bts []byte) error { var resp ProgressResponse @@ -291,6 +345,7 @@ func (c *Client) Create(ctx context.Context, req *CreateRequest, fn CreateProgre }) } +// List lists models that are available locally. func (c *Client) List(ctx context.Context) (*ListResponse, error) { var lr ListResponse if err := c.do(ctx, http.MethodGet, "/api/tags", nil, &lr); err != nil { @@ -299,6 +354,8 @@ func (c *Client) List(ctx context.Context) (*ListResponse, error) { return &lr, nil } +// Copy copies a model - creating a model with another name from an existing +// model. func (c *Client) Copy(ctx context.Context, req *CopyRequest) error { if err := c.do(ctx, http.MethodPost, "/api/copy", req, nil); err != nil { return err @@ -306,6 +363,7 @@ func (c *Client) Copy(ctx context.Context, req *CopyRequest) error { return nil } +// Delete deletes a model and its data. func (c *Client) Delete(ctx context.Context, req *DeleteRequest) error { if err := c.do(ctx, http.MethodDelete, "/api/delete", req, nil); err != nil { return err @@ -313,6 +371,7 @@ func (c *Client) Delete(ctx context.Context, req *DeleteRequest) error { return nil } +// Show obtains model information, including details, modelfile, license etc. func (c *Client) Show(ctx context.Context, req *ShowRequest) (*ShowResponse, error) { var resp ShowResponse if err := c.do(ctx, http.MethodPost, "/api/show", req, &resp); err != nil { @@ -321,12 +380,16 @@ func (c *Client) Show(ctx context.Context, req *ShowRequest) (*ShowResponse, err return &resp, nil } +// Hearbeat checks if the server has started and is responsive; if yes, it +// returns nil, otherwise an error. func (c *Client) Heartbeat(ctx context.Context) error { if err := c.do(ctx, http.MethodHead, "/", nil, nil); err != nil { return err } return nil } + +// Embeddings generates embeddings from a model. func (c *Client) Embeddings(ctx context.Context, req *EmbeddingRequest) (*EmbeddingResponse, error) { var resp EmbeddingResponse if err := c.do(ctx, http.MethodPost, "/api/embeddings", req, &resp); err != nil { @@ -335,10 +398,13 @@ func (c *Client) Embeddings(ctx context.Context, req *EmbeddingRequest) (*Embedd return &resp, nil } +// CreateBlob creates a blob from a file on the server. digest is the +// expected SHA256 digest of the file, and r represents the file. func (c *Client) CreateBlob(ctx context.Context, digest string, r io.Reader) error { return c.do(ctx, http.MethodPost, fmt.Sprintf("/api/blobs/%s", digest), r, nil) } +// Version returns the Ollama server version as a string. func (c *Client) Version(ctx context.Context) (string, error) { var version struct { Version string `json:"version"` diff --git a/api/client_test.go b/api/client_test.go index 0eafedca..b2c51d00 100644 --- a/api/client_test.go +++ b/api/client_test.go @@ -1,6 +1,12 @@ package api -import "testing" +import ( + "fmt" + "net" + "testing" + + "github.com/stretchr/testify/assert" +) func TestClientFromEnvironment(t *testing.T) { type testCase struct { @@ -40,4 +46,40 @@ func TestClientFromEnvironment(t *testing.T) { } }) } + + hostTestCases := map[string]*testCase{ + "empty": {value: "", expect: "127.0.0.1:11434"}, + "only address": {value: "1.2.3.4", expect: "1.2.3.4:11434"}, + "only port": {value: ":1234", expect: ":1234"}, + "address and port": {value: "1.2.3.4:1234", expect: "1.2.3.4:1234"}, + "hostname": {value: "example.com", expect: "example.com:11434"}, + "hostname and port": {value: "example.com:1234", expect: "example.com:1234"}, + "zero port": {value: ":0", expect: ":0"}, + "too large port": {value: ":66000", err: ErrInvalidHostPort}, + "too small port": {value: ":-1", err: ErrInvalidHostPort}, + "ipv6 localhost": {value: "[::1]", expect: "[::1]:11434"}, + "ipv6 world open": {value: "[::]", expect: "[::]:11434"}, + "ipv6 no brackets": {value: "::1", expect: "[::1]:11434"}, + "ipv6 + port": {value: "[::1]:1337", expect: "[::1]:1337"}, + "extra space": {value: " 1.2.3.4 ", expect: "1.2.3.4:11434"}, + "extra quotes": {value: "\"1.2.3.4\"", expect: "1.2.3.4:11434"}, + "extra space+quotes": {value: " \" 1.2.3.4 \" ", expect: "1.2.3.4:11434"}, + "extra single quotes": {value: "'1.2.3.4'", expect: "1.2.3.4:11434"}, + } + + for k, v := range hostTestCases { + t.Run(k, func(t *testing.T) { + t.Setenv("OLLAMA_HOST", v.value) + + oh, err := GetOllamaHost() + if err != v.err { + t.Fatalf("expected %s, got %s", v.err, err) + } + + if err == nil { + host := net.JoinHostPort(oh.Host, oh.Port) + assert.Equal(t, v.expect, host, fmt.Sprintf("%s: expected %s, got %s", k, v.expect, host)) + } + }) + } } diff --git a/api/types.go b/api/types.go index 2e762b39..5d0212e5 100644 --- a/api/types.go +++ b/api/types.go @@ -2,6 +2,7 @@ package api import ( "encoding/json" + "errors" "fmt" "math" "os" @@ -11,6 +12,7 @@ import ( "time" ) +// StatusError is an error with and HTTP status code. type StatusError struct { StatusCode int Status string @@ -31,6 +33,7 @@ func (e StatusError) Error() string { } } +// ImageData represents the raw binary data of an image file. type ImageData []byte // GenerateRequest describes a request sent by [Client.Generate]. While you @@ -76,22 +79,39 @@ type GenerateRequest struct { Options map[string]interface{} `json:"options"` } +// ChatRequest describes a request sent by [Client.Chat]. type ChatRequest struct { - Model string `json:"model"` - Messages []Message `json:"messages"` - Stream *bool `json:"stream,omitempty"` - Format string `json:"format"` + // Model is the model name, as in [GenerateRequest]. + Model string `json:"model"` + + // Messages is the messages of the chat - can be used to keep a chat memory. + Messages []Message `json:"messages"` + + // Stream enable streaming of returned response; true by default. + Stream *bool `json:"stream,omitempty"` + + // Format is the format to return the response in (e.g. "json"). + Format string `json:"format"` + + // KeepAlive controls how long the model will stay loaded into memory + // followin the request. KeepAlive *Duration `json:"keep_alive,omitempty"` + // Options lists model-specific options. Options map[string]interface{} `json:"options"` } +// Message is a single message in a chat sequence. The message contains the +// role ("system", "user", or "assistant"), the content and an optional list +// of images. type Message struct { - Role string `json:"role"` // one of ["system", "user", "assistant"] + Role string `json:"role"` Content string `json:"content"` Images []ImageData `json:"images,omitempty"` } +// ChatResponse is the response returned by [Client.Chat]. Its fields are +// similar to [GenerateResponse]. type ChatResponse struct { Model string `json:"model"` CreatedAt time.Time `json:"created_at"` @@ -111,7 +131,8 @@ type Metrics struct { EvalDuration time.Duration `json:"eval_duration,omitempty"` } -// Options specified in GenerateRequest, if you add a new option here add it to the API docs also +// Options specified in [GenerateRequest], if you add a new option here add it +// to the API docs also. type Options struct { Runner @@ -157,18 +178,28 @@ type Runner struct { RopeFrequencyScale float32 `json:"rope_frequency_scale,omitempty"` } +// EmbeddingRequest is the request passed to [Client.Embeddings]. type EmbeddingRequest struct { - Model string `json:"model"` - Prompt string `json:"prompt"` + // Model is the model name. + Model string `json:"model"` + + // Prompt is the textual prompt to embed. + Prompt string `json:"prompt"` + + // KeepAlive controls how long the model will stay loaded in memory following + // this request. KeepAlive *Duration `json:"keep_alive,omitempty"` + // Options lists model-specific options. Options map[string]interface{} `json:"options"` } +// EmbeddingResponse is the response from [Client.Embeddings]. type EmbeddingResponse struct { Embedding []float64 `json:"embedding"` } +// CreateRequest is the request passed to [Client.Create]. type CreateRequest struct { Model string `json:"model"` Path string `json:"path"` @@ -180,6 +211,7 @@ type CreateRequest struct { Name string `json:"name"` } +// DeleteRequest is the request passed to [Client.Delete]. type DeleteRequest struct { Model string `json:"model"` @@ -187,6 +219,7 @@ type DeleteRequest struct { Name string `json:"name"` } +// ShowRequest is the request passed to [Client.Show]. type ShowRequest struct { Model string `json:"model"` System string `json:"system"` @@ -198,6 +231,7 @@ type ShowRequest struct { Name string `json:"name"` } +// ShowResponse is the response returned from [Client.Show]. type ShowResponse struct { License string `json:"license,omitempty"` Modelfile string `json:"modelfile,omitempty"` @@ -208,11 +242,13 @@ type ShowResponse struct { Messages []Message `json:"messages,omitempty"` } +// CopyRequest is the request passed to [Client.Copy]. type CopyRequest struct { Source string `json:"source"` Destination string `json:"destination"` } +// PullRequest is the request passed to [Client.Pull]. type PullRequest struct { Model string `json:"model"` Insecure bool `json:"insecure,omitempty"` @@ -224,6 +260,8 @@ type PullRequest struct { Name string `json:"name"` } +// ProgressResponse is the response passed to progress functions like +// [PullProgressFunc] and [PushProgressFunc]. type ProgressResponse struct { Status string `json:"status"` Digest string `json:"digest,omitempty"` @@ -231,6 +269,7 @@ type ProgressResponse struct { Completed int64 `json:"completed,omitempty"` } +// PushRequest is the request passed to [Client.Push]. type PushRequest struct { Model string `json:"model"` Insecure bool `json:"insecure,omitempty"` @@ -242,10 +281,12 @@ type PushRequest struct { Name string `json:"name"` } +// ListResponse is the response from [Client.List]. type ListResponse struct { Models []ModelResponse `json:"models"` } +// ModelResponse is a single model description in [ListResponse]. type ModelResponse struct { Name string `json:"name"` Model string `json:"model"` @@ -259,17 +300,28 @@ type TokenResponse struct { Token string `json:"token"` } +// GenerateResponse is the response passed into [GenerateResponseFunc]. type GenerateResponse struct { - Model string `json:"model"` - CreatedAt time.Time `json:"created_at"` - Response string `json:"response"` + // Model is the model name that generated the response. + Model string `json:"model"` - Done bool `json:"done"` + //CreatedAt is the timestamp of the response. + CreatedAt time.Time `json:"created_at"` + + // Response is the textual response itself. + Response string `json:"response"` + + // Done specifies if the response is complete. + Done bool `json:"done"` + + // Context is an encoding of the conversation used in this response; this + // can be sent in the next request to keep a conversational memory. Context []int `json:"context,omitempty"` Metrics } +// ModelDetails provides details about a model. type ModelDetails struct { ParentModel string `json:"parent_model"` Format string `json:"format"` @@ -307,7 +359,9 @@ func (m *Metrics) Summary() { } } -var ErrInvalidOpts = fmt.Errorf("invalid options") +// ErrInvalidOpts is returned when invalid options are passed to the client. +var ErrInvalidOpts = errors.New("invalid options") +var ErrInvalidHostPort = errors.New("invalid port specified in OLLAMA_HOST") func (opts *Options) FromMap(m map[string]interface{}) error { valueOpts := reflect.ValueOf(opts).Elem() // names of the fields in the options struct @@ -392,11 +446,15 @@ func (opts *Options) FromMap(m map[string]interface{}) error { return nil } +// DefaultOptions is the default set of options for [GenerateRequest]; these +// values are used unless the user specifies other values explicitly. func DefaultOptions() Options { return Options{ // options set on request to runner - NumPredict: -1, - NumKeep: 0, + NumPredict: -1, + + // set a minimal num_keep to avoid issues on context shifts + NumKeep: 4, Temperature: 0.8, TopK: 40, TopP: 0.9, @@ -432,6 +490,13 @@ type Duration struct { time.Duration } +func (d Duration) MarshalJSON() ([]byte, error) { + if d.Duration < 0 { + return []byte("-1"), nil + } + return []byte("\"" + d.Duration.String() + "\""), nil +} + func (d *Duration) UnmarshalJSON(b []byte) (err error) { var v any if err := json.Unmarshal(b, &v); err != nil { @@ -445,7 +510,7 @@ func (d *Duration) UnmarshalJSON(b []byte) (err error) { if t < 0 { d.Duration = time.Duration(math.MaxInt64) } else { - d.Duration = time.Duration(t * float64(time.Second)) + d.Duration = time.Duration(int(t) * int(time.Second)) } case string: d.Duration, err = time.ParseDuration(t) @@ -455,6 +520,8 @@ func (d *Duration) UnmarshalJSON(b []byte) (err error) { if d.Duration < 0 { d.Duration = time.Duration(math.MaxInt64) } + default: + return fmt.Errorf("Unsupported type: '%s'", reflect.TypeOf(v)) } return nil diff --git a/api/types_test.go b/api/types_test.go index 5a093be2..cfe1331f 100644 --- a/api/types_test.go +++ b/api/types_test.go @@ -21,6 +21,11 @@ func TestKeepAliveParsingFromJSON(t *testing.T) { req: `{ "keep_alive": 42 }`, exp: &Duration{42 * time.Second}, }, + { + name: "Positive Float", + req: `{ "keep_alive": 42.5 }`, + exp: &Duration{42 * time.Second}, + }, { name: "Positive Integer String", req: `{ "keep_alive": "42m" }`, @@ -31,6 +36,11 @@ func TestKeepAliveParsingFromJSON(t *testing.T) { req: `{ "keep_alive": -1 }`, exp: &Duration{math.MaxInt64}, }, + { + name: "Negative Float", + req: `{ "keep_alive": -3.14 }`, + exp: &Duration{math.MaxInt64}, + }, { name: "Negative Integer String", req: `{ "keep_alive": "-1m" }`, @@ -48,3 +58,50 @@ func TestKeepAliveParsingFromJSON(t *testing.T) { }) } } + +func TestDurationMarshalUnmarshal(t *testing.T) { + tests := []struct { + name string + input time.Duration + expected time.Duration + }{ + { + "negative duration", + time.Duration(-1), + time.Duration(math.MaxInt64), + }, + { + "positive duration", + time.Duration(42 * time.Second), + time.Duration(42 * time.Second), + }, + { + "another positive duration", + time.Duration(42 * time.Minute), + time.Duration(42 * time.Minute), + }, + { + "zero duration", + time.Duration(0), + time.Duration(0), + }, + { + "max duration", + time.Duration(math.MaxInt64), + time.Duration(math.MaxInt64), + }, + } + + for _, test := range tests { + t.Run(test.name, func(t *testing.T) { + b, err := json.Marshal(Duration{test.input}) + require.NoError(t, err) + + var d Duration + err = json.Unmarshal(b, &d) + require.NoError(t, err) + + assert.Equal(t, test.expected, d.Duration, "input %v, marshalled %v, got %v", test.input, string(b), d.Duration) + }) + } +} diff --git a/app/lifecycle/logging.go b/app/lifecycle/logging.go index 98df9b41..4be90648 100644 --- a/app/lifecycle/logging.go +++ b/app/lifecycle/logging.go @@ -5,12 +5,14 @@ import ( "log/slog" "os" "path/filepath" + + "github.com/ollama/ollama/server/envconfig" ) func InitLogging() { level := slog.LevelInfo - if debug := os.Getenv("OLLAMA_DEBUG"); debug != "" { + if envconfig.Debug { level = slog.LevelDebug } diff --git a/app/lifecycle/server.go b/app/lifecycle/server.go index 8680e7bc..3c11edb8 100644 --- a/app/lifecycle/server.go +++ b/app/lifecycle/server.go @@ -43,37 +43,36 @@ func getCLIFullPath(command string) string { return command } -func SpawnServer(ctx context.Context, command string) (chan int, error) { - done := make(chan int) - - logDir := filepath.Dir(ServerLogFile) - _, err := os.Stat(logDir) - if errors.Is(err, os.ErrNotExist) { - if err := os.MkdirAll(logDir, 0o755); err != nil { - return done, fmt.Errorf("create ollama server log dir %s: %v", logDir, err) - } - } - +func start(ctx context.Context, command string) (*exec.Cmd, error) { cmd := getCmd(ctx, getCLIFullPath(command)) - // send stdout and stderr to a file stdout, err := cmd.StdoutPipe() if err != nil { - return done, fmt.Errorf("failed to spawn server stdout pipe %s", err) + return nil, fmt.Errorf("failed to spawn server stdout pipe: %w", err) } stderr, err := cmd.StderrPipe() if err != nil { - return done, fmt.Errorf("failed to spawn server stderr pipe %s", err) - } - stdin, err := cmd.StdinPipe() - if err != nil { - return done, fmt.Errorf("failed to spawn server stdin pipe %s", err) + return nil, fmt.Errorf("failed to spawn server stderr pipe: %w", err) } // TODO - rotation logFile, err := os.OpenFile(ServerLogFile, os.O_APPEND|os.O_WRONLY|os.O_CREATE, 0755) if err != nil { - return done, fmt.Errorf("failed to create server log %w", err) + return nil, fmt.Errorf("failed to create server log: %w", err) } + + logDir := filepath.Dir(ServerLogFile) + _, err = os.Stat(logDir) + if err != nil { + if !errors.Is(err, os.ErrNotExist) { + return nil, fmt.Errorf("stat ollama server log dir %s: %v", logDir, err) + + } + + if err := os.MkdirAll(logDir, 0o755); err != nil { + return nil, fmt.Errorf("create ollama server log dir %s: %v", logDir, err) + } + } + go func() { defer logFile.Close() io.Copy(logFile, stdout) //nolint:errcheck @@ -117,19 +116,33 @@ func SpawnServer(ctx context.Context, command string) (chan int, error) { // run the command and wait for it to finish if err := cmd.Start(); err != nil { - return done, fmt.Errorf("failed to start server %w", err) + return nil, fmt.Errorf("failed to start server %w", err) } if cmd.Process != nil { slog.Info(fmt.Sprintf("started ollama server with pid %d", cmd.Process.Pid)) } slog.Info(fmt.Sprintf("ollama server logs %s", ServerLogFile)) + return cmd, nil +} + +func SpawnServer(ctx context.Context, command string) (chan int, error) { + done := make(chan int) + go func() { // Keep the server running unless we're shuttind down the app crashCount := 0 for { + slog.Info("starting server...") + cmd, err := start(ctx, command) + if err != nil { + crashCount++ + slog.Error(fmt.Sprintf("failed to start server %s", err)) + time.Sleep(500 * time.Millisecond * time.Duration(crashCount)) + continue + } + cmd.Wait() //nolint:errcheck - stdin.Close() var code int if cmd.ProcessState != nil { code = cmd.ProcessState.ExitCode() @@ -143,15 +156,12 @@ func SpawnServer(ctx context.Context, command string) (chan int, error) { default: crashCount++ slog.Warn(fmt.Sprintf("server crash %d - exit code %d - respawning", crashCount, code)) - time.Sleep(500 * time.Millisecond) - if err := cmd.Start(); err != nil { - slog.Error(fmt.Sprintf("failed to restart server %s", err)) - // Keep trying, but back off if we keep failing - time.Sleep(time.Duration(crashCount) * time.Second) - } + time.Sleep(500 * time.Millisecond * time.Duration(crashCount)) + break } } }() + return done, nil } diff --git a/app/lifecycle/updater_windows.go b/app/lifecycle/updater_windows.go index f26c43c9..4053671a 100644 --- a/app/lifecycle/updater_windows.go +++ b/app/lifecycle/updater_windows.go @@ -31,16 +31,13 @@ func DoUpgrade(cancel context.CancelFunc, done chan int) error { "/LOG=" + filepath.Base(UpgradeLogFile), // Only relative seems reliable, so set pwd "/FORCECLOSEAPPLICATIONS", // Force close the tray app - might be needed } - // When we're not in debug mode, make the upgrade as quiet as possible (no GUI, no prompts) - // TODO - temporarily disable since we're pinning in debug mode for the preview - // if debug := os.Getenv("OLLAMA_DEBUG"); debug == "" { + // make the upgrade as quiet as possible (no GUI, no prompts) installArgs = append(installArgs, "/SP", // Skip the "This will install... Do you wish to continue" prompt "/SUPPRESSMSGBOXES", "/SILENT", "/VERYSILENT", ) - // } // Safeguard in case we have requests in flight that need to drain... slog.Info("Waiting for server to shutdown") diff --git a/app/ollama.iss b/app/ollama.iss index 8f46223b..9dc61abb 100644 --- a/app/ollama.iss +++ b/app/ollama.iss @@ -88,15 +88,12 @@ DialogFontSize=12 [Files] Source: ".\app.exe"; DestDir: "{app}"; DestName: "{#MyAppExeName}" ; Flags: ignoreversion 64bit Source: "..\ollama.exe"; DestDir: "{app}"; Flags: ignoreversion 64bit -Source: "..\dist\windeps\*.dll"; DestDir: "{app}"; Flags: ignoreversion 64bit +Source: "..\dist\windows-{#ARCH}\*.dll"; DestDir: "{app}"; Flags: ignoreversion 64bit +Source: "..\dist\windows-{#ARCH}\ollama_runners\*"; DestDir: "{app}\ollama_runners"; Flags: ignoreversion 64bit recursesubdirs Source: "..\dist\ollama_welcome.ps1"; DestDir: "{app}"; Flags: ignoreversion Source: ".\assets\app.ico"; DestDir: "{app}"; Flags: ignoreversion -; Assumes v5.7, may need adjustments for v6 -#if GetEnv("HIP_PATH") != "" - Source: "{#GetEnv('HIP_PATH')}\bin\hipblas.dll"; DestDir: "{app}\rocm\"; Flags: ignoreversion - Source: "{#GetEnv('HIP_PATH')}\bin\rocblas.dll"; DestDir: "{app}\rocm\"; Flags: ignoreversion - ; amdhip64.dll dependency comes from the driver and must be installed already - Source: "{#GetEnv('HIP_PATH')}\bin\rocblas\library\*"; DestDir: "{app}\rocm\rocblas\library\"; Flags: ignoreversion +#if DirExists("..\dist\windows-amd64\rocm") + Source: "..\dist\windows-amd64\rocm\*"; DestDir: "{app}\rocm\"; Flags: ignoreversion recursesubdirs #endif @@ -132,7 +129,7 @@ SetupAppRunningError=Another Ollama installer is running.%n%nPlease cancel or fi ;FinishedHeadingLabel=Run your first model -;FinishedLabel=%nRun this command in a PowerShell or cmd terminal.%n%n%n ollama run llama2 +;FinishedLabel=%nRun this command in a PowerShell or cmd terminal.%n%n%n ollama run llama3 ;ClickFinish=%n [Registry] diff --git a/app/tray/wintray/menus.go b/app/tray/wintray/menus.go index 74defa67..9cb3b893 100644 --- a/app/tray/wintray/menus.go +++ b/app/tray/wintray/menus.go @@ -1,71 +1,71 @@ -//go:build windows - -package wintray - -import ( - "fmt" - "log/slog" - "unsafe" - - "golang.org/x/sys/windows" -) - -const ( - updatAvailableMenuID = 1 - updateMenuID = updatAvailableMenuID + 1 - separatorMenuID = updateMenuID + 1 - diagLogsMenuID = separatorMenuID + 1 - diagSeparatorMenuID = diagLogsMenuID + 1 - quitMenuID = diagSeparatorMenuID + 1 -) - -func (t *winTray) initMenus() error { - if err := t.addOrUpdateMenuItem(diagLogsMenuID, 0, diagLogsMenuTitle, false); err != nil { - return fmt.Errorf("unable to create menu entries %w\n", err) - } - if err := t.addSeparatorMenuItem(diagSeparatorMenuID, 0); err != nil { - return fmt.Errorf("unable to create menu entries %w", err) - } - if err := t.addOrUpdateMenuItem(quitMenuID, 0, quitMenuTitle, false); err != nil { - return fmt.Errorf("unable to create menu entries %w\n", err) - } - return nil -} - -func (t *winTray) UpdateAvailable(ver string) error { - if !t.updateNotified { - slog.Debug("updating menu and sending notification for new update") - if err := t.addOrUpdateMenuItem(updatAvailableMenuID, 0, updateAvailableMenuTitle, true); err != nil { - return fmt.Errorf("unable to create menu entries %w", err) - } - if err := t.addOrUpdateMenuItem(updateMenuID, 0, updateMenutTitle, false); err != nil { - return fmt.Errorf("unable to create menu entries %w", err) - } - if err := t.addSeparatorMenuItem(separatorMenuID, 0); err != nil { - return fmt.Errorf("unable to create menu entries %w", err) - } - iconFilePath, err := iconBytesToFilePath(wt.updateIcon) - if err != nil { - return fmt.Errorf("unable to write icon data to temp file: %w", err) - } - if err := wt.setIcon(iconFilePath); err != nil { - return fmt.Errorf("unable to set icon: %w", err) - } - t.updateNotified = true - - t.pendingUpdate = true - // Now pop up the notification - t.muNID.Lock() - defer t.muNID.Unlock() - copy(t.nid.InfoTitle[:], windows.StringToUTF16(updateTitle)) - copy(t.nid.Info[:], windows.StringToUTF16(fmt.Sprintf(updateMessage, ver))) - t.nid.Flags |= NIF_INFO - t.nid.Timeout = 10 - t.nid.Size = uint32(unsafe.Sizeof(*wt.nid)) - err = t.nid.modify() - if err != nil { - return err - } - } - return nil -} +//go:build windows + +package wintray + +import ( + "fmt" + "log/slog" + "unsafe" + + "golang.org/x/sys/windows" +) + +const ( + updatAvailableMenuID = 1 + updateMenuID = updatAvailableMenuID + 1 + separatorMenuID = updateMenuID + 1 + diagLogsMenuID = separatorMenuID + 1 + diagSeparatorMenuID = diagLogsMenuID + 1 + quitMenuID = diagSeparatorMenuID + 1 +) + +func (t *winTray) initMenus() error { + if err := t.addOrUpdateMenuItem(diagLogsMenuID, 0, diagLogsMenuTitle, false); err != nil { + return fmt.Errorf("unable to create menu entries %w\n", err) + } + if err := t.addSeparatorMenuItem(diagSeparatorMenuID, 0); err != nil { + return fmt.Errorf("unable to create menu entries %w", err) + } + if err := t.addOrUpdateMenuItem(quitMenuID, 0, quitMenuTitle, false); err != nil { + return fmt.Errorf("unable to create menu entries %w\n", err) + } + return nil +} + +func (t *winTray) UpdateAvailable(ver string) error { + if !t.updateNotified { + slog.Debug("updating menu and sending notification for new update") + if err := t.addOrUpdateMenuItem(updatAvailableMenuID, 0, updateAvailableMenuTitle, true); err != nil { + return fmt.Errorf("unable to create menu entries %w", err) + } + if err := t.addOrUpdateMenuItem(updateMenuID, 0, updateMenutTitle, false); err != nil { + return fmt.Errorf("unable to create menu entries %w", err) + } + if err := t.addSeparatorMenuItem(separatorMenuID, 0); err != nil { + return fmt.Errorf("unable to create menu entries %w", err) + } + iconFilePath, err := iconBytesToFilePath(wt.updateIcon) + if err != nil { + return fmt.Errorf("unable to write icon data to temp file: %w", err) + } + if err := wt.setIcon(iconFilePath); err != nil { + return fmt.Errorf("unable to set icon: %w", err) + } + t.updateNotified = true + + t.pendingUpdate = true + // Now pop up the notification + t.muNID.Lock() + defer t.muNID.Unlock() + copy(t.nid.InfoTitle[:], windows.StringToUTF16(updateTitle)) + copy(t.nid.Info[:], windows.StringToUTF16(fmt.Sprintf(updateMessage, ver))) + t.nid.Flags |= NIF_INFO + t.nid.Timeout = 10 + t.nid.Size = uint32(unsafe.Sizeof(*wt.nid)) + err = t.nid.modify() + if err != nil { + return err + } + } + return nil +} diff --git a/auth/auth.go b/auth/auth.go index ca64670d..026b2a2c 100644 --- a/auth/auth.go +++ b/auth/auth.go @@ -10,12 +10,44 @@ import ( "log/slog" "os" "path/filepath" + "strings" "golang.org/x/crypto/ssh" ) const defaultPrivateKey = "id_ed25519" +func keyPath() (string, error) { + home, err := os.UserHomeDir() + if err != nil { + return "", err + } + + return filepath.Join(home, ".ollama", defaultPrivateKey), nil +} + +func GetPublicKey() (string, error) { + keyPath, err := keyPath() + if err != nil { + return "", err + } + + privateKeyFile, err := os.ReadFile(keyPath) + if err != nil { + slog.Info(fmt.Sprintf("Failed to load private key: %v", err)) + return "", err + } + + privateKey, err := ssh.ParsePrivateKey(privateKeyFile) + if err != nil { + return "", err + } + + publicKey := ssh.MarshalAuthorizedKey(privateKey.PublicKey()) + + return strings.TrimSpace(string(publicKey)), nil +} + func NewNonce(r io.Reader, length int) (string, error) { nonce := make([]byte, length) if _, err := io.ReadFull(r, nonce); err != nil { @@ -26,13 +58,11 @@ func NewNonce(r io.Reader, length int) (string, error) { } func Sign(ctx context.Context, bts []byte) (string, error) { - home, err := os.UserHomeDir() + keyPath, err := keyPath() if err != nil { return "", err } - keyPath := filepath.Join(home, ".ollama", defaultPrivateKey) - privateKeyFile, err := os.ReadFile(keyPath) if err != nil { slog.Info(fmt.Sprintf("Failed to load private key: %v", err)) diff --git a/cmd/cmd.go b/cmd/cmd.go index f77c08b8..bf305d81 100644 --- a/cmd/cmd.go +++ b/cmd/cmd.go @@ -17,6 +17,7 @@ import ( "os" "os/signal" "path/filepath" + "regexp" "runtime" "strings" "syscall" @@ -31,10 +32,12 @@ import ( "golang.org/x/term" "github.com/ollama/ollama/api" + "github.com/ollama/ollama/auth" "github.com/ollama/ollama/format" - "github.com/ollama/ollama/parser" "github.com/ollama/ollama/progress" "github.com/ollama/ollama/server" + "github.com/ollama/ollama/types/errtypes" + "github.com/ollama/ollama/types/model" "github.com/ollama/ollama/version" ) @@ -53,14 +56,13 @@ func CreateHandler(cmd *cobra.Command, args []string) error { p := progress.NewProgress(os.Stderr) defer p.Stop() - bars := make(map[string]*progress.Bar) - - modelfile, err := os.ReadFile(filename) + f, err := os.Open(filename) if err != nil { return err } + defer f.Close() - commands, err := parser.Parse(bytes.NewReader(modelfile)) + modelfile, err := model.ParseFile(f) if err != nil { return err } @@ -74,10 +76,10 @@ func CreateHandler(cmd *cobra.Command, args []string) error { spinner := progress.NewSpinner(status) p.Add(status, spinner) - for _, c := range commands { - switch c.Name { + for i := range modelfile.Commands { + switch modelfile.Commands[i].Name { case "model", "adapter": - path := c.Args + path := modelfile.Commands[i].Args if path == "~" { path = home } else if strings.HasPrefix(path, "~/") { @@ -89,101 +91,22 @@ func CreateHandler(cmd *cobra.Command, args []string) error { } fi, err := os.Stat(path) - if errors.Is(err, os.ErrNotExist) && c.Name == "model" { + if errors.Is(err, os.ErrNotExist) && modelfile.Commands[i].Name == "model" { continue } else if err != nil { return err } - // TODO make this work w/ adapters if fi.IsDir() { - tf, err := os.CreateTemp("", "ollama-tf") + // this is likely a safetensors or pytorch directory + // TODO make this work w/ adapters + tempfile, err := tempZipFiles(path) if err != nil { return err } - defer os.RemoveAll(tf.Name()) + defer os.RemoveAll(tempfile) - zf := zip.NewWriter(tf) - - files := []string{} - - tfiles, err := filepath.Glob(filepath.Join(path, "pytorch_model-*.bin")) - if err != nil { - return err - } else if len(tfiles) == 0 { - tfiles, err = filepath.Glob(filepath.Join(path, "model-*.safetensors")) - if err != nil { - return err - } - } - - files = append(files, tfiles...) - - if len(files) == 0 { - return fmt.Errorf("no models were found in '%s'", path) - } - - // add the safetensor/torch config file + tokenizer - files = append(files, filepath.Join(path, "config.json")) - files = append(files, filepath.Join(path, "params.json")) - files = append(files, filepath.Join(path, "added_tokens.json")) - files = append(files, filepath.Join(path, "tokenizer.model")) - - for _, fn := range files { - f, err := os.Open(fn) - - // just skip whatever files aren't there - if os.IsNotExist(err) { - if strings.HasSuffix(fn, "tokenizer.model") { - // try the parent dir before giving up - parentDir := filepath.Dir(path) - newFn := filepath.Join(parentDir, "tokenizer.model") - f, err = os.Open(newFn) - if os.IsNotExist(err) { - continue - } else if err != nil { - return err - } - } else { - continue - } - } else if err != nil { - return err - } - - fi, err := f.Stat() - if err != nil { - return err - } - - h, err := zip.FileInfoHeader(fi) - if err != nil { - return err - } - - h.Name = filepath.Base(fn) - h.Method = zip.Store - - w, err := zf.CreateHeader(h) - if err != nil { - return err - } - - _, err = io.Copy(w, f) - if err != nil { - return err - } - - } - - if err := zf.Close(); err != nil { - return err - } - - if err := tf.Close(); err != nil { - return err - } - path = tf.Name() + path = tempfile } digest, err := createBlob(cmd, client, path) @@ -191,10 +114,11 @@ func CreateHandler(cmd *cobra.Command, args []string) error { return err } - modelfile = bytes.ReplaceAll(modelfile, []byte(c.Args), []byte("@"+digest)) + modelfile.Commands[i].Args = "@" + digest } } + bars := make(map[string]*progress.Bar) fn := func(resp api.ProgressResponse) error { if resp.Digest != "" { spinner.Stop() @@ -220,7 +144,7 @@ func CreateHandler(cmd *cobra.Command, args []string) error { quantization, _ := cmd.Flags().GetString("quantization") - request := api.CreateRequest{Name: args[0], Modelfile: string(modelfile), Quantization: quantization} + request := api.CreateRequest{Name: args[0], Modelfile: modelfile.String(), Quantization: quantization} if err := client.Create(cmd.Context(), &request, fn); err != nil { return err } @@ -228,6 +152,114 @@ func CreateHandler(cmd *cobra.Command, args []string) error { return nil } +func tempZipFiles(path string) (string, error) { + tempfile, err := os.CreateTemp("", "ollama-tf") + if err != nil { + return "", err + } + defer tempfile.Close() + + zipfile := zip.NewWriter(tempfile) + defer zipfile.Close() + + detectContentType := func(path string) (string, error) { + f, err := os.Open(path) + if err != nil { + return "", err + } + defer f.Close() + + var b bytes.Buffer + b.Grow(512) + + if _, err := io.CopyN(&b, f, 512); err != nil && !errors.Is(err, io.EOF) { + return "", err + } + + contentType, _, _ := strings.Cut(http.DetectContentType(b.Bytes()), ";") + return contentType, nil + } + + glob := func(pattern, contentType string) ([]string, error) { + matches, err := filepath.Glob(pattern) + if err != nil { + return nil, err + } + + for _, safetensor := range matches { + if ct, err := detectContentType(safetensor); err != nil { + return nil, err + } else if ct != contentType { + return nil, fmt.Errorf("invalid content type: expected %s for %s", ct, safetensor) + } + } + + return matches, nil + } + + var files []string + if st, _ := glob(filepath.Join(path, "model*.safetensors"), "application/octet-stream"); len(st) > 0 { + // safetensors files might be unresolved git lfs references; skip if they are + // covers model-x-of-y.safetensors, model.fp32-x-of-y.safetensors, model.safetensors + files = append(files, st...) + } else if pt, _ := glob(filepath.Join(path, "pytorch_model*.bin"), "application/zip"); len(pt) > 0 { + // pytorch files might also be unresolved git lfs references; skip if they are + // covers pytorch_model-x-of-y.bin, pytorch_model.fp32-x-of-y.bin, pytorch_model.bin + files = append(files, pt...) + } else if pt, _ := glob(filepath.Join(path, "consolidated*.pth"), "application/octet-stream"); len(pt) > 0 { + // pytorch files might also be unresolved git lfs references; skip if they are + // covers consolidated.x.pth, consolidated.pth + files = append(files, pt...) + } else { + return "", errors.New("no safetensors or torch files found") + } + + // add configuration files, json files are detected as text/plain + js, err := glob(filepath.Join(path, "*.json"), "text/plain") + if err != nil { + return "", err + } + files = append(files, js...) + + if tks, _ := glob(filepath.Join(path, "tokenizer.model"), "application/octet-stream"); len(tks) > 0 { + // add tokenizer.model if it exists, tokenizer.json is automatically picked up by the previous glob + // tokenizer.model might be a unresolved git lfs reference; error if it is + files = append(files, tks...) + } else if tks, _ := glob(filepath.Join(path, "**/tokenizer.model"), "text/plain"); len(tks) > 0 { + // some times tokenizer.model is in a subdirectory (e.g. meta-llama/Meta-Llama-3-8B) + files = append(files, tks...) + } + + for _, file := range files { + f, err := os.Open(file) + if err != nil { + return "", err + } + defer f.Close() + + fi, err := f.Stat() + if err != nil { + return "", err + } + + zfi, err := zip.FileInfoHeader(fi) + if err != nil { + return "", err + } + + zf, err := zipfile.CreateHeader(zfi) + if err != nil { + return "", err + } + + if _, err := io.Copy(zf, f); err != nil { + return "", err + } + } + + return tempfile.Name(), nil +} + func createBlob(cmd *cobra.Command, client *api.Client, path string) (string, error) { bin, err := os.Open(path) if err != nil { @@ -322,6 +354,47 @@ func RunHandler(cmd *cobra.Command, args []string) error { return generateInteractive(cmd, opts) } +func errFromUnknownKey(unknownKeyErr error) error { + // find SSH public key in the error message + sshKeyPattern := `ssh-\w+ [^\s"]+` + re := regexp.MustCompile(sshKeyPattern) + matches := re.FindStringSubmatch(unknownKeyErr.Error()) + + if len(matches) > 0 { + serverPubKey := matches[0] + + localPubKey, err := auth.GetPublicKey() + if err != nil { + return unknownKeyErr + } + + if runtime.GOOS == "linux" && serverPubKey != localPubKey { + // try the ollama service public key + svcPubKey, err := os.ReadFile("/usr/share/ollama/.ollama/id_ed25519.pub") + if err != nil { + return unknownKeyErr + } + localPubKey = strings.TrimSpace(string(svcPubKey)) + } + + // check if the returned public key matches the local public key, this prevents adding a remote key to the user's account + if serverPubKey != localPubKey { + return unknownKeyErr + } + + var msg strings.Builder + msg.WriteString(unknownKeyErr.Error()) + msg.WriteString("\n\nYour ollama key is:\n") + msg.WriteString(localPubKey) + msg.WriteString("\nAdd your key at:\n") + msg.WriteString("https://ollama.com/settings/keys") + + return errors.New(msg.String()) + } + + return unknownKeyErr +} + func PushHandler(cmd *cobra.Command, args []string) error { client, err := api.ClientFromEnvironment() if err != nil { @@ -369,6 +442,20 @@ func PushHandler(cmd *cobra.Command, args []string) error { request := api.PushRequest{Name: args[0], Insecure: insecure} if err := client.Push(cmd.Context(), &request, fn); err != nil { + if spinner != nil { + spinner.Stop() + } + if strings.Contains(err.Error(), "access denied") { + return errors.New("you are not authorized to push to this namespace, create the model under a namespace you own") + } + host := model.ParseName(args[0]).Host + isOllamaHost := strings.HasSuffix(host, ".ollama.ai") || strings.HasSuffix(host, ".ollama.com") + if strings.Contains(err.Error(), errtypes.UnknownOllamaKeyErrMsg) && isOllamaHost { + // the user has not added their ollama key to ollama.com + // re-throw an error with a more user-friendly message + return errFromUnknownKey(err) + } + return err } @@ -796,24 +883,27 @@ func generate(cmd *cobra.Command, opts runOptions) error { } func RunServer(cmd *cobra.Command, _ []string) error { - host, port, err := net.SplitHostPort(strings.Trim(os.Getenv("OLLAMA_HOST"), "\"'")) + // retrieve the OLLAMA_HOST environment variable + ollamaHost, err := api.GetOllamaHost() if err != nil { - host, port = "127.0.0.1", "11434" - if ip := net.ParseIP(strings.Trim(os.Getenv("OLLAMA_HOST"), "[]")); ip != nil { - host = ip.String() - } + return err } if err := initializeKeypair(); err != nil { return err } - ln, err := net.Listen("tcp", net.JoinHostPort(host, port)) + ln, err := net.Listen("tcp", net.JoinHostPort(ollamaHost.Host, ollamaHost.Port)) if err != nil { return err } - return server.Serve(ln) + err = server.Serve(ln) + if errors.Is(err, http.ErrServerClosed) { + return nil + } + + return err } func initializeKeypair() error { @@ -1034,7 +1124,7 @@ Environment Variables: RunE: ListHandler, } copyCmd := &cobra.Command{ - Use: "cp SOURCE TARGET", + Use: "cp SOURCE DESTINATION", Short: "Copy a model", Args: cobra.ExactArgs(2), PreRunE: checkServerHeartbeat, diff --git a/cmd/interactive.go b/cmd/interactive.go index 12c31052..c294b7b5 100644 --- a/cmd/interactive.go +++ b/cmd/interactive.go @@ -94,6 +94,7 @@ func generateInteractive(cmd *cobra.Command, opts runOptions) error { fmt.Fprintln(os.Stderr, " /show Show model information") fmt.Fprintln(os.Stderr, " /load Load a session or model") fmt.Fprintln(os.Stderr, " /save Save your current session") + fmt.Fprintln(os.Stderr, " /clear Clear session context") fmt.Fprintln(os.Stderr, " /bye Exit") fmt.Fprintln(os.Stderr, " /?, /help Help for a command") fmt.Fprintln(os.Stderr, " /? shortcuts Help for keyboard shortcuts") @@ -161,7 +162,7 @@ func generateInteractive(cmd *cobra.Command, opts runOptions) error { fmt.Fprintln(os.Stderr, " /set parameter repeat_penalty How strongly to penalize repetitions") fmt.Fprintln(os.Stderr, " /set parameter repeat_last_n Set how far back to look for repetitions") fmt.Fprintln(os.Stderr, " /set parameter num_gpu The number of layers to send to the GPU") - fmt.Fprintln(os.Stderr, " /set parameter stop \"\", ... Set the stop parameters") + fmt.Fprintln(os.Stderr, " /set parameter stop ... Set the stop parameters") fmt.Fprintln(os.Stderr, "") } @@ -280,6 +281,10 @@ func generateInteractive(cmd *cobra.Command, opts runOptions) error { } fmt.Printf("Created new model '%s'\n", args[1]) continue + case strings.HasPrefix(line, "/clear"): + opts.Messages = []api.Message{} + fmt.Println("Cleared session context") + continue case strings.HasPrefix(line, "/set"): args := strings.Fields(line) if len(args) > 1 { diff --git a/convert/convert.go b/convert/convert.go index bf6f0bf5..f4210e50 100644 --- a/convert/convert.go +++ b/convert/convert.go @@ -5,6 +5,7 @@ import ( "encoding/binary" "encoding/json" "fmt" + "io" "log/slog" "os" "path/filepath" @@ -18,19 +19,23 @@ import ( ) type Params struct { - Architectures []string `json:"architectures"` - VocabSize int `json:"vocab_size"` - HiddenSize int `json:"hidden_size"` // n_embd - HiddenLayers int `json:"num_hidden_layers"` // n_layer - ContextSize int `json:"max_position_embeddings"` - IntermediateSize int `json:"intermediate_size"` - AttentionHeads int `json:"num_attention_heads"` // n_head - KeyValHeads int `json:"num_key_value_heads"` - NormEPS float64 `json:"rms_norm_eps"` - BoSTokenID int `json:"bos_token_id"` - EoSTokenID int `json:"eos_token_id"` - HeadDimension int `json:"head_dim"` - PaddingTokenID int `json:"pad_token_id"` + Architectures []string `json:"architectures"` + VocabSize int `json:"vocab_size"` + HiddenSize int `json:"hidden_size"` // n_embd + HiddenLayers int `json:"num_hidden_layers"` // n_layer + ContextSize int `json:"max_position_embeddings"` + IntermediateSize int `json:"intermediate_size"` + AttentionHeads int `json:"num_attention_heads"` // n_head + KeyValHeads int `json:"num_key_value_heads"` + NormEPS float64 `json:"rms_norm_eps"` + BoSTokenID int `json:"bos_token_id"` + EoSTokenID int `json:"eos_token_id"` + HeadDimension int `json:"head_dim"` + PaddingTokenID int `json:"pad_token_id"` + RopeFrequencyBase float64 `json:"rope_theta"` + + Experts int `json:"num_local_experts"` + ExpertsUsed int `json:"num_experts_per_tok"` ByteOrder } @@ -43,7 +48,7 @@ type ByteOrder interface { type ModelArch interface { GetTensors() error LoadVocab() error - WriteGGUF() (string, error) + WriteGGUF(io.WriteSeeker) error } type ModelFormat interface { diff --git a/convert/gemma.go b/convert/gemma.go index 648a4ad9..88abe646 100644 --- a/convert/gemma.go +++ b/convert/gemma.go @@ -94,7 +94,7 @@ func (m *GemmaModel) LoadVocab() error { return nil } -func (m *GemmaModel) WriteGGUF() (string, error) { +func (m *GemmaModel) WriteGGUF(ws io.WriteSeeker) error { kv := llm.KV{ "general.architecture": "gemma", "general.name": m.Name, @@ -122,16 +122,5 @@ func (m *GemmaModel) WriteGGUF() (string, error) { "tokenizer.ggml.add_eos_token": false, } - f, err := os.CreateTemp("", "ollama-gguf") - if err != nil { - return "", err - } - defer f.Close() - - mod := llm.NewGGUFV3(m.Params.ByteOrder) - if err := mod.Encode(f, kv, m.Tensors); err != nil { - return "", err - } - - return f.Name(), nil + return llm.NewGGUFV3(m.Params.ByteOrder).Encode(ws, kv, m.Tensors) } diff --git a/convert/llama.go b/convert/llama.go index c7f7b290..fb576e2e 100644 --- a/convert/llama.go +++ b/convert/llama.go @@ -5,7 +5,6 @@ import ( "fmt" "io" "log/slog" - "os" "regexp" "strings" @@ -132,7 +131,7 @@ func (m *LlamaModel) LoadVocab() error { return nil } -func (m *LlamaModel) WriteGGUF() (string, error) { +func (m *LlamaModel) WriteGGUF(ws io.WriteSeeker) error { kv := llm.KV{ "general.architecture": "llama", "general.name": m.Name, @@ -159,18 +158,5 @@ func (m *LlamaModel) WriteGGUF() (string, error) { "tokenizer.ggml.add_eos_token": false, } - f, err := os.CreateTemp("", "ollama-gguf") - if err != nil { - return "", err - } - defer f.Close() - - mod := llm.NewGGUFV3(m.Params.ByteOrder) - if err := mod.Encode(f, kv, m.Tensors); err != nil { - return "", err - } - - slog.Debug(fmt.Sprintf("gguf file = %s", f.Name())) - - return f.Name(), nil + return llm.NewGGUFV3(m.Params.ByteOrder).Encode(ws, kv, m.Tensors) } diff --git a/convert/mistral.go b/convert/mistral.go index 70c92edd..f88de12b 100644 --- a/convert/mistral.go +++ b/convert/mistral.go @@ -132,7 +132,7 @@ func (m *MistralModel) LoadVocab() error { return nil } -func (m *MistralModel) WriteGGUF() (string, error) { +func (m *MistralModel) WriteGGUF(ws io.WriteSeeker) error { kv := llm.KV{ "general.architecture": "llama", "general.name": m.Name, @@ -158,16 +158,5 @@ func (m *MistralModel) WriteGGUF() (string, error) { "tokenizer.ggml.unknown_token_id": uint32(0), } - f, err := os.CreateTemp("", "ollama-gguf") - if err != nil { - return "", err - } - defer f.Close() - - mod := llm.NewGGUFV3(m.Params.ByteOrder) - if err := mod.Encode(f, kv, m.Tensors); err != nil { - return "", err - } - - return f.Name(), nil + return llm.NewGGUFV3(m.Params.ByteOrder).Encode(ws, kv, m.Tensors) } diff --git a/convert/mixtral.go b/convert/mixtral.go new file mode 100644 index 00000000..940df55d --- /dev/null +++ b/convert/mixtral.go @@ -0,0 +1,85 @@ +package convert + +import ( + "io" + "regexp" + + "github.com/ollama/ollama/llm" +) + +type MixtralModel struct { + ModelData +} + +func (m *MixtralModel) GetTensors() error { + t, err := m.Format.GetTensors(m.Path, m.Params) + if err != nil { + return err + } + + m.Tensors = []llm.Tensor{} + + pattern := `^blk\.[0-9]+\.attn_(?Pq|k)\.weight$` + re, err := regexp.Compile(pattern) + if err != nil { + return err + } + + for _, l := range t { + matches := re.FindAllStringSubmatch(l.Name, -1) + if len(matches) > 0 { + wt := l.WriterTo.(safetensorWriterTo) + wt.handler = mistralLayerHandler + l.WriterTo = wt + } + m.Tensors = append(m.Tensors, l) + } + + return nil +} + +func (m *MixtralModel) LoadVocab() error { + v, err := LoadSentencePieceTokens(m.Path, m.Params) + if err != nil { + return err + } + m.Vocab = v + return nil +} + +func (m *MixtralModel) WriteGGUF(ws io.WriteSeeker) error { + kv := llm.KV{ + "general.architecture": "llama", + "general.name": m.Name, + "llama.block_count": uint32(m.Params.HiddenLayers), + "llama.context_length": uint32(m.Params.ContextSize), + "llama.embedding_length": uint32(m.Params.HiddenSize), + "llama.feed_forward_length": uint32(m.Params.IntermediateSize), + "llama.attention.head_count": uint32(m.Params.AttentionHeads), + "llama.attention.head_count_kv": uint32(m.Params.KeyValHeads), + + "llama.rope.freq_base": float32(m.Params.RopeFrequencyBase), + "llama.attention.layer_norm_rms_epsilon": float32(m.Params.NormEPS), + + "llama.expert_count": uint32(m.Params.Experts), + "llama.expert_used_count": uint32(m.Params.ExpertsUsed), + + "llama.vocab_size": uint32(len(m.Vocab.Tokens)), + "llama.rope.dimension_count": uint32(m.Params.HiddenSize / m.Params.AttentionHeads), + + "general.file_type": uint32(1), + "tokenizer.ggml.model": "llama", + + "tokenizer.ggml.tokens": m.Vocab.Tokens, + "tokenizer.ggml.scores": m.Vocab.Scores, + "tokenizer.ggml.token_type": m.Vocab.Types, + + "tokenizer.ggml.bos_token_id": uint32(m.Params.BoSTokenID), + "tokenizer.ggml.eos_token_id": uint32(m.Params.EoSTokenID), + "tokenizer.ggml.unknown_token_id": uint32(0), + "tokenizer.ggml.add_bos_token": true, + "tokenizer.ggml.add_eos_token": false, + } + + return llm.NewGGUFV3(m.Params.ByteOrder).Encode(ws, kv, m.Tensors) +} diff --git a/convert/safetensors.go b/convert/safetensors.go index 468bc707..69424c4d 100644 --- a/convert/safetensors.go +++ b/convert/safetensors.go @@ -53,7 +53,7 @@ func (m *SafetensorFormat) GetTensors(dirpath string, params *Params) ([]llm.Ten var err error t, offset, err = m.readTensors(f, offset, params) if err != nil { - slog.Error("%v", err) + slog.Error(err.Error()) return nil, err } tensors = append(tensors, t...) @@ -93,7 +93,6 @@ func (m *SafetensorFormat) readTensors(fn string, offset uint64, params *Params) } slices.Sort(keys) - slog.Info("converting layers") var tensors []llm.Tensor @@ -105,7 +104,6 @@ func (m *SafetensorFormat) readTensors(fn string, offset uint64, params *Params) return nil, 0, err } - slog.Debug(fmt.Sprintf("metadata = %#v", data)) var size uint64 var kind uint32 switch len(data.Shape) { @@ -124,7 +122,7 @@ func (m *SafetensorFormat) readTensors(fn string, offset uint64, params *Params) ggufName, err := m.GetLayerName(k) if err != nil { - slog.Error("%v", err) + slog.Error(err.Error()) return nil, 0, err } @@ -150,11 +148,13 @@ func (m *SafetensorFormat) readTensors(fn string, offset uint64, params *Params) padding: 8 + jsonSize, } - tensors = append(tensors, t) offset += size + tensors = append(tensors, t) } + slog.Debug(fmt.Sprintf("total tensors for file = %d", len(tensors))) slog.Debug(fmt.Sprintf("offset = %d", offset)) + return tensors, offset, nil } @@ -185,15 +185,19 @@ func (m *SafetensorFormat) GetLayerName(n string) (string, error) { } tMap := map[string]string{ - "model.layers.(\\d+).input_layernorm.weight": "blk.$1.attn_norm.weight", - "model.layers.(\\d+).mlp.down_proj.weight": "blk.$1.ffn_down.weight", - "model.layers.(\\d+).mlp.gate_proj.weight": "blk.$1.ffn_gate.weight", - "model.layers.(\\d+).mlp.up_proj.weight": "blk.$1.ffn_up.weight", - "model.layers.(\\d+).post_attention_layernorm.weight": "blk.$1.ffn_norm.weight", - "model.layers.(\\d+).self_attn.k_proj.weight": "blk.$1.attn_k.weight", - "model.layers.(\\d+).self_attn.o_proj.weight": "blk.$1.attn_output.weight", - "model.layers.(\\d+).self_attn.q_proj.weight": "blk.$1.attn_q.weight", - "model.layers.(\\d+).self_attn.v_proj.weight": "blk.$1.attn_v.weight", + "model.layers.(\\d+).input_layernorm.weight": "blk.$1.attn_norm.weight", + "model.layers.(\\d+).mlp.down_proj.weight": "blk.$1.ffn_down.weight", + "model.layers.(\\d+).mlp.gate_proj.weight": "blk.$1.ffn_gate.weight", + "model.layers.(\\d+).mlp.up_proj.weight": "blk.$1.ffn_up.weight", + "model.layers.(\\d+).post_attention_layernorm.weight": "blk.$1.ffn_norm.weight", + "model.layers.(\\d+).self_attn.k_proj.weight": "blk.$1.attn_k.weight", + "model.layers.(\\d+).self_attn.o_proj.weight": "blk.$1.attn_output.weight", + "model.layers.(\\d+).self_attn.q_proj.weight": "blk.$1.attn_q.weight", + "model.layers.(\\d+).self_attn.v_proj.weight": "blk.$1.attn_v.weight", + "model.layers.(\\d+).block_sparse_moe.gate.weight": "blk.$1.ffn_gate_inp.weight", + "model.layers.(\\d+).block_sparse_moe.experts.(\\d+).w1.weight": "blk.$1.ffn_gate.$2.weight", + "model.layers.(\\d+).block_sparse_moe.experts.(\\d+).w2.weight": "blk.$1.ffn_down.$2.weight", + "model.layers.(\\d+).block_sparse_moe.experts.(\\d+).w3.weight": "blk.$1.ffn_up.$2.weight", } v, ok := directMap[n] @@ -286,6 +290,15 @@ func (m *SafetensorFormat) GetModelArch(name, dirPath string, params *Params) (M Format: m, }, }, nil + case "MixtralForCausalLM": + return &MixtralModel{ + ModelData{ + Name: name, + Path: dirPath, + Params: params, + Format: m, + }, + }, nil case "GemmaForCausalLM": return &GemmaModel{ ModelData{ diff --git a/convert/torch.go b/convert/torch.go index fd237505..92c58872 100644 --- a/convert/torch.go +++ b/convert/torch.go @@ -74,7 +74,7 @@ func (tf *TorchFormat) GetTensors(dirpath string, params *Params) ([]llm.Tensor, ggufName, err := tf.GetLayerName(k.(string)) if err != nil { - slog.Error("%v", err) + slog.Error(err.Error()) return nil, err } slog.Debug(fmt.Sprintf("finding name for '%s' -> '%s'", k.(string), ggufName)) diff --git a/docs/api.md b/docs/api.md index aba605f7..2f52c55a 100644 --- a/docs/api.md +++ b/docs/api.md @@ -17,7 +17,7 @@ ### Model names -Model names follow a `model:tag` format, where `model` can have an optional namespace such as `example/model`. Some examples are `orca-mini:3b-q4_1` and `llama2:70b`. The tag is optional and, if not provided, will default to `latest`. The tag is used to identify a specific version. +Model names follow a `model:tag` format, where `model` can have an optional namespace such as `example/model`. Some examples are `orca-mini:3b-q4_1` and `llama3:70b`. The tag is optional and, if not provided, will default to `latest`. The tag is used to identify a specific version. ### Durations @@ -66,7 +66,7 @@ Enable JSON mode by setting the `format` parameter to `json`. This will structur ```shell curl http://localhost:11434/api/generate -d '{ - "model": "llama2", + "model": "llama3", "prompt": "Why is the sky blue?" }' ``` @@ -77,7 +77,7 @@ A stream of JSON objects is returned: ```json { - "model": "llama2", + "model": "llama3", "created_at": "2023-08-04T08:52:19.385406455-07:00", "response": "The", "done": false @@ -90,16 +90,16 @@ The final response in the stream also includes additional data about the generat - `load_duration`: time spent in nanoseconds loading the model - `prompt_eval_count`: number of tokens in the prompt - `prompt_eval_duration`: time spent in nanoseconds evaluating the prompt -- `eval_count`: number of tokens the response +- `eval_count`: number of tokens in the response - `eval_duration`: time in nanoseconds spent generating the response - `context`: an encoding of the conversation used in this response, this can be sent in the next request to keep a conversational memory - `response`: empty if the response was streamed, if not streamed, this will contain the full response -To calculate how fast the response is generated in tokens per second (token/s), divide `eval_count` / `eval_duration`. +To calculate how fast the response is generated in tokens per second (token/s), divide `eval_count` / `eval_duration` * `10^9`. ```json { - "model": "llama2", + "model": "llama3", "created_at": "2023-08-04T19:22:45.499127Z", "response": "", "done": true, @@ -121,7 +121,7 @@ A response can be received in one reply when streaming is off. ```shell curl http://localhost:11434/api/generate -d '{ - "model": "llama2", + "model": "llama3", "prompt": "Why is the sky blue?", "stream": false }' @@ -133,7 +133,7 @@ If `stream` is set to `false`, the response will be a single JSON object: ```json { - "model": "llama2", + "model": "llama3", "created_at": "2023-08-04T19:22:45.499127Z", "response": "The sky is blue because it is the color of the sky.", "done": true, @@ -155,7 +155,7 @@ If `stream` is set to `false`, the response will be a single JSON object: ```shell curl http://localhost:11434/api/generate -d '{ - "model": "llama2", + "model": "llama3", "prompt": "What color is the sky at different times of the day? Respond using JSON", "format": "json", "stream": false @@ -166,7 +166,7 @@ curl http://localhost:11434/api/generate -d '{ ```json { - "model": "llama2", + "model": "llama3", "created_at": "2023-11-09T21:07:55.186497Z", "response": "{\n\"morning\": {\n\"color\": \"blue\"\n},\n\"noon\": {\n\"color\": \"blue-gray\"\n},\n\"afternoon\": {\n\"color\": \"warm gray\"\n},\n\"evening\": {\n\"color\": \"orange\"\n}\n}\n", "done": true, @@ -289,7 +289,7 @@ If you want to set custom options for the model at runtime rather than in the Mo ```shell curl http://localhost:11434/api/generate -d '{ - "model": "llama2", + "model": "llama3", "prompt": "Why is the sky blue?", "stream": false, "options": { @@ -332,7 +332,7 @@ curl http://localhost:11434/api/generate -d '{ ```json { - "model": "llama2", + "model": "llama3", "created_at": "2023-08-04T19:22:45.499127Z", "response": "The sky is blue because it is the color of the sky.", "done": true, @@ -354,7 +354,7 @@ If an empty prompt is provided, the model will be loaded into memory. ```shell curl http://localhost:11434/api/generate -d '{ - "model": "llama2" + "model": "llama3" }' ``` @@ -364,7 +364,7 @@ A single JSON object is returned: ```json { - "model": "llama2", + "model": "llama3", "created_at": "2023-12-18T19:52:07.071755Z", "response": "", "done": true @@ -407,7 +407,7 @@ Send a chat message with a streaming response. ```shell curl http://localhost:11434/api/chat -d '{ - "model": "llama2", + "model": "llama3", "messages": [ { "role": "user", @@ -423,7 +423,7 @@ A stream of JSON objects is returned: ```json { - "model": "llama2", + "model": "llama3", "created_at": "2023-08-04T08:52:19.385406455-07:00", "message": { "role": "assistant", @@ -438,7 +438,7 @@ Final response: ```json { - "model": "llama2", + "model": "llama3", "created_at": "2023-08-04T19:22:45.499127Z", "done": true, "total_duration": 4883583458, @@ -456,7 +456,7 @@ Final response: ```shell curl http://localhost:11434/api/chat -d '{ - "model": "llama2", + "model": "llama3", "messages": [ { "role": "user", @@ -471,7 +471,7 @@ curl http://localhost:11434/api/chat -d '{ ```json { - "model": "registry.ollama.ai/library/llama2:latest", + "model": "registry.ollama.ai/library/llama3:latest", "created_at": "2023-12-12T14:13:43.416799Z", "message": { "role": "assistant", @@ -495,7 +495,7 @@ Send a chat message with a conversation history. You can use this same approach ```shell curl http://localhost:11434/api/chat -d '{ - "model": "llama2", + "model": "llama3", "messages": [ { "role": "user", @@ -519,7 +519,7 @@ A stream of JSON objects is returned: ```json { - "model": "llama2", + "model": "llama3", "created_at": "2023-08-04T08:52:19.385406455-07:00", "message": { "role": "assistant", @@ -533,7 +533,7 @@ Final response: ```json { - "model": "llama2", + "model": "llama3", "created_at": "2023-08-04T19:22:45.499127Z", "done": true, "total_duration": 8113331500, @@ -591,7 +591,7 @@ curl http://localhost:11434/api/chat -d '{ ```shell curl http://localhost:11434/api/chat -d '{ - "model": "llama2", + "model": "llama3", "messages": [ { "role": "user", @@ -609,7 +609,7 @@ curl http://localhost:11434/api/chat -d '{ ```json { - "model": "registry.ollama.ai/library/llama2:latest", + "model": "registry.ollama.ai/library/llama3:latest", "created_at": "2023-12-12T14:13:43.416799Z", "message": { "role": "assistant", @@ -651,7 +651,7 @@ Create a new model from a `Modelfile`. ```shell curl http://localhost:11434/api/create -d '{ "name": "mario", - "modelfile": "FROM llama2\nSYSTEM You are mario from Super Mario Bros." + "modelfile": "FROM llama3\nSYSTEM You are mario from Super Mario Bros." }' ``` @@ -758,7 +758,7 @@ A single JSON object will be returned. } }, { - "name": "llama2:latest", + "name": "llama3:latest", "modified_at": "2023-12-07T09:32:18.757212583-08:00", "size": 3825819519, "digest": "fe938a131f40e6f6d40083c9f0f430a515233eb2edaa6d72eb85c50d64f2300e", @@ -792,7 +792,7 @@ Show information about a model including details, modelfile, template, parameter ```shell curl http://localhost:11434/api/show -d '{ - "name": "llama2" + "name": "llama3" }' ``` @@ -827,8 +827,8 @@ Copy a model. Creates a model with another name from an existing model. ```shell curl http://localhost:11434/api/copy -d '{ - "source": "llama2", - "destination": "llama2-backup" + "source": "llama3", + "destination": "llama3-backup" }' ``` @@ -854,7 +854,7 @@ Delete a model and its data. ```shell curl -X DELETE http://localhost:11434/api/delete -d '{ - "name": "llama2:13b" + "name": "llama3:13b" }' ``` @@ -882,7 +882,7 @@ Download a model from the ollama library. Cancelled pulls are resumed from where ```shell curl http://localhost:11434/api/pull -d '{ - "name": "llama2" + "name": "llama3" }' ``` diff --git a/docs/development.md b/docs/development.md index 76936c35..2f7b9ecf 100644 --- a/docs/development.md +++ b/docs/development.md @@ -51,7 +51,7 @@ Typically the build scripts will auto-detect CUDA, however, if your Linux distro or installation approach uses unusual paths, you can specify the location by specifying an environment variable `CUDA_LIB_DIR` to the location of the shared libraries, and `CUDACXX` to the location of the nvcc compiler. You can customize -set set of target CUDA architectues by setting `CMAKE_CUDA_ARCHITECTURES` (e.g. "50;60;70") +a set of target CUDA architectures by setting `CMAKE_CUDA_ARCHITECTURES` (e.g. "50;60;70") Then generate dependencies: @@ -142,4 +142,4 @@ In addition to the common Windows development tools described above, install AMD - [AMD HIP](https://www.amd.com/en/developer/resources/rocm-hub/hip-sdk.html) - [Strawberry Perl](https://strawberryperl.com/) -Lastly, add `ninja.exe` included with MSVC to the system path (e.g. `C:\Program Files (x86)\Microsoft Visual Studio\2019\Community\Common7\IDE\CommonExtensions\Microsoft\CMake\Ninja`). \ No newline at end of file +Lastly, add `ninja.exe` included with MSVC to the system path (e.g. `C:\Program Files (x86)\Microsoft Visual Studio\2019\Community\Common7\IDE\CommonExtensions\Microsoft\CMake\Ninja`). diff --git a/docs/faq.md b/docs/faq.md index 6bd1b340..3fe3da89 100644 --- a/docs/faq.md +++ b/docs/faq.md @@ -32,7 +32,7 @@ When using the API, specify the `num_ctx` parameter: ``` curl http://localhost:11434/api/generate -d '{ - "model": "llama2", + "model": "llama3", "prompt": "Why is the sky blue?", "options": { "num_ctx": 4096 @@ -88,9 +88,9 @@ On windows, Ollama inherits your user and system environment variables. 3. Edit or create New variable(s) for your user account for `OLLAMA_HOST`, `OLLAMA_MODELS`, etc. -4. Click OK/Apply to save +4. Click OK/Apply to save -5. Run `ollama` from a new terminal window +5. Run `ollama` from a new terminal window ## How can I expose Ollama on my network? @@ -140,7 +140,7 @@ Refer to the section [above](#how-do-i-configure-ollama-server) for how to set e - macOS: `~/.ollama/models` - Linux: `/usr/share/ollama/.ollama/models` -- Windows: `C:\Users\\.ollama\models` +- Windows: `C:\Users\%username%\.ollama\models` ### How do I set them to a different location? @@ -221,10 +221,20 @@ The `keep_alive` parameter can be set to: For example, to preload a model and leave it in memory use: ```shell -curl http://localhost:11434/api/generate -d '{"model": "llama2", "keep_alive": -1}' +curl http://localhost:11434/api/generate -d '{"model": "llama3", "keep_alive": -1}' ``` To unload the model and free up memory use: ```shell -curl http://localhost:11434/api/generate -d '{"model": "llama2", "keep_alive": 0}' +curl http://localhost:11434/api/generate -d '{"model": "llama3", "keep_alive": 0}' ``` + +Alternatively, you can change the amount of time all models are loaded into memory by setting the `OLLAMA_KEEP_ALIVE` environment variable when starting the Ollama server. The `OLLAMA_KEEP_ALIVE` variable uses the same parameter types as the `keep_alive` parameter types mentioned above. Refer to section explaining [how to configure the Ollama server](#how-do-i-configure-ollama-server) to correctly set the environment variable. + +If you wish to override the `OLLAMA_KEEP_ALIVE` setting, use the `keep_alive` API parameter with the `/api/generate` or `/api/chat` API endpoints. + +## How do I manage the maximum number of requests the server can queue + +If too many requests are sent to the server, it will respond with a 503 error +indicating the server is overloaded. You can adjust how many requests may be +queue by setting `OLLAMA_MAX_QUEUE` \ No newline at end of file diff --git a/docs/import.md b/docs/import.md index 672916b5..7041b74d 100644 --- a/docs/import.md +++ b/docs/import.md @@ -125,7 +125,7 @@ Publishing models is in early alpha. If you'd like to publish your model to shar 1. Create [an account](https://ollama.com/signup) 2. Copy your Ollama public key: - - macOS: `cat ~/.ollama/id_ed25519.pub` + - macOS: `cat ~/.ollama/id_ed25519.pub | pbcopy` - Windows: `type %USERPROFILE%\.ollama\id_ed25519.pub` - Linux: `cat /usr/share/ollama/.ollama/id_ed25519.pub` 3. Add your public key to your [Ollama account](https://ollama.com/settings/keys) @@ -136,6 +136,8 @@ Next, copy your model to your username's namespace: ollama cp example /example ``` +> Note: model names may only contain lowercase letters, digits, and the characters `.`, `-`, and `_`. + Then push the model: ``` diff --git a/docs/linux.md b/docs/linux.md index 0ef4a30f..9e7e06fa 100644 --- a/docs/linux.md +++ b/docs/linux.md @@ -105,7 +105,7 @@ sudo chmod +x /usr/bin/ollama To view logs of Ollama running as a startup service, run: ```bash -journalctl -u ollama +journalctl -e -u ollama ``` ## Uninstall diff --git a/docs/modelfile.md b/docs/modelfile.md index 24002bde..21ee1826 100644 --- a/docs/modelfile.md +++ b/docs/modelfile.md @@ -10,7 +10,7 @@ A model file is the blueprint to create and share models with Ollama. - [Examples](#examples) - [Instructions](#instructions) - [FROM (Required)](#from-required) - - [Build from llama2](#build-from-llama2) + - [Build from llama3](#build-from-llama3) - [Build from a bin file](#build-from-a-bin-file) - [PARAMETER](#parameter) - [Valid Parameters and Values](#valid-parameters-and-values) @@ -48,7 +48,7 @@ INSTRUCTION arguments An example of a `Modelfile` creating a mario blueprint: ```modelfile -FROM llama2 +FROM llama3 # sets the temperature to 1 [higher is more creative, lower is more coherent] PARAMETER temperature 1 # sets the context window size to 4096, this controls how many tokens the LLM can use as context to generate the next token @@ -67,33 +67,25 @@ To use this: More examples are available in the [examples directory](../examples). -### `Modelfile`s in [ollama.com/library][1] - -There are two ways to view `Modelfile`s underlying the models in [ollama.com/library][1]: - -- Option 1: view a details page from a model's tags page: - 1. Go to a particular model's tags (e.g. https://ollama.com/library/llama2/tags) - 2. Click on a tag (e.g. https://ollama.com/library/llama2:13b) - 3. Scroll down to "Layers" - - Note: if the [`FROM` instruction](#from-required) is not present, - it means the model was created from a local file -- Option 2: use `ollama show` to print the `Modelfile` for any local models like so: +To view the Modelfile of a given model, use the `ollama show --modelfile` command. ```bash - > ollama show --modelfile llama2:13b + > ollama show --modelfile llama3 # Modelfile generated by "ollama show" # To build a new Modelfile based on this one, replace the FROM line with: - # FROM llama2:13b + # FROM llama3:latest + FROM /Users/pdevine/.ollama/models/blobs/sha256-00e1317cbf74d901080d7100f57580ba8dd8de57203072dc6f668324ba545f29 + TEMPLATE """{{ if .System }}<|start_header_id|>system<|end_header_id|> - FROM /root/.ollama/models/blobs/sha256:123abc - TEMPLATE """[INST] {{ if .System }}<>{{ .System }}<> + {{ .System }}<|eot_id|>{{ end }}{{ if .Prompt }}<|start_header_id|>user<|end_header_id|> - {{ end }}{{ .Prompt }} [/INST] """ - SYSTEM """""" - PARAMETER stop [INST] - PARAMETER stop [/INST] - PARAMETER stop <> - PARAMETER stop <> + {{ .Prompt }}<|eot_id|>{{ end }}<|start_header_id|>assistant<|end_header_id|> + + {{ .Response }}<|eot_id|>""" + PARAMETER stop "<|start_header_id|>" + PARAMETER stop "<|end_header_id|>" + PARAMETER stop "<|eot_id|>" + PARAMETER stop "<|reserved_special_token" ``` ## Instructions @@ -106,10 +98,10 @@ The `FROM` instruction defines the base model to use when creating a model. FROM : ``` -#### Build from llama2 +#### Build from llama3 ```modelfile -FROM llama2 +FROM llama3 ``` A list of available base models: diff --git a/docs/openai.md b/docs/openai.md index b4dc1f21..557b5846 100644 --- a/docs/openai.md +++ b/docs/openai.md @@ -25,7 +25,7 @@ chat_completion = client.chat.completions.create( 'content': 'Say this is a test', } ], - model='llama2', + model='llama3', ) ``` @@ -43,7 +43,7 @@ const openai = new OpenAI({ const chatCompletion = await openai.chat.completions.create({ messages: [{ role: 'user', content: 'Say this is a test' }], - model: 'llama2', + model: 'llama3', }) ``` @@ -53,7 +53,7 @@ const chatCompletion = await openai.chat.completions.create({ curl http://localhost:11434/v1/chat/completions \ -H "Content-Type: application/json" \ -d '{ - "model": "llama2", + "model": "llama3", "messages": [ { "role": "system", @@ -113,7 +113,7 @@ curl http://localhost:11434/v1/chat/completions \ Before using a model, pull it locally `ollama pull`: ```shell -ollama pull llama2 +ollama pull llama3 ``` ### Default model names @@ -121,7 +121,7 @@ ollama pull llama2 For tooling that relies on default OpenAI model names such as `gpt-3.5-turbo`, use `ollama cp` to copy an existing model name to a temporary name: ``` -ollama cp llama2 gpt-3.5-turbo +ollama cp llama3 gpt-3.5-turbo ``` Afterwards, this new model name can be specified the `model` field: diff --git a/docs/tutorials/langchainjs.md b/docs/tutorials/langchainjs.md index 7cd4012f..63b34aa6 100644 --- a/docs/tutorials/langchainjs.md +++ b/docs/tutorials/langchainjs.md @@ -15,7 +15,7 @@ import { Ollama } from "langchain/llms/ollama"; const ollama = new Ollama({ baseUrl: "http://localhost:11434", - model: "llama2", + model: "llama3", }); const answer = await ollama.invoke(`why is the sky blue?`); @@ -23,10 +23,10 @@ const answer = await ollama.invoke(`why is the sky blue?`); console.log(answer); ``` -That will get us the same thing as if we ran `ollama run llama2 "why is the sky blue"` in the terminal. But we want to load a document from the web to ask a question against. **Cheerio** is a great library for ingesting a webpage, and **LangChain** uses it in their **CheerioWebBaseLoader**. So let's install **Cheerio** and build that part of the app. +That will get us the same thing as if we ran `ollama run llama3 "why is the sky blue"` in the terminal. But we want to load a document from the web to ask a question against. **Cheerio** is a great library for ingesting a webpage, and **LangChain** uses it in their **CheerioWebBaseLoader**. So let's install **Cheerio** and build that part of the app. ```bash -npm install cheerio +npm install cheerio ``` ```javascript diff --git a/docs/tutorials/langchainpy.md b/docs/tutorials/langchainpy.md index f6ee4fa3..9a1bca0d 100644 --- a/docs/tutorials/langchainpy.md +++ b/docs/tutorials/langchainpy.md @@ -12,15 +12,17 @@ So let's figure out how we can use **LangChain** with Ollama to ask our question Let's start by asking a simple question that we can get an answer to from the **Llama2** model using **Ollama**. First, we need to install the **LangChain** package: -`pip install langchain` +`pip install langchain_community` Then we can create a model and ask the question: ```python -from langchain.llms import Ollama -ollama = Ollama(base_url='http://localhost:11434', -model="llama2") -print(ollama("why is the sky blue")) +from langchain_community.llms import Ollama +ollama = Ollama( + base_url='http://localhost:11434', + model="llama3" +) +print(ollama.invoke("why is the sky blue")) ``` Notice that we are defining the model and the base URL for Ollama. diff --git a/docs/tutorials/nvidia-jetson.md b/docs/tutorials/nvidia-jetson.md index 2d3adb98..bb77c486 100644 --- a/docs/tutorials/nvidia-jetson.md +++ b/docs/tutorials/nvidia-jetson.md @@ -1,38 +1,15 @@ # Running Ollama on NVIDIA Jetson Devices -With some minor configuration, Ollama runs well on [NVIDIA Jetson Devices](https://www.nvidia.com/en-us/autonomous-machines/embedded-systems/). The following has been tested on [JetPack 5.1.2](https://developer.nvidia.com/embedded/jetpack). +Ollama runs well on [NVIDIA Jetson Devices](https://www.nvidia.com/en-us/autonomous-machines/embedded-systems/) and should run out of the box with the standard installation instructions. -NVIDIA Jetson devices are Linux-based embedded AI computers that are purpose-built for AI applications. - -Jetsons have an integrated GPU that is wired directly to the memory controller of the machine. For this reason, the `nvidia-smi` command is unrecognized, and Ollama proceeds to operate in "CPU only" -mode. This can be verified by using a monitoring tool like jtop. - -In order to address this, we simply pass the path to the Jetson's pre-installed CUDA libraries into `ollama serve` (while in a tmux session). We then hardcode the num_gpu parameters into a cloned -version of our target model. - -Prerequisites: - -- curl -- tmux - -Here are the steps: +The following has been tested on [JetPack 5.1.2](https://developer.nvidia.com/embedded/jetpack), but should also work on JetPack 6.0. - Install Ollama via standard Linux command (ignore the 404 error): `curl https://ollama.com/install.sh | sh` -- Stop the Ollama service: `sudo systemctl stop ollama` -- Start Ollama serve in a tmux session called ollama_jetson and reference the CUDA libraries path: `tmux has-session -t ollama_jetson 2>/dev/null || tmux new-session -d -s ollama_jetson -'LD_LIBRARY_PATH=/usr/local/cuda/lib64 ollama serve'` - Pull the model you want to use (e.g. mistral): `ollama pull mistral` -- Create a new Modelfile specifically for enabling GPU support on the Jetson: `touch ModelfileMistralJetson` -- In the ModelfileMistralJetson file, specify the FROM model and the num_gpu PARAMETER as shown below: - -``` -FROM mistral -PARAMETER num_gpu 999 -``` - -- Create a new model from your Modelfile: `ollama create mistral-jetson -f ./ModelfileMistralJetson` -- Run the new model: `ollama run mistral-jetson` - -If you run a monitoring tool like jtop you should now see that Ollama is using the Jetson's integrated GPU. +- Start an interactive session: `ollama run mistral` And that's it! + +# Running Ollama in Docker + +When running GPU accelerated applications in Docker, it is highly recommended to use [dusty-nv jetson-containers repo](https://github.com/dusty-nv/jetson-containers). \ No newline at end of file diff --git a/docs/windows.md b/docs/windows.md index 49d579c9..242b810a 100644 --- a/docs/windows.md +++ b/docs/windows.md @@ -1,47 +1,61 @@ -# Ollama Windows Preview - -Welcome to the Ollama Windows preview. - -No more WSL required! - -Ollama now runs as a native Windows application, including NVIDIA and AMD Radeon GPU support. -After installing Ollama Windows Preview, Ollama will run in the background and -the `ollama` command line is available in `cmd`, `powershell` or your favorite -terminal application. As usual the Ollama [api](./api.md) will be served on -`http://localhost:11434`. - -As this is a preview release, you should expect a few bugs here and there. If -you run into a problem you can reach out on -[Discord](https://discord.gg/ollama), or file an -[issue](https://github.com/ollama/ollama/issues). -Logs will often be helpful in dianosing the problem (see -[Troubleshooting](#troubleshooting) below) - -## System Requirements - -* Windows 10 or newer, Home or Pro -* NVIDIA 452.39 or newer Drivers if you have an NVIDIA card -* AMD Radeon Driver https://www.amd.com/en/support if you have a Radeon card - -## API Access - -Here's a quick example showing API access from `powershell` -```powershell -(Invoke-WebRequest -method POST -Body '{"model":"llama2", "prompt":"Why is the sky blue?", "stream": false}' -uri http://localhost:11434/api/generate ).Content | ConvertFrom-json -``` - -## Troubleshooting - -While we're in preview, `OLLAMA_DEBUG` is always enabled, which adds -a "view logs" menu item to the app, and increses logging for the GUI app and -server. - -Ollama on Windows stores files in a few different locations. You can view them in -the explorer window by hitting `+R` and type in: -- `explorer %LOCALAPPDATA%\Ollama` contains logs, and downloaded updates - - *app.log* contains logs from the GUI application - - *server.log* contains the server logs - - *upgrade.log* contains log output for upgrades -- `explorer %LOCALAPPDATA%\Programs\Ollama` contains the binaries (The installer adds this to your user PATH) -- `explorer %HOMEPATH%\.ollama` contains models and configuration -- `explorer %TEMP%` contains temporary executable files in one or more `ollama*` directories +# Ollama Windows Preview + +Welcome to the Ollama Windows preview. + +No more WSL required! + +Ollama now runs as a native Windows application, including NVIDIA and AMD Radeon GPU support. +After installing Ollama Windows Preview, Ollama will run in the background and +the `ollama` command line is available in `cmd`, `powershell` or your favorite +terminal application. As usual the Ollama [api](./api.md) will be served on +`http://localhost:11434`. + +As this is a preview release, you should expect a few bugs here and there. If +you run into a problem you can reach out on +[Discord](https://discord.gg/ollama), or file an +[issue](https://github.com/ollama/ollama/issues). +Logs will often be helpful in diagnosing the problem (see +[Troubleshooting](#troubleshooting) below) + +## System Requirements + +* Windows 10 or newer, Home or Pro +* NVIDIA 452.39 or newer Drivers if you have an NVIDIA card +* AMD Radeon Driver https://www.amd.com/en/support if you have a Radeon card + +## API Access + +Here's a quick example showing API access from `powershell` +```powershell +(Invoke-WebRequest -method POST -Body '{"model":"llama3", "prompt":"Why is the sky blue?", "stream": false}' -uri http://localhost:11434/api/generate ).Content | ConvertFrom-json +``` + +## Troubleshooting + +While we're in preview, `OLLAMA_DEBUG` is always enabled, which adds +a "view logs" menu item to the app, and increses logging for the GUI app and +server. + +Ollama on Windows stores files in a few different locations. You can view them in +the explorer window by hitting `+R` and type in: +- `explorer %LOCALAPPDATA%\Ollama` contains logs, and downloaded updates + - *app.log* contains logs from the GUI application + - *server.log* contains the server logs + - *upgrade.log* contains log output for upgrades +- `explorer %LOCALAPPDATA%\Programs\Ollama` contains the binaries (The installer adds this to your user PATH) +- `explorer %HOMEPATH%\.ollama` contains models and configuration +- `explorer %TEMP%` contains temporary executable files in one or more `ollama*` directories + + +## Standalone CLI + +The easiest way to install Ollama on Windows is to use the `OllamaSetup.exe` +installer. It installs in your account without requiring Administrator rights. +We update Ollama regularly to support the latest models, and this installer will +help you keep up to date. + +If you'd like to install or integrate Ollama as a service, a standalone +`ollama-windows-amd64.zip` zip file is available containing only the Ollama CLI +and GPU library dependencies for Nvidia and AMD. This allows for embedding +Ollama in existing applications, or running it as a system service via `ollama +serve` with tools such as [NSSM](https://nssm.cc/). diff --git a/examples/bash-comparemodels/README.md b/examples/bash-comparemodels/README.md index 91499255..65e66f1e 100644 --- a/examples/bash-comparemodels/README.md +++ b/examples/bash-comparemodels/README.md @@ -2,7 +2,7 @@ When calling `ollama`, you can pass it a file to run all the prompts in the file, one after the other: -`ollama run llama2 < sourcequestions.txt` +`ollama run llama3 < sourcequestions.txt` This concept is used in the following example. diff --git a/examples/flyio/.gitignore b/examples/flyio/.gitignore new file mode 100644 index 00000000..0501d092 --- /dev/null +++ b/examples/flyio/.gitignore @@ -0,0 +1 @@ +fly.toml diff --git a/examples/flyio/README.md b/examples/flyio/README.md new file mode 100644 index 00000000..09b90aad --- /dev/null +++ b/examples/flyio/README.md @@ -0,0 +1,67 @@ +# Deploy Ollama to Fly.io + +> Note: this example exposes a public endpoint and does not configure authentication. Use with care. + +## Prerequisites + +- Ollama: https://ollama.com/download +- Fly.io account. Sign up for a free account: https://fly.io/app/sign-up + +## Steps + +1. Login to Fly.io + + ```bash + fly auth login + ``` + +1. Create a new Fly app + + ```bash + fly launch --name --image ollama/ollama --internal-port 11434 --vm-size shared-cpu-8x --now + ``` + +1. Pull and run `orca-mini:3b` + + ```bash + OLLAMA_HOST=https://.fly.dev ollama run orca-mini:3b + ``` + +`shared-cpu-8x` is a free-tier eligible machine type. For better performance, switch to a `performance` or `dedicated` machine type or attach a GPU for hardware acceleration (see below). + +## (Optional) Persistent Volume + +By default Fly Machines use ephemeral storage which is problematic if you want to use the same model across restarts without pulling it again. Create and attach a persistent volume to store the downloaded models: + +1. Create the Fly Volume + + ```bash + fly volume create ollama + ``` + +1. Update `fly.toml` and add `[mounts]` + + ```toml + [mounts] + source = "ollama" + destination = "/mnt/ollama/models" + ``` + +1. Update `fly.toml` and add `[env]` + + ```toml + [env] + OLLAMA_MODELS = "/mnt/ollama/models" + ``` + +1. Deploy your app + + ```bash + fly deploy + ``` + +## (Optional) Hardware Acceleration + +Fly.io GPU is currently in waitlist. Sign up for the waitlist: https://fly.io/gpu + +Once you've been accepted, create the app with the additional flags `--vm-gpu-kind a100-pcie-40gb` or `--vm-gpu-kind a100-pcie-80gb`. diff --git a/examples/go-chat/main.go b/examples/go-chat/main.go index 83aaad3d..5266f03e 100644 --- a/examples/go-chat/main.go +++ b/examples/go-chat/main.go @@ -35,7 +35,7 @@ func main() { ctx := context.Background() req := &api.ChatRequest{ - Model: "llama2", + Model: "llama3", Messages: messages, } diff --git a/examples/go-http-generate/main.go b/examples/go-http-generate/main.go index f4ca32f4..e5b64348 100644 --- a/examples/go-http-generate/main.go +++ b/examples/go-http-generate/main.go @@ -19,7 +19,7 @@ func main() { } defer resp.Body.Close() - + responseData, err := io.ReadAll(resp.Body) if err != nil { log.Fatal(err) diff --git a/examples/kubernetes/README.md b/examples/kubernetes/README.md index c522ba76..2e2444c7 100644 --- a/examples/kubernetes/README.md +++ b/examples/kubernetes/README.md @@ -7,12 +7,24 @@ ## Steps -1. Create the Ollama namespace, daemon set, and service +1. Create the Ollama namespace, deployment, and service ```bash kubectl apply -f cpu.yaml ``` +## (Optional) Hardware Acceleration + +Hardware acceleration in Kubernetes requires NVIDIA's [`k8s-device-plugin`](https://github.com/NVIDIA/k8s-device-plugin) which is deployed in Kubernetes in form of daemonset. Follow the link for more details. + +Once configured, create a GPU enabled Ollama deployment. + +```bash +kubectl apply -f gpu.yaml +``` + +## Test + 1. Port forward the Ollama service to connect and use it locally ```bash @@ -23,14 +35,4 @@ ```bash ollama run orca-mini:3b - ``` - -## (Optional) Hardware Acceleration - -Hardware acceleration in Kubernetes requires NVIDIA's [`k8s-device-plugin`](https://github.com/NVIDIA/k8s-device-plugin). Follow the link for more details. - -Once configured, create a GPU enabled Ollama deployment. - -```bash -kubectl apply -f gpu.yaml -``` + ``` \ No newline at end of file diff --git a/examples/langchain-python-rag-document/main.py b/examples/langchain-python-rag-document/main.py index b9f98c4e..3ed9499f 100644 --- a/examples/langchain-python-rag-document/main.py +++ b/examples/langchain-python-rag-document/main.py @@ -40,9 +40,9 @@ while True: continue # Prompt - template = """Use the following pieces of context to answer the question at the end. - If you don't know the answer, just say that you don't know, don't try to make up an answer. - Use three sentences maximum and keep the answer as concise as possible. + template = """Use the following pieces of context to answer the question at the end. + If you don't know the answer, just say that you don't know, don't try to make up an answer. + Use three sentences maximum and keep the answer as concise as possible. {context} Question: {question} Helpful Answer:""" @@ -51,11 +51,11 @@ while True: template=template, ) - llm = Ollama(model="llama2:13b", callback_manager=CallbackManager([StreamingStdOutCallbackHandler()])) + llm = Ollama(model="llama3:8b", callback_manager=CallbackManager([StreamingStdOutCallbackHandler()])) qa_chain = RetrievalQA.from_chain_type( llm, retriever=vectorstore.as_retriever(), chain_type_kwargs={"prompt": QA_CHAIN_PROMPT}, ) - result = qa_chain({"query": query}) \ No newline at end of file + result = qa_chain({"query": query}) diff --git a/examples/langchain-python-rag-websummary/main.py b/examples/langchain-python-rag-websummary/main.py index cd2ef47f..d1b05ba8 100644 --- a/examples/langchain-python-rag-websummary/main.py +++ b/examples/langchain-python-rag-websummary/main.py @@ -1,12 +1,12 @@ -from langchain.llms import Ollama -from langchain.document_loaders import WebBaseLoader +from langchain_community.llms import Ollama +from langchain_community.document_loaders import WebBaseLoader from langchain.chains.summarize import load_summarize_chain loader = WebBaseLoader("https://ollama.com/blog/run-llama2-uncensored-locally") docs = loader.load() -llm = Ollama(model="llama2") +llm = Ollama(model="llama3") chain = load_summarize_chain(llm, chain_type="stuff") -result = chain.run(docs) +result = chain.invoke(docs) print(result) diff --git a/examples/langchain-python-simple/README.md b/examples/langchain-python-simple/README.md index 3f401ca8..d4102dec 100644 --- a/examples/langchain-python-simple/README.md +++ b/examples/langchain-python-simple/README.md @@ -4,10 +4,10 @@ This example is a basic "hello world" of using LangChain with Ollama. ## Running the Example -1. Ensure you have the `llama2` model installed: +1. Ensure you have the `llama3` model installed: ```bash - ollama pull llama2 + ollama pull llama3 ``` 2. Install the Python Requirements. @@ -21,4 +21,3 @@ This example is a basic "hello world" of using LangChain with Ollama. ```bash python main.py ``` - \ No newline at end of file diff --git a/examples/langchain-python-simple/main.py b/examples/langchain-python-simple/main.py index da696e00..7cb65286 100644 --- a/examples/langchain-python-simple/main.py +++ b/examples/langchain-python-simple/main.py @@ -1,6 +1,6 @@ from langchain.llms import Ollama input = input("What is your question?") -llm = Ollama(model="llama2") +llm = Ollama(model="llama3") res = llm.predict(input) print (res) diff --git a/examples/modelfile-mario/Modelfile b/examples/modelfile-mario/Modelfile index 35c787fc..33d5952b 100644 --- a/examples/modelfile-mario/Modelfile +++ b/examples/modelfile-mario/Modelfile @@ -1,4 +1,4 @@ -FROM llama2 +FROM llama3 PARAMETER temperature 1 SYSTEM """ You are Mario from super mario bros, acting as an assistant. diff --git a/examples/modelfile-mario/readme.md b/examples/modelfile-mario/readme.md index 0d72dddc..e4f0d417 100644 --- a/examples/modelfile-mario/readme.md +++ b/examples/modelfile-mario/readme.md @@ -2,12 +2,12 @@ # Example character: Mario -This example shows how to create a basic character using Llama2 as the base model. +This example shows how to create a basic character using Llama3 as the base model. To run this example: 1. Download the Modelfile -2. `ollama pull llama2` to get the base model used in the model file. +2. `ollama pull llama3` to get the base model used in the model file. 3. `ollama create NAME -f ./Modelfile` 4. `ollama run NAME` @@ -18,7 +18,7 @@ Ask it some questions like "Who are you?" or "Is Peach in trouble again?" What the model file looks like: ``` -FROM llama2 +FROM llama3 PARAMETER temperature 1 SYSTEM """ You are Mario from Super Mario Bros, acting as an assistant. diff --git a/examples/python-json-datagenerator/predefinedschema.py b/examples/python-json-datagenerator/predefinedschema.py index abc399c4..1fd54892 100644 --- a/examples/python-json-datagenerator/predefinedschema.py +++ b/examples/python-json-datagenerator/predefinedschema.py @@ -2,16 +2,16 @@ import requests import json import random -model = "llama2" +model = "llama3" template = { - "firstName": "", - "lastName": "", + "firstName": "", + "lastName": "", "address": { - "street": "", - "city": "", - "state": "", + "street": "", + "city": "", + "state": "", "zipCode": "" - }, + }, "phoneNumber": "" } diff --git a/examples/python-json-datagenerator/randomaddresses.py b/examples/python-json-datagenerator/randomaddresses.py index 5f27448f..72b1fefb 100644 --- a/examples/python-json-datagenerator/randomaddresses.py +++ b/examples/python-json-datagenerator/randomaddresses.py @@ -12,7 +12,7 @@ countries = [ "France", ] country = random.choice(countries) -model = "llama2" +model = "llama3" prompt = f"generate one realistically believable sample data set of a persons first name, last name, address in {country}, and phone number. Do not use common names. Respond using JSON. Key names should have no backslashes, values should use plain ascii with no special characters." diff --git a/examples/python-json-datagenerator/readme.md b/examples/python-json-datagenerator/readme.md index 369fb2a5..88357044 100644 --- a/examples/python-json-datagenerator/readme.md +++ b/examples/python-json-datagenerator/readme.md @@ -6,10 +6,10 @@ There are two python scripts in this example. `randomaddresses.py` generates ran ## Running the Example -1. Ensure you have the `llama2` model installed: +1. Ensure you have the `llama3` model installed: ```bash - ollama pull llama2 + ollama pull llama3 ``` 2. Install the Python Requirements. diff --git a/examples/python-simplechat/client.py b/examples/python-simplechat/client.py index 768a2289..9ae99fb7 100644 --- a/examples/python-simplechat/client.py +++ b/examples/python-simplechat/client.py @@ -2,7 +2,7 @@ import json import requests # NOTE: ollama must be running for this to work, start the ollama app or run `ollama serve` -model = "llama2" # TODO: update this for whatever model you wish to use +model = "llama3" # TODO: update this for whatever model you wish to use def chat(messages): diff --git a/examples/python-simplechat/readme.md b/examples/python-simplechat/readme.md index 204a8159..dd2576bc 100644 --- a/examples/python-simplechat/readme.md +++ b/examples/python-simplechat/readme.md @@ -4,10 +4,10 @@ The **chat** endpoint is one of two ways to generate text from an LLM with Ollam ## Running the Example -1. Ensure you have the `llama2` model installed: +1. Ensure you have the `llama3` model installed: ```bash - ollama pull llama2 + ollama pull llama3 ``` 2. Install the Python Requirements. diff --git a/examples/typescript-mentors/README.md b/examples/typescript-mentors/README.md index c3ce9c82..d3611a5e 100644 --- a/examples/typescript-mentors/README.md +++ b/examples/typescript-mentors/README.md @@ -4,10 +4,10 @@ This example demonstrates how one would create a set of 'mentors' you can have a ## Usage -1. Add llama2 to have the mentors ask your questions: +1. Add llama3 to have the mentors ask your questions: ```bash - ollama pull llama2 + ollama pull llama3 ``` 2. Install prerequisites: diff --git a/examples/typescript-mentors/character-generator.ts b/examples/typescript-mentors/character-generator.ts index 886eec67..dc5d2f5e 100644 --- a/examples/typescript-mentors/character-generator.ts +++ b/examples/typescript-mentors/character-generator.ts @@ -15,7 +15,7 @@ async function characterGenerator() { ollama.setModel("stablebeluga2:70b-q4_K_M"); const bio = await ollama.generate(`create a bio of ${character} in a single long paragraph. Instead of saying '${character} is...' or '${character} was...' use language like 'You are...' or 'You were...'. Then create a paragraph describing the speaking mannerisms and style of ${character}. Don't include anything about how ${character} looked or what they sounded like, just focus on the words they said. Instead of saying '${character} would say...' use language like 'You should say...'. If you use quotes, always use single quotes instead of double quotes. If there are any specific words or phrases you used a lot, show how you used them. `); - const thecontents = `FROM llama2\nSYSTEM """\n${bio.response.replace(/(\r\n|\n|\r)/gm, " ").replace('would', 'should')} All answers to questions should be related back to what you are most known for.\n"""`; + const thecontents = `FROM llama3\nSYSTEM """\n${bio.response.replace(/(\r\n|\n|\r)/gm, " ").replace('would', 'should')} All answers to questions should be related back to what you are most known for.\n"""`; fs.writeFile(path.join(directory, 'Modelfile'), thecontents, (err: any) => { if (err) throw err; @@ -23,4 +23,4 @@ async function characterGenerator() { }); } -characterGenerator(); \ No newline at end of file +characterGenerator(); diff --git a/examples/typescript-simplechat/client.ts b/examples/typescript-simplechat/client.ts index 3e571ab6..a1e0eea3 100644 --- a/examples/typescript-simplechat/client.ts +++ b/examples/typescript-simplechat/client.ts @@ -1,6 +1,6 @@ import * as readline from "readline"; -const model = "llama2"; +const model = "llama3"; type Message = { role: "assistant" | "user" | "system"; content: string; @@ -74,4 +74,4 @@ async function main() { } -main(); \ No newline at end of file +main(); diff --git a/format/bytes.go b/format/bytes.go index f4bcc8c5..13d8575e 100644 --- a/format/bytes.go +++ b/format/bytes.go @@ -15,6 +15,7 @@ const ( KibiByte = Byte * 1024 MebiByte = KibiByte * 1024 + GibiByte = MebiByte * 1024 ) func HumanBytes(b int64) string { @@ -52,6 +53,8 @@ func HumanBytes(b int64) string { func HumanBytes2(b uint64) string { switch { + case b >= GibiByte: + return fmt.Sprintf("%.1f GiB", float64(b)/GibiByte) case b >= MebiByte: return fmt.Sprintf("%.1f MiB", float64(b)/MebiByte) case b >= KibiByte: diff --git a/format/format.go b/format/format.go index 8fd2defa..31059578 100644 --- a/format/format.go +++ b/format/format.go @@ -13,12 +13,20 @@ const ( func HumanNumber(b uint64) string { switch { - case b > Billion: - return fmt.Sprintf("%.0fB", math.Round(float64(b)/Billion)) - case b > Million: - return fmt.Sprintf("%.0fM", math.Round(float64(b)/Million)) - case b > Thousand: - return fmt.Sprintf("%.0fK", math.Round(float64(b)/Thousand)) + case b >= Billion: + number := float64(b) / Billion + if number == math.Floor(number) { + return fmt.Sprintf("%.0fB", number) // no decimals if whole number + } + return fmt.Sprintf("%.1fB", number) // one decimal if not a whole number + case b >= Million: + number := float64(b) / Million + if number == math.Floor(number) { + return fmt.Sprintf("%.0fM", number) // no decimals if whole number + } + return fmt.Sprintf("%.2fM", number) // two decimals if not a whole number + case b >= Thousand: + return fmt.Sprintf("%.0fK", float64(b)/Thousand) default: return fmt.Sprintf("%d", b) } diff --git a/format/format_test.go b/format/format_test.go new file mode 100644 index 00000000..1d73c80b --- /dev/null +++ b/format/format_test.go @@ -0,0 +1,34 @@ +package format + +import ( + "testing" +) + +func TestHumanNumber(t *testing.T) { + + type testCase struct { + input uint64 + expected string + } + + testCases := []testCase{ + {0, "0"}, + {1000000, "1M"}, + {125000000, "125M"}, + {500500000, "500.50M"}, + {500550000, "500.55M"}, + {1000000000, "1B"}, + {2800000000, "2.8B"}, + {2850000000, "2.9B"}, + {1000000000000, "1000B"}, + } + + for _, tc := range testCases { + t.Run(tc.expected, func(t *testing.T) { + result := HumanNumber(tc.input) + if result != tc.expected { + t.Errorf("Expected %s, got %s", tc.expected, result) + } + }) + } +} diff --git a/gpu/amd_common.go b/gpu/amd_common.go index cf3348a8..27a81e3f 100644 --- a/gpu/amd_common.go +++ b/gpu/amd_common.go @@ -7,7 +7,7 @@ import ( "log/slog" "os" "path/filepath" - "strconv" + "runtime" "strings" ) @@ -35,22 +35,66 @@ func GetSupportedGFX(libDir string) ([]string, error) { return ret, nil } -func amdSetVisibleDevices(ids []int, skip map[int]interface{}) { - // Set the visible devices if not already set - // TODO - does sort order matter? - devices := []string{} - for i := range ids { - if _, skipped := skip[i]; skipped { +func rocmGetVisibleDevicesEnv(gpuInfo []GpuInfo) (string, string) { + ids := []string{} + for _, info := range gpuInfo { + if info.Library != "rocm" { + // TODO shouldn't happen if things are wired correctly... + slog.Debug("rocmGetVisibleDevicesEnv skipping over non-rocm device", "library", info.Library) continue } - devices = append(devices, strconv.Itoa(i)) + ids = append(ids, info.ID) + } + return "HIP_VISIBLE_DEVICES", strings.Join(ids, ",") +} + +func commonAMDValidateLibDir() (string, error) { + // We try to favor system paths first, so that we can wire up the subprocess to use + // the system version. Only use our bundled version if the system version doesn't work + // This gives users a more recovery options if versions have subtle problems at runtime + + // Prefer explicit HIP env var + hipPath := os.Getenv("HIP_PATH") + if hipPath != "" { + hipLibDir := filepath.Join(hipPath, "bin") + if rocmLibUsable(hipLibDir) { + slog.Debug("detected ROCM via HIP_PATH=" + hipPath) + return hipLibDir, nil + } } - val := strings.Join(devices, ",") - err := os.Setenv("HIP_VISIBLE_DEVICES", val) - if err != nil { - slog.Warn(fmt.Sprintf("failed to set env: %s", err)) - } else { - slog.Info("Setting HIP_VISIBLE_DEVICES=" + val) + // Scan the LD_LIBRARY_PATH or PATH + pathEnv := "LD_LIBRARY_PATH" + if runtime.GOOS == "windows" { + pathEnv = "PATH" } + + paths := os.Getenv(pathEnv) + for _, path := range filepath.SplitList(paths) { + d, err := filepath.Abs(path) + if err != nil { + continue + } + if rocmLibUsable(d) { + return d, nil + } + } + + // Well known location(s) + for _, path := range RocmStandardLocations { + if rocmLibUsable(path) { + return path, nil + } + } + + // Installer payload location if we're running the installed binary + exe, err := os.Executable() + if err == nil { + rocmTargetDir := filepath.Join(filepath.Dir(exe), "rocm") + if rocmLibUsable(rocmTargetDir) { + slog.Debug("detected ROCM next to ollama executable " + rocmTargetDir) + return rocmTargetDir, nil + } + } + return "", fmt.Errorf("no suitable rocm found, falling back to CPU") } diff --git a/gpu/amd_hip_windows.go b/gpu/amd_hip_windows.go index 14a6c7d6..4e216132 100644 --- a/gpu/amd_hip_windows.go +++ b/gpu/amd_hip_windows.go @@ -69,7 +69,7 @@ func NewHipLib() (*HipLib, error) { func (hl *HipLib) Release() { err := windows.FreeLibrary(hl.dll) if err != nil { - slog.Warn(fmt.Sprintf("failed to unload amdhip64.dll: %s", err)) + slog.Warn("failed to unload amdhip64.dll", "error", err) } hl.dll = 0 } @@ -98,7 +98,7 @@ func (hl *HipLib) HipGetDeviceCount() int { return 0 } if status != hipSuccess { - slog.Warn(fmt.Sprintf("failed call to hipGetDeviceCount: %d %s", status, err)) + slog.Warn("failed call to hipGetDeviceCount", "status", status, "error", err) } return count } diff --git a/gpu/amd_linux.go b/gpu/amd_linux.go index 529fb8db..9f9f8e74 100644 --- a/gpu/amd_linux.go +++ b/gpu/amd_linux.go @@ -11,6 +11,8 @@ import ( "slices" "strconv" "strings" + + "github.com/ollama/ollama/format" ) // Discovery logic for AMD/ROCm GPUs @@ -23,26 +25,20 @@ const ( // Prefix with the node dir GPUTotalMemoryFileGlob = "mem_banks/*/properties" // size_in_bytes line GPUUsedMemoryFileGlob = "mem_banks/*/used_memory" - RocmStandardLocation = "/opt/rocm/lib" - - // TODO find a better way to detect iGPU instead of minimum memory - IGPUMemLimit = 1024 * 1024 * 1024 // 512G is what they typically report, so anything less than 1G must be iGPU ) var ( // Used to validate if the given ROCm lib is usable - ROCmLibGlobs = []string{"libhipblas.so.2*", "rocblas"} // TODO - probably include more coverage of files here... + ROCmLibGlobs = []string{"libhipblas.so.2*", "rocblas"} // TODO - probably include more coverage of files here... + RocmStandardLocations = []string{"/opt/rocm/lib", "/usr/lib64"} ) // Gather GPU information from the amdgpu driver if any supported GPUs are detected -// HIP_VISIBLE_DEVICES will be set if we detect a mix of unsupported and supported devices -// and the user hasn't already set this variable -func AMDGetGPUInfo(resp *GpuInfo) { - // TODO - DRY this out with windows +func AMDGetGPUInfo() []GpuInfo { + resp := []GpuInfo{} if !AMDDetected() { - return + return resp } - skip := map[int]interface{}{} // Opportunistic logging of driver version to aid in troubleshooting ver, err := AMDDriverVersion() @@ -50,160 +46,117 @@ func AMDGetGPUInfo(resp *GpuInfo) { slog.Info("AMD Driver: " + ver) } else { // TODO - if we see users crash and burn with the upstreamed kernel this can be adjusted to hard-fail rocm support and fallback to CPU - slog.Warn(fmt.Sprintf("ollama recommends running the https://www.amd.com/en/support/linux-drivers: %s", err)) + slog.Warn("ollama recommends running the https://www.amd.com/en/support/linux-drivers", "error", err) } - // If the user has specified exactly which GPUs to use, look up their memory - visibleDevices := os.Getenv("HIP_VISIBLE_DEVICES") - if visibleDevices != "" { - ids := []int{} - for _, idStr := range strings.Split(visibleDevices, ",") { - id, err := strconv.Atoi(idStr) - if err != nil { - slog.Warn(fmt.Sprintf("malformed HIP_VISIBLE_DEVICES=%s %s", visibleDevices, err)) - } else { - ids = append(ids, id) - } - } - amdProcMemLookup(resp, nil, ids) - return + // Determine if the user has already pre-selected which GPUs to look at, then ignore the others + var visibleDevices []string + hipVD := os.Getenv("HIP_VISIBLE_DEVICES") // zero based index only + rocrVD := os.Getenv("ROCR_VISIBLE_DEVICES") // zero based index or UUID, but consumer cards seem to not support UUID + gpuDO := os.Getenv("GPU_DEVICE_ORDINAL") // zero based index + switch { + // TODO is this priorty order right? + case hipVD != "": + visibleDevices = strings.Split(hipVD, ",") + case rocrVD != "": + visibleDevices = strings.Split(rocrVD, ",") + // TODO - since we don't yet support UUIDs, consider detecting and reporting here + // all our test systems show GPU-XX indicating UUID is not supported + case gpuDO != "": + visibleDevices = strings.Split(gpuDO, ",") } - // Gather GFX version information from all detected cards - gfx := AMDGFXVersions() - verStrings := []string{} - for i, v := range gfx { - verStrings = append(verStrings, v.ToGFXString()) - if v.Major == 0 { - // Silently skip CPUs - skip[i] = struct{}{} - continue - } - if v.Major < 9 { - // TODO consider this a build-time setting if we can support 8xx family GPUs - slog.Warn(fmt.Sprintf("amdgpu [%d] too old %s", i, v.ToGFXString())) - skip[i] = struct{}{} - } - } - slog.Info(fmt.Sprintf("detected amdgpu versions %v", verStrings)) - - // Abort if all GPUs are skipped - if len(skip) >= len(gfx) { - slog.Info("all detected amdgpus are skipped, falling back to CPU") - return - } - - // If we got this far, then we have at least 1 GPU that's a ROCm candidate, so make sure we have a lib - libDir, err := AMDValidateLibDir() - if err != nil { - slog.Warn(fmt.Sprintf("unable to verify rocm library, will use cpu: %s", err)) - return - } - - updateLibPath(libDir) - gfxOverride := os.Getenv("HSA_OVERRIDE_GFX_VERSION") - if gfxOverride == "" { - supported, err := GetSupportedGFX(libDir) + var supported []string + libDir := "" + + // The amdgpu driver always exposes the host CPU(s) first, but we have to skip them and subtract + // from the other IDs to get alignment with the HIP libraries expectations (zero is the first GPU, not the CPU) + matches, _ := filepath.Glob(GPUPropertiesFileGlob) + cpuCount := 0 + for _, match := range matches { + slog.Debug("evaluating amdgpu node " + match) + fp, err := os.Open(match) if err != nil { - slog.Warn(fmt.Sprintf("failed to lookup supported GFX types, falling back to CPU mode: %s", err)) - return - } - slog.Debug(fmt.Sprintf("rocm supported GPU types %v", supported)) - - for i, v := range gfx { - if !slices.Contains[[]string, string](supported, v.ToGFXString()) { - slog.Warn(fmt.Sprintf("amdgpu [%d] %s is not supported by %s %v", i, v.ToGFXString(), libDir, supported)) - // TODO - consider discrete markdown just for ROCM troubleshooting? - slog.Warn("See https://github.com/ollama/ollama/blob/main/docs/gpu.md#overrides for HSA_OVERRIDE_GFX_VERSION usage") - skip[i] = struct{}{} - } else { - slog.Info(fmt.Sprintf("amdgpu [%d] %s is supported", i, v.ToGFXString())) - } - } - } else { - slog.Debug("skipping rocm gfx compatibility check with HSA_OVERRIDE_GFX_VERSION=" + gfxOverride) - } - - if len(skip) >= len(gfx) { - slog.Info("all detected amdgpus are skipped, falling back to CPU") - return - } - - ids := make([]int, len(gfx)) - i := 0 - for k := range gfx { - ids[i] = k - i++ - } - amdProcMemLookup(resp, skip, ids) - if resp.memInfo.DeviceCount == 0 { - return - } - if len(skip) > 0 { - amdSetVisibleDevices(ids, skip) - } -} - -func updateLibPath(libDir string) { - ldPaths := []string{} - if val, ok := os.LookupEnv("LD_LIBRARY_PATH"); ok { - ldPaths = strings.Split(val, ":") - } - for _, d := range ldPaths { - if d == libDir { - return - } - } - val := strings.Join(append(ldPaths, libDir), ":") - slog.Debug("updated lib path", "LD_LIBRARY_PATH", val) - os.Setenv("LD_LIBRARY_PATH", val) -} - -// Walk the sysfs nodes for the available GPUs and gather information from them -// skipping over any devices in the skip map -func amdProcMemLookup(resp *GpuInfo, skip map[int]interface{}, ids []int) { - resp.memInfo.DeviceCount = 0 - resp.memInfo.TotalMemory = 0 - resp.memInfo.FreeMemory = 0 - slog.Debug("discovering VRAM for amdgpu devices") - if len(ids) == 0 { - entries, err := os.ReadDir(AMDNodesSysfsDir) - if err != nil { - slog.Warn(fmt.Sprintf("failed to read amdgpu sysfs %s - %s", AMDNodesSysfsDir, err)) - return - } - for _, node := range entries { - if !node.IsDir() { - continue - } - id, err := strconv.Atoi(node.Name()) - if err != nil { - slog.Warn("malformed amdgpu sysfs node id " + node.Name()) - continue - } - ids = append(ids, id) - } - } - slog.Debug(fmt.Sprintf("amdgpu devices %v", ids)) - - for _, id := range ids { - if _, skipped := skip[id]; skipped { + slog.Debug("failed to open sysfs node", "file", match, "error", err) continue } + defer fp.Close() + nodeID, err := strconv.Atoi(filepath.Base(filepath.Dir(match))) + if err != nil { + slog.Debug("failed to parse node ID", "error", err) + continue + } + + scanner := bufio.NewScanner(fp) + isCPU := false + var major, minor, patch uint64 + for scanner.Scan() { + line := strings.TrimSpace(scanner.Text()) + // Note: we could also use "cpu_cores_count X" where X is greater than zero to detect CPUs + if strings.HasPrefix(line, "gfx_target_version") { + ver := strings.Fields(line) + + // Detect CPUs + if len(ver) == 2 && ver[1] == "0" { + slog.Debug("detected CPU " + match) + isCPU = true + break + } + + if len(ver) != 2 || len(ver[1]) < 5 { + slog.Warn("malformed "+match, "gfx_target_version", line) + // If this winds up being a CPU, our offsets may be wrong + continue + } + l := len(ver[1]) + var err1, err2, err3 error + patch, err1 = strconv.ParseUint(ver[1][l-2:l], 10, 32) + minor, err2 = strconv.ParseUint(ver[1][l-4:l-2], 10, 32) + major, err3 = strconv.ParseUint(ver[1][:l-4], 10, 32) + if err1 != nil || err2 != nil || err3 != nil { + slog.Debug("malformed int " + line) + continue + } + } + + // TODO - any other properties we want to extract and record? + // vendor_id + device_id -> pci lookup for "Name" + // Other metrics that may help us understand relative performance between multiple GPUs + } + + if isCPU { + cpuCount++ + continue + } + + // CPUs are always first in the list + gpuID := nodeID - cpuCount + + // Shouldn't happen, but just in case... + if gpuID < 0 { + slog.Error("unexpected amdgpu sysfs data resulted in negative GPU ID, please set OLLAMA_DEBUG=1 and report an issue") + return []GpuInfo{} + } + + if int(major) < RocmComputeMin { + slog.Warn(fmt.Sprintf("amdgpu too old gfx%d%d%x", major, minor, patch), "gpu", gpuID) + continue + } + + // Look up the memory for the current node totalMemory := uint64(0) usedMemory := uint64(0) - // Adjust for sysfs vs HIP ids - propGlob := filepath.Join(AMDNodesSysfsDir, strconv.Itoa(id+1), GPUTotalMemoryFileGlob) + propGlob := filepath.Join(AMDNodesSysfsDir, strconv.Itoa(nodeID), GPUTotalMemoryFileGlob) propFiles, err := filepath.Glob(propGlob) if err != nil { - slog.Warn(fmt.Sprintf("error looking up total GPU memory: %s %s", propGlob, err)) + slog.Warn("error looking up total GPU memory", "glob", propGlob, "error", err) } // 1 or more memory banks - sum the values of all of them for _, propFile := range propFiles { fp, err := os.Open(propFile) if err != nil { - slog.Warn(fmt.Sprintf("failed to open sysfs node file %s: %s", propFile, err)) + slog.Warn("failed to open sysfs node", "file", propFile, "erroir", err) continue } defer fp.Close() @@ -226,49 +179,113 @@ func amdProcMemLookup(resp *GpuInfo, skip map[int]interface{}, ids []int) { } } if totalMemory == 0 { - slog.Warn(fmt.Sprintf("amdgpu [%d] reports zero total memory, skipping", id)) - skip[id] = struct{}{} + slog.Warn("amdgpu reports zero total memory", "gpu", gpuID) continue } - if totalMemory < IGPUMemLimit { - slog.Info(fmt.Sprintf("amdgpu [%d] appears to be an iGPU with %dM reported total memory, skipping", id, totalMemory/1024/1024)) - skip[id] = struct{}{} - continue - } - usedGlob := filepath.Join(AMDNodesSysfsDir, strconv.Itoa(id), GPUUsedMemoryFileGlob) + usedGlob := filepath.Join(AMDNodesSysfsDir, strconv.Itoa(nodeID), GPUUsedMemoryFileGlob) usedFiles, err := filepath.Glob(usedGlob) if err != nil { - slog.Warn(fmt.Sprintf("error looking up used GPU memory: %s %s", usedGlob, err)) + slog.Warn("error looking up used GPU memory", "glob", usedGlob, "error", err) continue } for _, usedFile := range usedFiles { fp, err := os.Open(usedFile) if err != nil { - slog.Warn(fmt.Sprintf("failed to open sysfs node file %s: %s", usedFile, err)) + slog.Warn("failed to open sysfs node", "file", usedFile, "error", err) continue } defer fp.Close() data, err := io.ReadAll(fp) if err != nil { - slog.Warn(fmt.Sprintf("failed to read sysfs node file %s: %s", usedFile, err)) + slog.Warn("failed to read sysfs node", "file", usedFile, "error", err) continue } used, err := strconv.ParseUint(strings.TrimSpace(string(data)), 10, 64) if err != nil { - slog.Warn(fmt.Sprintf("malformed used memory %s: %s", string(data), err)) + slog.Warn("malformed used memory", "data", string(data), "error", err) continue } usedMemory += used } - slog.Info(fmt.Sprintf("[%d] amdgpu totalMemory %dM", id, totalMemory/1024/1024)) - slog.Info(fmt.Sprintf("[%d] amdgpu freeMemory %dM", id, (totalMemory-usedMemory)/1024/1024)) - resp.memInfo.DeviceCount++ - resp.memInfo.TotalMemory += totalMemory - resp.memInfo.FreeMemory += (totalMemory - usedMemory) + + // iGPU detection, remove this check once we can support an iGPU variant of the rocm library + if totalMemory < IGPUMemLimit { + slog.Info("amdgpu appears to be an iGPU, skipping", "gpu", gpuID, "total", format.HumanBytes2(totalMemory)) + continue + } + + slog.Info("amdgpu memory", "gpu", gpuID, "total", format.HumanBytes2(totalMemory)) + slog.Info("amdgpu memory", "gpu", gpuID, "available", format.HumanBytes2(totalMemory-usedMemory)) + gpuInfo := GpuInfo{ + Library: "rocm", + memInfo: memInfo{ + TotalMemory: totalMemory, + FreeMemory: (totalMemory - usedMemory), + }, + ID: fmt.Sprintf("%d", gpuID), + // Name: not exposed in sysfs directly, would require pci device id lookup + Major: int(major), + Minor: int(minor), + Patch: int(patch), + MinimumMemory: rocmMinimumMemory, + } + + // If the user wants to filter to a subset of devices, filter out if we aren't a match + if len(visibleDevices) > 0 { + include := false + for _, visible := range visibleDevices { + if visible == gpuInfo.ID { + include = true + break + } + } + if !include { + slog.Info("filtering out device per user request", "id", gpuInfo.ID, "visible_devices", visibleDevices) + continue + } + } + + // Final validation is gfx compatibility - load the library if we haven't already loaded it + // even if the user overrides, we still need to validate the library + if libDir == "" { + libDir, err = AMDValidateLibDir() + if err != nil { + slog.Warn("unable to verify rocm library, will use cpu", "error", err) + return []GpuInfo{} + } + } + gpuInfo.DependencyPath = libDir + + if gfxOverride == "" { + // Only load supported list once + if len(supported) == 0 { + supported, err = GetSupportedGFX(libDir) + if err != nil { + slog.Warn("failed to lookup supported GFX types, falling back to CPU mode", "error", err) + return []GpuInfo{} + } + slog.Debug("rocm supported GPUs", "types", supported) + } + gfx := fmt.Sprintf("gfx%d%d%x", gpuInfo.Major, gpuInfo.Minor, gpuInfo.Patch) + if !slices.Contains[[]string, string](supported, gfx) { + slog.Warn("amdgpu is not supported", "gpu", gpuInfo.ID, "gpu_type", gfx, "library", libDir, "supported_types", supported) + // TODO - consider discrete markdown just for ROCM troubleshooting? + slog.Warn("See https://github.com/ollama/ollama/blob/main/docs/gpu.md#overrides for HSA_OVERRIDE_GFX_VERSION usage") + continue + } else { + slog.Info("amdgpu is supported", "gpu", gpuInfo.ID, "gpu_type", gfx) + } + } else { + slog.Debug("skipping rocm gfx compatibility check with HSA_OVERRIDE_GFX_VERSION=" + gfxOverride) + } + + // The GPU has passed all the verification steps and is supported + resp = append(resp, gpuInfo) } - if resp.memInfo.DeviceCount > 0 { - resp.Library = "rocm" + if len(resp) == 0 { + slog.Info("no compatible amdgpu devices detected") } + return resp } // Quick check for AMD driver so we can skip amdgpu discovery if not present @@ -280,87 +297,24 @@ func AMDDetected() bool { slog.Debug("amdgpu driver not detected " + sysfsDir) return false } else if err != nil { - slog.Debug(fmt.Sprintf("error looking up amd driver %s %s", sysfsDir, err)) + slog.Debug("error looking up amd driver", "path", sysfsDir, "error", err) return false } return true } -func setupLink(source, target string) error { - if err := os.RemoveAll(target); err != nil { - return fmt.Errorf("failed to remove old rocm directory %s %w", target, err) - } - if err := os.Symlink(source, target); err != nil { - return fmt.Errorf("failed to create link %s => %s %w", source, target, err) - } - slog.Debug(fmt.Sprintf("host rocm linked %s => %s", source, target)) - return nil -} - -// Ensure the AMD rocm lib dir is wired up // Prefer to use host installed ROCm, as long as it meets our minimum requirements // failing that, tell the user how to download it on their own func AMDValidateLibDir() (string, error) { - // We rely on the rpath compiled into our library to find rocm - // so we establish a symlink to wherever we find it on the system - // to /rocm - payloadsDir, err := PayloadsDir() - if err != nil { - return "", err - } - - // If we already have a rocm dependency wired, nothing more to do - rocmTargetDir := filepath.Clean(filepath.Join(payloadsDir, "..", "rocm")) - if rocmLibUsable(rocmTargetDir) { - return rocmTargetDir, nil - } - - // next to the running binary - exe, err := os.Executable() + libDir, err := commonAMDValidateLibDir() if err == nil { - peerDir := filepath.Dir(exe) - if rocmLibUsable(peerDir) { - slog.Debug("detected ROCM next to ollama executable " + peerDir) - return rocmTargetDir, setupLink(peerDir, rocmTargetDir) - } - peerDir = filepath.Join(filepath.Dir(exe), "rocm") - if rocmLibUsable(peerDir) { - slog.Debug("detected ROCM next to ollama executable " + peerDir) - return rocmTargetDir, setupLink(peerDir, rocmTargetDir) - } + return libDir, nil } // Well known ollama installer path installedRocmDir := "/usr/share/ollama/lib/rocm" if rocmLibUsable(installedRocmDir) { - return rocmTargetDir, setupLink(installedRocmDir, rocmTargetDir) - } - - // Prefer explicit HIP env var - hipPath := os.Getenv("HIP_PATH") - if hipPath != "" { - hipLibDir := filepath.Join(hipPath, "lib") - if rocmLibUsable(hipLibDir) { - slog.Debug("detected ROCM via HIP_PATH=" + hipPath) - return rocmTargetDir, setupLink(hipLibDir, rocmTargetDir) - } - } - - // Scan the library path for potential matches - ldPaths := strings.Split(os.Getenv("LD_LIBRARY_PATH"), ":") - for _, ldPath := range ldPaths { - d, err := filepath.Abs(ldPath) - if err != nil { - continue - } - if rocmLibUsable(d) { - return rocmTargetDir, setupLink(d, rocmTargetDir) - } - } - - // Well known location(s) - if rocmLibUsable("/opt/rocm/lib") { - return rocmTargetDir, setupLink("/opt/rocm/lib", rocmTargetDir) + return installedRocmDir, nil } // If we still haven't found a usable rocm, the user will have to install it on their own @@ -384,68 +338,3 @@ func AMDDriverVersion() (string, error) { } return strings.TrimSpace(string(verString)), nil } - -func AMDGFXVersions() map[int]Version { - // The amdgpu driver always exposes the host CPU as node 0, but we have to skip that and subtract one - // from the other IDs to get alignment with the HIP libraries expectations (zero is the first GPU, not the CPU) - res := map[int]Version{} - matches, _ := filepath.Glob(GPUPropertiesFileGlob) - for _, match := range matches { - fp, err := os.Open(match) - if err != nil { - slog.Debug(fmt.Sprintf("failed to open sysfs node file %s: %s", match, err)) - continue - } - defer fp.Close() - i, err := strconv.Atoi(filepath.Base(filepath.Dir(match))) - if err != nil { - slog.Debug(fmt.Sprintf("failed to parse node ID %s", err)) - continue - } - - if i == 0 { - // Skipping the CPU - continue - } - // Align with HIP IDs (zero is first GPU, not CPU) - i -= 1 - - scanner := bufio.NewScanner(fp) - for scanner.Scan() { - line := strings.TrimSpace(scanner.Text()) - if strings.HasPrefix(line, "gfx_target_version") { - ver := strings.Fields(line) - if len(ver) != 2 || len(ver[1]) < 5 { - if ver[1] != "0" { - slog.Debug("malformed " + line) - } - res[i] = Version{ - Major: 0, - Minor: 0, - Patch: 0, - } - continue - } - l := len(ver[1]) - patch, err1 := strconv.ParseUint(ver[1][l-2:l], 10, 32) - minor, err2 := strconv.ParseUint(ver[1][l-4:l-2], 10, 32) - major, err3 := strconv.ParseUint(ver[1][:l-4], 10, 32) - if err1 != nil || err2 != nil || err3 != nil { - slog.Debug("malformed int " + line) - continue - } - - res[i] = Version{ - Major: uint(major), - Minor: uint(minor), - Patch: uint(patch), - } - } - } - } - return res -} - -func (v Version) ToGFXString() string { - return fmt.Sprintf("gfx%d%d%d", v.Major, v.Minor, v.Patch) -} diff --git a/gpu/amd_windows.go b/gpu/amd_windows.go index be1be567..22c9f427 100644 --- a/gpu/amd_windows.go +++ b/gpu/amd_windows.go @@ -7,11 +7,13 @@ import ( "os" "path/filepath" "slices" + "strconv" "strings" + + "github.com/ollama/ollama/format" ) const ( - RocmStandardLocation = "C:\\Program Files\\AMD\\ROCm\\5.7\\bin" // TODO glob? // TODO We're lookinng for this exact name to detect iGPUs since hipGetDeviceProperties never reports integrated==true iGPUName = "AMD Radeon(TM) Graphics" @@ -19,39 +21,36 @@ const ( var ( // Used to validate if the given ROCm lib is usable - ROCmLibGlobs = []string{"hipblas.dll", "rocblas"} // TODO - probably include more coverage of files here... + ROCmLibGlobs = []string{"hipblas.dll", "rocblas"} // TODO - probably include more coverage of files here... + RocmStandardLocations = []string{"C:\\Program Files\\AMD\\ROCm\\5.7\\bin"} // TODO glob? ) -func AMDGetGPUInfo(resp *GpuInfo) { +func AMDGetGPUInfo() []GpuInfo { + resp := []GpuInfo{} hl, err := NewHipLib() if err != nil { slog.Debug(err.Error()) - return + return nil } defer hl.Release() - skip := map[int]interface{}{} - ids := []int{} - resp.memInfo.DeviceCount = 0 - resp.memInfo.TotalMemory = 0 - resp.memInfo.FreeMemory = 0 ver, err := hl.AMDDriverVersion() if err == nil { slog.Info("AMD Driver: " + ver) } else { // For now this is benign, but we may eventually need to fail compatibility checks - slog.Debug(fmt.Sprintf("error looking up amd driver version: %s", err)) + slog.Debug("error looking up amd driver version", "error", err) } - // Note: the HIP library automatically handles HIP_VISIBLE_DEVICES + // Note: the HIP library automatically handles subsetting to any HIP_VISIBLE_DEVICES the user specified count := hl.HipGetDeviceCount() if count == 0 { - return + return nil } libDir, err := AMDValidateLibDir() if err != nil { - slog.Warn(fmt.Sprintf("unable to verify rocm library, will use cpu: %s", err)) - return + slog.Warn("unable to verify rocm library, will use cpu", "error", err) + return nil } var supported []string @@ -59,95 +58,120 @@ func AMDGetGPUInfo(resp *GpuInfo) { if gfxOverride == "" { supported, err = GetSupportedGFX(libDir) if err != nil { - slog.Warn(fmt.Sprintf("failed to lookup supported GFX types, falling back to CPU mode: %s", err)) - return + slog.Warn("failed to lookup supported GFX types, falling back to CPU mode", "error", err) + return nil } } else { slog.Debug("skipping rocm gfx compatibility check with HSA_OVERRIDE_GFX_VERSION=" + gfxOverride) } - slog.Info(fmt.Sprintf("detected %d hip devices", count)) + slog.Info("detected hip devices", "count", count) + // TODO how to determine the underlying device ID when visible devices is causing this to subset? for i := 0; i < count; i++ { - ids = append(ids, i) err = hl.HipSetDevice(i) if err != nil { - slog.Warn(fmt.Sprintf("[%d] %s", i, err)) - skip[i] = struct{}{} + slog.Warn("set device", "id", i, "error", err) continue } props, err := hl.HipGetDeviceProperties(i) if err != nil { - slog.Warn(fmt.Sprintf("[%d] %s", i, err)) - skip[i] = struct{}{} + slog.Warn("get properties", "id", i, "error", err) continue } n := bytes.IndexByte(props.Name[:], 0) name := string(props.Name[:n]) - slog.Info(fmt.Sprintf("[%d] Name: %s", i, name)) + // TODO is UUID actually populated on windows? + // Can luid be used on windows for setting visible devices (and is it actually set?) n = bytes.IndexByte(props.GcnArchName[:], 0) gfx := string(props.GcnArchName[:n]) - slog.Info(fmt.Sprintf("[%d] GcnArchName: %s", i, gfx)) + slog.Info("hip device", "id", i, "name", name, "gfx", gfx) + var major, minor, patch string + switch len(gfx) { + case 6: + major, minor, patch = gfx[3:4], gfx[4:5], gfx[5:] + case 7: + major, minor, patch = gfx[3:5], gfx[5:6], gfx[6:] + } //slog.Info(fmt.Sprintf("[%d] Integrated: %d", i, props.iGPU)) // DOESN'T REPORT CORRECTLY! Always 0 // TODO Why isn't props.iGPU accurate!? if strings.EqualFold(name, iGPUName) { - slog.Info(fmt.Sprintf("iGPU detected [%d] skipping", i)) - skip[i] = struct{}{} + slog.Info("iGPU detected skipping", "id", i) continue } if gfxOverride == "" { if !slices.Contains[[]string, string](supported, gfx) { - slog.Warn(fmt.Sprintf("amdgpu [%d] %s is not supported by %s %v", i, gfx, libDir, supported)) + slog.Warn("amdgpu is not supported", "gpu", i, "gpu_type", gfx, "library", libDir, "supported_types", supported) // TODO - consider discrete markdown just for ROCM troubleshooting? slog.Warn("See https://github.com/ollama/ollama/blob/main/docs/troubleshooting.md for HSA_OVERRIDE_GFX_VERSION usage") - skip[i] = struct{}{} continue } else { - slog.Info(fmt.Sprintf("amdgpu [%d] %s is supported", i, gfx)) + slog.Info("amdgpu is supported", "gpu", i, "gpu_type", gfx) } } - totalMemory, freeMemory, err := hl.HipMemGetInfo() + freeMemory, totalMemory, err := hl.HipMemGetInfo() if err != nil { - slog.Warn(fmt.Sprintf("[%d] %s", i, err)) + slog.Warn("get mem info", "id", i, "error", err) continue } - // TODO according to docs, freeMem may lie on windows! - slog.Info(fmt.Sprintf("[%d] Total Mem: %d", i, totalMemory)) - slog.Info(fmt.Sprintf("[%d] Free Mem: %d", i, freeMemory)) - resp.memInfo.DeviceCount++ - resp.memInfo.TotalMemory += totalMemory - resp.memInfo.FreeMemory += freeMemory + // iGPU detection, remove this check once we can support an iGPU variant of the rocm library + if totalMemory < IGPUMemLimit { + slog.Info("amdgpu appears to be an iGPU, skipping", "gpu", i, "total", format.HumanBytes2(totalMemory)) + continue + } + + // TODO revisit this once ROCm v6 is available on windows. + // v5.7 only reports VRAM used by this process, so it's completely wrong and unusable + slog.Info("amdgpu memory", "gpu", i, "total", format.HumanBytes2(totalMemory)) + slog.Info("amdgpu memory", "gpu", i, "available", format.HumanBytes2(freeMemory)) + gpuInfo := GpuInfo{ + Library: "rocm", + memInfo: memInfo{ + TotalMemory: totalMemory, + FreeMemory: freeMemory, + }, + ID: fmt.Sprintf("%d", i), // TODO this is probably wrong if we specify visible devices + DependencyPath: libDir, + MinimumMemory: rocmMinimumMemory, + } + if major != "" { + gpuInfo.Major, err = strconv.Atoi(major) + if err != nil { + slog.Info("failed to parse version", "version", gfx, "error", err) + } + } + if minor != "" { + gpuInfo.Minor, err = strconv.Atoi(minor) + if err != nil { + slog.Info("failed to parse version", "version", gfx, "error", err) + } + } + if patch != "" { + // Patch rev is hex; e.g. gfx90a + p, err := strconv.ParseInt(patch, 16, 0) + if err != nil { + slog.Info("failed to parse version", "version", gfx, "error", err) + } else { + gpuInfo.Patch = int(p) + } + } + if gpuInfo.Major < RocmComputeMin { + slog.Warn(fmt.Sprintf("amdgpu [%s] too old gfx%d%d%x", gpuInfo.ID, gpuInfo.Major, gpuInfo.Minor, gpuInfo.Patch)) + continue + } + + resp = append(resp, gpuInfo) } - if resp.memInfo.DeviceCount > 0 { - resp.Library = "rocm" - } - // Abort if all GPUs are skipped - if len(skip) >= count { - slog.Info("all detected amdgpus are skipped, falling back to CPU") - return - } - if len(skip) > 0 { - amdSetVisibleDevices(ids, skip) - } - UpdatePath(libDir) + + return resp } func AMDValidateLibDir() (string, error) { - // On windows non-admins typically can't create links - // so instead of trying to rely on rpath and a link in - // $LibDir/rocm, we instead rely on setting PATH to point - // to the location of the ROCm library - - // Installer payload location if we're running the installed binary - exe, err := os.Executable() + libDir, err := commonAMDValidateLibDir() if err == nil { - rocmTargetDir := filepath.Join(filepath.Dir(exe), "rocm") - if rocmLibUsable(rocmTargetDir) { - slog.Debug("detected ROCM next to ollama executable " + rocmTargetDir) - return rocmTargetDir, nil - } + return libDir, nil } // Installer payload (if we're running from some other location) @@ -159,21 +183,6 @@ func AMDValidateLibDir() (string, error) { return rocmTargetDir, nil } - // Prefer explicit HIP env var - hipPath := os.Getenv("HIP_PATH") - if hipPath != "" { - hipLibDir := filepath.Join(hipPath, "bin") - if rocmLibUsable(hipLibDir) { - slog.Debug("detected ROCM via HIP_PATH=" + hipPath) - return hipLibDir, nil - } - } - - // Well known location(s) - if rocmLibUsable(RocmStandardLocation) { - return RocmStandardLocation, nil - } - // Should not happen on windows since we include it in the installer, but stand-alone binary might hit this slog.Warn("amdgpu detected, but no compatible rocm library found. Please install ROCm") return "", fmt.Errorf("no suitable rocm found, falling back to CPU") diff --git a/gpu/assets.go b/gpu/assets.go index 085c05bc..911a6977 100644 --- a/gpu/assets.go +++ b/gpu/assets.go @@ -12,6 +12,8 @@ import ( "sync" "syscall" "time" + + "github.com/ollama/ollama/server/envconfig" ) var ( @@ -24,8 +26,16 @@ func PayloadsDir() (string, error) { defer lock.Unlock() var err error if payloadsDir == "" { + runnersDir := envconfig.RunnersDir + + if runnersDir != "" { + payloadsDir = runnersDir + return payloadsDir, nil + } + + // The remainder only applies on non-windows where we still carry payloads in the main executable cleanupTmpDirs() - tmpDir := os.Getenv("OLLAMA_TMPDIR") + tmpDir := envconfig.TmpDir if tmpDir == "" { tmpDir, err = os.MkdirTemp("", "ollama") if err != nil { @@ -80,7 +90,7 @@ func cleanupTmpDirs() { } err = os.RemoveAll(d) if err != nil { - slog.Debug(fmt.Sprintf("unable to cleanup stale tmpdir %s: %s", d, err)) + slog.Debug("unable to cleanup stale tmpdir", "path", d, "error", err) } } } @@ -88,7 +98,8 @@ func cleanupTmpDirs() { func Cleanup() { lock.Lock() defer lock.Unlock() - if payloadsDir != "" { + runnersDir := envconfig.RunnersDir + if payloadsDir != "" && runnersDir == "" && runtime.GOOS != "windows" { // We want to fully clean up the tmpdir parent of the payloads dir tmpDir := filepath.Clean(filepath.Join(payloadsDir, "..")) slog.Debug("cleaning up", "dir", tmpDir) @@ -120,7 +131,7 @@ func UpdatePath(dir string) { } } newPath := strings.Join(append([]string{dir}, pathComponents...), ";") - slog.Info(fmt.Sprintf("Updating PATH to %s", newPath)) + slog.Info("updating", "PATH", newPath) os.Setenv("PATH", newPath) } // linux and darwin rely on rpath diff --git a/gpu/cuda_common.go b/gpu/cuda_common.go new file mode 100644 index 00000000..03c1a25b --- /dev/null +++ b/gpu/cuda_common.go @@ -0,0 +1,22 @@ +//go:build linux || windows + +package gpu + +import ( + "log/slog" + "strings" +) + +func cudaGetVisibleDevicesEnv(gpuInfo []GpuInfo) (string, string) { + ids := []string{} + for _, info := range gpuInfo { + if info.Library != "cuda" { + // TODO shouldn't happen if things are wired correctly... + slog.Debug("cudaGetVisibleDevicesEnv skipping over non-cuda device", "library", info.Library) + continue + } + ids = append(ids, info.ID) + } + return "CUDA_VISIBLE_DEVICES", strings.Join(ids, ",") + +} diff --git a/gpu/gpu.go b/gpu/gpu.go index 47d70ed0..f8bae9b0 100644 --- a/gpu/gpu.go +++ b/gpu/gpu.go @@ -16,22 +16,23 @@ import ( "os" "path/filepath" "runtime" - "strconv" "strings" "sync" "unsafe" "github.com/ollama/ollama/format" + "github.com/ollama/ollama/server/envconfig" ) type handles struct { - nvml *C.nvml_handle_t - cudart *C.cudart_handle_t + deviceCount int + cudart *C.cudart_handle_t + nvcuda *C.nvcuda_handle_t } const ( - cudaMinimumMemory = 457 * format.MebiByte - rocmMinimumMemory = 457 * format.MebiByte + cudaMinimumMemory = 256 * format.MebiByte + rocmMinimumMemory = 256 * format.MebiByte ) var gpuMutex sync.Mutex @@ -39,26 +40,10 @@ var gpuMutex sync.Mutex // With our current CUDA compile flags, older than 5.0 will not work properly var CudaComputeMin = [2]C.int{5, 0} -// Possible locations for the nvidia-ml library -var NvmlLinuxGlobs = []string{ - "/usr/local/cuda/lib64/libnvidia-ml.so*", - "/usr/lib/x86_64-linux-gnu/nvidia/current/libnvidia-ml.so*", - "/usr/lib/x86_64-linux-gnu/libnvidia-ml.so*", - "/usr/lib/wsl/lib/libnvidia-ml.so*", - "/usr/lib/wsl/drivers/*/libnvidia-ml.so*", - "/opt/cuda/lib64/libnvidia-ml.so*", - "/usr/lib*/libnvidia-ml.so*", - "/usr/lib/aarch64-linux-gnu/nvidia/current/libnvidia-ml.so*", - "/usr/lib/aarch64-linux-gnu/libnvidia-ml.so*", - "/usr/local/lib*/libnvidia-ml.so*", +var RocmComputeMin = 9 - // TODO: are these stubs ever valid? - "/opt/cuda/targets/x86_64-linux/lib/stubs/libnvidia-ml.so*", -} - -var NvmlWindowsGlobs = []string{ - "c:\\Windows\\System32\\nvml.dll", -} +// TODO find a better way to detect iGPU instead of minimum memory +const IGPUMemLimit = 1 * format.GibiByte // 512G is what they typically report, so anything less than 1G must be iGPU var CudartLinuxGlobs = []string{ "/usr/local/cuda/lib64/libcudart.so*", @@ -79,6 +64,22 @@ var CudartWindowsGlobs = []string{ "c:\\Program Files\\NVIDIA GPU Computing Toolkit\\CUDA\\v*\\bin\\cudart64_*.dll", } +var NvcudaLinuxGlobs = []string{ + "/usr/local/cuda*/targets/*/lib/libcuda.so*", + "/usr/lib/*-linux-gnu/nvidia/current/libcuda.so*", + "/usr/lib/*-linux-gnu/libcuda.so*", + "/usr/lib/wsl/lib/libcuda.so*", + "/usr/lib/wsl/drivers/*/libcuda.so*", + "/opt/cuda/lib*/libcuda.so*", + "/usr/local/cuda/lib*/libcuda.so*", + "/usr/lib*/libcuda.so*", + "/usr/local/lib*/libcuda.so*", +} + +var NvcudaWindowsGlobs = []string{ + "c:\\windows\\system*\\nvcuda.dll", +} + // Jetson devices have JETSON_JETPACK="x.y.z" factory set to the Jetpack version installed. // Included to drive logic for reducing Ollama-allocated overhead on L4T/Jetson devices. var CudaTegra string = os.Getenv("JETSON_JETPACK") @@ -88,61 +89,62 @@ func initGPUHandles() *handles { // TODO - if the ollama build is CPU only, don't do these checks as they're irrelevant and confusing - gpuHandles := &handles{nil, nil} - var nvmlMgmtName string - var nvmlMgmtPatterns []string + gpuHandles := &handles{} var cudartMgmtName string var cudartMgmtPatterns []string + var nvcudaMgmtName string + var nvcudaMgmtPatterns []string tmpDir, _ := PayloadsDir() switch runtime.GOOS { case "windows": - nvmlMgmtName = "nvml.dll" - nvmlMgmtPatterns = make([]string, len(NvmlWindowsGlobs)) - copy(nvmlMgmtPatterns, NvmlWindowsGlobs) cudartMgmtName = "cudart64_*.dll" localAppData := os.Getenv("LOCALAPPDATA") cudartMgmtPatterns = []string{filepath.Join(localAppData, "Programs", "Ollama", cudartMgmtName)} cudartMgmtPatterns = append(cudartMgmtPatterns, CudartWindowsGlobs...) + // Aligned with driver, we can't carry as payloads + nvcudaMgmtName = "nvcuda.dll" + nvcudaMgmtPatterns = NvcudaWindowsGlobs case "linux": - nvmlMgmtName = "libnvidia-ml.so" - nvmlMgmtPatterns = make([]string, len(NvmlLinuxGlobs)) - copy(nvmlMgmtPatterns, NvmlLinuxGlobs) cudartMgmtName = "libcudart.so*" if tmpDir != "" { // TODO - add "payloads" for subprocess cudartMgmtPatterns = []string{filepath.Join(tmpDir, "cuda*", cudartMgmtName)} } cudartMgmtPatterns = append(cudartMgmtPatterns, CudartLinuxGlobs...) + // Aligned with driver, we can't carry as payloads + nvcudaMgmtName = "libcuda.so*" + nvcudaMgmtPatterns = NvcudaLinuxGlobs default: return gpuHandles } - slog.Info("Detecting GPU type") - cudartLibPaths := FindGPULibs(cudartMgmtName, cudartMgmtPatterns) - if len(cudartLibPaths) > 0 { - cudart := LoadCUDARTMgmt(cudartLibPaths) - if cudart != nil { - slog.Info("Nvidia GPU detected via cudart") - gpuHandles.cudart = cudart + slog.Info("Detecting GPUs") + nvcudaLibPaths := FindGPULibs(nvcudaMgmtName, nvcudaMgmtPatterns) + if len(nvcudaLibPaths) > 0 { + deviceCount, nvcuda, libPath := LoadNVCUDAMgmt(nvcudaLibPaths) + if nvcuda != nil { + slog.Info("detected GPUs", "count", deviceCount, "library", libPath) + gpuHandles.nvcuda = nvcuda + gpuHandles.deviceCount = deviceCount return gpuHandles } } - // TODO once we build confidence, remove this and the gpu_info_nvml.[ch] files - nvmlLibPaths := FindGPULibs(nvmlMgmtName, nvmlMgmtPatterns) - if len(nvmlLibPaths) > 0 { - nvml := LoadNVMLMgmt(nvmlLibPaths) - if nvml != nil { - slog.Info("Nvidia GPU detected via nvidia-ml") - gpuHandles.nvml = nvml + cudartLibPaths := FindGPULibs(cudartMgmtName, cudartMgmtPatterns) + if len(cudartLibPaths) > 0 { + deviceCount, cudart, libPath := LoadCUDARTMgmt(cudartLibPaths) + if cudart != nil { + slog.Info("detected GPUs", "library", libPath, "count", deviceCount) + gpuHandles.cudart = cudart + gpuHandles.deviceCount = deviceCount return gpuHandles } } return gpuHandles } -func GetGPUInfo() GpuInfo { +func GetGPUInfo() GpuInfoList { // TODO - consider exploring lspci (and equivalent on windows) to check for // GPUs so we can report warnings if we see Nvidia/AMD but fail to load the libraries gpuMutex.Lock() @@ -150,12 +152,12 @@ func GetGPUInfo() GpuInfo { gpuHandles := initGPUHandles() defer func() { - if gpuHandles.nvml != nil { - C.nvml_release(*gpuHandles.nvml) - } if gpuHandles.cudart != nil { C.cudart_release(*gpuHandles.cudart) } + if gpuHandles.nvcuda != nil { + C.nvcuda_release(*gpuHandles.nvcuda) + } }() // All our GPU builds on x86 have AVX enabled, so fallback to CPU if we don't detect at least AVX @@ -164,73 +166,75 @@ func GetGPUInfo() GpuInfo { slog.Warn("CPU does not have AVX or AVX2, disabling GPU support.") } - var memInfo C.mem_info_t - resp := GpuInfo{} - if gpuHandles.nvml != nil && (cpuVariant != "" || runtime.GOARCH != "amd64") { - C.nvml_check_vram(*gpuHandles.nvml, &memInfo) - if memInfo.err != nil { - slog.Info(fmt.Sprintf("[nvidia-ml] error looking up NVML GPU memory: %s", C.GoString(memInfo.err))) - C.free(unsafe.Pointer(memInfo.err)) - } else if memInfo.count > 0 { - // Verify minimum compute capability - var cc C.nvml_compute_capability_t - C.nvml_compute_capability(*gpuHandles.nvml, &cc) - if cc.err != nil { - slog.Info(fmt.Sprintf("[nvidia-ml] error looking up NVML GPU compute capability: %s", C.GoString(cc.err))) - C.free(unsafe.Pointer(cc.err)) - } else if cc.major > CudaComputeMin[0] || (cc.major == CudaComputeMin[0] && cc.minor >= CudaComputeMin[1]) { - slog.Info(fmt.Sprintf("[nvidia-ml] NVML CUDA Compute Capability detected: %d.%d", cc.major, cc.minor)) - resp.Library = "cuda" - resp.MinimumMemory = cudaMinimumMemory - } else { - slog.Info(fmt.Sprintf("[nvidia-ml] CUDA GPU is too old. Falling back to CPU mode. Compute Capability detected: %d.%d", cc.major, cc.minor)) - } - } - } else if gpuHandles.cudart != nil && (cpuVariant != "" || runtime.GOARCH != "amd64") { - C.cudart_check_vram(*gpuHandles.cudart, &memInfo) - if memInfo.err != nil { - slog.Info(fmt.Sprintf("[cudart] error looking up CUDART GPU memory: %s", C.GoString(memInfo.err))) - C.free(unsafe.Pointer(memInfo.err)) - } else if memInfo.count > 0 { - // Verify minimum compute capability - var cc C.cudart_compute_capability_t - C.cudart_compute_capability(*gpuHandles.cudart, &cc) - if cc.err != nil { - slog.Info(fmt.Sprintf("[cudart] error looking up CUDA compute capability: %s", C.GoString(cc.err))) - C.free(unsafe.Pointer(cc.err)) - } else if cc.major > CudaComputeMin[0] || (cc.major == CudaComputeMin[0] && cc.minor >= CudaComputeMin[1]) { - slog.Info(fmt.Sprintf("[cudart] CUDART CUDA Compute Capability detected: %d.%d", cc.major, cc.minor)) - resp.Library = "cuda" - resp.MinimumMemory = cudaMinimumMemory - } else { - slog.Info(fmt.Sprintf("[cudart] CUDA GPU is too old. Falling back to CPU mode. Compute Capability detected: %d.%d", cc.major, cc.minor)) - } - } - } else { - AMDGetGPUInfo(&resp) - if resp.Library != "" { - resp.MinimumMemory = rocmMinimumMemory - return resp - } - } - if resp.Library == "" { - C.cpu_check_ram(&memInfo) - resp.Library = "cpu" - resp.Variant = cpuVariant - } - if memInfo.err != nil { - slog.Info(fmt.Sprintf("error looking up CPU memory: %s", C.GoString(memInfo.err))) - C.free(unsafe.Pointer(memInfo.err)) - return resp + // On windows we bundle the nvidia library one level above the runner dir + depPath := "" + if runtime.GOOS == "windows" && envconfig.RunnersDir != "" { + depPath = filepath.Dir(envconfig.RunnersDir) + } + + var memInfo C.mem_info_t + resp := []GpuInfo{} + + // NVIDIA first + for i := 0; i < gpuHandles.deviceCount; i++ { + // TODO once we support CPU compilation variants of GPU libraries refine this... + if cpuVariant == "" && runtime.GOARCH == "amd64" { + continue + } + gpuInfo := GpuInfo{ + Library: "cuda", + } + if gpuHandles.cudart != nil { + C.cudart_check_vram(*gpuHandles.cudart, C.int(i), &memInfo) + } else { + C.nvcuda_check_vram(*gpuHandles.nvcuda, C.int(i), &memInfo) + } + if memInfo.err != nil { + slog.Info("error looking up nvidia GPU memory", "error", C.GoString(memInfo.err)) + C.free(unsafe.Pointer(memInfo.err)) + continue + } + if memInfo.major < CudaComputeMin[0] || (memInfo.major == CudaComputeMin[0] && memInfo.minor < CudaComputeMin[1]) { + slog.Info(fmt.Sprintf("[%d] CUDA GPU is too old. Compute Capability detected: %d.%d", i, memInfo.major, memInfo.minor)) + continue + } + gpuInfo.TotalMemory = uint64(memInfo.total) + gpuInfo.FreeMemory = uint64(memInfo.free) + gpuInfo.ID = C.GoString(&memInfo.gpu_id[0]) + gpuInfo.Major = int(memInfo.major) + gpuInfo.Minor = int(memInfo.minor) + gpuInfo.MinimumMemory = cudaMinimumMemory + gpuInfo.DependencyPath = depPath + + // TODO potentially sort on our own algorithm instead of what the underlying GPU library does... + resp = append(resp, gpuInfo) + } + + // Then AMD + resp = append(resp, AMDGetGPUInfo()...) + + if len(resp) == 0 { + C.cpu_check_ram(&memInfo) + if memInfo.err != nil { + slog.Info("error looking up CPU memory", "error", C.GoString(memInfo.err)) + C.free(unsafe.Pointer(memInfo.err)) + return resp + } + gpuInfo := GpuInfo{ + Library: "cpu", + Variant: cpuVariant, + } + gpuInfo.TotalMemory = uint64(memInfo.total) + gpuInfo.FreeMemory = uint64(memInfo.free) + gpuInfo.ID = C.GoString(&memInfo.gpu_id[0]) + + resp = append(resp, gpuInfo) } - resp.DeviceCount = uint32(memInfo.count) - resp.FreeMemory = uint64(memInfo.free) - resp.TotalMemory = uint64(memInfo.total) return resp } -func getCPUMem() (memInfo, error) { +func GetCPUMem() (memInfo, error) { var ret memInfo var info C.mem_info_t C.cpu_check_ram(&info) @@ -243,29 +247,12 @@ func getCPUMem() (memInfo, error) { return ret, nil } -func CheckVRAM() (uint64, error) { - userLimit := os.Getenv("OLLAMA_MAX_VRAM") - if userLimit != "" { - avail, err := strconv.ParseInt(userLimit, 10, 64) - if err != nil { - return 0, fmt.Errorf("Invalid OLLAMA_MAX_VRAM setting %s: %s", userLimit, err) - } - slog.Info(fmt.Sprintf("user override OLLAMA_MAX_VRAM=%d", avail)) - return uint64(avail), nil - } - gpuInfo := GetGPUInfo() - if gpuInfo.FreeMemory > 0 && (gpuInfo.Library == "cuda" || gpuInfo.Library == "rocm") { - return gpuInfo.FreeMemory, nil - } - - return 0, fmt.Errorf("no GPU detected") // TODO - better handling of CPU based memory determiniation -} - -func FindGPULibs(baseLibName string, patterns []string) []string { +func FindGPULibs(baseLibName string, defaultPatterns []string) []string { // Multiple GPU libraries may exist, and some may not work, so keep trying until we exhaust them var ldPaths []string + var patterns []string gpuLibPaths := []string{} - slog.Info(fmt.Sprintf("Searching for GPU management library %s", baseLibName)) + slog.Debug("Searching for GPU library", "name", baseLibName) switch runtime.GOOS { case "windows": @@ -283,8 +270,14 @@ func FindGPULibs(baseLibName string, patterns []string) []string { } patterns = append(patterns, filepath.Join(d, baseLibName+"*")) } - slog.Debug(fmt.Sprintf("gpu management search paths: %v", patterns)) + patterns = append(patterns, defaultPatterns...) + slog.Debug("gpu library search", "globs", patterns) for _, pattern := range patterns { + + // Nvidia PhysX known to return bogus results + if strings.Contains(pattern, "PhysX") { + slog.Debug("skipping PhysX cuda library path", "path", pattern) + } // Ignore glob discovery errors matches, _ := filepath.Glob(pattern) for _, match := range matches { @@ -311,28 +304,11 @@ func FindGPULibs(baseLibName string, patterns []string) []string { } } } - slog.Info(fmt.Sprintf("Discovered GPU libraries: %v", gpuLibPaths)) + slog.Debug("discovered GPU libraries", "paths", gpuLibPaths) return gpuLibPaths } -func LoadNVMLMgmt(nvmlLibPaths []string) *C.nvml_handle_t { - var resp C.nvml_init_resp_t - resp.ch.verbose = getVerboseState() - for _, libPath := range nvmlLibPaths { - lib := C.CString(libPath) - defer C.free(unsafe.Pointer(lib)) - C.nvml_init(lib, &resp) - if resp.err != nil { - slog.Info(fmt.Sprintf("Unable to load NVML management library %s: %s", libPath, C.GoString(resp.err))) - C.free(unsafe.Pointer(resp.err)) - } else { - return &resp.ch - } - } - return nil -} - -func LoadCUDARTMgmt(cudartLibPaths []string) *C.cudart_handle_t { +func LoadCUDARTMgmt(cudartLibPaths []string) (int, *C.cudart_handle_t, string) { var resp C.cudart_init_resp_t resp.ch.verbose = getVerboseState() for _, libPath := range cudartLibPaths { @@ -340,18 +316,54 @@ func LoadCUDARTMgmt(cudartLibPaths []string) *C.cudart_handle_t { defer C.free(unsafe.Pointer(lib)) C.cudart_init(lib, &resp) if resp.err != nil { - slog.Info(fmt.Sprintf("Unable to load cudart CUDA management library %s: %s", libPath, C.GoString(resp.err))) + slog.Debug("Unable to load cudart", "library", libPath, "error", C.GoString(resp.err)) C.free(unsafe.Pointer(resp.err)) } else { - return &resp.ch + return int(resp.num_devices), &resp.ch, libPath } } - return nil + return 0, nil, "" +} + +func LoadNVCUDAMgmt(nvcudaLibPaths []string) (int, *C.nvcuda_handle_t, string) { + var resp C.nvcuda_init_resp_t + resp.ch.verbose = getVerboseState() + for _, libPath := range nvcudaLibPaths { + lib := C.CString(libPath) + defer C.free(unsafe.Pointer(lib)) + C.nvcuda_init(lib, &resp) + if resp.err != nil { + slog.Debug("Unable to load nvcuda", "library", libPath, "error", C.GoString(resp.err)) + C.free(unsafe.Pointer(resp.err)) + } else { + return int(resp.num_devices), &resp.ch, libPath + } + } + return 0, nil, "" } func getVerboseState() C.uint16_t { - if debug := os.Getenv("OLLAMA_DEBUG"); debug != "" { + if envconfig.Debug { return C.uint16_t(1) } return C.uint16_t(0) } + +// Given the list of GPUs this instantiation is targeted for, +// figure out the visible devices environment variable +// +// If different libraries are detected, the first one is what we use +func (l GpuInfoList) GetVisibleDevicesEnv() (string, string) { + if len(l) == 0 { + return "", "" + } + switch l[0].Library { + case "cuda": + return cudaGetVisibleDevicesEnv(l) + case "rocm": + return rocmGetVisibleDevicesEnv(l) + default: + slog.Debug("no filter required for library " + l[0].Library) + return "", "" + } +} diff --git a/gpu/gpu_darwin.go b/gpu/gpu_darwin.go index bf764ce6..0ba02e1b 100644 --- a/gpu/gpu_darwin.go +++ b/gpu/gpu_darwin.go @@ -9,52 +9,47 @@ package gpu */ import "C" import ( - "fmt" - "log/slog" - "os" "runtime" - "strconv" + + "github.com/ollama/ollama/format" ) -// CheckVRAM returns the free VRAM in bytes on Linux machines with NVIDIA GPUs -func CheckVRAM() (uint64, error) { - userLimit := os.Getenv("OLLAMA_MAX_VRAM") - if userLimit != "" { - avail, err := strconv.ParseInt(userLimit, 10, 64) - if err != nil { - return 0, fmt.Errorf("Invalid OLLAMA_MAX_VRAM setting %s: %s", userLimit, err) - } - slog.Info(fmt.Sprintf("user override OLLAMA_MAX_VRAM=%d", avail)) - return uint64(avail), nil - } +const ( + metalMinimumMemory = 384 * format.MebiByte +) +func GetGPUInfo() GpuInfoList { + mem, _ := GetCPUMem() if runtime.GOARCH == "amd64" { - // gpu not supported, this may not be metal - return 0, nil - } - - return uint64(C.getRecommendedMaxVRAM()), nil -} - -func GetGPUInfo() GpuInfo { - mem, _ := getCPUMem() - if runtime.GOARCH == "amd64" { - return GpuInfo{ - Library: "cpu", - Variant: GetCPUVariant(), - memInfo: mem, + return []GpuInfo{ + { + Library: "cpu", + Variant: GetCPUVariant(), + memInfo: mem, + }, } } - return GpuInfo{ + info := GpuInfo{ Library: "metal", - memInfo: mem, + ID: "0", } + info.TotalMemory = uint64(C.getRecommendedMaxVRAM()) + + // TODO is there a way to gather actual allocated video memory? (currentAllocatedSize doesn't work) + info.FreeMemory = info.TotalMemory + + info.MinimumMemory = metalMinimumMemory + return []GpuInfo{info} } -func getCPUMem() (memInfo, error) { +func GetCPUMem() (memInfo, error) { return memInfo{ TotalMemory: uint64(C.getPhysicalMemory()), FreeMemory: 0, - DeviceCount: 1, }, nil } + +func (l GpuInfoList) GetVisibleDevicesEnv() (string, string) { + // No-op on darwin + return "", "" +} diff --git a/gpu/gpu_info.h b/gpu/gpu_info.h index 4c449a60..577bd3f0 100644 --- a/gpu/gpu_info.h +++ b/gpu/gpu_info.h @@ -38,12 +38,17 @@ extern "C" { #endif +#define GPU_ID_LEN 64 + typedef struct mem_info { + char *err; // If non-nill, caller responsible for freeing + char gpu_id[GPU_ID_LEN]; uint64_t total; uint64_t free; - unsigned int count; - int igpu_index; // If >= 0, we detected an integrated GPU to ignore - char *err; // If non-nill, caller responsible for freeing + + // Compute Capability + int major; + int minor; } mem_info_t; void cpu_check_ram(mem_info_t *resp); @@ -52,8 +57,8 @@ void cpu_check_ram(mem_info_t *resp); } #endif -#include "gpu_info_nvml.h" #include "gpu_info_cudart.h" +#include "gpu_info_nvcuda.h" #endif // __GPU_INFO_H__ #endif // __APPLE__ \ No newline at end of file diff --git a/gpu/gpu_info_cpu.c b/gpu/gpu_info_cpu.c index 0c4d62c5..81ba3de4 100644 --- a/gpu/gpu_info_cpu.c +++ b/gpu/gpu_info_cpu.c @@ -8,9 +8,11 @@ void cpu_check_ram(mem_info_t *resp) { MEMORYSTATUSEX info; info.dwLength = sizeof(info); if (GlobalMemoryStatusEx(&info) != 0) { - resp->count = 1; resp->total = info.ullTotalPhys; resp->free = info.ullAvailPhys; + resp->major = 0; + resp->minor = 0; + snprintf(&resp->gpu_id[0], GPU_ID_LEN, "0"); } else { resp->err = LOAD_ERR(); } @@ -27,9 +29,11 @@ void cpu_check_ram(mem_info_t *resp) { if (sysinfo(&info) != 0) { resp->err = strdup(strerror(errno)); } else { - resp->count = 1; resp->total = info.totalram * info.mem_unit; resp->free = info.freeram * info.mem_unit; + resp->major = 0; + resp->minor = 0; + snprintf(&resp->gpu_id[0], GPU_ID_LEN, "0"); } return; } diff --git a/gpu/gpu_info_cudart.c b/gpu/gpu_info_cudart.c index 27cd2342..8e9204ea 100644 --- a/gpu/gpu_info_cudart.c +++ b/gpu/gpu_info_cudart.c @@ -6,6 +6,7 @@ void cudart_init(char *cudart_lib_path, cudart_init_resp_t *resp) { cudartReturn_t ret; resp->err = NULL; + resp->num_devices = 0; const int buflen = 256; char buf[buflen + 1]; int i; @@ -21,6 +22,7 @@ void cudart_init(char *cudart_lib_path, cudart_init_resp_t *resp) { {"cudaGetDeviceCount", (void *)&resp->ch.cudaGetDeviceCount}, {"cudaDeviceGetAttribute", (void *)&resp->ch.cudaDeviceGetAttribute}, {"cudaDriverGetVersion", (void *)&resp->ch.cudaDriverGetVersion}, + {"cudaGetDeviceProperties", (void *)&resp->ch.cudaGetDeviceProperties}, {NULL, NULL}, }; @@ -36,13 +38,7 @@ void cudart_init(char *cudart_lib_path, cudart_init_resp_t *resp) { return; } - // TODO once we've squashed the remaining corner cases remove this log - LOG(resp->ch.verbose, "wiring cudart library functions in %s\n", cudart_lib_path); - for (i = 0; l[i].s != NULL; i++) { - // TODO once we've squashed the remaining corner cases remove this log - LOG(resp->ch.verbose, "dlsym: %s\n", l[i].s); - *l[i].p = LOAD_SYMBOL(resp->ch.handle, l[i].s); if (!l[i].p) { char *msg = LOAD_ERR(); @@ -63,7 +59,7 @@ void cudart_init(char *cudart_lib_path, cudart_init_resp_t *resp) { UNLOAD_LIBRARY(resp->ch.handle); resp->ch.handle = NULL; if (ret == CUDA_ERROR_INSUFFICIENT_DRIVER) { - resp->err = strdup("your nvidia driver is too old or missing, please upgrade to run ollama"); + resp->err = strdup("your nvidia driver is too old or missing. If you have a CUDA GPU please upgrade to run ollama"); return; } snprintf(buf, buflen, "cudart init failure: %d", ret); @@ -85,110 +81,95 @@ void cudart_init(char *cudart_lib_path, cudart_init_resp_t *resp) { driverVersion.minor = (version - (driverVersion.major * 1000)) / 10; LOG(resp->ch.verbose, "CUDA driver version: %d-%d\n", driverVersion.major, driverVersion.minor); } + + ret = (*resp->ch.cudaGetDeviceCount)(&resp->num_devices); + if (ret != CUDART_SUCCESS) { + LOG(resp->ch.verbose, "cudaGetDeviceCount err: %d\n", ret); + UNLOAD_LIBRARY(resp->ch.handle); + resp->ch.handle = NULL; + snprintf(buf, buflen, "unable to get device count: %d", ret); + resp->err = strdup(buf); + return; + } } -void cudart_check_vram(cudart_handle_t h, mem_info_t *resp) { +void cudart_check_vram(cudart_handle_t h, int i, mem_info_t *resp) { resp->err = NULL; cudartMemory_t memInfo = {0,0,0}; cudartReturn_t ret; const int buflen = 256; char buf[buflen + 1]; - int i; if (h.handle == NULL) { resp->err = strdup("cudart handle isn't initialized"); return; } - // cudaGetDeviceCount takes int type, resp-> count is uint - int deviceCount; - ret = (*h.cudaGetDeviceCount)(&deviceCount); + ret = (*h.cudaSetDevice)(i); if (ret != CUDART_SUCCESS) { - snprintf(buf, buflen, "unable to get device count: %d", ret); + snprintf(buf, buflen, "cudart device failed to initialize"); resp->err = strdup(buf); return; + } + + cudaDeviceProp_t props; + ret = (*h.cudaGetDeviceProperties)(&props, i); + if (ret != CUDART_SUCCESS) { + LOG(h.verbose, "[%d] device properties lookup failure: %d\n", i, ret); + snprintf(&resp->gpu_id[0], GPU_ID_LEN, "%d", i); + resp->major = 0; + resp->minor = 0; } else { - resp->count = (unsigned int)deviceCount; - } - - resp->total = 0; - resp->free = 0; - for (i = 0; i < resp-> count; i++) { - ret = (*h.cudaSetDevice)(i); - if (ret != CUDART_SUCCESS) { - snprintf(buf, buflen, "cudart device failed to initialize"); - resp->err = strdup(buf); - return; + int allNull = 1; + for (int j = 0; j < 16; j++) { + if (props.uuid.bytes[j] != 0) { + allNull = 0; + break; + } } - ret = (*h.cudaMemGetInfo)(&memInfo.free, &memInfo.total); - if (ret != CUDART_SUCCESS) { - snprintf(buf, buflen, "cudart device memory info lookup failure %d", ret); - resp->err = strdup(buf); - return; + if (allNull != 0) { + snprintf(&resp->gpu_id[0], GPU_ID_LEN, "%d", i); + } else { + // GPU-d110a105-ac29-1d54-7b49-9c90440f215b + snprintf(&resp->gpu_id[0], GPU_ID_LEN, + "GPU-%02x%02x%02x%02x-%02x%02x-%02x%02x-%02x%02x-%02x%02x%02x%02x%02x%02x", + props.uuid.bytes[0], + props.uuid.bytes[1], + props.uuid.bytes[2], + props.uuid.bytes[3], + props.uuid.bytes[4], + props.uuid.bytes[5], + props.uuid.bytes[6], + props.uuid.bytes[7], + props.uuid.bytes[8], + props.uuid.bytes[9], + props.uuid.bytes[10], + props.uuid.bytes[11], + props.uuid.bytes[12], + props.uuid.bytes[13], + props.uuid.bytes[14], + props.uuid.bytes[15] + ); } + resp->major = props.major; + resp->minor = props.minor; - LOG(h.verbose, "[%d] CUDA totalMem %lu\n", i, memInfo.total); - LOG(h.verbose, "[%d] CUDA freeMem %lu\n", i, memInfo.free); - - resp->total += memInfo.total; - resp->free += memInfo.free; + // TODO add other useful properties from props } -} - -void cudart_compute_capability(cudart_handle_t h, cudart_compute_capability_t *resp) { - resp->err = NULL; - resp->major = 0; - resp->minor = 0; - int major = 0; - int minor = 0; - cudartReturn_t ret; - const int buflen = 256; - char buf[buflen + 1]; - int i; - - if (h.handle == NULL) { - resp->err = strdup("cudart handle not initialized"); - return; - } - - int devices; - ret = (*h.cudaGetDeviceCount)(&devices); + ret = (*h.cudaMemGetInfo)(&memInfo.free, &memInfo.total); if (ret != CUDART_SUCCESS) { - snprintf(buf, buflen, "unable to get cudart device count: %d", ret); + snprintf(buf, buflen, "cudart device memory info lookup failure %d", ret); resp->err = strdup(buf); return; } - for (i = 0; i < devices; i++) { - ret = (*h.cudaSetDevice)(i); - if (ret != CUDART_SUCCESS) { - snprintf(buf, buflen, "cudart device failed to initialize"); - resp->err = strdup(buf); - return; - } + resp->total = memInfo.total; + resp->free = memInfo.free; - ret = (*h.cudaDeviceGetAttribute)(&major, cudartDevAttrComputeCapabilityMajor, i); - if (ret != CUDART_SUCCESS) { - snprintf(buf, buflen, "device compute capability lookup failure %d: %d", i, ret); - resp->err = strdup(buf); - return; - } - ret = (*h.cudaDeviceGetAttribute)(&minor, cudartDevAttrComputeCapabilityMinor, i); - if (ret != CUDART_SUCCESS) { - snprintf(buf, buflen, "device compute capability lookup failure %d: %d", i, ret); - resp->err = strdup(buf); - return; - } - - // Report the lowest major.minor we detect as that limits our compatibility - if (resp->major == 0 || resp->major > major ) { - resp->major = major; - resp->minor = minor; - } else if ( resp->major == major && resp->minor > minor ) { - resp->minor = minor; - } - } + LOG(h.verbose, "[%s] CUDA totalMem %lu\n", resp->gpu_id, resp->total); + LOG(h.verbose, "[%s] CUDA freeMem %lu\n", resp->gpu_id, resp->free); + LOG(h.verbose, "[%s] Compute Capability %d.%d\n", resp->gpu_id, resp->major, resp->minor); } void cudart_release(cudart_handle_t h) { diff --git a/gpu/gpu_info_cudart.h b/gpu/gpu_info_cudart.h index eb9336ec..e8a89856 100644 --- a/gpu/gpu_info_cudart.h +++ b/gpu/gpu_info_cudart.h @@ -6,14 +6,20 @@ // Just enough typedef's to dlopen/dlsym for memory information typedef enum cudartReturn_enum { CUDART_SUCCESS = 0, - CUDART_UNSUPPORTED = 1, - CUDA_ERROR_INSUFFICIENT_DRIVER = 35, + CUDART_ERROR_INVALID_VALUE = 1, + CUDART_ERROR_MEMORY_ALLOCATION = 2, + CUDART_ERROR_INSUFFICIENT_DRIVER = 35, // Other values omitted for now... } cudartReturn_t; typedef enum cudartDeviceAttr_enum { cudartDevAttrComputeCapabilityMajor = 75, cudartDevAttrComputeCapabilityMinor = 76, + + // TODO - not yet wired up but may be useful for Jetson or other + // integrated GPU scenarios with shared memory + cudaDevAttrIntegrated = 18 + } cudartDeviceAttr_t; typedef void *cudartDevice_t; // Opaque is sufficient @@ -28,6 +34,92 @@ typedef struct cudartDriverVersion { int minor; } cudartDriverVersion_t; +typedef struct cudaUUID { + unsigned char bytes[16]; +} cudaUUID_t; +typedef struct cudaDeviceProp { + char name[256]; /**< ASCII string identifying device */ + cudaUUID_t uuid; /**< 16-byte unique identifier */ + char luid[8]; /**< 8-byte locally unique identifier. Value is undefined on TCC and non-Windows platforms */ + unsigned int luidDeviceNodeMask; /**< LUID device node mask. Value is undefined on TCC and non-Windows platforms */ + size_t totalGlobalMem; /**< Global memory available on device in bytes */ + size_t sharedMemPerBlock; /**< Shared memory available per block in bytes */ + int regsPerBlock; /**< 32-bit registers available per block */ + int warpSize; /**< Warp size in threads */ + size_t memPitch; /**< Maximum pitch in bytes allowed by memory copies */ + int maxThreadsPerBlock; /**< Maximum number of threads per block */ + int maxThreadsDim[3]; /**< Maximum size of each dimension of a block */ + int maxGridSize[3]; /**< Maximum size of each dimension of a grid */ + int clockRate; /**< Clock frequency in kilohertz */ + size_t totalConstMem; /**< Constant memory available on device in bytes */ + int major; /**< Major compute capability */ + int minor; /**< Minor compute capability */ + size_t textureAlignment; /**< Alignment requirement for textures */ + size_t texturePitchAlignment; /**< Pitch alignment requirement for texture references bound to pitched memory */ + int deviceOverlap; /**< Device can concurrently copy memory and execute a kernel. Deprecated. Use instead asyncEngineCount. */ + int multiProcessorCount; /**< Number of multiprocessors on device */ + int kernelExecTimeoutEnabled; /**< Specified whether there is a run time limit on kernels */ + int integrated; /**< Device is integrated as opposed to discrete */ + int canMapHostMemory; /**< Device can map host memory with cudaHostAlloc/cudaHostGetDevicePointer */ + int computeMode; /**< Compute mode (See ::cudaComputeMode) */ + int maxTexture1D; /**< Maximum 1D texture size */ + int maxTexture1DMipmap; /**< Maximum 1D mipmapped texture size */ + int maxTexture1DLinear; /**< Deprecated, do not use. Use cudaDeviceGetTexture1DLinearMaxWidth() or cuDeviceGetTexture1DLinearMaxWidth() instead. */ + int maxTexture2D[2]; /**< Maximum 2D texture dimensions */ + int maxTexture2DMipmap[2]; /**< Maximum 2D mipmapped texture dimensions */ + int maxTexture2DLinear[3]; /**< Maximum dimensions (width, height, pitch) for 2D textures bound to pitched memory */ + int maxTexture2DGather[2]; /**< Maximum 2D texture dimensions if texture gather operations have to be performed */ + int maxTexture3D[3]; /**< Maximum 3D texture dimensions */ + int maxTexture3DAlt[3]; /**< Maximum alternate 3D texture dimensions */ + int maxTextureCubemap; /**< Maximum Cubemap texture dimensions */ + int maxTexture1DLayered[2]; /**< Maximum 1D layered texture dimensions */ + int maxTexture2DLayered[3]; /**< Maximum 2D layered texture dimensions */ + int maxTextureCubemapLayered[2];/**< Maximum Cubemap layered texture dimensions */ + int maxSurface1D; /**< Maximum 1D surface size */ + int maxSurface2D[2]; /**< Maximum 2D surface dimensions */ + int maxSurface3D[3]; /**< Maximum 3D surface dimensions */ + int maxSurface1DLayered[2]; /**< Maximum 1D layered surface dimensions */ + int maxSurface2DLayered[3]; /**< Maximum 2D layered surface dimensions */ + int maxSurfaceCubemap; /**< Maximum Cubemap surface dimensions */ + int maxSurfaceCubemapLayered[2];/**< Maximum Cubemap layered surface dimensions */ + size_t surfaceAlignment; /**< Alignment requirements for surfaces */ + int concurrentKernels; /**< Device can possibly execute multiple kernels concurrently */ + int ECCEnabled; /**< Device has ECC support enabled */ + int pciBusID; /**< PCI bus ID of the device */ + int pciDeviceID; /**< PCI device ID of the device */ + int pciDomainID; /**< PCI domain ID of the device */ + int tccDriver; /**< 1 if device is a Tesla device using TCC driver, 0 otherwise */ + int asyncEngineCount; /**< Number of asynchronous engines */ + int unifiedAddressing; /**< Device shares a unified address space with the host */ + int memoryClockRate; /**< Peak memory clock frequency in kilohertz */ + int memoryBusWidth; /**< Global memory bus width in bits */ + int l2CacheSize; /**< Size of L2 cache in bytes */ + int persistingL2CacheMaxSize; /**< Device's maximum l2 persisting lines capacity setting in bytes */ + int maxThreadsPerMultiProcessor;/**< Maximum resident threads per multiprocessor */ + int streamPrioritiesSupported; /**< Device supports stream priorities */ + int globalL1CacheSupported; /**< Device supports caching globals in L1 */ + int localL1CacheSupported; /**< Device supports caching locals in L1 */ + size_t sharedMemPerMultiprocessor; /**< Shared memory available per multiprocessor in bytes */ + int regsPerMultiprocessor; /**< 32-bit registers available per multiprocessor */ + int managedMemory; /**< Device supports allocating managed memory on this system */ + int isMultiGpuBoard; /**< Device is on a multi-GPU board */ + int multiGpuBoardGroupID; /**< Unique identifier for a group of devices on the same multi-GPU board */ + int hostNativeAtomicSupported; /**< Link between the device and the host supports native atomic operations */ + int singleToDoublePrecisionPerfRatio; /**< Ratio of single precision performance (in floating-point operations per second) to double precision performance */ + int pageableMemoryAccess; /**< Device supports coherently accessing pageable memory without calling cudaHostRegister on it */ + int concurrentManagedAccess; /**< Device can coherently access managed memory concurrently with the CPU */ + int computePreemptionSupported; /**< Device supports Compute Preemption */ + int canUseHostPointerForRegisteredMem; /**< Device can access host registered memory at the same virtual address as the CPU */ + int cooperativeLaunch; /**< Device supports launching cooperative kernels via ::cudaLaunchCooperativeKernel */ + int cooperativeMultiDeviceLaunch; /**< Deprecated, cudaLaunchCooperativeKernelMultiDevice is deprecated. */ + size_t sharedMemPerBlockOptin; /**< Per device maximum shared memory per block usable by special opt in */ + int pageableMemoryAccessUsesHostPageTables; /**< Device accesses pageable memory via the host's page tables */ + int directManagedMemAccessFromHost; /**< Host can directly access managed memory on the device without migration. */ + int maxBlocksPerMultiProcessor; /**< Maximum number of resident blocks per multiprocessor */ + int accessPolicyMaxWindowSize; /**< The maximum value of ::cudaAccessPolicyWindow::num_bytes. */ + size_t reservedSharedMemPerBlock; /**< Shared memory reserved by CUDA driver per block in bytes */ + } cudaDeviceProp_t; + typedef struct cudart_handle { void *handle; uint16_t verbose; @@ -38,23 +130,17 @@ typedef struct cudart_handle { cudartReturn_t (*cudaGetDeviceCount)(int *); cudartReturn_t (*cudaDeviceGetAttribute)(int* value, cudartDeviceAttr_t attr, int device); cudartReturn_t (*cudaDriverGetVersion) (int *driverVersion); + cudartReturn_t (*cudaGetDeviceProperties) (cudaDeviceProp_t* prop, int device); } cudart_handle_t; typedef struct cudart_init_resp { char *err; // If err is non-null handle is invalid cudart_handle_t ch; + int num_devices; } cudart_init_resp_t; -typedef struct cudart_compute_capability { - char *err; - int major; - int minor; -} cudart_compute_capability_t; - - void cudart_init(char *cudart_lib_path, cudart_init_resp_t *resp); -void cudart_check_vram(cudart_handle_t ch, mem_info_t *resp); -void cudart_compute_capability(cudart_handle_t th, cudart_compute_capability_t *cc); +void cudart_check_vram(cudart_handle_t ch, int device_id, mem_info_t *resp); void cudart_release(cudart_handle_t ch); #endif // __GPU_INFO_CUDART_H__ diff --git a/gpu/gpu_info_nvcuda.c b/gpu/gpu_info_nvcuda.c new file mode 100644 index 00000000..e192d2e6 --- /dev/null +++ b/gpu/gpu_info_nvcuda.c @@ -0,0 +1,203 @@ +#ifndef __APPLE__ // TODO - maybe consider nvidia support on intel macs? + +#include +#include "gpu_info_nvcuda.h" + +void nvcuda_init(char *nvcuda_lib_path, nvcuda_init_resp_t *resp) { + CUresult ret; + resp->err = NULL; + resp->num_devices = 0; + const int buflen = 256; + char buf[buflen + 1]; + int i; + + struct lookup { + char *s; + void **p; + } l[] = { + + {"cuInit", (void *)&resp->ch.cuInit}, + {"cuDriverGetVersion", (void *)&resp->ch.cuDriverGetVersion}, + {"cuDeviceGetCount", (void *)&resp->ch.cuDeviceGetCount}, + {"cuDeviceGet", (void *)&resp->ch.cuDeviceGet}, + {"cuDeviceGetAttribute", (void *)&resp->ch.cuDeviceGetAttribute}, + {"cuDeviceGetUuid", (void *)&resp->ch.cuDeviceGetUuid}, + {"cuCtxCreate_v3", (void *)&resp->ch.cuCtxCreate_v3}, + {"cuMemGetInfo_v2", (void *)&resp->ch.cuMemGetInfo_v2}, + {"cuCtxDestroy", (void *)&resp->ch.cuCtxDestroy}, + {NULL, NULL}, + }; + + resp->ch.handle = LOAD_LIBRARY(nvcuda_lib_path, RTLD_LAZY); + if (!resp->ch.handle) { + char *msg = LOAD_ERR(); + LOG(resp->ch.verbose, "library %s load err: %s\n", nvcuda_lib_path, msg); + snprintf(buf, buflen, + "Unable to load %s library to query for Nvidia GPUs: %s", + nvcuda_lib_path, msg); + free(msg); + resp->err = strdup(buf); + return; + } + + for (i = 0; l[i].s != NULL; i++) { + *l[i].p = LOAD_SYMBOL(resp->ch.handle, l[i].s); + if (!*l[i].p) { + char *msg = LOAD_ERR(); + LOG(resp->ch.verbose, "dlerr: %s\n", msg); + UNLOAD_LIBRARY(resp->ch.handle); + resp->ch.handle = NULL; + snprintf(buf, buflen, "symbol lookup for %s failed: %s", l[i].s, + msg); + free(msg); + resp->err = strdup(buf); + return; + } + } + + ret = (*resp->ch.cuInit)(0); + if (ret != CUDA_SUCCESS) { + LOG(resp->ch.verbose, "cuInit err: %d\n", ret); + UNLOAD_LIBRARY(resp->ch.handle); + resp->ch.handle = NULL; + if (ret == CUDA_ERROR_INSUFFICIENT_DRIVER) { + resp->err = strdup("your nvidia driver is too old or missing. If you have a CUDA GPU please upgrade to run ollama"); + return; + } + snprintf(buf, buflen, "nvcuda init failure: %d", ret); + resp->err = strdup(buf); + return; + } + + int version = 0; + nvcudaDriverVersion_t driverVersion; + driverVersion.major = 0; + driverVersion.minor = 0; + + // Report driver version if we're in verbose mode, ignore errors + ret = (*resp->ch.cuDriverGetVersion)(&version); + if (ret != CUDA_SUCCESS) { + LOG(resp->ch.verbose, "cuDriverGetVersion failed: %d\n", ret); + } else { + driverVersion.major = version / 1000; + driverVersion.minor = (version - (driverVersion.major * 1000)) / 10; + LOG(resp->ch.verbose, "CUDA driver version: %d-%d\n", driverVersion.major, driverVersion.minor); + } + + ret = (*resp->ch.cuDeviceGetCount)(&resp->num_devices); + if (ret != CUDA_SUCCESS) { + LOG(resp->ch.verbose, "cuDeviceGetCount err: %d\n", ret); + UNLOAD_LIBRARY(resp->ch.handle); + resp->ch.handle = NULL; + snprintf(buf, buflen, "unable to get device count: %d", ret); + resp->err = strdup(buf); + return; + } +} + +const int buflen = 256; +void nvcuda_check_vram(nvcuda_handle_t h, int i, mem_info_t *resp) { + resp->err = NULL; + nvcudaMemory_t memInfo = {0,0}; + CUresult ret; + CUdevice device = -1; + CUcontext ctx = NULL; + char buf[buflen + 1]; + CUuuid uuid = {0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0}; + + if (h.handle == NULL) { + resp->err = strdup("nvcuda handle isn't initialized"); + return; + } + + ret = (*h.cuDeviceGet)(&device, i); + if (ret != CUDA_SUCCESS) { + snprintf(buf, buflen, "nvcuda device failed to initialize"); + resp->err = strdup(buf); + return; + } + + resp->major = 0; + resp->minor = 0; + int major = 0; + int minor = 0; + ret = (*h.cuDeviceGetAttribute)(&major, CU_DEVICE_ATTRIBUTE_COMPUTE_CAPABILITY_MAJOR, device); + if (ret != CUDA_SUCCESS) { + LOG(h.verbose, "[%d] device major lookup failure: %d\n", i, ret); + } else { + ret = (*h.cuDeviceGetAttribute)(&minor, CU_DEVICE_ATTRIBUTE_COMPUTE_CAPABILITY_MINOR, device); + if (ret != CUDA_SUCCESS) { + LOG(h.verbose, "[%d] device minor lookup failure: %d\n", i, ret); + } else { + resp->minor = minor; + resp->major = major; + } + } + + ret = (*h.cuDeviceGetUuid)(&uuid, device); + if (ret != CUDA_SUCCESS) { + LOG(h.verbose, "[%d] device uuid lookup failure: %d\n", i, ret); + snprintf(&resp->gpu_id[0], GPU_ID_LEN, "%d", i); + } else { + // GPU-d110a105-ac29-1d54-7b49-9c90440f215b + snprintf(&resp->gpu_id[0], GPU_ID_LEN, + "GPU-%02x%02x%02x%02x-%02x%02x-%02x%02x-%02x%02x-%02x%02x%02x%02x%02x%02x", + uuid.bytes[0], + uuid.bytes[1], + uuid.bytes[2], + uuid.bytes[3], + uuid.bytes[4], + uuid.bytes[5], + uuid.bytes[6], + uuid.bytes[7], + uuid.bytes[8], + uuid.bytes[9], + uuid.bytes[10], + uuid.bytes[11], + uuid.bytes[12], + uuid.bytes[13], + uuid.bytes[14], + uuid.bytes[15] + ); + } + + // To get memory we have to set (and release) a context + ret = (*h.cuCtxCreate_v3)(&ctx, NULL, 0, 0, device); + if (ret != CUDA_SUCCESS) { + snprintf(buf, buflen, "nvcuda failed to get primary device context %d", ret); + resp->err = strdup(buf); + return; + } + + ret = (*h.cuMemGetInfo_v2)(&memInfo.free, &memInfo.total); + if (ret != CUDA_SUCCESS) { + snprintf(buf, buflen, "nvcuda device memory info lookup failure %d", ret); + resp->err = strdup(buf); + // Best effort on failure... + (*h.cuCtxDestroy)(ctx); + return; + } + + resp->total = memInfo.total; + resp->free = memInfo.free; + + LOG(h.verbose, "[%s] CUDA totalMem %lu mb\n", resp->gpu_id, resp->total / 1024 / 1024); + LOG(h.verbose, "[%s] CUDA freeMem %lu mb\n", resp->gpu_id, resp->free / 1024 / 1024); + LOG(h.verbose, "[%s] Compute Capability %d.%d\n", resp->gpu_id, resp->major, resp->minor); + + + + ret = (*h.cuCtxDestroy)(ctx); + if (ret != CUDA_SUCCESS) { + LOG(1, "nvcuda failed to release primary device context %d", ret); + } +} + +void nvcuda_release(nvcuda_handle_t h) { + LOG(h.verbose, "releasing nvcuda library\n"); + UNLOAD_LIBRARY(h.handle); + // TODO and other context release logic? + h.handle = NULL; +} + +#endif // __APPLE__ \ No newline at end of file diff --git a/gpu/gpu_info_nvcuda.h b/gpu/gpu_info_nvcuda.h new file mode 100644 index 00000000..c4d94edd --- /dev/null +++ b/gpu/gpu_info_nvcuda.h @@ -0,0 +1,71 @@ +#ifndef __APPLE__ +#ifndef __GPU_INFO_NVCUDA_H__ +#define __GPU_INFO_NVCUDA_H__ +#include "gpu_info.h" + +// Just enough typedef's to dlopen/dlsym for memory information +typedef enum cudaError_enum { + CUDA_SUCCESS = 0, + CUDA_ERROR_INVALID_VALUE = 1, + CUDA_ERROR_MEMORY_ALLOCATION = 2, + CUDA_ERROR_NOT_INITIALIZED = 3, + CUDA_ERROR_INSUFFICIENT_DRIVER = 35, + // Other values omitted for now... +} CUresult; + +typedef enum CUdevice_attribute_enum { + CU_DEVICE_ATTRIBUTE_COMPUTE_CAPABILITY_MAJOR = 75, + CU_DEVICE_ATTRIBUTE_COMPUTE_CAPABILITY_MINOR = 76, + + // TODO - not yet wired up but may be useful for Jetson or other + // integrated GPU scenarios with shared memory + CU_DEVICE_ATTRIBUTE_INTEGRATED = 18 + +} CUdevice_attribute; + +typedef void *nvcudaDevice_t; // Opaque is sufficient +typedef struct nvcudaMemory_st { + uint64_t total; + uint64_t free; +} nvcudaMemory_t; + +typedef struct nvcudaDriverVersion { + int major; + int minor; +} nvcudaDriverVersion_t; + +typedef struct CUuuid_st { + unsigned char bytes[16]; +} CUuuid; + +typedef int CUdevice; +typedef void* CUcontext; + +typedef struct nvcuda_handle { + void *handle; + uint16_t verbose; + CUresult (*cuInit)(unsigned int Flags); + CUresult (*cuDriverGetVersion)(int *driverVersion); + CUresult (*cuDeviceGetCount)(int *); + CUresult (*cuDeviceGet)(CUdevice* device, int ordinal); + CUresult (*cuDeviceGetAttribute)(int* pi, CUdevice_attribute attrib, CUdevice dev); + CUresult (*cuDeviceGetUuid)(CUuuid* uuid, CUdevice dev); // signature compatible with cuDeviceGetUuid_v2 + + // Context specific aspects + CUresult (*cuCtxCreate_v3)(CUcontext* pctx, void *params, int len, unsigned int flags, CUdevice dev); + CUresult (*cuMemGetInfo_v2)(uint64_t* free, uint64_t* total); + CUresult (*cuCtxDestroy)(CUcontext ctx); +} nvcuda_handle_t; + +typedef struct nvcuda_init_resp { + char *err; // If err is non-null handle is invalid + nvcuda_handle_t ch; + int num_devices; +} nvcuda_init_resp_t; + +void nvcuda_init(char *nvcuda_lib_path, nvcuda_init_resp_t *resp); +void nvcuda_check_vram(nvcuda_handle_t ch, int device_id, mem_info_t *resp); +void nvcuda_release(nvcuda_handle_t ch); + +#endif // __GPU_INFO_NVCUDA_H__ +#endif // __APPLE__ diff --git a/gpu/gpu_info_nvml.c b/gpu/gpu_info_nvml.c deleted file mode 100644 index 67c80b0f..00000000 --- a/gpu/gpu_info_nvml.c +++ /dev/null @@ -1,221 +0,0 @@ -#ifndef __APPLE__ // TODO - maybe consider nvidia support on intel macs? - -#include - -#include "gpu_info_nvml.h" - -void nvml_init(char *nvml_lib_path, nvml_init_resp_t *resp) { - nvmlReturn_t ret; - resp->err = NULL; - const int buflen = 256; - char buf[buflen + 1]; - int i; - - struct lookup { - char *s; - void **p; - } l[] = { - {"nvmlInit_v2", (void *)&resp->ch.nvmlInit_v2}, - {"nvmlShutdown", (void *)&resp->ch.nvmlShutdown}, - {"nvmlDeviceGetHandleByIndex", (void *)&resp->ch.nvmlDeviceGetHandleByIndex}, - {"nvmlDeviceGetMemoryInfo", (void *)&resp->ch.nvmlDeviceGetMemoryInfo}, - {"nvmlDeviceGetCount_v2", (void *)&resp->ch.nvmlDeviceGetCount_v2}, - {"nvmlDeviceGetCudaComputeCapability", (void *)&resp->ch.nvmlDeviceGetCudaComputeCapability}, - {"nvmlSystemGetDriverVersion", (void *)&resp->ch.nvmlSystemGetDriverVersion}, - {"nvmlDeviceGetName", (void *)&resp->ch.nvmlDeviceGetName}, - {"nvmlDeviceGetSerial", (void *)&resp->ch.nvmlDeviceGetSerial}, - {"nvmlDeviceGetVbiosVersion", (void *)&resp->ch.nvmlDeviceGetVbiosVersion}, - {"nvmlDeviceGetBoardPartNumber", (void *)&resp->ch.nvmlDeviceGetBoardPartNumber}, - {"nvmlDeviceGetBrand", (void *)&resp->ch.nvmlDeviceGetBrand}, - {NULL, NULL}, - }; - - resp->ch.handle = LOAD_LIBRARY(nvml_lib_path, RTLD_LAZY); - if (!resp->ch.handle) { - char *msg = LOAD_ERR(); - LOG(resp->ch.verbose, "library %s load err: %s\n", nvml_lib_path, msg); - snprintf(buf, buflen, - "Unable to load %s library to query for Nvidia GPUs: %s", - nvml_lib_path, msg); - free(msg); - resp->err = strdup(buf); - return; - } - - // TODO once we've squashed the remaining corner cases remove this log - LOG(resp->ch.verbose, "wiring nvidia management library functions in %s\n", nvml_lib_path); - - for (i = 0; l[i].s != NULL; i++) { - // TODO once we've squashed the remaining corner cases remove this log - LOG(resp->ch.verbose, "dlsym: %s\n", l[i].s); - - *l[i].p = LOAD_SYMBOL(resp->ch.handle, l[i].s); - if (!l[i].p) { - resp->ch.handle = NULL; - char *msg = LOAD_ERR(); - LOG(resp->ch.verbose, "dlerr: %s\n", msg); - UNLOAD_LIBRARY(resp->ch.handle); - snprintf(buf, buflen, "symbol lookup for %s failed: %s", l[i].s, - msg); - free(msg); - resp->err = strdup(buf); - return; - } - } - - ret = (*resp->ch.nvmlInit_v2)(); - if (ret != NVML_SUCCESS) { - LOG(resp->ch.verbose, "nvmlInit_v2 err: %d\n", ret); - UNLOAD_LIBRARY(resp->ch.handle); - resp->ch.handle = NULL; - snprintf(buf, buflen, "nvml vram init failure: %d", ret); - resp->err = strdup(buf); - return; - } - - // Report driver version if we're in verbose mode, ignore errors - ret = (*resp->ch.nvmlSystemGetDriverVersion)(buf, buflen); - if (ret != NVML_SUCCESS) { - LOG(resp->ch.verbose, "nvmlSystemGetDriverVersion failed: %d\n", ret); - } else { - LOG(resp->ch.verbose, "CUDA driver version: %s\n", buf); - } -} - -void nvml_check_vram(nvml_handle_t h, mem_info_t *resp) { - resp->err = NULL; - nvmlDevice_t device; - nvmlMemory_t memInfo = {0}; - nvmlReturn_t ret; - const int buflen = 256; - char buf[buflen + 1]; - int i; - - if (h.handle == NULL) { - resp->err = strdup("nvml handle isn't initialized"); - return; - } - - ret = (*h.nvmlDeviceGetCount_v2)(&resp->count); - if (ret != NVML_SUCCESS) { - snprintf(buf, buflen, "unable to get device count: %d", ret); - resp->err = strdup(buf); - return; - } - - resp->total = 0; - resp->free = 0; - for (i = 0; i < resp->count; i++) { - ret = (*h.nvmlDeviceGetHandleByIndex)(i, &device); - if (ret != NVML_SUCCESS) { - snprintf(buf, buflen, "unable to get device handle %d: %d", i, ret); - resp->err = strdup(buf); - return; - } - - ret = (*h.nvmlDeviceGetMemoryInfo)(device, &memInfo); - if (ret != NVML_SUCCESS) { - snprintf(buf, buflen, "device memory info lookup failure %d: %d", i, ret); - resp->err = strdup(buf); - return; - } - if (h.verbose) { - nvmlBrandType_t brand = 0; - // When in verbose mode, report more information about - // the card we discover, but don't fail on error - ret = (*h.nvmlDeviceGetName)(device, buf, buflen); - if (ret != NVML_SUCCESS) { - LOG(h.verbose, "nvmlDeviceGetName failed: %d\n", ret); - } else { - LOG(h.verbose, "[%d] CUDA device name: %s\n", i, buf); - } - ret = (*h.nvmlDeviceGetBoardPartNumber)(device, buf, buflen); - if (ret != NVML_SUCCESS) { - LOG(h.verbose, "nvmlDeviceGetBoardPartNumber failed: %d\n", ret); - } else { - LOG(h.verbose, "[%d] CUDA part number: %s\n", i, buf); - } - ret = (*h.nvmlDeviceGetSerial)(device, buf, buflen); - if (ret != NVML_SUCCESS) { - LOG(h.verbose, "nvmlDeviceGetSerial failed: %d\n", ret); - } else { - LOG(h.verbose, "[%d] CUDA S/N: %s\n", i, buf); - } - ret = (*h.nvmlDeviceGetVbiosVersion)(device, buf, buflen); - if (ret != NVML_SUCCESS) { - LOG(h.verbose, "nvmlDeviceGetVbiosVersion failed: %d\n", ret); - } else { - LOG(h.verbose, "[%d] CUDA vbios version: %s\n", i, buf); - } - ret = (*h.nvmlDeviceGetBrand)(device, &brand); - if (ret != NVML_SUCCESS) { - LOG(h.verbose, "nvmlDeviceGetBrand failed: %d\n", ret); - } else { - LOG(h.verbose, "[%d] CUDA brand: %d\n", i, brand); - } - } - - LOG(h.verbose, "[%d] CUDA totalMem %ld\n", i, memInfo.total); - LOG(h.verbose, "[%d] CUDA freeMem %ld\n", i, memInfo.free); - - resp->total += memInfo.total; - resp->free += memInfo.free; - } -} - -void nvml_compute_capability(nvml_handle_t h, nvml_compute_capability_t *resp) { - resp->err = NULL; - resp->major = 0; - resp->minor = 0; - nvmlDevice_t device; - int major = 0; - int minor = 0; - nvmlReturn_t ret; - const int buflen = 256; - char buf[buflen + 1]; - int i; - - if (h.handle == NULL) { - resp->err = strdup("nvml handle not initialized"); - return; - } - - unsigned int devices; - ret = (*h.nvmlDeviceGetCount_v2)(&devices); - if (ret != NVML_SUCCESS) { - snprintf(buf, buflen, "unable to get device count: %d", ret); - resp->err = strdup(buf); - return; - } - - for (i = 0; i < devices; i++) { - ret = (*h.nvmlDeviceGetHandleByIndex)(i, &device); - if (ret != NVML_SUCCESS) { - snprintf(buf, buflen, "unable to get device handle %d: %d", i, ret); - resp->err = strdup(buf); - return; - } - - ret = (*h.nvmlDeviceGetCudaComputeCapability)(device, &major, &minor); - if (ret != NVML_SUCCESS) { - snprintf(buf, buflen, "device compute capability lookup failure %d: %d", i, ret); - resp->err = strdup(buf); - return; - } - // Report the lowest major.minor we detect as that limits our compatibility - if (resp->major == 0 || resp->major > major ) { - resp->major = major; - resp->minor = minor; - } else if ( resp->major == major && resp->minor > minor ) { - resp->minor = minor; - } - } -} - -void nvml_release(nvml_handle_t h) { - LOG(h.verbose, "releasing nvml library\n"); - UNLOAD_LIBRARY(h.handle); - h.handle = NULL; -} - -#endif // __APPLE__ \ No newline at end of file diff --git a/gpu/gpu_info_nvml.h b/gpu/gpu_info_nvml.h deleted file mode 100644 index bd1d6001..00000000 --- a/gpu/gpu_info_nvml.h +++ /dev/null @@ -1,57 +0,0 @@ -#ifndef __APPLE__ -#ifndef __GPU_INFO_NVML_H__ -#define __GPU_INFO_NVML_H__ -#include "gpu_info.h" - -// Just enough typedef's to dlopen/dlsym for memory information -typedef enum nvmlReturn_enum { - NVML_SUCCESS = 0, - // Other values omitted for now... -} nvmlReturn_t; -typedef void *nvmlDevice_t; // Opaque is sufficient -typedef struct nvmlMemory_st { - unsigned long long total; - unsigned long long free; - unsigned long long used; -} nvmlMemory_t; - -typedef enum nvmlBrandType_enum -{ - NVML_BRAND_UNKNOWN = 0, -} nvmlBrandType_t; - -typedef struct nvml_handle { - void *handle; - uint16_t verbose; - nvmlReturn_t (*nvmlInit_v2)(void); - nvmlReturn_t (*nvmlShutdown)(void); - nvmlReturn_t (*nvmlDeviceGetHandleByIndex)(unsigned int, nvmlDevice_t *); - nvmlReturn_t (*nvmlDeviceGetMemoryInfo)(nvmlDevice_t, nvmlMemory_t *); - nvmlReturn_t (*nvmlDeviceGetCount_v2)(unsigned int *); - nvmlReturn_t (*nvmlDeviceGetCudaComputeCapability)(nvmlDevice_t, int* major, int* minor); - nvmlReturn_t (*nvmlSystemGetDriverVersion) (char* version, unsigned int length); - nvmlReturn_t (*nvmlDeviceGetName) (nvmlDevice_t device, char* name, unsigned int length); - nvmlReturn_t (*nvmlDeviceGetSerial) (nvmlDevice_t device, char* serial, unsigned int length); - nvmlReturn_t (*nvmlDeviceGetVbiosVersion) (nvmlDevice_t device, char* version, unsigned int length); - nvmlReturn_t (*nvmlDeviceGetBoardPartNumber) (nvmlDevice_t device, char* partNumber, unsigned int length); - nvmlReturn_t (*nvmlDeviceGetBrand) (nvmlDevice_t device, nvmlBrandType_t* type); -} nvml_handle_t; - -typedef struct nvml_init_resp { - char *err; // If err is non-null handle is invalid - nvml_handle_t ch; -} nvml_init_resp_t; - -typedef struct nvml_compute_capability { - char *err; - int major; - int minor; -} nvml_compute_capability_t; - -void nvml_init(char *nvml_lib_path, nvml_init_resp_t *resp); -void nvml_check_vram(nvml_handle_t ch, mem_info_t *resp); -void nvml_compute_capability(nvml_handle_t ch, nvml_compute_capability_t *cc); -void nvml_release(nvml_handle_t ch); - -#endif // __GPU_INFO_NVML_H__ -#endif // __APPLE__ \ No newline at end of file diff --git a/gpu/gpu_test.go b/gpu/gpu_test.go index f57597b5..a28cbe8c 100644 --- a/gpu/gpu_test.go +++ b/gpu/gpu_test.go @@ -9,23 +9,16 @@ import ( func TestBasicGetGPUInfo(t *testing.T) { info := GetGPUInfo() - assert.Contains(t, "cuda rocm cpu metal", info.Library) - - switch runtime.GOOS { - case "darwin": - // TODO - remove this once MacOS returns some size for CPU - return - case "linux", "windows": - assert.Greater(t, info.TotalMemory, uint64(0)) - assert.Greater(t, info.FreeMemory, uint64(0)) - assert.Greater(t, info.DeviceCount, uint32(0)) - default: - return + assert.Greater(t, len(info), 0) + assert.Contains(t, "cuda rocm cpu metal", info[0].Library) + if info[0].Library != "cpu" { + assert.Greater(t, info[0].TotalMemory, uint64(0)) + assert.Greater(t, info[0].FreeMemory, uint64(0)) } } func TestCPUMemInfo(t *testing.T) { - info, err := getCPUMem() + info, err := GetCPUMem() assert.NoError(t, err) switch runtime.GOOS { case "darwin": diff --git a/gpu/types.go b/gpu/types.go index 7fe6c40c..7a5d5ba7 100644 --- a/gpu/types.go +++ b/gpu/types.go @@ -3,7 +3,6 @@ package gpu type memInfo struct { TotalMemory uint64 `json:"total_memory,omitempty"` FreeMemory uint64 `json:"free_memory,omitempty"` - DeviceCount uint32 `json:"device_count,omitempty"` } // Beginning of an `ollama info` command @@ -17,11 +16,49 @@ type GpuInfo struct { // MinimumMemory represents the minimum memory required to use the GPU MinimumMemory uint64 `json:"-"` - // TODO add other useful attributes about the card here for discovery information + // Any extra PATH/LD_LIBRARY_PATH dependencies required for the Library to operate properly + DependencyPath string `json:"lib_path,omitempty"` + + // GPU information + ID string `json:"gpu_id"` // string to use for selection of this specific GPU + Name string `json:"name"` // user friendly name if available + Major int `json:"major,omitempty"` // Major compatibility version (CC or gfx) + Minor int `json:"minor,omitempty"` // Minor compatibility version (CC or gfx) + Patch int `json:"patch,omitempty"` // Patch compatibility only matters on AMD + + // TODO other performance capability info to help in scheduling decisions } -type Version struct { - Major uint - Minor uint - Patch uint +type GpuInfoList []GpuInfo + +// Split up the set of gpu info's by Library and variant +func (l GpuInfoList) ByLibrary() []GpuInfoList { + resp := []GpuInfoList{} + libs := []string{} + for _, info := range l { + found := false + requested := info.Library + if info.Variant != "" { + requested += "_" + info.Variant + } + for i, lib := range libs { + if lib == requested { + resp[i] = append(resp[i], info) + found = true + break + } + } + if !found { + libs = append(libs, info.Library) + resp = append(resp, []GpuInfo{info}) + } + } + return resp } + +// Sort by Free Space +type ByFreeMemory []GpuInfo + +func (a ByFreeMemory) Len() int { return len(a) } +func (a ByFreeMemory) Swap(i, j int) { a[i], a[j] = a[j], a[i] } +func (a ByFreeMemory) Less(i, j int) bool { return a[i].FreeMemory < a[j].FreeMemory } diff --git a/integration/basic_test.go b/integration/basic_test.go index 40bde03c..6e632a1c 100644 --- a/integration/basic_test.go +++ b/integration/basic_test.go @@ -4,11 +4,14 @@ package integration import ( "context" - "net/http" + "log/slog" + "os" + "runtime" "testing" "time" "github.com/ollama/ollama/api" + "github.com/stretchr/testify/require" ) func TestOrcaMiniBlueSky(t *testing.T) { @@ -24,5 +27,44 @@ func TestOrcaMiniBlueSky(t *testing.T) { "seed": 123, }, } - GenerateTestHelper(ctx, t, &http.Client{}, req, []string{"rayleigh", "scattering"}) + GenerateTestHelper(ctx, t, req, []string{"rayleigh", "scattering"}) +} + +func TestUnicodeModelDir(t *testing.T) { + // This is only useful for Windows with utf-16 characters, so skip this test for other platforms + if runtime.GOOS != "windows" { + t.Skip("Unicode test only applicable to windows") + } + // Only works for local testing + if os.Getenv("OLLAMA_TEST_EXISTING") != "" { + t.Skip("TestUnicodeModelDir only works for local testing, skipping") + } + + modelDir, err := os.MkdirTemp("", "ollama_埃") + require.NoError(t, err) + defer os.RemoveAll(modelDir) + slog.Info("unicode", "OLLAMA_MODELS", modelDir) + + oldModelsDir := os.Getenv("OLLAMA_MODELS") + if oldModelsDir == "" { + defer os.Unsetenv("OLLAMA_MODELS") + } else { + defer os.Setenv("OLLAMA_MODELS", oldModelsDir) + } + err = os.Setenv("OLLAMA_MODELS", modelDir) + require.NoError(t, err) + + ctx, cancel := context.WithTimeout(context.Background(), 2*time.Minute) + defer cancel() + + req := api.GenerateRequest{ + Model: "orca-mini", + Prompt: "why is the sky blue?", + Stream: &stream, + Options: map[string]interface{}{ + "temperature": 0, + "seed": 123, + }, + } + GenerateTestHelper(ctx, t, req, []string{"rayleigh", "scattering"}) } diff --git a/integration/concurrency_test.go b/integration/concurrency_test.go new file mode 100644 index 00000000..110301ab --- /dev/null +++ b/integration/concurrency_test.go @@ -0,0 +1,225 @@ +//go:build integration + +package integration + +import ( + "context" + "log/slog" + "os" + "strconv" + "sync" + "testing" + "time" + + "github.com/ollama/ollama/api" + "github.com/stretchr/testify/require" +) + +func TestMultiModelConcurrency(t *testing.T) { + var ( + req = [2]api.GenerateRequest{ + { + Model: "orca-mini", + Prompt: "why is the ocean blue?", + Stream: &stream, + Options: map[string]interface{}{ + "seed": 42, + "temperature": 0.0, + }, + }, { + Model: "tinydolphin", + Prompt: "what is the origin of the us thanksgiving holiday?", + Stream: &stream, + Options: map[string]interface{}{ + "seed": 42, + "temperature": 0.0, + }, + }, + } + resp = [2][]string{ + []string{"sunlight"}, + []string{"england", "english", "massachusetts", "pilgrims"}, + } + ) + var wg sync.WaitGroup + wg.Add(len(req)) + ctx, cancel := context.WithTimeout(context.Background(), time.Second*120) + defer cancel() + for i := 0; i < len(req); i++ { + go func(i int) { + defer wg.Done() + GenerateTestHelper(ctx, t, req[i], resp[i]) + }(i) + } + wg.Wait() +} + +func TestIntegrationConcurrentPredictOrcaMini(t *testing.T) { + ctx, cancel := context.WithTimeout(context.Background(), 10*time.Minute) // GTX 750 2G card takes ~9 minutes + defer cancel() + client, _, cleanup := InitServerConnection(ctx, t) + defer cleanup() + + req, resp := GenerateRequests() + // Get the server running (if applicable) warm the model up with a single initial request + DoGenerate(ctx, t, client, req[0], resp[0], 60*time.Second, 5*time.Second) + + var wg sync.WaitGroup + wg.Add(len(req)) + for i := 0; i < len(req); i++ { + go func(i int) { + defer wg.Done() + for j := 0; j < 5; j++ { + slog.Info("Starting", "req", i, "iter", j) + // On slower GPUs it can take a while to process the 4 concurrent requests + // so we allow a much longer initial timeout + DoGenerate(ctx, t, client, req[i], resp[i], 90*time.Second, 5*time.Second) + } + }(i) + } + wg.Wait() +} + +// Stress the system if we know how much VRAM it has, and attempt to load more models than will fit +func TestMultiModelStress(t *testing.T) { + vram := os.Getenv("OLLAMA_MAX_VRAM") + if vram == "" { + t.Skip("OLLAMA_MAX_VRAM not specified, can't pick the right models for the stress test") + } + max, err := strconv.ParseUint(vram, 10, 64) + require.NoError(t, err) + const MB = uint64(1024 * 1024) + type model struct { + name string + size uint64 // Approximate amount of VRAM they typically use when fully loaded in VRAM + } + + smallModels := []model{ + { + name: "orca-mini", + size: 2992 * MB, + }, + { + name: "phi", + size: 2616 * MB, + }, + { + name: "gemma:2b", + size: 2364 * MB, + }, + { + name: "stable-code:3b", + size: 2608 * MB, + }, + { + name: "starcoder2:3b", + size: 2166 * MB, + }, + } + mediumModels := []model{ + { + name: "llama2", + size: 5118 * MB, + }, + { + name: "mistral", + size: 4620 * MB, + }, + { + name: "orca-mini:7b", + size: 5118 * MB, + }, + { + name: "dolphin-mistral", + size: 4620 * MB, + }, + { + name: "gemma:7b", + size: 5000 * MB, + }, + // TODO - uncomment this once #3565 is merged and this is rebased on it + // { + // name: "codellama:7b", + // size: 5118 * MB, + // }, + } + + // These seem to be too slow to be useful... + // largeModels := []model{ + // { + // name: "llama2:13b", + // size: 7400 * MB, + // }, + // { + // name: "codellama:13b", + // size: 7400 * MB, + // }, + // { + // name: "orca-mini:13b", + // size: 7400 * MB, + // }, + // { + // name: "gemma:7b", + // size: 5000 * MB, + // }, + // { + // name: "starcoder2:15b", + // size: 9100 * MB, + // }, + // } + + var chosenModels []model + switch { + case max < 10000*MB: + slog.Info("selecting small models") + chosenModels = smallModels + // case max < 30000*MB: + default: + slog.Info("selecting medium models") + chosenModels = mediumModels + // default: + // slog.Info("selecting large models") + // chosenModels = largModels + } + + req, resp := GenerateRequests() + + for i := range req { + if i > len(chosenModels) { + break + } + req[i].Model = chosenModels[i].name + } + + ctx, cancel := context.WithTimeout(context.Background(), 15*time.Minute) // TODO baseline -- 10m too short + defer cancel() + client, _, cleanup := InitServerConnection(ctx, t) + defer cleanup() + + // Make sure all the models are pulled before we get started + for _, r := range req { + require.NoError(t, PullIfMissing(ctx, client, r.Model)) + } + + var wg sync.WaitGroup + consumed := uint64(256 * MB) // Assume some baseline usage + for i := 0; i < len(req); i++ { + // Always get at least 2 models, but dont' overshoot VRAM too much or we'll take too long + if i > 1 && consumed > max { + slog.Info("achieved target vram exhaustion", "count", i, "vramMB", max/1024/1024, "modelsMB", consumed/1024/1024) + break + } + consumed += chosenModels[i].size + slog.Info("target vram", "count", i, "vramMB", max/1024/1024, "modelsMB", consumed/1024/1024) + + wg.Add(1) + go func(i int) { + defer wg.Done() + for j := 0; j < 3; j++ { + slog.Info("Starting", "req", i, "iter", j, "model", req[i].Model) + DoGenerate(ctx, t, client, req[i], resp[i], 90*time.Second, 5*time.Second) + } + }(i) + } + wg.Wait() +} diff --git a/integration/context_test.go b/integration/context_test.go index 80ea540b..08033125 100644 --- a/integration/context_test.go +++ b/integration/context_test.go @@ -4,7 +4,6 @@ package integration import ( "context" - "net/http" "testing" "time" @@ -25,5 +24,5 @@ func TestContextExhaustion(t *testing.T) { "num_ctx": 128, }, } - GenerateTestHelper(ctx, t, &http.Client{}, req, []string{"once", "upon", "lived"}) + GenerateTestHelper(ctx, t, req, []string{"once", "upon", "lived"}) } diff --git a/integration/llm_image_test.go b/integration/llm_image_test.go index 94082d6e..77319aef 100644 --- a/integration/llm_image_test.go +++ b/integration/llm_image_test.go @@ -5,7 +5,6 @@ package integration import ( "context" "encoding/base64" - "net/http" "testing" "time" @@ -29,10 +28,11 @@ func TestIntegrationMultimodal(t *testing.T) { }, } - resp := "the ollamas" + // Note: sometimes it returns "the ollamas" sometimes "the ollams" + resp := "the ollam" ctx, cancel := context.WithTimeout(context.Background(), 3*time.Minute) defer cancel() - GenerateTestHelper(ctx, t, &http.Client{}, req, []string{resp}) + GenerateTestHelper(ctx, t, req, []string{resp}) } const imageEncoding = `iVBORw0KGgoAAAANSUhEUgAAANIAAAB4CAYAAACHHqzKAAAAAXNSR0IArs4c6QAAAIRlWElmTU0AKgAAAAgABQESAAMAAAABAAEAAAEaAAUAAAABAAAASgEb diff --git a/integration/llm_test.go b/integration/llm_test.go index bcc169d6..4952b072 100644 --- a/integration/llm_test.go +++ b/integration/llm_test.go @@ -4,8 +4,6 @@ package integration import ( "context" - "net/http" - "sync" "testing" "time" @@ -45,25 +43,5 @@ var ( func TestIntegrationSimpleOrcaMini(t *testing.T) { ctx, cancel := context.WithTimeout(context.Background(), time.Second*120) defer cancel() - GenerateTestHelper(ctx, t, &http.Client{}, req[0], resp[0]) + GenerateTestHelper(ctx, t, req[0], resp[0]) } - -// TODO -// The server always loads a new runner and closes the old one, which forces serial execution -// At present this test case fails with concurrency problems. Eventually we should try to -// get true concurrency working with n_parallel support in the backend -func TestIntegrationConcurrentPredictOrcaMini(t *testing.T) { - var wg sync.WaitGroup - wg.Add(len(req)) - ctx, cancel := context.WithTimeout(context.Background(), time.Second*120) - defer cancel() - for i := 0; i < len(req); i++ { - go func(i int) { - defer wg.Done() - GenerateTestHelper(ctx, t, &http.Client{}, req[i], resp[i]) - }(i) - } - wg.Wait() -} - -// TODO - create a parallel test with 2 different models once we support concurrency diff --git a/integration/max_queue_test.go b/integration/max_queue_test.go new file mode 100644 index 00000000..43b15c6c --- /dev/null +++ b/integration/max_queue_test.go @@ -0,0 +1,117 @@ +//go:build integration + +package integration + +import ( + "context" + "errors" + "fmt" + "log/slog" + "os" + "strconv" + "strings" + "sync" + "testing" + "time" + + "github.com/ollama/ollama/api" + "github.com/stretchr/testify/require" +) + +func TestMaxQueue(t *testing.T) { + // Note: This test can be quite slow when running in CPU mode, so keep the threadCount low unless your on GPU + // Also note that by default Darwin can't sustain > ~128 connections without adjusting limits + threadCount := 32 + mq := os.Getenv("OLLAMA_MAX_QUEUE") + if mq != "" { + var err error + threadCount, err = strconv.Atoi(mq) + require.NoError(t, err) + } else { + os.Setenv("OLLAMA_MAX_QUEUE", fmt.Sprintf("%d", threadCount)) + } + + req := api.GenerateRequest{ + Model: "orca-mini", + Prompt: "write a long historical fiction story about christopher columbus. use at least 10 facts from his actual journey", + Options: map[string]interface{}{ + "seed": 42, + "temperature": 0.0, + }, + } + resp := []string{"explore", "discover", "ocean"} + + // CPU mode takes much longer at the limit with a large queue setting + ctx, cancel := context.WithTimeout(context.Background(), 5*time.Minute) + defer cancel() + client, _, cleanup := InitServerConnection(ctx, t) + defer cleanup() + + require.NoError(t, PullIfMissing(ctx, client, req.Model)) + + // Context for the worker threads so we can shut them down + // embedCtx, embedCancel := context.WithCancel(ctx) + embedCtx := ctx + + var genwg sync.WaitGroup + go func() { + genwg.Add(1) + defer genwg.Done() + slog.Info("Starting generate request") + DoGenerate(ctx, t, client, req, resp, 45*time.Second, 5*time.Second) + slog.Info("generate completed") + }() + + // Give the generate a chance to get started before we start hammering on embed requests + time.Sleep(5 * time.Millisecond) + + threadCount += 10 // Add a few extra to ensure we push the queue past its limit + busyCount := 0 + resetByPeerCount := 0 + canceledCount := 0 + succesCount := 0 + counterMu := sync.Mutex{} + var embedwg sync.WaitGroup + for i := 0; i < threadCount; i++ { + go func(i int) { + embedwg.Add(1) + defer embedwg.Done() + slog.Info("embed started", "id", i) + embedReq := api.EmbeddingRequest{ + Model: req.Model, + Prompt: req.Prompt, + Options: req.Options, + } + // Fresh client for every request + client, _ = GetTestEndpoint() + + resp, genErr := client.Embeddings(embedCtx, &embedReq) + counterMu.Lock() + defer counterMu.Unlock() + switch { + case genErr == nil: + succesCount++ + require.Greater(t, len(resp.Embedding), 5) // somewhat arbitrary, but sufficient to be reasonable + case errors.Is(genErr, context.Canceled): + canceledCount++ + case strings.Contains(genErr.Error(), "busy"): + busyCount++ + case strings.Contains(genErr.Error(), "connection reset by peer"): + resetByPeerCount++ + default: + require.NoError(t, genErr, "%d request failed", i) + } + + slog.Info("embed finished", "id", i) + }(i) + } + genwg.Wait() + slog.Info("generate done, waiting for embeds") + embedwg.Wait() + + require.Equal(t, resetByPeerCount, 0, "Connections reset by peer, have you updated your fd and socket limits?") + require.True(t, busyCount > 0, "no requests hit busy error but some should have") + require.True(t, canceledCount == 0, "no requests should have been canceled due to timeout") + + slog.Info("embeds completed", "success", succesCount, "busy", busyCount, "reset", resetByPeerCount, "canceled", canceledCount) +} diff --git a/integration/utils_test.go b/integration/utils_test.go index 0f712271..e133e76d 100644 --- a/integration/utils_test.go +++ b/integration/utils_test.go @@ -5,13 +5,14 @@ package integration import ( "bytes" "context" - "encoding/json" + "errors" "fmt" "io" "log/slog" "math/rand" "net" "net/http" + "net/url" "os" "path/filepath" "runtime" @@ -23,9 +24,13 @@ import ( "github.com/ollama/ollama/api" "github.com/ollama/ollama/app/lifecycle" - "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" ) +func Init() { + lifecycle.InitLogging() +} + func FindPort() string { port := 0 if a, err := net.ResolveTCPAddr("tcp", "localhost:0"); err == nil { @@ -41,7 +46,7 @@ func FindPort() string { return strconv.Itoa(port) } -func GetTestEndpoint() (string, string) { +func GetTestEndpoint() (*api.Client, string) { defaultPort := "11434" ollamaHost := os.Getenv("OLLAMA_HOST") @@ -67,16 +72,20 @@ func GetTestEndpoint() (string, string) { port = FindPort() } - url := fmt.Sprintf("%s:%s", host, port) - slog.Info("server connection", "url", url) - return scheme, url + slog.Info("server connection", "host", host, "port", port) + + return api.NewClient( + &url.URL{ + Scheme: scheme, + Host: net.JoinHostPort(host, port), + }, + http.DefaultClient), fmt.Sprintf("%s:%s", host, port) } -// TODO make fanicier, grab logs, etc. var serverMutex sync.Mutex var serverReady bool -func StartServer(ctx context.Context, ollamaHost string) error { +func startServer(ctx context.Context, ollamaHost string) error { // Make sure the server has been built CLIName, err := filepath.Abs("../ollama") if err != nil { @@ -98,7 +107,7 @@ func StartServer(ctx context.Context, ollamaHost string) error { if tmp := os.Getenv("OLLAMA_HOST"); tmp != ollamaHost { slog.Info("setting env", "OLLAMA_HOST", ollamaHost) - os.Setenv("OLLAMA_HOST", ollamaHost) + t.Setenv("OLLAMA_HOST", ollamaHost) } slog.Info("starting server", "url", ollamaHost) @@ -125,67 +134,76 @@ func StartServer(ctx context.Context, ollamaHost string) error { return nil } -func PullIfMissing(ctx context.Context, client *http.Client, scheme, testEndpoint, modelName string) error { +func PullIfMissing(ctx context.Context, client *api.Client, modelName string) error { slog.Info("checking status of model", "model", modelName) showReq := &api.ShowRequest{Name: modelName} - requestJSON, err := json.Marshal(showReq) - if err != nil { - return err - } - req, err := http.NewRequest("POST", scheme+"://"+testEndpoint+"/api/show", bytes.NewReader(requestJSON)) - if err != nil { + showCtx, cancel := context.WithDeadlineCause( + ctx, + time.Now().Add(5*time.Second), + fmt.Errorf("show for existing model %s took too long", modelName), + ) + defer cancel() + _, err := client.Show(showCtx, showReq) + var statusError api.StatusError + switch { + case errors.As(err, &statusError) && statusError.StatusCode == http.StatusNotFound: + break + case err != nil: return err - } - - // Make the request with the HTTP client - response, err := client.Do(req.WithContext(ctx)) - if err != nil { - return err - } - defer response.Body.Close() - if response.StatusCode == 200 { + default: slog.Info("model already present", "model", modelName) return nil } - slog.Info("model missing", "status", response.StatusCode) + slog.Info("model missing", "model", modelName) + stallDuration := 30 * time.Second // This includes checksum verification, which can take a while on larger models + stallTimer := time.NewTimer(stallDuration) + fn := func(resp api.ProgressResponse) error { + // fmt.Print(".") + if !stallTimer.Reset(stallDuration) { + return fmt.Errorf("stall was detected, aborting status reporting") + } + return nil + } + + stream := true pullReq := &api.PullRequest{Name: modelName, Stream: &stream} - requestJSON, err = json.Marshal(pullReq) - if err != nil { - return err - } - req, err = http.NewRequest("POST", scheme+"://"+testEndpoint+"/api/pull", bytes.NewReader(requestJSON)) - if err != nil { - return err - } - slog.Info("pulling", "model", modelName) + var pullError error - response, err = client.Do(req.WithContext(ctx)) - if err != nil { - return err + done := make(chan int) + go func() { + pullError = client.Pull(ctx, pullReq, fn) + done <- 0 + }() + + select { + case <-stallTimer.C: + return fmt.Errorf("download stalled") + case <-done: + return pullError } - defer response.Body.Close() - if response.StatusCode != 200 { - return fmt.Errorf("failed to pull model") // TODO more details perhaps - } - slog.Info("model pulled", "model", modelName) - return nil } var serverProcMutex sync.Mutex -func GenerateTestHelper(ctx context.Context, t *testing.T, client *http.Client, genReq api.GenerateRequest, anyResp []string) { - - // TODO maybe stuff in an init routine? - lifecycle.InitLogging() - - requestJSON, err := json.Marshal(genReq) - if err != nil { - t.Fatalf("Error serializing request: %v", err) +// Returns an Client, the testEndpoint, and a cleanup function, fails the test on errors +// Starts the server if needed +func InitServerConnection(ctx context.Context, t *testing.T) (*api.Client, string, func()) { + client, testEndpoint := GetTestEndpoint() + if os.Getenv("OLLAMA_TEST_EXISTING") == "" { + serverProcMutex.Lock() + fp, err := os.CreateTemp("", "ollama-server-*.log") + if err != nil { + t.Fatalf("failed to generate log file: %s", err) + } + lifecycle.ServerLogFile = fp.Name() + fp.Close() + require.NoError(t, startServer(ctx, testEndpoint)) } - defer func() { + + return client, testEndpoint, func() { if os.Getenv("OLLAMA_TEST_EXISTING") == "" { defer serverProcMutex.Unlock() if t.Failed() { @@ -203,63 +221,118 @@ func GenerateTestHelper(ctx context.Context, t *testing.T, client *http.Client, os.Stderr.Write(data) slog.Warn("END OF SERVER") } - err = os.Remove(lifecycle.ServerLogFile) + err := os.Remove(lifecycle.ServerLogFile) if err != nil && !os.IsNotExist(err) { slog.Warn("failed to cleanup", "logfile", lifecycle.ServerLogFile, "error", err) } } - }() - scheme, testEndpoint := GetTestEndpoint() - - if os.Getenv("OLLAMA_TEST_EXISTING") == "" { - serverProcMutex.Lock() - fp, err := os.CreateTemp("", "ollama-server-*.log") - if err != nil { - t.Fatalf("failed to generate log file: %s", err) - } - lifecycle.ServerLogFile = fp.Name() - fp.Close() - assert.NoError(t, StartServer(ctx, testEndpoint)) } - - err = PullIfMissing(ctx, client, scheme, testEndpoint, genReq.Model) - if err != nil { - t.Fatalf("Error pulling model: %v", err) - } - - // Make the request and get the response - req, err := http.NewRequest("POST", scheme+"://"+testEndpoint+"/api/generate", bytes.NewReader(requestJSON)) - if err != nil { - t.Fatalf("Error creating request: %v", err) - } - - // Set the content type for the request - req.Header.Set("Content-Type", "application/json") - - // Make the request with the HTTP client - response, err := client.Do(req.WithContext(ctx)) - if err != nil { - t.Fatalf("Error making request: %v", err) - } - defer response.Body.Close() - body, err := io.ReadAll(response.Body) - assert.NoError(t, err) - assert.Equal(t, response.StatusCode, 200, string(body)) - - // Verify the response is valid JSON - var payload api.GenerateResponse - err = json.Unmarshal(body, &payload) - if err != nil { - assert.NoError(t, err, body) - } - - // Verify the response contains the expected data - atLeastOne := false - for _, resp := range anyResp { - if strings.Contains(strings.ToLower(payload.Response), resp) { - atLeastOne = true - break - } - } - assert.True(t, atLeastOne, "none of %v found in %s", anyResp, payload.Response) +} + +func GenerateTestHelper(ctx context.Context, t *testing.T, genReq api.GenerateRequest, anyResp []string) { + client, _, cleanup := InitServerConnection(ctx, t) + defer cleanup() + require.NoError(t, PullIfMissing(ctx, client, genReq.Model)) + DoGenerate(ctx, t, client, genReq, anyResp, 30*time.Second, 10*time.Second) +} + +func DoGenerate(ctx context.Context, t *testing.T, client *api.Client, genReq api.GenerateRequest, anyResp []string, initialTimeout, streamTimeout time.Duration) { + stallTimer := time.NewTimer(initialTimeout) + var buf bytes.Buffer + fn := func(response api.GenerateResponse) error { + // fmt.Print(".") + buf.Write([]byte(response.Response)) + if !stallTimer.Reset(streamTimeout) { + return fmt.Errorf("stall was detected while streaming response, aborting") + } + return nil + } + + stream := true + genReq.Stream = &stream + done := make(chan int) + var genErr error + go func() { + genErr = client.Generate(ctx, &genReq, fn) + done <- 0 + }() + + select { + case <-stallTimer.C: + if buf.Len() == 0 { + t.Errorf("generate never started. Timed out after :%s", initialTimeout.String()) + } else { + t.Errorf("generate stalled. Response so far:%s", buf.String()) + } + case <-done: + require.NoError(t, genErr, "failed with %s request prompt %s ", genReq.Model, genReq.Prompt) + // Verify the response contains the expected data + response := buf.String() + atLeastOne := false + for _, resp := range anyResp { + if strings.Contains(strings.ToLower(response), resp) { + atLeastOne = true + break + } + } + require.True(t, atLeastOne, "none of %v found in %s", anyResp, response) + slog.Info("test pass", "model", genReq.Model, "prompt", genReq.Prompt, "contains", anyResp, "response", response) + case <-ctx.Done(): + t.Error("outer test context done while waiting for generate") + } +} + +// Generate a set of requests +// By default each request uses orca-mini as the model +func GenerateRequests() ([]api.GenerateRequest, [][]string) { + return []api.GenerateRequest{ + { + Model: "orca-mini", + Prompt: "why is the ocean blue?", + Stream: &stream, + Options: map[string]interface{}{ + "seed": 42, + "temperature": 0.0, + }, + }, { + Model: "orca-mini", + Prompt: "why is the color of dirt brown?", + Stream: &stream, + Options: map[string]interface{}{ + "seed": 42, + "temperature": 0.0, + }, + }, { + Model: "orca-mini", + Prompt: "what is the origin of the us thanksgiving holiday?", + Stream: &stream, + Options: map[string]interface{}{ + "seed": 42, + "temperature": 0.0, + }, + }, { + Model: "orca-mini", + Prompt: "what is the origin of independence day?", + Stream: &stream, + Options: map[string]interface{}{ + "seed": 42, + "temperature": 0.0, + }, + }, { + Model: "orca-mini", + Prompt: "what is the composition of air?", + Stream: &stream, + Options: map[string]interface{}{ + "seed": 42, + "temperature": 0.0, + }, + }, + }, + [][]string{ + []string{"sunlight"}, + []string{"soil", "organic", "earth", "black", "tan"}, + []string{"england", "english", "massachusetts", "pilgrims"}, + []string{"fourth", "july", "declaration", "independence"}, + []string{"nitrogen", "oxygen", "carbon", "dioxide"}, + } } diff --git a/llm/ext_server/server.cpp b/llm/ext_server/server.cpp index 96df9f4b..41455c65 100644 --- a/llm/ext_server/server.cpp +++ b/llm/ext_server/server.cpp @@ -1032,7 +1032,7 @@ struct llama_server_context slot.has_next_token = false; } - if (!slot.cache_tokens.empty() && result.tok == llama_token_eos(model)) + if (!slot.cache_tokens.empty() && llama_token_is_eog(model, result.tok)) { slot.stopped_eos = true; slot.has_next_token = false; @@ -1144,12 +1144,15 @@ struct llama_server_context res.result_json = json { - {"content", tkn.text_to_send}, {"stop", false}, {"slot_id", slot.id}, {"multimodal", multimodal} }; + if (!llama_token_is_eog(model, tkn.tok)) { + res.result_json["content"] = tkn.text_to_send; + } + if (slot.sparams.n_probs > 0) { std::vector probs_output = {}; @@ -1183,8 +1186,6 @@ struct llama_server_context {"model", params.model_alias}, {"tokens_predicted", slot.n_decoded}, {"tokens_evaluated", slot.n_prompt_tokens}, - {"generation_settings", get_formated_generation(slot)}, - {"prompt", slot.prompt}, {"truncated", slot.truncated}, {"stopped_eos", slot.stopped_eos}, {"stopped_word", slot.stopped_word}, @@ -2644,18 +2645,18 @@ static void server_params_parse(int argc, char **argv, server_params &sparams, if (strncmp(sep, "int:", 4) == 0) { sep += 4; kvo.tag = LLAMA_KV_OVERRIDE_TYPE_INT; - kvo.int_value = std::atol(sep); + kvo.val_i64 = std::atol(sep); } else if (strncmp(sep, "float:", 6) == 0) { sep += 6; kvo.tag = LLAMA_KV_OVERRIDE_TYPE_FLOAT; - kvo.float_value = std::atof(sep); + kvo.val_f64 = std::atof(sep); } else if (strncmp(sep, "bool:", 5) == 0) { sep += 5; kvo.tag = LLAMA_KV_OVERRIDE_TYPE_BOOL; if (std::strcmp(sep, "true") == 0) { - kvo.bool_value = true; + kvo.val_bool = true; } else if (std::strcmp(sep, "false") == 0) { - kvo.bool_value = false; + kvo.val_bool = false; } else { fprintf(stderr, "error: Invalid boolean value for KV override: %s\n", argv[i]); invalid_param = true; diff --git a/llm/filetype.go b/llm/filetype.go new file mode 100644 index 00000000..e5e9410d --- /dev/null +++ b/llm/filetype.go @@ -0,0 +1,140 @@ +package llm + +import "fmt" + +type fileType uint32 + +const ( + fileTypeF32 fileType = iota + fileTypeF16 + fileTypeQ4_0 + fileTypeQ4_1 + fileTypeQ4_1_F16 + fileTypeQ4_2 // unused + fileTypeQ4_3 // unused + fileTypeQ8_0 + fileTypeQ5_0 + fileTypeQ5_1 + fileTypeQ2_K + fileTypeQ3_K_S + fileTypeQ3_K_M + fileTypeQ3_K_L + fileTypeQ4_K_S + fileTypeQ4_K_M + fileTypeQ5_K_S + fileTypeQ5_K_M + fileTypeQ6_K + fileTypeIQ2_XXS + fileTypeIQ2_XS + fileTypeQ2_K_S + fileTypeQ3_K_XS + fileTypeIQ3_XXS + + fileTypeUnknown +) + +func ParseFileType(s string) (fileType, error) { + switch s { + case "F32": + return fileTypeF32, nil + case "F16": + return fileTypeF16, nil + case "Q4_0": + return fileTypeQ4_0, nil + case "Q4_1": + return fileTypeQ4_1, nil + case "Q4_1_F16": + return fileTypeQ4_1_F16, nil + case "Q8_0": + return fileTypeQ8_0, nil + case "Q5_0": + return fileTypeQ5_0, nil + case "Q5_1": + return fileTypeQ5_1, nil + case "Q2_K": + return fileTypeQ2_K, nil + case "Q3_K_S": + return fileTypeQ3_K_S, nil + case "Q3_K_M": + return fileTypeQ3_K_M, nil + case "Q3_K_L": + return fileTypeQ3_K_L, nil + case "Q4_K_S": + return fileTypeQ4_K_S, nil + case "Q4_K_M": + return fileTypeQ4_K_M, nil + case "Q5_K_S": + return fileTypeQ5_K_S, nil + case "Q5_K_M": + return fileTypeQ5_K_M, nil + case "Q6_K": + return fileTypeQ6_K, nil + case "IQ2_XXS": + return fileTypeIQ2_XXS, nil + case "IQ2_XS": + return fileTypeIQ2_XS, nil + case "Q2_K_S": + return fileTypeQ2_K_S, nil + case "Q3_K_XS": + return fileTypeQ3_K_XS, nil + case "IQ3_XXS": + return fileTypeIQ3_XXS, nil + default: + return fileTypeUnknown, fmt.Errorf("unknown fileType: %s", s) + } +} + +func (t fileType) String() string { + switch t { + case fileTypeF32: + return "F32" + case fileTypeF16: + return "F16" + case fileTypeQ4_0: + return "Q4_0" + case fileTypeQ4_1: + return "Q4_1" + case fileTypeQ4_1_F16: + return "Q4_1_F16" + case fileTypeQ8_0: + return "Q8_0" + case fileTypeQ5_0: + return "Q5_0" + case fileTypeQ5_1: + return "Q5_1" + case fileTypeQ2_K: + return "Q2_K" + case fileTypeQ3_K_S: + return "Q3_K_S" + case fileTypeQ3_K_M: + return "Q3_K_M" + case fileTypeQ3_K_L: + return "Q3_K_L" + case fileTypeQ4_K_S: + return "Q4_K_S" + case fileTypeQ4_K_M: + return "Q4_K_M" + case fileTypeQ5_K_S: + return "Q5_K_S" + case fileTypeQ5_K_M: + return "Q5_K_M" + case fileTypeQ6_K: + return "Q6_K" + case fileTypeIQ2_XXS: + return "IQ2_XXS" + case fileTypeIQ2_XS: + return "IQ2_XS" + case fileTypeQ2_K_S: + return "Q2_K_S" + case fileTypeQ3_K_XS: + return "Q3_K_XS" + case fileTypeIQ3_XXS: + return "IQ3_XXS" + default: + return "unknown" + } +} + +func (t fileType) Value() uint32 { + return uint32(t) +} diff --git a/llm/generate/gen_common.sh b/llm/generate/gen_common.sh index 16ff710a..da1b0688 100644 --- a/llm/generate/gen_common.sh +++ b/llm/generate/gen_common.sh @@ -21,7 +21,7 @@ init_vars() { # TODO - add additional optimization flags... CMAKE_DEFS="-DCMAKE_BUILD_TYPE=Release -DLLAMA_SERVER_VERBOSE=off ${CMAKE_DEFS}" fi - case $(uname -s) in + case $(uname -s) in "Darwin") LIB_EXT="dylib" WHOLE_ARCHIVE="-Wl,-force_load" diff --git a/llm/generate/gen_linux.sh b/llm/generate/gen_linux.sh index fd4a6bc0..63668bd2 100755 --- a/llm/generate/gen_linux.sh +++ b/llm/generate/gen_linux.sh @@ -57,21 +57,21 @@ init_vars git_module_setup apply_patches +init_vars +if [ -z "${OLLAMA_SKIP_STATIC_GENERATE}" -o "${OLLAMA_CPU_TARGET}" = "static" ]; then + # Builds by default, allows skipping, forces build if OLLAMA_CPU_TARGET="static" + # Enables optimized Dockerfile builds using a blanket skip and targeted overrides + # Static build for linking into the Go binary + init_vars + CMAKE_TARGETS="--target llama --target ggml" + CMAKE_DEFS="-DBUILD_SHARED_LIBS=off -DLLAMA_NATIVE=off -DLLAMA_AVX=off -DLLAMA_AVX2=off -DLLAMA_AVX512=off -DLLAMA_FMA=off -DLLAMA_F16C=off ${CMAKE_DEFS}" + BUILD_DIR="../build/linux/${ARCH}_static" + echo "Building static library" + build +fi init_vars if [ -z "${OLLAMA_SKIP_CPU_GENERATE}" ]; then - - if [ -z "${OLLAMA_CPU_TARGET}" -o "${OLLAMA_CPU_TARGET}" = "static" ]; then - # Static build for linking into the Go binary - init_vars - CMAKE_TARGETS="--target llama --target ggml" - CMAKE_DEFS="-DBUILD_SHARED_LIBS=off -DLLAMA_NATIVE=off -DLLAMA_AVX=off -DLLAMA_AVX2=off -DLLAMA_AVX512=off -DLLAMA_FMA=off -DLLAMA_F16C=off ${CMAKE_DEFS}" - BUILD_DIR="../build/linux/${ARCH}_static" - echo "Building static library" - build - fi - - # Users building from source can tune the exact flags we pass to cmake for configuring # llama.cpp, and we'll build only 1 CPU variant in that case as the default. if [ -n "${OLLAMA_CUSTOM_CPU_DEFS}" ]; then @@ -165,14 +165,22 @@ if [ -d "${CUDA_LIB_DIR}" ]; then fi if [ "${ARCH}" == "arm64" ]; then echo "ARM CPU detected - disabling unsupported AVX instructions" - + # ARM-based CPUs such as M1 and Tegra do not support AVX extensions. # - # CUDA compute < 6.0 lacks proper FP16 support on ARM. - # Disabling has minimal performance effect while maintaining compatibility. + # CUDA compute < 6.0 lacks proper FP16 support on ARM. + # Disabling has minimal performance effect while maintaining compatibility. ARM64_DEFS="-DLLAMA_AVX=off -DLLAMA_AVX2=off -DLLAMA_AVX512=off -DLLAMA_CUDA_F16=off" fi - CMAKE_DEFS="-DLLAMA_CUDA=on -DLLAMA_CUDA_FORCE_MMQ=on -DCMAKE_CUDA_ARCHITECTURES=${CMAKE_CUDA_ARCHITECTURES} ${COMMON_CMAKE_DEFS} ${CMAKE_DEFS} ${ARM64_DEFS}" + # Users building from source can tune the exact flags we pass to cmake for configuring llama.cpp + if [ -n "${OLLAMA_CUSTOM_CUDA_DEFS}" ]; then + echo "OLLAMA_CUSTOM_CUDA_DEFS=\"${OLLAMA_CUSTOM_CUDA_DEFS}\"" + CMAKE_CUDA_DEFS="-DLLAMA_CUDA=on -DCMAKE_CUDA_ARCHITECTURES=${CMAKE_CUDA_ARCHITECTURES} ${OLLAMA_CUSTOM_CUDA_DEFS}" + echo "Building custom CUDA GPU" + else + CMAKE_CUDA_DEFS="-DLLAMA_CUDA=on -DLLAMA_CUDA_FORCE_MMQ=on -DCMAKE_CUDA_ARCHITECTURES=${CMAKE_CUDA_ARCHITECTURES}" + fi + CMAKE_DEFS="${COMMON_CMAKE_DEFS} ${CMAKE_DEFS} ${ARM64_DEFS} ${CMAKE_CUDA_DEFS}" BUILD_DIR="../build/linux/${ARCH}/cuda${CUDA_VARIANT}" EXTRA_LIBS="-L${CUDA_LIB_DIR} -lcudart -lcublas -lcublasLt -lcuda" build @@ -217,6 +225,12 @@ if [ -d "${ROCM_PATH}" ]; then fi init_vars CMAKE_DEFS="${COMMON_CMAKE_DEFS} ${CMAKE_DEFS} -DLLAMA_HIPBLAS=on -DCMAKE_C_COMPILER=$ROCM_PATH/llvm/bin/clang -DCMAKE_CXX_COMPILER=$ROCM_PATH/llvm/bin/clang++ -DAMDGPU_TARGETS=$(amdGPUs) -DGPU_TARGETS=$(amdGPUs)" + # Users building from source can tune the exact flags we pass to cmake for configuring llama.cpp + if [ -n "${OLLAMA_CUSTOM_ROCM_DEFS}" ]; then + echo "OLLAMA_CUSTOM_ROCM_DEFS=\"${OLLAMA_CUSTOM_ROCM_DEFS}\"" + CMAKE_DEFS="${CMAKE_DEFS} ${OLLAMA_CUSTOM_ROCM_DEFS}" + echo "Building custom ROCM GPU" + fi BUILD_DIR="../build/linux/${ARCH}/rocm${ROCM_VARIANT}" EXTRA_LIBS="-L${ROCM_PATH}/lib -L/opt/amdgpu/lib/x86_64-linux-gnu/ -Wl,-rpath,\$ORIGIN/../../rocm/ -lhipblas -lrocblas -lamdhip64 -lrocsolver -lamd_comgr -lhsa-runtime64 -lrocsparse -ldrm -ldrm_amdgpu" build diff --git a/llm/generate/gen_windows.ps1 b/llm/generate/gen_windows.ps1 index 0d2ae57f..9bdfb9d3 100644 --- a/llm/generate/gen_windows.ps1 +++ b/llm/generate/gen_windows.ps1 @@ -26,15 +26,25 @@ function amdGPUs { $GPU_LIST -join ';' } + function init_vars { - $script:SRC_DIR = $(resolve-path "..\..\") - $script:llamacppDir = "../llama.cpp" + if (!$script:SRC_DIR) { + $script:SRC_DIR = $(resolve-path "..\..\") + } + if (!$script:llamacppDir) { + $script:llamacppDir = "../llama.cpp" + } + if (!$script:cmakeTargets) { + $script:cmakeTargets = @("ollama_llama_server") + } $script:cmakeDefs = @( "-DBUILD_SHARED_LIBS=on", "-DLLAMA_NATIVE=off" ) - $script:cmakeTargets = @("ollama_llama_server") - $script:ARCH = "amd64" # arm not yet supported. + $script:commonCpuDefs = @("-DCMAKE_POSITION_INDEPENDENT_CODE=on") + $script:ARCH = $Env:PROCESSOR_ARCHITECTURE.ToLower() + $script:DIST_BASE = "${script:SRC_DIR}\dist\windows-${script:ARCH}\ollama_runners" + md "$script:DIST_BASE" -ea 0 > $null if ($env:CGO_CFLAGS -contains "-g") { $script:cmakeDefs += @("-DCMAKE_VERBOSE_MAKEFILE=on", "-DLLAMA_SERVER_VERBOSE=on", "-DCMAKE_BUILD_TYPE=RelWithDebInfo") $script:config = "RelWithDebInfo" @@ -55,7 +65,6 @@ function init_vars { } else { $script:CUDA_LIB_DIR=$env:CUDA_LIB_DIR } - $script:GZIP=(get-command -ea 'silentlycontinue' gzip).path $script:DUMPBIN=(get-command -ea 'silentlycontinue' dumpbin).path if ($null -eq $env:CMAKE_CUDA_ARCHITECTURES) { $script:CMAKE_CUDA_ARCHITECTURES="50;52;61;70;75;80" @@ -134,21 +143,18 @@ function sign { } } -function compress { - if ($script:GZIP -eq $null) { - write-host "gzip not installed, not compressing files" - return - } - write-host "Compressing binaries..." +function install { + write-host "Installing binaries to dist dir ${script:distDir}" + mkdir ${script:distDir} -ErrorAction SilentlyContinue $binaries = dir "${script:buildDir}/bin/*.exe" foreach ($file in $binaries) { - & "$script:GZIP" --best -f $file + copy-item -Path $file -Destination ${script:distDir} -Force } - write-host "Compressing dlls..." + write-host "Installing dlls to dist dir ${script:distDir}" $dlls = dir "${script:buildDir}/bin/*.dll" foreach ($file in $dlls) { - & "$script:GZIP" --best -f $file + copy-item -Path $file -Destination ${script:distDir} -Force } } @@ -169,123 +175,195 @@ function cleanup { } } -init_vars -git_module_setup -apply_patches # -DLLAMA_AVX -- 2011 Intel Sandy Bridge & AMD Bulldozer # -DLLAMA_AVX2 -- 2013 Intel Haswell & 2015 AMD Excavator / 2017 AMD Zen # -DLLAMA_FMA (FMA3) -- 2013 Intel Haswell & 2012 AMD Piledriver -$script:commonCpuDefs = @("-DCMAKE_POSITION_INDEPENDENT_CODE=on") -if ($null -eq ${env:OLLAMA_SKIP_CPU_GENERATE}) { +function build_static() { + if ((-not "${env:OLLAMA_SKIP_STATIC_GENERATE}") -and ((-not "${env:OLLAMA_CPU_TARGET}") -or ("${env:OLLAMA_CPU_TARGET}" -eq "static"))) { + # GCC build for direct linking into the Go binary + init_vars + # cmake will silently fallback to msvc compilers if mingw isn't in the path, so detect and fail fast + # as we need this to be compiled by gcc for golang to be able to link with itx + write-host "Checking for MinGW..." + # error action ensures we exit on failure + get-command gcc + get-command mingw32-make + $oldTargets = $script:cmakeTargets + $script:cmakeTargets = @("llama", "ggml") + $script:cmakeDefs = @( + "-G", "MinGW Makefiles" + "-DCMAKE_C_COMPILER=gcc.exe", + "-DCMAKE_CXX_COMPILER=g++.exe", + "-DBUILD_SHARED_LIBS=off", + "-DLLAMA_NATIVE=off", + "-DLLAMA_AVX=off", + "-DLLAMA_AVX2=off", + "-DLLAMA_AVX512=off", + "-DLLAMA_F16C=off", + "-DLLAMA_FMA=off") + $script:buildDir="../build/windows/${script:ARCH}_static" + write-host "Building static library" + build + $script:cmakeTargets = $oldTargets + } else { + write-host "Skipping CPU generation step as requested" + } +} + +function build_cpu($gen_arch) { + if ((-not "${env:OLLAMA_SKIP_CPU_GENERATE}" ) -and ((-not "${env:OLLAMA_CPU_TARGET}") -or ("${env:OLLAMA_CPU_TARGET}" -eq "cpu"))) { + # remaining llama.cpp builds use MSVC + init_vars + $script:cmakeDefs = $script:commonCpuDefs + @("-A", $gen_arch, "-DLLAMA_AVX=off", "-DLLAMA_AVX2=off", "-DLLAMA_AVX512=off", "-DLLAMA_FMA=off", "-DLLAMA_F16C=off") + $script:cmakeDefs + $script:buildDir="../build/windows/${script:ARCH}/cpu" + $script:distDir="$script:DIST_BASE\cpu" + write-host "Building LCD CPU" + build + sign + install + } else { + write-host "Skipping CPU generation step as requested" + } +} + +function build_cpu_avx() { + if ((-not "${env:OLLAMA_SKIP_CPU_GENERATE}" ) -and ((-not "${env:OLLAMA_CPU_TARGET}") -or ("${env:OLLAMA_CPU_TARGET}" -eq "cpu_avx"))) { + init_vars + $script:cmakeDefs = $script:commonCpuDefs + @("-A", "x64", "-DLLAMA_AVX=on", "-DLLAMA_AVX2=off", "-DLLAMA_AVX512=off", "-DLLAMA_FMA=off", "-DLLAMA_F16C=off") + $script:cmakeDefs + $script:buildDir="../build/windows/${script:ARCH}/cpu_avx" + $script:distDir="$script:DIST_BASE\cpu_avx" + write-host "Building AVX CPU" + build + sign + install + } else { + write-host "Skipping CPU AVX generation step as requested" + } +} + +function build_cpu_avx2() { + if ((-not "${env:OLLAMA_SKIP_CPU_GENERATE}" ) -and ((-not "${env:OLLAMA_CPU_TARGET}") -or ("${env:OLLAMA_CPU_TARGET}" -eq "cpu_avx2"))) { + init_vars + $script:cmakeDefs = $script:commonCpuDefs + @("-A", "x64", "-DLLAMA_AVX=on", "-DLLAMA_AVX2=on", "-DLLAMA_AVX512=off", "-DLLAMA_FMA=on", "-DLLAMA_F16C=on") + $script:cmakeDefs + $script:buildDir="../build/windows/${script:ARCH}/cpu_avx2" + $script:distDir="$script:DIST_BASE\cpu_avx2" + write-host "Building AVX2 CPU" + build + sign + install + } else { + write-host "Skipping CPU AVX2 generation step as requested" + } +} + +function build_cuda() { + if ((-not "${env:OLLAMA_SKIP_CUDA_GENERATE}") -and ("${script:CUDA_LIB_DIR}")) { + # Then build cuda as a dynamically loaded library + $nvcc = "$script:CUDA_LIB_DIR\nvcc.exe" + $script:CUDA_VERSION=(get-item ($nvcc | split-path | split-path)).Basename + if ($null -ne $script:CUDA_VERSION) { + $script:CUDA_VARIANT="_"+$script:CUDA_VERSION + } + init_vars + $script:buildDir="../build/windows/${script:ARCH}/cuda$script:CUDA_VARIANT" + $script:distDir="$script:DIST_BASE\cuda$script:CUDA_VARIANT" + $script:cmakeDefs += @("-A", "x64", "-DLLAMA_CUDA=ON", "-DLLAMA_AVX=on", "-DLLAMA_AVX2=off", "-DCUDAToolkit_INCLUDE_DIR=$script:CUDA_INCLUDE_DIR", "-DCMAKE_CUDA_ARCHITECTURES=${script:CMAKE_CUDA_ARCHITECTURES}") + if ($null -ne $env:OLLAMA_CUSTOM_CUDA_DEFS) { + write-host "OLLAMA_CUSTOM_CUDA_DEFS=`"${env:OLLAMA_CUSTOM_CUDA_DEFS}`"" + $script:cmakeDefs +=@("${env:OLLAMA_CUSTOM_CUDA_DEFS}") + write-host "building custom CUDA GPU" + } + build + sign + install + + write-host "copying CUDA dependencies to ${script:SRC_DIR}\dist\windows-${script:ARCH}\" + cp "${script:CUDA_LIB_DIR}\cudart64_*.dll" "${script:SRC_DIR}\dist\windows-${script:ARCH}\" + cp "${script:CUDA_LIB_DIR}\cublas64_*.dll" "${script:SRC_DIR}\dist\windows-${script:ARCH}\" + cp "${script:CUDA_LIB_DIR}\cublasLt64_*.dll" "${script:SRC_DIR}\dist\windows-${script:ARCH}\" + } else { + write-host "Skipping CUDA generation step" + } +} + +function build_rocm() { + if ((-not "${env:OLLAMA_SKIP_ROCM_GENERATE}") -and ("${env:HIP_PATH}")) { + $script:ROCM_VERSION=(get-item $env:HIP_PATH).Basename + if ($null -ne $script:ROCM_VERSION) { + $script:ROCM_VARIANT="_v"+$script:ROCM_VERSION + } + + init_vars + $script:buildDir="../build/windows/${script:ARCH}/rocm$script:ROCM_VARIANT" + $script:distDir="$script:DIST_BASE\rocm$script:ROCM_VARIANT" + $script:cmakeDefs += @( + "-G", "Ninja", + "-DCMAKE_C_COMPILER=clang.exe", + "-DCMAKE_CXX_COMPILER=clang++.exe", + "-DLLAMA_HIPBLAS=on", + "-DHIP_PLATFORM=amd", + "-DLLAMA_AVX=on", + "-DLLAMA_AVX2=off", + "-DCMAKE_POSITION_INDEPENDENT_CODE=on", + "-DAMDGPU_TARGETS=$(amdGPUs)", + "-DGPU_TARGETS=$(amdGPUs)" + ) + + # Make sure the ROCm binary dir is first in the path + $env:PATH="$env:HIP_PATH\bin;$env:PATH" + + # We have to clobber the LIB var from the developer shell for clang to work properly + $env:LIB="" + if ($null -ne $env:OLLAMA_CUSTOM_ROCM_DEFS) { + write-host "OLLAMA_CUSTOM_ROCM_DEFS=`"${env:OLLAMA_CUSTOM_ROCM_DEFS}`"" + $script:cmakeDefs += @("${env:OLLAMA_CUSTOM_ROCM_DEFS}") + write-host "building custom ROCM GPU" + } + write-host "Building ROCm" + build + # Ninja doesn't prefix with config name + ${script:config}="" + if ($null -ne $script:DUMPBIN) { + & "$script:DUMPBIN" /dependents "${script:buildDir}/bin/ollama_llama_server.exe" | select-string ".dll" + } + sign + install + + # Assumes v5.7, may need adjustments for v6 + rm -ea 0 -recurse -force -path "${script:SRC_DIR}\dist\windows-${script:ARCH}\rocm\" + md "${script:SRC_DIR}\dist\windows-${script:ARCH}\rocm\rocblas\library\" -ea 0 > $null + cp "${env:HIP_PATH}\bin\hipblas.dll" "${script:SRC_DIR}\dist\windows-${script:ARCH}\rocm\" + cp "${env:HIP_PATH}\bin\rocblas.dll" "${script:SRC_DIR}\dist\windows-${script:ARCH}\rocm\" + # amdhip64.dll dependency comes from the driver and must be installed on the host to use AMD GPUs + cp "${env:HIP_PATH}\bin\rocblas\library\*" "${script:SRC_DIR}\dist\windows-${script:ARCH}\rocm\rocblas\library\" + } else { + write-host "Skipping ROCm generation step" + } +} -# GCC build for direct linking into the Go binary init_vars -# cmake will silently fallback to msvc compilers if mingw isn't in the path, so detect and fail fast -# as we need this to be compiled by gcc for golang to be able to link with itx -write-host "Checking for MinGW..." -# error action ensures we exit on failure -get-command gcc -get-command mingw32-make -$script:cmakeTargets = @("llama", "ggml") -$script:cmakeDefs = @( - "-G", "MinGW Makefiles" - "-DCMAKE_C_COMPILER=gcc.exe", - "-DCMAKE_CXX_COMPILER=g++.exe", - "-DBUILD_SHARED_LIBS=off", - "-DLLAMA_NATIVE=off", - "-DLLAMA_AVX=off", - "-DLLAMA_AVX2=off", - "-DLLAMA_AVX512=off", - "-DLLAMA_F16C=off", - "-DLLAMA_FMA=off") -$script:buildDir="../build/windows/${script:ARCH}_static" -write-host "Building static library" -build +if ($($args.count) -eq 0) { + git_module_setup + apply_patches + build_static + if ($script:ARCH -eq "arm64") { + build_cpu("ARM64") + } else { # amd64 + build_cpu("x64") + build_cpu_avx + build_cpu_avx2 + build_cuda + build_rocm + } -# remaining llama.cpp builds use MSVC - init_vars - $script:cmakeDefs = $script:commonCpuDefs + @("-A", "x64", "-DLLAMA_AVX=off", "-DLLAMA_AVX2=off", "-DLLAMA_AVX512=off", "-DLLAMA_FMA=off", "-DLLAMA_F16C=off") + $script:cmakeDefs - $script:buildDir="../build/windows/${script:ARCH}/cpu" - write-host "Building LCD CPU" - build - sign - compress - - init_vars - $script:cmakeDefs = $script:commonCpuDefs + @("-A", "x64", "-DLLAMA_AVX=on", "-DLLAMA_AVX2=off", "-DLLAMA_AVX512=off", "-DLLAMA_FMA=off", "-DLLAMA_F16C=off") + $script:cmakeDefs - $script:buildDir="../build/windows/${script:ARCH}/cpu_avx" - write-host "Building AVX CPU" - build - sign - compress - - init_vars - $script:cmakeDefs = $script:commonCpuDefs + @("-A", "x64", "-DLLAMA_AVX=on", "-DLLAMA_AVX2=on", "-DLLAMA_AVX512=off", "-DLLAMA_FMA=on", "-DLLAMA_F16C=on") + $script:cmakeDefs - $script:buildDir="../build/windows/${script:ARCH}/cpu_avx2" - write-host "Building AVX2 CPU" - build - sign - compress + cleanup + write-host "`ngo generate completed. LLM runners: $(get-childitem -path $script:DIST_BASE)" } else { - write-host "Skipping CPU generation step as requested" -} - -if ($null -ne $script:CUDA_LIB_DIR) { - # Then build cuda as a dynamically loaded library - $nvcc = "$script:CUDA_LIB_DIR\nvcc.exe" - $script:CUDA_VERSION=(get-item ($nvcc | split-path | split-path)).Basename - if ($null -ne $script:CUDA_VERSION) { - $script:CUDA_VARIANT="_"+$script:CUDA_VERSION - } - init_vars - $script:buildDir="../build/windows/${script:ARCH}/cuda$script:CUDA_VARIANT" - $script:cmakeDefs += @("-A", "x64", "-DLLAMA_CUDA=ON", "-DLLAMA_AVX=on", "-DLLAMA_AVX2=off", "-DCUDAToolkit_INCLUDE_DIR=$script:CUDA_INCLUDE_DIR", "-DCMAKE_CUDA_ARCHITECTURES=${script:CMAKE_CUDA_ARCHITECTURES}") - build - sign - compress -} - -if ($null -ne $env:HIP_PATH) { - $script:ROCM_VERSION=(get-item $env:HIP_PATH).Basename - if ($null -ne $script:ROCM_VERSION) { - $script:ROCM_VARIANT="_v"+$script:ROCM_VERSION - } - - init_vars - $script:buildDir="../build/windows/${script:ARCH}/rocm$script:ROCM_VARIANT" - $script:cmakeDefs += @( - "-G", "Ninja", - "-DCMAKE_C_COMPILER=clang.exe", - "-DCMAKE_CXX_COMPILER=clang++.exe", - "-DLLAMA_HIPBLAS=on", - "-DHIP_PLATFORM=amd", - "-DLLAMA_AVX=on", - "-DLLAMA_AVX2=off", - "-DCMAKE_POSITION_INDEPENDENT_CODE=on", - "-DAMDGPU_TARGETS=$(amdGPUs)", - "-DGPU_TARGETS=$(amdGPUs)" - ) - - # Make sure the ROCm binary dir is first in the path - $env:PATH="$env:HIP_PATH\bin;$env:PATH" - - # We have to clobber the LIB var from the developer shell for clang to work properly - $env:LIB="" - - write-host "Building ROCm" - build - # Ninja doesn't prefix with config name - ${script:config}="" - if ($null -ne $script:DUMPBIN) { - & "$script:DUMPBIN" /dependents "${script:buildDir}/bin/ollama_llama_server.exe" | select-string ".dll" - } - sign - compress -} - - -cleanup -write-host "`ngo generate completed. LLM runners: $(get-childitem -path ${script:SRC_DIR}\llm\build\windows\${script:ARCH})" + for ( $i = 0; $i -lt $args.count; $i++ ) { + write-host "performing $($args[$i])" + & $($args[$i]) + } +} \ No newline at end of file diff --git a/llm/ggml.go b/llm/ggml.go index f40f17e5..1c21bde0 100644 --- a/llm/ggml.go +++ b/llm/ggml.go @@ -13,82 +13,6 @@ type GGML struct { model } -const ( - fileTypeF32 uint32 = iota - fileTypeF16 - fileTypeQ4_0 - fileTypeQ4_1 - fileTypeQ4_1_F16 - fileTypeQ8_0 uint32 = iota + 2 - fileTypeQ5_0 - fileTypeQ5_1 - fileTypeQ2_K - fileTypeQ3_K_S - fileTypeQ3_K_M - fileTypeQ3_K_L - fileTypeQ4_K_S - fileTypeQ4_K_M - fileTypeQ5_K_S - fileTypeQ5_K_M - fileTypeQ6_K - fileTypeIQ2_XXS - fileTypeIQ2_XS - fileTypeQ2_K_S - fileTypeQ3_K_XS - fileTypeIQ3_XXS -) - -func fileType(fileType uint32) string { - switch fileType { - case fileTypeF32: - return "F32" - case fileTypeF16: - return "F16" - case fileTypeQ4_0: - return "Q4_0" - case fileTypeQ4_1: - return "Q4_1" - case fileTypeQ4_1_F16: - return "Q4_1_F16" - case fileTypeQ8_0: - return "Q8_0" - case fileTypeQ5_0: - return "Q5_0" - case fileTypeQ5_1: - return "Q5_1" - case fileTypeQ2_K: - return "Q2_K" - case fileTypeQ3_K_S: - return "Q3_K_S" - case fileTypeQ3_K_M: - return "Q3_K_M" - case fileTypeQ3_K_L: - return "Q3_K_L" - case fileTypeQ4_K_S: - return "Q4_K_S" - case fileTypeQ4_K_M: - return "Q4_K_M" - case fileTypeQ5_K_S: - return "Q5_K_S" - case fileTypeQ5_K_M: - return "Q5_K_M" - case fileTypeQ6_K: - return "Q6_K" - case fileTypeIQ2_XXS: - return "IQ2_XXS" - case fileTypeIQ2_XS: - return "IQ2_XS" - case fileTypeQ2_K_S: - return "Q2_K_S" - case fileTypeQ3_K_XS: - return "Q3_K_XS" - case fileTypeIQ3_XXS: - return "IQ3_XXS" - default: - return "unknown" - } -} - type model interface { KV() KV Tensors() Tensors @@ -121,12 +45,12 @@ func (kv KV) ParameterCount() uint64 { return kv.u64("general.parameter_count") } -func (kv KV) FileType() string { +func (kv KV) FileType() fileType { if u64 := kv.u64("general.file_type"); u64 > 0 { return fileType(uint32(u64)) } - return "unknown" + return fileTypeUnknown } func (kv KV) BlockCount() uint64 { @@ -286,6 +210,23 @@ const ( var ErrUnsupportedFormat = errors.New("unsupported model format") +func DetectGGMLType(b []byte) string { + switch binary.LittleEndian.Uint32(b[:4]) { + case FILE_MAGIC_GGML: + return "ggml" + case FILE_MAGIC_GGMF: + return "ggmf" + case FILE_MAGIC_GGJT: + return "ggjt" + case FILE_MAGIC_GGLA: + return "ggla" + case FILE_MAGIC_GGUF_LE, FILE_MAGIC_GGUF_BE: + return "gguf" + default: + return "" + } +} + func DecodeGGML(rs io.ReadSeeker) (*GGML, int64, error) { var magic uint32 if err := binary.Read(rs, binary.LittleEndian, &magic); err != nil { @@ -343,7 +284,15 @@ func (llm GGML) GraphSize(context, batch uint64) (partialOffload, fullOffload ui 4*batch*(embedding+vocab)+embedding*vocab*105/128, ) - if ffnGateWeight, ok := layers["0"]["ffn_gate.0.weight"]; ok { + if ffnGateExpsWeight, ok := layers["blk.0"]["ffn_gate_exps.weight"]; ok { + // mixtral 8x22b + ff := uint64(llm.KV()["llama.feed_forward_length"].(uint32)) + partialOffload = max( + 3*ffnGateExpsWeight.size()+4*batch*(2*ff+headsKV+embedding+context+embedding/heads*headsKV), + 4*(context*batch*heads+context*embedding/heads*headsKV+batch*1024+embedding/heads*headsKV*batch), + ) + } else if ffnGateWeight, ok := layers["blk.0"]["ffn_gate.0.weight"]; ok { + // mixtral 8x7b ffnGateWeight1 := ffnGateWeight.Shape[1] fullOffload = 4 * batch * (2 + 3*embedding + context*(1+heads) + 2*headsKV + ffnGateWeight1) partialOffload = max( diff --git a/llm/gguf.go b/llm/gguf.go index acdeb29f..5f6e8004 100644 --- a/llm/gguf.go +++ b/llm/gguf.go @@ -190,8 +190,6 @@ func (llm *gguf) Decode(rs io.ReadSeeker) error { llm.kv[k] = v } - slog.Debug(fmt.Sprintf("general.architecture = %s", llm.kv["general.architecture"])) - // decode tensors for i := 0; uint64(i) < llm.numTensor(); i++ { name, err := readGGUFString(llm, rs) @@ -465,11 +463,13 @@ var ggufKVOrder = map[string][]string{ "llama.embedding_length", "llama.block_count", "llama.feed_forward_length", - "llama.rope.dimension_count", "llama.attention.head_count", "llama.attention.head_count_kv", "llama.attention.layer_norm_rms_epsilon", "llama.rope.freq_base", + "llama.rope.dimension_count", + "llama.expert_count", + "llama.expert_used_count", "gemma.context_length", "gemma.embedding_length", "gemma.block_count", @@ -577,6 +577,8 @@ func (llm *gguf) Encode(ws io.WriteSeeker, kv KV, tensors []Tensor) error { return err } } + default: + return fmt.Errorf("improper type for '%s'", k) } if err != nil { return err @@ -598,9 +600,11 @@ func (llm *gguf) Encode(ws io.WriteSeeker, kv KV, tensors []Tensor) error { return err } - dims := 1 - if tensor.Shape[1] > 0 { - dims = 2 + dims := 0 + for cnt := 0; cnt < len(tensor.Shape); cnt++ { + if tensor.Shape[cnt] > 0 { + dims++ + } } if err := binary.Write(ws, llm.ByteOrder, uint32(dims)); err != nil { diff --git a/llm/llama.cpp b/llm/llama.cpp index 7593639c..952d03db 160000 --- a/llm/llama.cpp +++ b/llm/llama.cpp @@ -1 +1 @@ -Subproject commit 7593639ce335e8d7f89aa9a54d616951f273af60 +Subproject commit 952d03dbead16e4dbdd1d3458486340673cc2465 diff --git a/llm/llm.go b/llm/llm.go index 33949c76..2a0c4b91 100644 --- a/llm/llm.go +++ b/llm/llm.go @@ -4,6 +4,7 @@ package llm // #cgo darwin,arm64 LDFLAGS: ${SRCDIR}/build/darwin/arm64_static/libllama.a -lstdc++ // #cgo darwin,amd64 LDFLAGS: ${SRCDIR}/build/darwin/x86_64_static/libllama.a -lstdc++ // #cgo windows,amd64 LDFLAGS: ${SRCDIR}/build/windows/amd64_static/libllama.a -static -lstdc++ +// #cgo windows,arm64 LDFLAGS: ${SRCDIR}/build/windows/arm64_static/libllama.a -static -lstdc++ // #cgo linux,amd64 LDFLAGS: ${SRCDIR}/build/linux/x86_64_static/libllama.a -lstdc++ // #cgo linux,arm64 LDFLAGS: ${SRCDIR}/build/linux/arm64_static/libllama.a -lstdc++ // #include @@ -19,7 +20,7 @@ func SystemInfo() string { return C.GoString(C.llama_print_system_info()) } -func Quantize(infile, outfile, filetype string) error { +func Quantize(infile, outfile string, ftype fileType) error { cinfile := C.CString(infile) defer C.free(unsafe.Pointer(cinfile)) @@ -28,58 +29,10 @@ func Quantize(infile, outfile, filetype string) error { params := C.llama_model_quantize_default_params() params.nthread = -1 + params.ftype = ftype.Value() - switch filetype { - case "F32": - params.ftype = fileTypeF32 - case "F16": - params.ftype = fileTypeF16 - case "Q4_0": - params.ftype = fileTypeQ4_0 - case "Q4_1": - params.ftype = fileTypeQ4_1 - case "Q4_1_F16": - params.ftype = fileTypeQ4_1_F16 - case "Q8_0": - params.ftype = fileTypeQ8_0 - case "Q5_0": - params.ftype = fileTypeQ5_0 - case "Q5_1": - params.ftype = fileTypeQ5_1 - case "Q2_K": - params.ftype = fileTypeQ2_K - case "Q3_K_S": - params.ftype = fileTypeQ3_K_S - case "Q3_K_M": - params.ftype = fileTypeQ3_K_M - case "Q3_K_L": - params.ftype = fileTypeQ3_K_L - case "Q4_K_S": - params.ftype = fileTypeQ4_K_S - case "Q4_K_M": - params.ftype = fileTypeQ4_K_M - case "Q5_K_S": - params.ftype = fileTypeQ5_K_S - case "Q5_K_M": - params.ftype = fileTypeQ5_K_M - case "Q6_K": - params.ftype = fileTypeQ6_K - case "IQ2_XXS": - params.ftype = fileTypeIQ2_XXS - case "IQ2_XS": - params.ftype = fileTypeIQ2_XS - case "Q2_K_S": - params.ftype = fileTypeQ2_K_S - case "Q3_K_XS": - params.ftype = fileTypeQ3_K_XS - case "IQ3_XXS": - params.ftype = fileTypeIQ3_XXS - default: - return fmt.Errorf("unknown filetype: %s", filetype) - } - - if retval := C.llama_model_quantize(cinfile, coutfile, ¶ms); retval != 0 { - return fmt.Errorf("llama_model_quantize: %d", retval) + if rc := C.llama_model_quantize(cinfile, coutfile, ¶ms); rc != 0 { + return fmt.Errorf("llama_model_quantize: %d", rc) } return nil diff --git a/llm/llm_windows.go b/llm/llm_windows.go index 17967b4e..e44f4b95 100644 --- a/llm/llm_windows.go +++ b/llm/llm_windows.go @@ -2,5 +2,5 @@ package llm import "embed" -//go:embed build/windows/*/*/bin/* +// unused on windows var libEmbed embed.FS diff --git a/llm/memory.go b/llm/memory.go new file mode 100644 index 00000000..6890b08c --- /dev/null +++ b/llm/memory.go @@ -0,0 +1,185 @@ +package llm + +import ( + "fmt" + "log/slog" + + "github.com/ollama/ollama/api" + "github.com/ollama/ollama/format" + "github.com/ollama/ollama/gpu" + "github.com/ollama/ollama/server/envconfig" +) + +// This algorithm looks for a complete fit to determine if we need to unload other models +func PredictServerFit(allGpus gpu.GpuInfoList, ggml *GGML, adapters, projectors []string, opts api.Options) (bool, uint64) { + var estimatedVRAM uint64 + if opts.NumCtx > int(ggml.KV().ContextLength()) { + slog.Warn("requested context length is greater than model max context length", "requested", opts.NumCtx, "model", ggml.KV().ContextLength()) + opts.NumCtx = int(ggml.KV().ContextLength()) + } + + if opts.NumCtx < 4 { + opts.NumCtx = 4 + } + + // Split up the GPUs by type and try them + for _, gpus := range allGpus.ByLibrary() { + var layerCount int + layerCount, estimatedVRAM, _ = EstimateGPULayers(gpus, ggml, projectors, opts) + if opts.NumGPU < 0 { + if layerCount > 0 && layerCount >= int(ggml.KV().BlockCount()+1) { + return true, estimatedVRAM + } + } else { + if layerCount > 0 && layerCount >= opts.NumGPU { + return true, estimatedVRAM + } + } + } + return false, estimatedVRAM +} + +// Given a model and one or more GPU targets, predict how many layers and bytes we can load, and the total size +// The GPUs provided must all be the same Library +func EstimateGPULayers(gpus []gpu.GpuInfo, ggml *GGML, projectors []string, opts api.Options) (int, uint64, uint64) { + var memoryAvailable uint64 + for _, info := range gpus { + memoryAvailable += info.FreeMemory + } + if envconfig.MaxVRAM > 0 { + memoryAvailable = envconfig.MaxVRAM + } + + slog.Debug("evaluating", "library", gpus[0].Library, "gpu_count", len(gpus), "available", format.HumanBytes2(memoryAvailable)) + + // TODO - this is probably wrong, first GPU vs secondaries will have different overheads + memoryMinimum := gpus[0].MinimumMemory + + for _, projector := range projectors { + memoryMinimum += projectorMemoryRequirements(projector) + + // multimodal models require at least 2048 context + opts.NumCtx = max(opts.NumCtx, 2048) + } + + // fp16 k,v = (1 (k) + 1 (v)) * sizeof(float16) * n_ctx * n_layer * n_embd / n_head * n_head_kv + var kv uint64 = 2 * 2 * uint64(opts.NumCtx) * ggml.KV().BlockCount() * ggml.KV().EmbeddingLength() / ggml.KV().HeadCount() * ggml.KV().HeadCountKV() + + graphPartialOffload, graphFullOffload := ggml.GraphSize(uint64(opts.NumCtx), uint64(min(opts.NumCtx, opts.NumBatch))) + if graphPartialOffload == 0 { + graphPartialOffload = ggml.KV().GQA() * kv / 6 + } + + if graphFullOffload == 0 { + graphFullOffload = graphPartialOffload + } + + graphFullOffload *= uint64(len(gpus)) + graphPartialOffload *= uint64(len(gpus)) + + // on metal there's no partial offload overhead + if gpus[0].Library == "metal" { + graphPartialOffload = graphFullOffload + } + + layers := ggml.Tensors().Layers() + + // memoryRequiredTotal represents the memory required for full GPU offloading (all layers) + memoryRequiredTotal := memoryMinimum + graphFullOffload + layers["blk.0"].size() + + // memoryRequiredPartial represents the memory required for partial GPU offloading (n > 0, n < layers) + memoryRequiredPartial := memoryMinimum + graphPartialOffload + layers["blk.0"].size() + + var memoryLayerOutput uint64 + if layer, ok := layers["output_norm"]; ok { + memoryLayerOutput += layer.size() + } + + if layer, ok := layers["output"]; ok { + memoryLayerOutput += layer.size() + } else if layer, ok := layers["token_embd"]; ok { + memoryLayerOutput += layer.size() + } + + if gpus[0].Library == "metal" && opts.UseMMap { + // memory is preallocated for output tensors + memoryRequiredTotal += memoryLayerOutput + memoryRequiredPartial += memoryLayerOutput + } + + var layerCount int + for i := 0; i < int(ggml.KV().BlockCount()); i++ { + memoryLayer := layers[fmt.Sprintf("blk.%d", i)].size() + + // KV is proportional to the number of layers + memoryLayer += kv / ggml.KV().BlockCount() + + memoryRequiredTotal += memoryLayer + if memoryAvailable > memoryRequiredPartial+memoryLayer { + memoryRequiredPartial += memoryLayer + layerCount++ + } + } + + if gpus[0].Library != "metal" || !opts.UseMMap { + // memory was not preallocated for output tensors + memoryRequiredTotal += memoryLayerOutput + } + + if memoryAvailable > memoryRequiredTotal { + layerCount = int(ggml.KV().BlockCount()) + 1 + memoryRequiredPartial = memoryRequiredTotal + } + + memoryWeights := memoryRequiredTotal - memoryMinimum - graphFullOffload - kv + + slog.Info( + "offload to gpu", + slog.Group( + "layers", + // actual number of layers offloaded + "real", opts.NumGPU, + // estimated number of layers that can be offloaded + "estimate", layerCount, + ), + slog.Group( + "memory", + // memory available for offloading + "available", format.HumanBytes2(memoryAvailable), + slog.Group( + "required", + // memory required for full offloading + "full", format.HumanBytes2(memoryRequiredTotal), + // memory required to offload layers.estimate layers + "partial", format.HumanBytes2(memoryRequiredPartial), + // memory of KV cache + "kv", format.HumanBytes2(kv), + ), + slog.Group( + "weights", + // memory of the weights + "total", format.HumanBytes2(memoryWeights), + // memory of repeating layers + "repeating", format.HumanBytes2(memoryWeights-memoryLayerOutput), + // memory of non-repeating layers + "nonrepeating", format.HumanBytes2(memoryLayerOutput), + ), + slog.Group( + "graph", + // memory of graph when fully offloaded + "full", format.HumanBytes2(graphFullOffload), + // memory of graph when not fully offloaded + "partial", format.HumanBytes2(graphPartialOffload), + ), + ), + ) + if gpus[0].Library == "cpu" { + return 0, 0, memoryRequiredTotal + } + if memoryRequiredPartial > memoryAvailable { + slog.Debug("insufficient VRAM to load any model layers") + return 0, 0, memoryRequiredTotal + } + + return layerCount, memoryRequiredPartial, memoryRequiredTotal +} diff --git a/llm/patches/02-clip-log.diff b/llm/patches/02-clip-log.diff new file mode 100644 index 00000000..34a018e8 --- /dev/null +++ b/llm/patches/02-clip-log.diff @@ -0,0 +1,12 @@ +diff --git a/examples/llava/clip.cpp b/examples/llava/clip.cpp +index e431c7f7..f077e688 100644 +--- a/examples/llava/clip.cpp ++++ b/examples/llava/clip.cpp +@@ -3,6 +3,7 @@ + // I'll gradually clean and extend it + // Note: Even when using identical normalized image inputs (see normalize_image_u8_to_f32()) we have a significant difference in resulting embeddings compared to pytorch + #include "clip.h" ++#include "common.h" + #include "log.h" + #include "ggml.h" + #include "ggml-alloc.h" diff --git a/llm/patches/04-metal.diff b/llm/patches/04-metal.diff new file mode 100644 index 00000000..f8fa7db7 --- /dev/null +++ b/llm/patches/04-metal.diff @@ -0,0 +1,45 @@ +diff --git a/ggml-metal.m b/ggml-metal.m +index 0207b787..b5e9884b 100644 +--- a/ggml-metal.m ++++ b/ggml-metal.m +@@ -1396,27 +1396,23 @@ static enum ggml_status ggml_metal_graph_compute( + // to the matrix-vector kernel + int ne11_mm_min = 1; + +-#if 0 + // the numbers below are measured on M2 Ultra for 7B and 13B models + // these numbers do not translate to other devices or model sizes + // TODO: need to find a better approach +- if ([ctx->device.name isEqualToString:@"Apple M2 Ultra"]) { +- switch (src0t) { +- case GGML_TYPE_F16: ne11_mm_min = 2; break; +- case GGML_TYPE_Q8_0: ne11_mm_min = 7; break; +- case GGML_TYPE_Q2_K: ne11_mm_min = 15; break; +- case GGML_TYPE_Q3_K: ne11_mm_min = 7; break; +- case GGML_TYPE_Q4_0: +- case GGML_TYPE_Q4_1: ne11_mm_min = 15; break; +- case GGML_TYPE_Q4_K: ne11_mm_min = 11; break; +- case GGML_TYPE_Q5_0: // not tested yet +- case GGML_TYPE_Q5_1: ne11_mm_min = 13; break; // not tested yet +- case GGML_TYPE_Q5_K: ne11_mm_min = 7; break; +- case GGML_TYPE_Q6_K: ne11_mm_min = 7; break; +- default: ne11_mm_min = 1; break; +- } ++ switch (src0t) { ++ case GGML_TYPE_F16: ne11_mm_min = 2; break; ++ case GGML_TYPE_Q8_0: ne11_mm_min = 7; break; ++ case GGML_TYPE_Q2_K: ne11_mm_min = 15; break; ++ case GGML_TYPE_Q3_K: ne11_mm_min = 7; break; ++ case GGML_TYPE_Q4_0: ++ case GGML_TYPE_Q4_1: ne11_mm_min = 15; break; ++ case GGML_TYPE_Q4_K: ne11_mm_min = 11; break; ++ case GGML_TYPE_Q5_0: // not tested yet ++ case GGML_TYPE_Q5_1: ne11_mm_min = 13; break; // not tested yet ++ case GGML_TYPE_Q5_K: ne11_mm_min = 7; break; ++ case GGML_TYPE_Q6_K: ne11_mm_min = 7; break; ++ default: ne11_mm_min = 1; break; + } +-#endif + + // for now the matrix-matrix multiplication kernel only works on A14+/M1+ SoCs + // AMD GPU and older A-chips will reuse matrix-vector multiplication kernel diff --git a/llm/patches/05-clip-fix.diff b/llm/patches/05-clip-fix.diff new file mode 100644 index 00000000..3f68a5bb --- /dev/null +++ b/llm/patches/05-clip-fix.diff @@ -0,0 +1,24 @@ +diff --git a/examples/llava/clip.cpp b/examples/llava/clip.cpp +index e3c9bcd4..b43f892d 100644 +--- a/examples/llava/clip.cpp ++++ b/examples/llava/clip.cpp +@@ -573,14 +573,16 @@ static ggml_cgraph * clip_image_build_graph(clip_ctx * ctx, const clip_image_f32 + struct ggml_tensor * embeddings = inp; + if (ctx->has_class_embedding) { + embeddings = ggml_new_tensor_3d(ctx0, GGML_TYPE_F32, hidden_size, num_positions, batch_size); ++ } ++ ggml_set_name(embeddings, "embeddings"); ++ ggml_set_input(embeddings); ++ ++ if (ctx->has_class_embedding) { + embeddings = ggml_acc(ctx0, embeddings, model.class_embedding, + embeddings->nb[1], embeddings->nb[2], embeddings->nb[3], 0); + embeddings = ggml_acc(ctx0, embeddings, inp, + embeddings->nb[1], embeddings->nb[2], embeddings->nb[3], model.class_embedding->nb[1]); + } +- ggml_set_name(embeddings, "embeddings"); +- ggml_set_input(embeddings); +- + + struct ggml_tensor * positions = ggml_new_tensor_1d(ctx0, GGML_TYPE_I32, num_positions); + ggml_set_name(positions, "positions"); diff --git a/llm/payload.go b/llm/payload.go index 8a134357..abe3d263 100644 --- a/llm/payload.go +++ b/llm/payload.go @@ -9,6 +9,7 @@ import ( "log/slog" "os" "path/filepath" + "runtime" "strings" "golang.org/x/exp/slices" @@ -17,7 +18,7 @@ import ( "github.com/ollama/ollama/gpu" ) -var errPayloadMissing = fmt.Errorf("expected payloads not included in this build of ollama") +var errPayloadMissing = errors.New("expected payloads not included in this build of ollama") func Init() error { payloadsDir, err := gpu.PayloadsDir() @@ -25,13 +26,15 @@ func Init() error { return err } - slog.Info("extracting embedded files", "dir", payloadsDir) - binGlob := "build/*/*/*/bin/*" + if runtime.GOOS != "windows" { + slog.Info("extracting embedded files", "dir", payloadsDir) + binGlob := "build/*/*/*/bin/*" - // extract server libraries - err = extractFiles(payloadsDir, binGlob) - if err != nil { - return fmt.Errorf("extract binaries: %v", err) + // extract server libraries + err = extractFiles(payloadsDir, binGlob) + if err != nil { + return fmt.Errorf("extract binaries: %v", err) + } } var variants []string @@ -138,6 +141,23 @@ func serversForGpu(info gpu.GpuInfo) []string { return servers } +// Return the optimal server for this CPU architecture +func serverForCpu() string { + if runtime.GOOS == "darwin" && runtime.GOARCH == "arm64" { + return "metal" + } + variant := gpu.GetCPUVariant() + availableServers := availableServers() + if variant != "" { + for cmp := range availableServers { + if cmp == "cpu_"+variant { + return cmp + } + } + } + return "cpu" +} + // extract extracts the embedded files to the target directory func extractFiles(targetDir string, glob string) error { files, err := fs.Glob(libEmbed, glob) diff --git a/llm/server.go b/llm/server.go index 02780c2e..b452434e 100644 --- a/llm/server.go +++ b/llm/server.go @@ -21,21 +21,47 @@ import ( "strings" "time" + "golang.org/x/sync/semaphore" + "github.com/ollama/ollama/api" "github.com/ollama/ollama/format" "github.com/ollama/ollama/gpu" + "github.com/ollama/ollama/server/envconfig" ) -// LlamaServer is an instance of the llama.cpp server -type LlamaServer struct { +type LlamaServer interface { + Ping(ctx context.Context) error + WaitUntilRunning(ctx context.Context) error + Completion(ctx context.Context, req CompletionRequest, fn func(CompletionResponse)) error + Embedding(ctx context.Context, prompt string) ([]float64, error) + Tokenize(ctx context.Context, content string) ([]int, error) + Detokenize(ctx context.Context, tokens []int) (string, error) + Close() error + EstimatedVRAM() uint64 +} + +// llmServer is an instance of the llama.cpp server +type llmServer struct { port int cmd *exec.Cmd done chan error // Channel to signal when the process exits status *StatusWriter options api.Options + + // TODO - this should be broken down by GPU + estimatedVRAM uint64 // Estimated usage of VRAM by the loaded model + estimatedTotal uint64 // Total size of model + totalLayers uint64 + gpuCount int + + sem *semaphore.Weighted } -func NewLlamaServer(model string, adapters, projectors []string, opts api.Options) (*LlamaServer, error) { +func LoadModel(model string) (*GGML, error) { + if _, err := os.Stat(model); err != nil { + return nil, err + } + f, err := os.Open(model) if err != nil { return nil, err @@ -43,144 +69,69 @@ func NewLlamaServer(model string, adapters, projectors []string, opts api.Option defer f.Close() ggml, _, err := DecodeGGML(f) - if err != nil { - return nil, err - } + return ggml, err +} +// NewLlamaServer will run a server for the given GPUs +// The gpu list must be a single family. +func NewLlamaServer(gpus gpu.GpuInfoList, model string, ggml *GGML, adapters, projectors []string, opts api.Options) (LlamaServer, error) { + var err error if opts.NumCtx > int(ggml.KV().ContextLength()) { - slog.Warn("requested context length is greater than model max context length", "requested", opts.NumCtx, "model", ggml.KV().ContextLength()) - opts.NumCtx = int(ggml.KV().ContextLength()) + slog.Warn("requested context length is greater than the model's training context window size", "requested", opts.NumCtx, "training size", ggml.KV().ContextLength()) } if opts.NumCtx < 4 { opts.NumCtx = 4 } - memoryAvailable, _ := gpu.CheckVRAM() - info := gpu.GetGPUInfo() + cpuRunner := "" + var estimatedVRAM uint64 + var estimatedTotal uint64 + var systemMemory uint64 + gpuCount := len(gpus) + if (len(gpus) == 1 && gpus[0].Library == "cpu") || opts.NumGPU == 0 { - memoryMinimum := info.MinimumMemory - for _, projector := range projectors { - memoryMinimum += projectorMemoryRequirements(projector) + // TODO evaluate system memory to see if we should block the load, or force an unload of another CPU runner - // multimodal models require at least 2048 context - opts.NumCtx = max(opts.NumCtx, 2048) - } + cpuRunner = serverForCpu() + gpuCount = 0 + } else { + if gpus[0].Library == "metal" { + memInfo, err := gpu.GetCPUMem() + if err != nil { + slog.Error("failed to lookup system memory", "error", err) + } else { + systemMemory = memInfo.TotalMemory + slog.Debug("system memory", "total", format.HumanBytes2(systemMemory)) + } + } + var layers int + layers, estimatedVRAM, estimatedTotal = EstimateGPULayers(gpus, ggml, projectors, opts) - // fp16 k,v = (1 (k) + 1 (v)) * sizeof(float16) * n_ctx * n_layer * n_embd / n_head * n_head_kv - var kv uint64 = 2 * 2 * uint64(opts.NumCtx) * ggml.KV().BlockCount() * ggml.KV().EmbeddingLength() / ggml.KV().HeadCount() * ggml.KV().HeadCountKV() - - graphPartialOffload, graphFullOffload := ggml.GraphSize(uint64(opts.NumCtx), uint64(min(opts.NumCtx, opts.NumBatch))) - if graphPartialOffload == 0 { - graphPartialOffload = ggml.KV().GQA() * kv / 6 - } - - if graphFullOffload == 0 { - graphFullOffload = graphPartialOffload - } - - graphFullOffload *= uint64(info.DeviceCount) - graphPartialOffload *= uint64(info.DeviceCount) - - // memoryRequiredTotal represents the memory required for full GPU offloading (all layers) - memoryRequiredTotal := memoryMinimum + graphFullOffload - - // memoryRequiredPartial represents the memory required for partial GPU offloading (n > 0, n < layers) - memoryRequiredPartial := memoryMinimum + graphPartialOffload - - if info.Library != "metal" { - if memoryRequiredPartial > memoryAvailable { - info.Library = "cpu" + if gpus[0].Library == "metal" && estimatedVRAM > systemMemory { + // disable partial offloading when model is greater than total system memory as this + // can lead to locking up the system + opts.NumGPU = 0 + } else if opts.NumGPU < 0 && layers > 0 && gpus[0].Library != "cpu" { + opts.NumGPU = layers } } - var layerCount int - layers := ggml.Tensors().Layers() - for i := 0; i < int(ggml.KV().BlockCount()); i++ { - memoryLayer := layers[fmt.Sprintf("blk.%d", i)].size() - - // KV is proportional to the number of layers - memoryLayer += kv / ggml.KV().BlockCount() - - memoryRequiredTotal += memoryLayer - if memoryAvailable > memoryRequiredPartial+memoryLayer { - memoryRequiredPartial += memoryLayer - layerCount++ - } - } - - var memoryLayerOutput uint64 - for k, v := range layers { - if !strings.HasPrefix(k, "blk.") { - memoryLayerOutput += v.size() - } - } - - memoryRequiredTotal += memoryLayerOutput - - if info.Library == "metal" && memoryRequiredTotal > info.TotalMemory { - // disable partial offloading when model is greater than total system memory - opts.NumGPU = 0 - } else if memoryAvailable > memoryRequiredTotal { - layerCount = int(ggml.KV().BlockCount()) + 1 - memoryRequiredPartial = memoryRequiredTotal - } - - if opts.NumGPU < 0 { - opts.NumGPU = layerCount - } - - memoryWeights := memoryRequiredTotal - memoryMinimum - graphFullOffload - kv - - slog.Info( - "offload to gpu", - slog.Group( - "layers", - // actual number of layers offloaded - "real", opts.NumGPU, - // estimated number of layers that can be offloaded - "estimate", layerCount, - ), - slog.Group( - "memory", - // memory available for offloading - "available", format.HumanBytes2(memoryAvailable), - slog.Group( - "required", - // memory required for full offloading - "full", format.HumanBytes2(memoryRequiredTotal), - // memory required to offload layers.estimate layers - "partial", format.HumanBytes2(memoryRequiredPartial), - // memory of KV cache - "kv", format.HumanBytes2(kv), - ), - slog.Group( - "weights", - // memory of the weights - "total", format.HumanBytes2(memoryWeights), - // memory of repeating layers - "repeating", format.HumanBytes2(memoryWeights-memoryLayerOutput), - // memory of non-repeating layers - "nonrepeating", format.HumanBytes2(memoryLayerOutput), - ), - slog.Group( - "graph", - // memory of graph when fully offloaded - "full", format.HumanBytes2(graphFullOffload), - // memory of graph when not fully offloaded - "partial", format.HumanBytes2(graphPartialOffload), - ), - ), - ) + // Loop through potential servers + finalErr := fmt.Errorf("no suitable llama servers found") if len(adapters) > 1 { return nil, errors.New("ollama supports only one lora adapter, but multiple were provided") } availableServers := availableServers() - servers := serversForGpu(info) - - demandLib := os.Getenv("OLLAMA_LLM_LIBRARY") + var servers []string + if cpuRunner != "" { + servers = []string{cpuRunner} + } else { + servers = serversForGpu(gpus[0]) // All GPUs in the list are matching Library and Variant + } + demandLib := envconfig.LLMLibrary if demandLib != "" { serverPath := availableServers[demandLib] if serverPath == "" { @@ -188,11 +139,15 @@ func NewLlamaServer(model string, adapters, projectors []string, opts api.Option } else { slog.Info("user override", "OLLAMA_LLM_LIBRARY", demandLib, "path", serverPath) servers = []string{demandLib} + if strings.HasPrefix(demandLib, "cpu") { + // Omit the GPU flag to silence the warning + opts.NumGPU = -1 + } } } if len(servers) == 0 { - return nil, fmt.Errorf("no servers found for %v", info) + return nil, fmt.Errorf("no servers found for %v", gpus) } params := []string{ @@ -201,7 +156,7 @@ func NewLlamaServer(model string, adapters, projectors []string, opts api.Option "--batch-size", fmt.Sprintf("%d", opts.NumBatch), "--embedding", } - if debug := os.Getenv("OLLAMA_DEBUG"); debug != "" { + if envconfig.Debug { params = append(params, "--log-format", "json") } else { params = append(params, "--log-disable") @@ -211,7 +166,7 @@ func NewLlamaServer(model string, adapters, projectors []string, opts api.Option params = append(params, "--n-gpu-layers", fmt.Sprintf("%d", opts.NumGPU)) } - if debug := os.Getenv("OLLAMA_DEBUG"); debug != "" { + if envconfig.Debug { params = append(params, "--verbose") } @@ -249,10 +204,30 @@ func NewLlamaServer(model string, adapters, projectors []string, opts api.Option params = append(params, "--numa") } - // Loop through potential servers - var finalErr error + numParallel := envconfig.NumParallel + + // TODO (jmorganca): multimodal models don't support parallel yet + // see https://github.com/ollama/ollama/issues/4165 + if len(projectors) > 0 { + numParallel = 1 + slog.Warn("multimodal models don't support parallel requests yet") + } + + params = append(params, "--parallel", fmt.Sprintf("%d", numParallel)) + for i := 0; i < len(servers); i++ { dir := availableServers[servers[i]] + if dir == "" { + // Shouldn't happen + finalErr = fmt.Errorf("[%d] server %s not listed in available servers %v", i, servers[i], availableServers) + slog.Error("sever list inconsistent", "error", finalErr) + continue + } + + if strings.HasPrefix(servers[i], "cpu") { + // TODO if we tried a gpu runner first, and it failed, record the error and bubble that back up + gpuCount = 0 + } // Find an availableServers port, retry on each iterration in case the failure was a port conflict race port := 0 @@ -273,12 +248,21 @@ func NewLlamaServer(model string, adapters, projectors []string, opts api.Option if runtime.GOOS == "windows" { pathEnv = "PATH" } - // append the server directory to LD_LIBRARY_PATH/PATH + // prepend the server directory to LD_LIBRARY_PATH/PATH libraryPaths := []string{dir} + if libraryPath, ok := os.LookupEnv(pathEnv); ok { // Append our runner directory to the path // This will favor system libraries over our bundled library dependencies - libraryPaths = append(filepath.SplitList(libraryPath), libraryPaths...) + libraryPaths = append(libraryPaths, filepath.SplitList(libraryPath)...) + } + + // Note: we always put the dependency path first + // since this was the exact version we verified for AMD GPUs + // and we favor what the user had in their path + if gpus[0].DependencyPath != "" { + // TODO refine for multi-gpu support + libraryPaths = append([]string{gpus[0].DependencyPath}, libraryPaths...) } server := filepath.Join(dir, "ollama_llama_server") @@ -286,21 +270,66 @@ func NewLlamaServer(model string, adapters, projectors []string, opts api.Option server = server + ".exe" } - s := &LlamaServer{ - port: port, - cmd: exec.Command(server, finalParams...), - status: NewStatusWriter(os.Stderr), - options: opts, + // Detect tmp cleaners wiping out the file + _, err := os.Stat(server) + if errors.Is(err, os.ErrNotExist) { + slog.Warn("llama server disappeared, reinitializing payloads", "path", server, "error", err) + err = Init() + if err != nil { + slog.Warn("failed to reinitialize payloads", "error", err) + return nil, err + } } - libEnv := fmt.Sprintf("%s=%s", pathEnv, strings.Join(libraryPaths, string(filepath.ListSeparator))) - slog.Debug(libEnv) - s.cmd.Env = append(os.Environ(), libEnv) + + s := &llmServer{ + port: port, + cmd: exec.Command(server, finalParams...), + status: NewStatusWriter(os.Stderr), + options: opts, + estimatedVRAM: estimatedVRAM, + estimatedTotal: estimatedTotal, + sem: semaphore.NewWeighted(int64(numParallel)), + totalLayers: ggml.KV().BlockCount() + 1, + gpuCount: gpuCount, + } + + s.cmd.Env = os.Environ() s.cmd.Stdout = os.Stdout s.cmd.Stderr = s.status + visibleDevicesEnv, visibleDevicesEnvVal := gpu.GpuInfoList(gpus).GetVisibleDevicesEnv() + pathEnvVal := strings.Join(libraryPaths, string(filepath.ListSeparator)) + + // Update or add the path and visible devices variable with our adjusted version + pathNeeded := true + devicesNeeded := visibleDevicesEnv != "" + for i := range s.cmd.Env { + cmp := strings.SplitN(s.cmd.Env[i], "=", 2) + if strings.EqualFold(cmp[0], pathEnv) { + s.cmd.Env[i] = pathEnv + "=" + pathEnvVal + pathNeeded = false + } else if devicesNeeded && strings.EqualFold(cmp[0], visibleDevicesEnv) { + s.cmd.Env[i] = visibleDevicesEnv + "=" + visibleDevicesEnvVal + devicesNeeded = false + } + } + if pathNeeded { + s.cmd.Env = append(s.cmd.Env, pathEnv+"="+pathEnvVal) + } + if devicesNeeded { + s.cmd.Env = append(s.cmd.Env, visibleDevicesEnv+"="+visibleDevicesEnvVal) + } + slog.Info("starting llama server", "cmd", s.cmd.String()) + // Log at debug as the environment is inherited and might contain sensitive information + slog.Debug("subprocess", "environment", s.cmd.Env) if err = s.cmd.Start(); err != nil { + // Detect permission denied and augment them essage about noexec + if errors.Is(err, os.ErrPermission) { + finalErr = fmt.Errorf("unable to start server %w. %s may have noexec set. Set OLLAMA_TMPDIR for server to a writable executable directory", err, dir) + continue + } msg := "" if s.status != nil && s.status.LastErrMsg != "" { msg = s.status.LastErrMsg @@ -310,12 +339,6 @@ func NewLlamaServer(model string, adapters, projectors []string, opts api.Option continue } - // reap subprocess when it exits - go func() { - // Exit status managed via getServerStatus - _ = s.cmd.Wait() - }() - return s, nil } @@ -347,12 +370,27 @@ type ServerStatus int const ( // iota is reset to 0 ServerStatusReady ServerStatus = iota - ServerStatusNoSlotsAvaialble + ServerStatusNoSlotsAvailable ServerStatusLoadingModel ServerStatusNotResponding ServerStatusError ) +func (s ServerStatus) ToString() string { + switch s { + case ServerStatusReady: + return "llm server ready" + case ServerStatusNoSlotsAvailable: + return "llm busy - no slots available" + case ServerStatusLoadingModel: + return "llm server loading model" + case ServerStatusNotResponding: + return "llm server not responding" + default: + return "llm server error" + } +} + type ServerStatusResp struct { Status string `json:"status"` SlotsIdle int `json:"slots_idle"` @@ -360,13 +398,17 @@ type ServerStatusResp struct { Error string `json:"error"` } -func (s *LlamaServer) getServerStatus(ctx context.Context) (ServerStatus, error) { +func (s *llmServer) getServerStatus(ctx context.Context) (ServerStatus, error) { // Fail fast if its exited if s.cmd.ProcessState != nil { msg := "" if s.status != nil && s.status.LastErrMsg != "" { msg = s.status.LastErrMsg } + if s.cmd.ProcessState.ExitCode() == -1 { + // Most likely a signal killed it, log some more details to try to help troubleshoot + slog.Warn("llama runner process no longer running", "sys", s.cmd.ProcessState.Sys(), "string", s.cmd.ProcessState.String()) + } return ServerStatusError, fmt.Errorf("llama runner process no longer running: %d %s", s.cmd.ProcessState.ExitCode(), msg) } @@ -399,7 +441,7 @@ func (s *LlamaServer) getServerStatus(ctx context.Context) (ServerStatus, error) case "ok": return ServerStatusReady, nil case "no slot available": - return ServerStatusNoSlotsAvaialble, nil + return ServerStatusNoSlotsAvailable, nil case "loading model": return ServerStatusLoadingModel, nil default: @@ -407,7 +449,30 @@ func (s *LlamaServer) getServerStatus(ctx context.Context) (ServerStatus, error) } } -func (s *LlamaServer) Ping(ctx context.Context) error { +// getServerStatusRetry will retry if ServerStatusNoSlotsAvailable is received +func (s *llmServer) getServerStatusRetry(ctx context.Context) (ServerStatus, error) { + var retries int + for { + status, err := s.getServerStatus(ctx) + if err != nil { + return status, err + } + + if status == ServerStatusNoSlotsAvailable { + if retries >= 10 { + return status, fmt.Errorf("no slots available after %d retries", retries) + } + + time.Sleep(5 * time.Millisecond) + retries++ + continue + } + + return status, nil + } +} + +func (s *llmServer) Ping(ctx context.Context) error { _, err := s.getServerStatus(ctx) if err != nil { slog.Debug("server unhealthy", "error", err) @@ -416,13 +481,25 @@ func (s *LlamaServer) Ping(ctx context.Context) error { return nil } -func (s *LlamaServer) WaitUntilRunning() error { +func (s *llmServer) WaitUntilRunning(ctx context.Context) error { start := time.Now() expiresAt := time.Now().Add(10 * time.Minute) // be generous with timeout, large models can take a while to load slog.Info("waiting for llama runner to start responding") for { + select { + case <-ctx.Done(): + slog.Info("context expired before server started") + return fmt.Errorf("timed out waiting for llama runner to start: %w", ctx.Err()) + case err := <-s.done: + msg := "" + if s.status != nil && s.status.LastErrMsg != "" { + msg = s.status.LastErrMsg + } + return fmt.Errorf("llama runner process has terminated: %v %s", err, msg) + default: + } ctx, cancel := context.WithTimeout(context.Background(), 200*time.Millisecond) defer cancel() status, err := s.getServerStatus(ctx) @@ -487,7 +564,6 @@ ws ::= ([ \t\n] ws)? ` const maxBufferSize = 512 * format.KiloByte -const maxRetries = 3 type ImageData struct { Data []byte `json:"data"` @@ -524,7 +600,19 @@ type CompletionResponse struct { EvalDuration time.Duration } -func (s *LlamaServer) Completion(ctx context.Context, req CompletionRequest, fn func(CompletionResponse)) error { +func (s *llmServer) Completion(ctx context.Context, req CompletionRequest, fn func(CompletionResponse)) error { + if err := s.sem.Acquire(ctx, 1); err != nil { + slog.Error("Failed to acquire semaphore", "error", err) + return err + } + defer s.sem.Release(1) + + // only allow maximum 10 "context shifts" to avoid infinite generation + if req.Options.NumPredict < 0 || req.Options.NumPredict > 10*s.options.NumCtx { + req.Options.NumPredict = 10 * s.options.NumCtx + slog.Debug("setting token limit to 10x num_ctx", "num_ctx", s.options.NumCtx, "num_predict", req.Options.NumPredict) + } + request := map[string]any{ "prompt": req.Prompt, "stream": true, @@ -551,11 +639,11 @@ func (s *LlamaServer) Completion(ctx context.Context, req CompletionRequest, fn } // Make sure the server is ready - status, err := s.getServerStatus(ctx) + status, err := s.getServerStatusRetry(ctx) if err != nil { return err } else if status != ServerStatusReady { - return fmt.Errorf("unexpected server status: %d", status) + return fmt.Errorf("unexpected server status: %s", status.ToString()) } if req.Format == "json" { @@ -565,133 +653,113 @@ func (s *LlamaServer) Completion(ctx context.Context, req CompletionRequest, fn } } - retryDelay := 100 * time.Microsecond - for retries := 0; retries < maxRetries; retries++ { - if retries > 0 { - time.Sleep(retryDelay) // wait before retrying - retryDelay *= 2 // exponential backoff - } + // Handling JSON marshaling with special characters unescaped. + buffer := &bytes.Buffer{} + enc := json.NewEncoder(buffer) + enc.SetEscapeHTML(false) - // Handling JSON marshaling with special characters unescaped. - buffer := &bytes.Buffer{} - enc := json.NewEncoder(buffer) - enc.SetEscapeHTML(false) + if err := enc.Encode(request); err != nil { + return fmt.Errorf("failed to marshal data: %v", err) + } - if err := enc.Encode(request); err != nil { - return fmt.Errorf("failed to marshal data: %v", err) - } + endpoint := fmt.Sprintf("http://127.0.0.1:%d/completion", s.port) + serverReq, err := http.NewRequestWithContext(ctx, http.MethodPost, endpoint, buffer) + if err != nil { + return fmt.Errorf("error creating POST request: %v", err) + } + serverReq.Header.Set("Content-Type", "application/json") - endpoint := fmt.Sprintf("http://127.0.0.1:%d/completion", s.port) - req, err := http.NewRequestWithContext(ctx, http.MethodPost, endpoint, buffer) + res, err := http.DefaultClient.Do(serverReq) + if err != nil { + return fmt.Errorf("POST predict: %v", err) + } + defer res.Body.Close() + + if res.StatusCode >= 400 { + bodyBytes, err := io.ReadAll(res.Body) if err != nil { - return fmt.Errorf("error creating POST request: %v", err) + return fmt.Errorf("failed reading llm error response: %w", err) } - req.Header.Set("Content-Type", "application/json") + log.Printf("llm predict error: %s", bodyBytes) + return fmt.Errorf("%s", bodyBytes) + } - resp, err := http.DefaultClient.Do(req) - if err != nil { - return fmt.Errorf("POST predict: %v", err) - } - defer resp.Body.Close() + scanner := bufio.NewScanner(res.Body) + buf := make([]byte, 0, maxBufferSize) + scanner.Buffer(buf, maxBufferSize) - if resp.StatusCode >= 400 { - bodyBytes, err := io.ReadAll(resp.Body) - if err != nil { - return fmt.Errorf("failed reading llm error response: %w", err) + // keep track of the last token generated, this is used to abort if the model starts looping + var lastToken string + var tokenRepeat int + + for scanner.Scan() { + select { + case <-ctx.Done(): + // This handles the request cancellation + return ctx.Err() + default: + line := scanner.Bytes() + if len(line) == 0 { + continue } - log.Printf("llm predict error: %s", bodyBytes) - return fmt.Errorf("%s", bodyBytes) - } - scanner := bufio.NewScanner(resp.Body) - buf := make([]byte, 0, maxBufferSize) - scanner.Buffer(buf, maxBufferSize) + evt, ok := bytes.CutPrefix(line, []byte("data: ")) + if !ok { + return fmt.Errorf("error parsing llm response stream: %s", line) + } - retryNeeded := false - // keep track of the last token generated, this is used to abort if the model starts looping - var lastToken string - var tokenRepeat int + var c completion + if err := json.Unmarshal(evt, &c); err != nil { + return fmt.Errorf("error unmarshaling llm prediction response: %v", err) + } - for scanner.Scan() { - select { - case <-ctx.Done(): - // This handles the request cancellation - return ctx.Err() + switch { + case strings.TrimSpace(c.Content) == lastToken: + tokenRepeat++ default: - line := scanner.Bytes() - if len(line) == 0 { - continue - } - - // try again on slot unavailable - if bytes.Contains(line, []byte("slot unavailable")) { - retryNeeded = true - break - } - - evt, ok := bytes.CutPrefix(line, []byte("data: ")) - if !ok { - return fmt.Errorf("error parsing llm response stream: %s", line) - } - - var c completion - if err := json.Unmarshal(evt, &c); err != nil { - return fmt.Errorf("error unmarshaling llm prediction response: %v", err) - } - - switch { - case strings.TrimSpace(c.Content) == lastToken: - tokenRepeat++ - default: - lastToken = strings.TrimSpace(c.Content) - tokenRepeat = 0 - } - - // 30 picked as an arbitrary max token repeat limit, modify as needed - if tokenRepeat > 30 { - slog.Debug("prediction aborted, token repeat limit reached") - return ctx.Err() - } - - if c.Content != "" { - fn(CompletionResponse{ - Content: c.Content, - }) - } - - if c.Stop { - fn(CompletionResponse{ - Done: true, - PromptEvalCount: c.Timings.PromptN, - PromptEvalDuration: parseDurationMs(c.Timings.PromptMS), - EvalCount: c.Timings.PredictedN, - EvalDuration: parseDurationMs(c.Timings.PredictedMS), - }) - return nil - } + lastToken = strings.TrimSpace(c.Content) + tokenRepeat = 0 } - } - if err := scanner.Err(); err != nil { - if strings.Contains(err.Error(), "unexpected EOF") { - s.Close() - msg := "" - if s.status != nil && s.status.LastErrMsg != "" { - msg = s.status.LastErrMsg - } - - return fmt.Errorf("an unknown error was encountered while running the model %s", msg) + // 30 picked as an arbitrary max token repeat limit, modify as needed + if tokenRepeat > 30 { + slog.Debug("prediction aborted, token repeat limit reached") + return ctx.Err() } - return fmt.Errorf("error reading llm response: %v", err) - } - if !retryNeeded { - return nil // success + if c.Content != "" { + fn(CompletionResponse{ + Content: c.Content, + }) + } + + if c.Stop { + fn(CompletionResponse{ + Done: true, + PromptEvalCount: c.Timings.PromptN, + PromptEvalDuration: parseDurationMs(c.Timings.PromptMS), + EvalCount: c.Timings.PredictedN, + EvalDuration: parseDurationMs(c.Timings.PredictedMS), + }) + return nil + } } } - // should never reach here ideally - return fmt.Errorf("max retries exceeded") + if err := scanner.Err(); err != nil { + if strings.Contains(err.Error(), "unexpected EOF") { + s.Close() + msg := "" + if s.status != nil && s.status.LastErrMsg != "" { + msg = s.status.LastErrMsg + } + return fmt.Errorf("an unknown error was encountered while running the model %s", msg) + } + + return fmt.Errorf("error reading llm response: %v", err) + } + + return nil } type EmbeddingRequest struct { @@ -702,13 +770,19 @@ type EmbeddingResponse struct { Embedding []float64 `json:"embedding"` } -func (s *LlamaServer) Embedding(ctx context.Context, prompt string) ([]float64, error) { +func (s *llmServer) Embedding(ctx context.Context, prompt string) ([]float64, error) { + if err := s.sem.Acquire(ctx, 1); err != nil { + slog.Error("Failed to acquire semaphore", "error", err) + return nil, err + } + defer s.sem.Release(1) + // Make sure the server is ready - status, err := s.getServerStatus(ctx) + status, err := s.getServerStatusRetry(ctx) if err != nil { return nil, err } else if status != ServerStatusReady { - return nil, fmt.Errorf("unexpected server status: %d", status) + return nil, fmt.Errorf("unexpected server status: %s", status.ToString()) } data, err := json.Marshal(TokenizeRequest{Content: prompt}) @@ -754,13 +828,13 @@ type TokenizeResponse struct { Tokens []int `json:"tokens"` } -func (s *LlamaServer) Tokenize(ctx context.Context, content string) ([]int, error) { +func (s *llmServer) Tokenize(ctx context.Context, content string) ([]int, error) { // Make sure the server is ready status, err := s.getServerStatus(ctx) if err != nil { return nil, err - } else if status != ServerStatusReady { - return nil, fmt.Errorf("unexpected server status: %d", status) + } else if status != ServerStatusReady && status != ServerStatusNoSlotsAvailable { + return nil, fmt.Errorf("unexpected server status: %s", status.ToString()) } data, err := json.Marshal(TokenizeRequest{Content: content}) @@ -806,13 +880,13 @@ type DetokenizeResponse struct { Content string `json:"content"` } -func (s *LlamaServer) Detokenize(ctx context.Context, tokens []int) (string, error) { +func (s *llmServer) Detokenize(ctx context.Context, tokens []int) (string, error) { // Make sure the server is ready status, err := s.getServerStatus(ctx) if err != nil { return "", err - } else if status != ServerStatusReady { - return "", fmt.Errorf("unexpected server status: %d", status) + } else if status != ServerStatusReady && status != ServerStatusNoSlotsAvailable { + return "", fmt.Errorf("unexpected server status: %s", status.ToString()) } data, err := json.Marshal(DetokenizeRequest{Tokens: tokens}) @@ -850,15 +924,25 @@ func (s *LlamaServer) Detokenize(ctx context.Context, tokens []int) (string, err return decoded.Content, nil } -func (s *LlamaServer) Close() error { +func (s *llmServer) Close() error { if s.cmd != nil { slog.Debug("stopping llama server") - return s.cmd.Process.Kill() + if err := s.cmd.Process.Kill(); err != nil { + return err + } + + _ = s.cmd.Wait() + + slog.Debug("llama server stopped") } return nil } +func (s *llmServer) EstimatedVRAM() uint64 { + return s.estimatedVRAM +} + func parseDurationMs(ms float64) time.Duration { dur, err := time.ParseDuration(fmt.Sprintf("%fms", ms)) if err != nil { diff --git a/macapp/src/app.tsx b/macapp/src/app.tsx index fc1df21c..ab17df60 100644 --- a/macapp/src/app.tsx +++ b/macapp/src/app.tsx @@ -19,7 +19,7 @@ export default function () { const [step, setStep] = useState(Step.WELCOME) const [commandCopied, setCommandCopied] = useState(false) - const command = 'ollama run llama2' + const command = 'ollama run llama3' return (
diff --git a/parser/parser.go b/parser/parser.go deleted file mode 100644 index 947848b2..00000000 --- a/parser/parser.go +++ /dev/null @@ -1,132 +0,0 @@ -package parser - -import ( - "bufio" - "bytes" - "errors" - "fmt" - "io" - "log/slog" - "slices" -) - -type Command struct { - Name string - Args string -} - -func (c *Command) Reset() { - c.Name = "" - c.Args = "" -} - -func Parse(reader io.Reader) ([]Command, error) { - var commands []Command - var command, modelCommand Command - - scanner := bufio.NewScanner(reader) - scanner.Buffer(make([]byte, 0, bufio.MaxScanTokenSize), bufio.MaxScanTokenSize) - scanner.Split(scanModelfile) - for scanner.Scan() { - line := scanner.Bytes() - - fields := bytes.SplitN(line, []byte(" "), 2) - if len(fields) == 0 || len(fields[0]) == 0 { - continue - } - - switch string(bytes.ToUpper(fields[0])) { - case "FROM": - command.Name = "model" - command.Args = string(bytes.TrimSpace(fields[1])) - // copy command for validation - modelCommand = command - case "ADAPTER": - command.Name = string(bytes.ToLower(fields[0])) - command.Args = string(bytes.TrimSpace(fields[1])) - case "LICENSE", "TEMPLATE", "SYSTEM", "PROMPT": - command.Name = string(bytes.ToLower(fields[0])) - command.Args = string(fields[1]) - case "PARAMETER": - fields = bytes.SplitN(fields[1], []byte(" "), 2) - if len(fields) < 2 { - return nil, fmt.Errorf("missing value for %s", fields) - } - - command.Name = string(fields[0]) - command.Args = string(bytes.TrimSpace(fields[1])) - case "EMBED": - return nil, fmt.Errorf("deprecated command: EMBED is no longer supported, use the /embed API endpoint instead") - case "MESSAGE": - command.Name = string(bytes.ToLower(fields[0])) - fields = bytes.SplitN(fields[1], []byte(" "), 2) - if len(fields) < 2 { - return nil, fmt.Errorf("should be in the format ") - } - if !slices.Contains([]string{"system", "user", "assistant"}, string(bytes.ToLower(fields[0]))) { - return nil, fmt.Errorf("role must be one of \"system\", \"user\", or \"assistant\"") - } - command.Args = fmt.Sprintf("%s: %s", string(bytes.ToLower(fields[0])), string(fields[1])) - default: - if !bytes.HasPrefix(fields[0], []byte("#")) { - // log a warning for unknown commands - slog.Warn(fmt.Sprintf("Unknown command: %s", fields[0])) - } - continue - } - - commands = append(commands, command) - command.Reset() - } - - if modelCommand.Args == "" { - return nil, errors.New("no FROM line for the model was specified") - } - - return commands, scanner.Err() -} - -func scanModelfile(data []byte, atEOF bool) (advance int, token []byte, err error) { - advance, token, err = scan([]byte(`"""`), []byte(`"""`), data, atEOF) - if err != nil { - return 0, nil, err - } - - if advance > 0 && token != nil { - return advance, token, nil - } - - advance, token, err = scan([]byte(`"`), []byte(`"`), data, atEOF) - if err != nil { - return 0, nil, err - } - - if advance > 0 && token != nil { - return advance, token, nil - } - - return bufio.ScanLines(data, atEOF) -} - -func scan(openBytes, closeBytes, data []byte, atEOF bool) (advance int, token []byte, err error) { - newline := bytes.IndexByte(data, '\n') - - if start := bytes.Index(data, openBytes); start >= 0 && start < newline { - end := bytes.Index(data[start+len(openBytes):], closeBytes) - if end < 0 { - if atEOF { - return 0, nil, fmt.Errorf("unterminated %s: expecting %s", openBytes, closeBytes) - } else { - return 0, nil, nil - } - } - - n := start + len(openBytes) + end + len(closeBytes) - - newData := data[:start] - newData = append(newData, data[start+len(openBytes):n-len(closeBytes)]...) - return n, newData, nil - } - - return 0, nil, nil -} diff --git a/parser/parser_test.go b/parser/parser_test.go deleted file mode 100644 index 25e849b5..00000000 --- a/parser/parser_test.go +++ /dev/null @@ -1,98 +0,0 @@ -package parser - -import ( - "strings" - "testing" - - "github.com/stretchr/testify/assert" -) - -func Test_Parser(t *testing.T) { - - input := ` -FROM model1 -ADAPTER adapter1 -LICENSE MIT -PARAMETER param1 value1 -PARAMETER param2 value2 -TEMPLATE template1 -` - - reader := strings.NewReader(input) - - commands, err := Parse(reader) - assert.Nil(t, err) - - expectedCommands := []Command{ - {Name: "model", Args: "model1"}, - {Name: "adapter", Args: "adapter1"}, - {Name: "license", Args: "MIT"}, - {Name: "param1", Args: "value1"}, - {Name: "param2", Args: "value2"}, - {Name: "template", Args: "template1"}, - } - - assert.Equal(t, expectedCommands, commands) -} - -func Test_Parser_NoFromLine(t *testing.T) { - - input := ` -PARAMETER param1 value1 -PARAMETER param2 value2 -` - - reader := strings.NewReader(input) - - _, err := Parse(reader) - assert.ErrorContains(t, err, "no FROM line") -} - -func Test_Parser_MissingValue(t *testing.T) { - - input := ` -FROM foo -PARAMETER param1 -` - - reader := strings.NewReader(input) - - _, err := Parse(reader) - assert.ErrorContains(t, err, "missing value for [param1]") - -} - -func Test_Parser_Messages(t *testing.T) { - - input := ` -FROM foo -MESSAGE system You are a Parser. Always Parse things. -MESSAGE user Hey there! -MESSAGE assistant Hello, I want to parse all the things! -` - - reader := strings.NewReader(input) - commands, err := Parse(reader) - assert.Nil(t, err) - - expectedCommands := []Command{ - {Name: "model", Args: "foo"}, - {Name: "message", Args: "system: You are a Parser. Always Parse things."}, - {Name: "message", Args: "user: Hey there!"}, - {Name: "message", Args: "assistant: Hello, I want to parse all the things!"}, - } - - assert.Equal(t, expectedCommands, commands) -} - -func Test_Parser_Messages_BadRole(t *testing.T) { - - input := ` -FROM foo -MESSAGE badguy I'm a bad guy! -` - - reader := strings.NewReader(input) - _, err := Parse(reader) - assert.ErrorContains(t, err, "role must be one of \"system\", \"user\", or \"assistant\"") -} diff --git a/readline/readline.go b/readline/readline.go index 8ba7d89c..6fa45391 100644 --- a/readline/readline.go +++ b/readline/readline.go @@ -218,7 +218,7 @@ func (i *Instance) Readline() (string, error) { case CharCtrlZ: fd := int(syscall.Stdin) return handleCharCtrlZ(fd, i.Terminal.termios) - case CharEnter: + case CharEnter, CharCtrlJ: output := buf.String() if output != "" { i.History.Add([]rune(output)) @@ -232,7 +232,7 @@ func (i *Instance) Readline() (string, error) { metaDel = false continue } - if r >= CharSpace || r == CharEnter { + if r >= CharSpace || r == CharEnter || r == CharCtrlJ { buf.Add(r) } } diff --git a/scripts/build_windows.ps1 b/scripts/build_windows.ps1 index 1a89045a..60de0307 100644 --- a/scripts/build_windows.ps1 +++ b/scripts/build_windows.ps1 @@ -7,6 +7,8 @@ $ErrorActionPreference = "Stop" function checkEnv() { + $script:TARGET_ARCH=$Env:PROCESSOR_ARCHITECTURE.ToLower() + Write-host "Building for ${script:TARGET_ARCH}" write-host "Locating required tools and paths" $script:SRC_DIR=$PWD if (!$env:VCToolsRedistDir) { @@ -30,7 +32,7 @@ function checkEnv() { $script:INNO_SETUP_DIR=(get-item "C:\Program Files*\Inno Setup*\")[0] - $script:DEPS_DIR="${script:SRC_DIR}\dist\windeps" + $script:DEPS_DIR="${script:SRC_DIR}\dist\windows-${script:TARGET_ARCH}" $env:CGO_ENABLED="1" echo "Checking version" if (!$env:VERSION) { @@ -81,8 +83,8 @@ function buildOllama() { /csp "Google Cloud KMS Provider" /kc ${env:KEY_CONTAINER} ollama.exe if ($LASTEXITCODE -ne 0) { exit($LASTEXITCODE)} } - New-Item -ItemType Directory -Path .\dist -Force - cp .\ollama.exe .\dist\ollama-windows-amd64.exe + New-Item -ItemType Directory -Path .\dist\windows-${script:TARGET_ARCH}\ -Force + cp .\ollama.exe .\dist\windows-${script:TARGET_ARCH}\ } function buildApp() { @@ -101,7 +103,6 @@ function buildApp() { function gatherDependencies() { write-host "Gathering runtime dependencies" cd "${script:SRC_DIR}" - rm -ea 0 -recurse -force -path "${script:DEPS_DIR}" md "${script:DEPS_DIR}" -ea 0 > $null # TODO - this varies based on host build system and MSVC version - drive from dumpbin output @@ -110,9 +111,6 @@ function gatherDependencies() { cp "${env:VCToolsRedistDir}\x64\Microsoft.VC*.CRT\vcruntime140.dll" "${script:DEPS_DIR}\" cp "${env:VCToolsRedistDir}\x64\Microsoft.VC*.CRT\vcruntime140_1.dll" "${script:DEPS_DIR}\" - cp "${script:NVIDIA_DIR}\cudart64_*.dll" "${script:DEPS_DIR}\" - cp "${script:NVIDIA_DIR}\cublas64_*.dll" "${script:DEPS_DIR}\" - cp "${script:NVIDIA_DIR}\cublasLt64_*.dll" "${script:DEPS_DIR}\" cp "${script:SRC_DIR}\app\ollama_welcome.ps1" "${script:SRC_DIR}\dist\" if ("${env:KEY_CONTAINER}") { @@ -124,7 +122,6 @@ function gatherDependencies() { if ($LASTEXITCODE -ne 0) { exit($LASTEXITCODE)} } } - } function buildInstaller() { @@ -132,19 +129,25 @@ function buildInstaller() { cd "${script:SRC_DIR}\app" $env:PKG_VERSION=$script:PKG_VERSION if ("${env:KEY_CONTAINER}") { - & "${script:INNO_SETUP_DIR}\ISCC.exe" /SMySignTool="${script:SignTool} sign /fd sha256 /t http://timestamp.digicert.com /f ${script:OLLAMA_CERT} /csp `$qGoogle Cloud KMS Provider`$q /kc ${env:KEY_CONTAINER} `$f" .\ollama.iss + & "${script:INNO_SETUP_DIR}\ISCC.exe" /DARCH=$script:TARGET_ARCH /SMySignTool="${script:SignTool} sign /fd sha256 /t http://timestamp.digicert.com /f ${script:OLLAMA_CERT} /csp `$qGoogle Cloud KMS Provider`$q /kc ${env:KEY_CONTAINER} `$f" .\ollama.iss } else { - & "${script:INNO_SETUP_DIR}\ISCC.exe" .\ollama.iss + & "${script:INNO_SETUP_DIR}\ISCC.exe" /DARCH=$script:TARGET_ARCH .\ollama.iss } if ($LASTEXITCODE -ne 0) { exit($LASTEXITCODE)} } +function distZip() { + write-host "Generating stand-alone distribution zip file ${script:SRC_DIR}\dist\ollama-windows-${script:TARGET_ARCH}.zip" + Compress-Archive -Path "${script:SRC_DIR}\dist\windows-${script:TARGET_ARCH}\*" -DestinationPath "${script:SRC_DIR}\dist\ollama-windows-${script:TARGET_ARCH}.zip" -Force +} + try { checkEnv buildOllama buildApp gatherDependencies buildInstaller + distZip } catch { write-host "Build Failed" write-host $_ diff --git a/scripts/install.sh b/scripts/install.sh index eb3ff504..20b0db60 100644 --- a/scripts/install.sh +++ b/scripts/install.sh @@ -166,8 +166,8 @@ fi if check_gpu lspci amdgpu || check_gpu lshw amdgpu; then # Look for pre-existing ROCm v6 before downloading the dependencies - for search in "${HIP_PATH:-''}" "${ROCM_PATH:-''}" "/opt/rocm"; do - if [ -n "${search}" ] && [ -e "${search}/lib/libhipblas.so.2" ]; then + for search in "${HIP_PATH:-''}" "${ROCM_PATH:-''}" "/opt/rocm" "/usr/lib64"; do + if [ -n "${search}" ] && [ -e "${search}/libhipblas.so.2" -o -e "${search}/lib/libhipblas.so.2" ]; then status "Compatible AMD GPU ROCm library detected at ${search}" install_success exit 0 diff --git a/server/envconfig/config.go b/server/envconfig/config.go new file mode 100644 index 00000000..9ad68180 --- /dev/null +++ b/server/envconfig/config.go @@ -0,0 +1,174 @@ +package envconfig + +import ( + "fmt" + "log/slog" + "os" + "path/filepath" + "runtime" + "strconv" + "strings" +) + +var ( + // Set via OLLAMA_ORIGINS in the environment + AllowOrigins []string + // Set via OLLAMA_DEBUG in the environment + Debug bool + // Set via OLLAMA_LLM_LIBRARY in the environment + LLMLibrary string + // Set via OLLAMA_MAX_LOADED_MODELS in the environment + MaxRunners int + // Set via OLLAMA_MAX_QUEUE in the environment + MaxQueuedRequests int + // Set via OLLAMA_MAX_VRAM in the environment + MaxVRAM uint64 + // Set via OLLAMA_NOPRUNE in the environment + NoPrune bool + // Set via OLLAMA_NUM_PARALLEL in the environment + NumParallel int + // Set via OLLAMA_RUNNERS_DIR in the environment + RunnersDir string + // Set via OLLAMA_TMPDIR in the environment + TmpDir string +) + +func AsMap() map[string]string { + return map[string]string{ + "OLLAMA_ORIGINS": fmt.Sprintf("%v", AllowOrigins), + "OLLAMA_DEBUG": fmt.Sprintf("%v", Debug), + "OLLAMA_LLM_LIBRARY": fmt.Sprintf("%v", LLMLibrary), + "OLLAMA_MAX_LOADED_MODELS": fmt.Sprintf("%v", MaxRunners), + "OLLAMA_MAX_QUEUE": fmt.Sprintf("%v", MaxQueuedRequests), + "OLLAMA_MAX_VRAM": fmt.Sprintf("%v", MaxVRAM), + "OLLAMA_NOPRUNE": fmt.Sprintf("%v", NoPrune), + "OLLAMA_NUM_PARALLEL": fmt.Sprintf("%v", NumParallel), + "OLLAMA_RUNNERS_DIR": fmt.Sprintf("%v", RunnersDir), + "OLLAMA_TMPDIR": fmt.Sprintf("%v", TmpDir), + } +} + +var defaultAllowOrigins = []string{ + "localhost", + "127.0.0.1", + "0.0.0.0", +} + +// Clean quotes and spaces from the value +func clean(key string) string { + return strings.Trim(os.Getenv(key), "\"' ") +} + +func init() { + // default values + NumParallel = 1 + MaxRunners = 1 + MaxQueuedRequests = 512 + + LoadConfig() +} + +func LoadConfig() { + if debug := clean("OLLAMA_DEBUG"); debug != "" { + d, err := strconv.ParseBool(debug) + if err == nil { + Debug = d + } else { + Debug = true + } + } + + RunnersDir = clean("OLLAMA_RUNNERS_DIR") + if runtime.GOOS == "windows" && RunnersDir == "" { + // On Windows we do not carry the payloads inside the main executable + appExe, err := os.Executable() + if err != nil { + slog.Error("failed to lookup executable path", "error", err) + } + + cwd, err := os.Getwd() + if err != nil { + slog.Error("failed to lookup working directory", "error", err) + } + + var paths []string + for _, root := range []string{filepath.Dir(appExe), cwd} { + paths = append(paths, + filepath.Join(root), + filepath.Join(root, "windows-"+runtime.GOARCH), + filepath.Join(root, "dist", "windows-"+runtime.GOARCH), + ) + } + + // Try a few variations to improve developer experience when building from source in the local tree + for _, p := range paths { + candidate := filepath.Join(p, "ollama_runners") + _, err := os.Stat(candidate) + if err == nil { + RunnersDir = candidate + break + } + } + if RunnersDir == "" { + slog.Error("unable to locate llm runner directory. Set OLLAMA_RUNNERS_DIR to the location of 'ollama_runners'") + } + } + + TmpDir = clean("OLLAMA_TMPDIR") + + userLimit := clean("OLLAMA_MAX_VRAM") + if userLimit != "" { + avail, err := strconv.ParseUint(userLimit, 10, 64) + if err != nil { + slog.Error("invalid setting, ignoring", "OLLAMA_MAX_VRAM", userLimit, "error", err) + } else { + MaxVRAM = avail + } + } + + LLMLibrary = clean("OLLAMA_LLM_LIBRARY") + + if onp := clean("OLLAMA_NUM_PARALLEL"); onp != "" { + val, err := strconv.Atoi(onp) + if err != nil || val <= 0 { + slog.Error("invalid setting must be greater than zero", "OLLAMA_NUM_PARALLEL", onp, "error", err) + } else { + NumParallel = val + } + } + + if noprune := clean("OLLAMA_NOPRUNE"); noprune != "" { + NoPrune = true + } + + if origins := clean("OLLAMA_ORIGINS"); origins != "" { + AllowOrigins = strings.Split(origins, ",") + } + for _, allowOrigin := range defaultAllowOrigins { + AllowOrigins = append(AllowOrigins, + fmt.Sprintf("http://%s", allowOrigin), + fmt.Sprintf("https://%s", allowOrigin), + fmt.Sprintf("http://%s:*", allowOrigin), + fmt.Sprintf("https://%s:*", allowOrigin), + ) + } + + maxRunners := clean("OLLAMA_MAX_LOADED_MODELS") + if maxRunners != "" { + m, err := strconv.Atoi(maxRunners) + if err != nil { + slog.Error("invalid setting", "OLLAMA_MAX_LOADED_MODELS", maxRunners, "error", err) + } else { + MaxRunners = m + } + } + + if onp := os.Getenv("OLLAMA_MAX_QUEUE"); onp != "" { + p, err := strconv.Atoi(onp) + if err != nil || p <= 0 { + slog.Error("invalid setting", "OLLAMA_MAX_QUEUE", onp, "error", err) + } else { + MaxQueuedRequests = p + } + } +} diff --git a/server/envconfig/config_test.go b/server/envconfig/config_test.go new file mode 100644 index 00000000..b2760299 --- /dev/null +++ b/server/envconfig/config_test.go @@ -0,0 +1,20 @@ +package envconfig + +import ( + "os" + "testing" + + "github.com/stretchr/testify/require" +) + +func TestConfig(t *testing.T) { + os.Setenv("OLLAMA_DEBUG", "") + LoadConfig() + require.False(t, Debug) + os.Setenv("OLLAMA_DEBUG", "false") + LoadConfig() + require.False(t, Debug) + os.Setenv("OLLAMA_DEBUG", "1") + LoadConfig() + require.True(t, Debug) +} diff --git a/server/images.go b/server/images.go index 74fa1a5e..2be1d366 100644 --- a/server/images.go +++ b/server/images.go @@ -1,16 +1,16 @@ package server import ( - "archive/zip" "bytes" + "cmp" "context" "crypto/sha256" + "encoding/base64" "encoding/hex" "encoding/json" "errors" "fmt" "io" - "io/fs" "log" "log/slog" "net/http" @@ -20,15 +20,16 @@ import ( "runtime" "strconv" "strings" - "text/template" "golang.org/x/exp/slices" "github.com/ollama/ollama/api" - "github.com/ollama/ollama/convert" + "github.com/ollama/ollama/auth" "github.com/ollama/ollama/format" "github.com/ollama/ollama/llm" - "github.com/ollama/ollama/parser" + "github.com/ollama/ollama/server/envconfig" + "github.com/ollama/ollama/types/errtypes" + "github.com/ollama/ollama/types/model" "github.com/ollama/ollama/version" ) @@ -51,7 +52,6 @@ type Model struct { System string License []string Digest string - Size int64 Options map[string]interface{} Messages []Message } @@ -60,6 +60,76 @@ func (m *Model) IsEmbedding() bool { return slices.Contains(m.Config.ModelFamilies, "bert") || slices.Contains(m.Config.ModelFamilies, "nomic-bert") } +func (m *Model) String() string { + var modelfile model.File + + modelfile.Commands = append(modelfile.Commands, model.Command{ + Name: "model", + Args: m.ModelPath, + }) + + for _, adapter := range m.AdapterPaths { + modelfile.Commands = append(modelfile.Commands, model.Command{ + Name: "adapter", + Args: adapter, + }) + } + + for _, projector := range m.ProjectorPaths { + modelfile.Commands = append(modelfile.Commands, model.Command{ + Name: "model", + Args: projector, + }) + } + + if m.Template != "" { + modelfile.Commands = append(modelfile.Commands, model.Command{ + Name: "template", + Args: m.Template, + }) + } + + if m.System != "" { + modelfile.Commands = append(modelfile.Commands, model.Command{ + Name: "system", + Args: m.System, + }) + } + + for k, v := range m.Options { + switch v := v.(type) { + case []any: + for _, s := range v { + modelfile.Commands = append(modelfile.Commands, model.Command{ + Name: k, + Args: fmt.Sprintf("%v", s), + }) + } + default: + modelfile.Commands = append(modelfile.Commands, model.Command{ + Name: k, + Args: fmt.Sprintf("%v", v), + }) + } + } + + for _, license := range m.License { + modelfile.Commands = append(modelfile.Commands, model.Command{ + Name: "license", + Args: license, + }) + } + + for _, msg := range m.Messages { + modelfile.Commands = append(modelfile.Commands, model.Command{ + Name: "message", + Args: fmt.Sprintf("%s %s", msg.Role, msg.Content), + }) + } + + return modelfile.String() +} + type Message struct { Role string `json:"role"` Content string `json:"content"` @@ -85,50 +155,11 @@ type ConfigV2 struct { RootFS RootFS `json:"rootfs"` } -func (c *ConfigV2) SetModelFormat(format string) { - if c.ModelFormat == "" { - c.ModelFormat = format - } -} - -func (c *ConfigV2) SetModelFamily(families ...string) { - for _, family := range families { - if c.ModelFamily == "" { - c.ModelFamily = family - } - - if !slices.Contains(c.ModelFamilies, family) { - c.ModelFamilies = append(c.ModelFamilies, family) - } - } -} - -func (c *ConfigV2) SetModelType(modelType string) { - if c.ModelType == "" { - c.ModelType = modelType - } -} - -func (c *ConfigV2) SetFileType(fileType string) { - if c.FileType == "" { - c.FileType = fileType - } -} - type RootFS struct { Type string `json:"type"` DiffIDs []string `json:"diff_ids"` } -func (m *ManifestV2) GetTotalSize() (total int64) { - for _, layer := range m.Layers { - total += layer.Size - } - - total += m.Config.Size - return total -} - func GetManifest(mp ModelPath) (*ManifestV2, string, error) { fp, err := mp.GetManifestPath() if err != nil { @@ -169,7 +200,6 @@ func GetModel(name string) (*Model, error) { Digest: digest, Template: "{{ .Prompt }}", License: []string{}, - Size: manifest.GetTotalSize(), } filename, err := GetBlobsPath(manifest.Config.Digest) @@ -259,7 +289,7 @@ func GetModel(name string) (*Model, error) { return model, nil } -func realpath(mfDir, from string) string { +func realpath(rel, from string) string { abspath, err := filepath.Abs(from) if err != nil { return from @@ -276,22 +306,15 @@ func realpath(mfDir, from string) string { return filepath.Join(home, from[2:]) } - if _, err := os.Stat(filepath.Join(mfDir, from)); err == nil { + if _, err := os.Stat(filepath.Join(rel, from)); err == nil { // this is a file relative to the Modelfile - return filepath.Join(mfDir, from) + return filepath.Join(rel, from) } return abspath } -func CreateModel(ctx context.Context, name, modelFileDir, quantization string, commands []parser.Command, fn func(resp api.ProgressResponse)) error { - deleteMap := make(map[string]struct{}) - if manifest, _, err := GetManifest(ParseModelPath(name)); err == nil { - for _, layer := range append(manifest.Layers, manifest.Config) { - deleteMap[layer.Digest] = struct{}{} - } - } - +func CreateModel(ctx context.Context, name, modelFileDir, quantization string, modelfile *model.File, fn func(resp api.ProgressResponse)) (err error) { config := ConfigV2{ OS: "linux", Architecture: "amd64", @@ -300,250 +323,181 @@ func CreateModel(ctx context.Context, name, modelFileDir, quantization string, c }, } - var layers Layers - messages := []string{} + var messages []*api.Message + parameters := make(map[string]any) - params := make(map[string][]string) - fromParams := make(map[string]any) - - for _, c := range commands { + var layers []*Layer + for _, c := range modelfile.Commands { mediatype := fmt.Sprintf("application/vnd.ollama.image.%s", c.Name) switch c.Name { - case "model": - if strings.HasPrefix(c.Args, "@") { - blobPath, err := GetBlobsPath(strings.TrimPrefix(c.Args, "@")) + case "model", "adapter": + var baseLayers []*layerWithGGML + if name := model.ParseName(c.Args); name.IsValid() { + baseLayers, err = parseFromModel(ctx, name, fn) + if err != nil { + return err + } + } else if strings.HasPrefix(c.Args, "@") { + blobpath, err := GetBlobsPath(strings.TrimPrefix(c.Args, "@")) if err != nil { return err } - c.Args = blobPath - } - - pathName := realpath(modelFileDir, c.Args) - - ggufName, err := convertModel(name, pathName, fn) - if err != nil { - var pathErr *fs.PathError - switch { - case errors.Is(err, zip.ErrFormat): - // it's not a safetensor archive - case errors.As(err, &pathErr): - // it's not a file on disk, could be a model reference - default: + blob, err := os.Open(blobpath) + if err != nil { return err } + defer blob.Close() + + baseLayers, err = parseFromFile(ctx, blob, fn) + if err != nil { + return err + } + } else if file, err := os.Open(realpath(modelFileDir, c.Args)); err == nil { + defer file.Close() + + baseLayers, err = parseFromFile(ctx, file, fn) + if err != nil { + return err + } + } else { + return fmt.Errorf("invalid model reference: %s", c.Args) } - if ggufName != "" { - pathName = ggufName - defer os.RemoveAll(ggufName) - - if quantization != "" { - quantization = strings.ToUpper(quantization) - fn(api.ProgressResponse{Status: fmt.Sprintf("quantizing %s model to %s", "F16", quantization)}) - tempfile, err := os.CreateTemp(filepath.Dir(ggufName), quantization) + for _, baseLayer := range baseLayers { + if quantization != "" && + baseLayer.MediaType == "application/vnd.ollama.image.model" && + baseLayer.GGML != nil && + baseLayer.GGML.Name() == "gguf" { + want, err := llm.ParseFileType(quantization) if err != nil { return err } - defer os.RemoveAll(tempfile.Name()) - if err := llm.Quantize(ggufName, tempfile.Name(), quantization); err != nil { - return err - } + ft := baseLayer.GGML.KV().FileType() + if !slices.Contains([]string{"F16", "F32"}, ft.String()) { + return errors.New("quantization is only supported for F16 and F32 models") + } else if want != ft { + fn(api.ProgressResponse{Status: fmt.Sprintf("quantizing %s model to %s", ft, quantization)}) - if err := tempfile.Close(); err != nil { - return err - } - - pathName = tempfile.Name() - } - } - - bin, err := os.Open(pathName) - if err != nil { - // not a file on disk so must be a model reference - modelpath := ParseModelPath(c.Args) - manifest, _, err := GetManifest(modelpath) - switch { - case errors.Is(err, os.ErrNotExist): - fn(api.ProgressResponse{Status: "pulling model"}) - if err := PullModel(ctx, c.Args, ®istryOptions{}, fn); err != nil { - return err - } - - manifest, _, err = GetManifest(modelpath) - if err != nil { - return err - } - case err != nil: - return err - } - - fn(api.ProgressResponse{Status: "reading model metadata"}) - fromConfigPath, err := GetBlobsPath(manifest.Config.Digest) - if err != nil { - return err - } - - fromConfigFile, err := os.Open(fromConfigPath) - if err != nil { - return err - } - defer fromConfigFile.Close() - - var fromConfig ConfigV2 - if err := json.NewDecoder(fromConfigFile).Decode(&fromConfig); err != nil { - return err - } - - // if the model is still not in gguf format, error out - if fromConfig.ModelFormat != "gguf" { - return fmt.Errorf("%s is not in gguf format, this base model is not compatible with this version of ollama", c.Args) - } - - config.SetModelFormat(fromConfig.ModelFormat) - config.SetModelFamily(append(fromConfig.ModelFamilies, fromConfig.ModelFamily)...) - config.SetModelType(fromConfig.ModelType) - config.SetFileType(fromConfig.FileType) - - for _, layer := range manifest.Layers { - deleteMap[layer.Digest] = struct{}{} - if layer.MediaType == "application/vnd.ollama.image.params" { - fromParamsPath, err := GetBlobsPath(layer.Digest) + blob, err := GetBlobsPath(baseLayer.Digest) if err != nil { return err } - fromParamsFile, err := os.Open(fromParamsPath) + temp, err := os.CreateTemp(filepath.Dir(blob), quantization) if err != nil { return err } - defer fromParamsFile.Close() + defer temp.Close() + defer os.Remove(temp.Name()) - if err := json.NewDecoder(fromParamsFile).Decode(&fromParams); err != nil { + if err := llm.Quantize(blob, temp.Name(), want); err != nil { + return err + } + + baseLayer.Layer, err = NewLayer(temp, baseLayer.Layer.MediaType) + if err != nil { return err } } - - layer, err := NewLayerFromLayer(layer.Digest, layer.MediaType, modelpath.GetShortTagname()) - if err != nil { - return err - } - - layers.Add(layer) } - deleteMap[manifest.Config.Digest] = struct{}{} - continue + if baseLayer.GGML != nil { + config.ModelFormat = cmp.Or(config.ModelFormat, baseLayer.GGML.Name()) + config.ModelFamily = cmp.Or(config.ModelFamily, baseLayer.GGML.KV().Architecture()) + config.ModelType = cmp.Or(config.ModelType, format.HumanNumber(baseLayer.GGML.KV().ParameterCount())) + config.FileType = cmp.Or(config.FileType, baseLayer.GGML.KV().FileType().String()) + config.ModelFamilies = append(config.ModelFamilies, baseLayer.GGML.KV().Architecture()) + } + + layers = append(layers, baseLayer.Layer) } - defer bin.Close() - - var offset int64 - for { - fn(api.ProgressResponse{Status: "creating model layer"}) - if _, err := bin.Seek(offset, io.SeekStart); err != nil { - return err - } - - ggml, size, err := llm.DecodeGGML(bin) - if errors.Is(err, io.EOF) { - break - } else if errors.Is(err, llm.ErrUnsupportedFormat) { - return fmt.Errorf("model binary specified in FROM field is not a valid gguf format model, %w", err) - } else if err != nil { - return err - } - - config.SetModelFormat(ggml.Name()) - config.SetModelFamily(ggml.KV().Architecture()) - config.SetModelType(format.HumanNumber(ggml.KV().ParameterCount())) - config.SetFileType(ggml.KV().FileType()) - - mediatype := mediatype - if ggml.KV().Architecture() == "clip" { - mediatype = "application/vnd.ollama.image.projector" - } - - sr := io.NewSectionReader(bin, offset, size) - layer, err := NewLayer(sr, mediatype) - if err != nil { - return err - } - - layers.Add(layer) - - offset += size - } - case "adapter": - if strings.HasPrefix(c.Args, "@") { - blobPath, err := GetBlobsPath(strings.TrimPrefix(c.Args, "@")) - if err != nil { - return err - } - - c.Args = blobPath - } - - fn(api.ProgressResponse{Status: "creating adapter layer"}) - bin, err := os.Open(realpath(modelFileDir, c.Args)) - if err != nil { - return err - } - defer bin.Close() - - _, size, err := llm.DecodeGGML(bin) + case "license", "template", "system": + blob := strings.NewReader(c.Args) + layer, err := NewLayer(blob, mediatype) if err != nil { return err } - sr := io.NewSectionReader(bin, 0, size) - layer, err := NewLayer(sr, mediatype) - if err != nil { - return err + if c.Name != "license" { + // replace + layers = slices.DeleteFunc(layers, func(layer *Layer) bool { + return layer.MediaType == mediatype + }) } - layers.Add(layer) - case "license": - fn(api.ProgressResponse{Status: "creating license layer"}) - - bin := strings.NewReader(c.Args) - layer, err := NewLayer(bin, mediatype) - if err != nil { - return err - } - - layers.Add(layer) - case "template", "system": - fn(api.ProgressResponse{Status: fmt.Sprintf("creating %s layer", c.Name)}) - - bin := strings.NewReader(c.Args) - layer, err := NewLayer(bin, mediatype) - if err != nil { - return err - } - - layers.Replace(layer) + layers = append(layers, layer) case "message": - messages = append(messages, c.Args) + role, content, ok := strings.Cut(c.Args, ": ") + if !ok { + return fmt.Errorf("invalid message: %s", c.Args) + } + + messages = append(messages, &api.Message{Role: role, Content: content}) default: - params[c.Name] = append(params[c.Name], c.Args) + ps, err := api.FormatParams(map[string][]string{c.Name: {c.Args}}) + if err != nil { + return err + } + + for k, v := range ps { + if ks, ok := parameters[k].([]string); ok { + parameters[k] = append(ks, v.([]string)...) + } else if vs, ok := v.([]string); ok { + parameters[k] = vs + } else { + parameters[k] = v + } + } } } - if len(messages) > 0 { - fn(api.ProgressResponse{Status: "creating parameters layer"}) + var err2 error + layers = slices.DeleteFunc(layers, func(layer *Layer) bool { + switch layer.MediaType { + case "application/vnd.ollama.image.message": + // if there are new messages, remove the inherited ones + if len(messages) > 0 { + return true + } - msgs := make([]api.Message, 0) + return false + case "application/vnd.ollama.image.params": + // merge inherited parameters with new ones + r, err := layer.Open() + if err != nil { + err2 = err + return false + } + defer r.Close() - for _, m := range messages { - // todo: handle images - msg := strings.SplitN(m, ": ", 2) - msgs = append(msgs, api.Message{Role: msg[0], Content: msg[1]}) + var ps map[string]any + if err := json.NewDecoder(r).Decode(&ps); err != nil { + err2 = err + return false + } + + for k, v := range ps { + if _, ok := parameters[k]; !ok { + parameters[k] = v + } + } + + return true + default: + return false } + }) + if err2 != nil { + return err2 + } + + if len(messages) > 0 { var b bytes.Buffer - if err := json.NewEncoder(&b).Encode(msgs); err != nil { + if err := json.NewEncoder(&b).Encode(messages); err != nil { return err } @@ -552,39 +506,25 @@ func CreateModel(ctx context.Context, name, modelFileDir, quantization string, c return err } - layers.Replace(layer) + layers = append(layers, layer) } - if len(params) > 0 { - fn(api.ProgressResponse{Status: "creating parameters layer"}) - - formattedParams, err := api.FormatParams(params) - if err != nil { - return err - } - - for k, v := range fromParams { - if _, ok := formattedParams[k]; !ok { - formattedParams[k] = v - } - } - + if len(parameters) > 0 { var b bytes.Buffer - if err := json.NewEncoder(&b).Encode(formattedParams); err != nil { + if err := json.NewEncoder(&b).Encode(parameters); err != nil { return err } - fn(api.ProgressResponse{Status: "creating config layer"}) layer, err := NewLayer(&b, "application/vnd.ollama.image.params") if err != nil { return err } - layers.Replace(layer) + layers = append(layers, layer) } - digests := make([]string, len(layers.items)) - for i, layer := range layers.items { + digests := make([]string, len(layers)) + for i, layer := range layers { digests[i] = layer.Digest } @@ -595,36 +535,37 @@ func CreateModel(ctx context.Context, name, modelFileDir, quantization string, c return err } - configLayer, err := NewLayer(&b, "application/vnd.docker.container.image.v1+json") + layer, err := NewLayer(&b, "application/vnd.docker.container.image.v1+json") if err != nil { return err } - delete(deleteMap, configLayer.Digest) + for _, layer := range append(layers, layer) { + if layer.status != "" { + fn(api.ProgressResponse{Status: layer.status}) + } + } - for _, layer := range append(layers.items, configLayer) { - committed, err := layer.Commit() - if err != nil { - return err + unref := make(map[string]struct{}) + if manifest, _, err := GetManifest(ParseModelPath(name)); err == nil { + for _, layer := range manifest.Layers { + if !slices.Contains(digests, layer.Digest) { + unref[layer.Digest] = struct{}{} + } } - status := "writing layer" - if !committed { - status = "using already created layer" + if manifest.Config.Digest != layer.Digest { + unref[manifest.Config.Digest] = struct{}{} } - - fn(api.ProgressResponse{Status: fmt.Sprintf("%s %s", status, layer.Digest)}) - - delete(deleteMap, layer.Digest) } fn(api.ProgressResponse{Status: "writing manifest"}) - if err := WriteManifest(name, configLayer, layers.items); err != nil { + if err := WriteManifest(name, layer, layers); err != nil { return err } - if noprune := os.Getenv("OLLAMA_NOPRUNE"); noprune == "" { - if err := deleteUnusedLayers(nil, deleteMap, false); err != nil { + if !envconfig.NoPrune { + if err := deleteUnusedLayers(nil, unref, false); err != nil { return err } } @@ -633,104 +574,43 @@ func CreateModel(ctx context.Context, name, modelFileDir, quantization string, c return nil } -func convertModel(name, path string, fn func(resp api.ProgressResponse)) (string, error) { - r, err := zip.OpenReader(path) - if err != nil { - return "", err +func CopyModel(src, dst model.Name) error { + if !dst.IsFullyQualified() { + return model.Unqualified(dst) } - defer r.Close() - - tempDir, err := os.MkdirTemp("", "ollama-convert") - if err != nil { - return "", err - } - defer os.RemoveAll(tempDir) - - fn(api.ProgressResponse{Status: "unpacking model metadata"}) - for _, f := range r.File { - fpath := filepath.Join(tempDir, f.Name) - outFile, err := os.OpenFile(fpath, os.O_WRONLY|os.O_CREATE|os.O_TRUNC, f.Mode()) - if err != nil { - return "", err - } - - rc, err := f.Open() - if err != nil { - return "", err - } - - _, err = io.Copy(outFile, rc) - if err != nil { - return "", err - } - - outFile.Close() - rc.Close() + if !src.IsFullyQualified() { + return model.Unqualified(src) } - mf, err := convert.GetModelFormat(tempDir) - if err != nil { - return "", err + if src.Filepath() == dst.Filepath() { + return nil } - params, err := mf.GetParams(tempDir) - if err != nil { - return "", err - } - - mArch, err := mf.GetModelArch(name, tempDir, params) - if err != nil { - return "", err - } - - fn(api.ProgressResponse{Status: "processing tensors"}) - if err := mArch.GetTensors(); err != nil { - return "", err - } - - if err := mArch.LoadVocab(); err != nil { - return "", err - } - - fn(api.ProgressResponse{Status: "converting model"}) - path, err = mArch.WriteGGUF() - if err != nil { - return "", err - } - - return path, nil -} - -func CopyModel(src, dest string) error { - srcModelPath := ParseModelPath(src) - srcPath, err := srcModelPath.GetManifestPath() + manifests, err := GetManifestPath() if err != nil { return err } - destModelPath := ParseModelPath(dest) - destPath, err := destModelPath.GetManifestPath() - if err != nil { - return err - } - if err := os.MkdirAll(filepath.Dir(destPath), 0o755); err != nil { + dstpath := filepath.Join(manifests, dst.Filepath()) + if err := os.MkdirAll(filepath.Dir(dstpath), 0o755); err != nil { return err } - // copy the file - input, err := os.ReadFile(srcPath) + srcpath := filepath.Join(manifests, src.Filepath()) + srcfile, err := os.Open(srcpath) if err != nil { - fmt.Println("Error reading file:", err) return err } + defer srcfile.Close() - err = os.WriteFile(destPath, input, 0o644) + dstfile, err := os.Create(dstpath) if err != nil { - fmt.Println("Error reading file:", err) return err } + defer dstfile.Close() - return nil + _, err = io.Copy(dstfile, srcfile) + return err } func deleteUnusedLayers(skipModelPath *ModelPath, deleteMap map[string]struct{}, dryRun bool) error { @@ -890,67 +770,6 @@ func DeleteModel(name string) error { return nil } -func ShowModelfile(model *Model) (string, error) { - var mt struct { - *Model - From string - Parameters map[string][]any - } - - mt.Parameters = make(map[string][]any) - for k, v := range model.Options { - if s, ok := v.([]any); ok { - mt.Parameters[k] = s - continue - } - - mt.Parameters[k] = []any{v} - } - - mt.Model = model - mt.From = model.ModelPath - - if model.ParentModel != "" { - mt.From = model.ParentModel - } - - modelFile := `# Modelfile generated by "ollama show" -# To build a new Modelfile based on this one, replace the FROM line with: -# FROM {{ .ShortName }} - -FROM {{ .From }} -TEMPLATE """{{ .Template }}""" - -{{- if .System }} -SYSTEM """{{ .System }}""" -{{- end }} - -{{- range $adapter := .AdapterPaths }} -ADAPTER {{ $adapter }} -{{- end }} - -{{- range $k, $v := .Parameters }} -{{- range $parameter := $v }} -PARAMETER {{ $k }} {{ printf "%#v" $parameter }} -{{- end }} -{{- end }}` - - tmpl, err := template.New("").Parse(modelFile) - if err != nil { - slog.Info(fmt.Sprintf("error parsing template: %q", err)) - return "", err - } - - var buf bytes.Buffer - - if err = tmpl.Execute(&buf, mt); err != nil { - slog.Info(fmt.Sprintf("error executing template: %q", err)) - return "", err - } - - return buf.String(), nil -} - func PushModel(ctx context.Context, name string, regOpts *registryOptions, fn func(api.ProgressResponse)) error { mp := ParseModelPath(name) fn(api.ProgressResponse{Status: "retrieving manifest"}) @@ -972,9 +791,6 @@ func PushModel(ctx context.Context, name string, regOpts *registryOptions, fn fu for _, layer := range layers { if err := uploadBlob(ctx, mp, layer, regOpts, fn); err != nil { slog.Info(fmt.Sprintf("error uploading blob: %v", err)) - if errors.Is(err, errUnauthorized) { - return fmt.Errorf("unable to push %s, make sure this namespace exists and you are authorized to push to it", ParseModelPath(name).GetNamespaceRepository()) - } return err } } @@ -1011,7 +827,7 @@ func PullModel(ctx context.Context, name string, regOpts *registryOptions, fn fu // build deleteMap to prune unused layers deleteMap := make(map[string]struct{}) - if noprune = os.Getenv("OLLAMA_NOPRUNE"); noprune == "" { + if !envconfig.NoPrune { manifest, _, err = GetManifest(mp) if err != nil && !errors.Is(err, os.ErrNotExist) { return err @@ -1137,9 +953,40 @@ func GetSHA256Digest(r io.Reader) (string, int64) { return fmt.Sprintf("sha256:%x", h.Sum(nil)), n } -var errUnauthorized = fmt.Errorf("unauthorized") +var errUnauthorized = fmt.Errorf("unauthorized: access denied") + +// getTokenSubject returns the subject of a JWT token, it does not validate the token +func getTokenSubject(token string) string { + parts := strings.Split(token, ".") + if len(parts) != 3 { + slog.Error("jwt token does not contain 3 parts") + return "" + } + + payload := parts[1] + payloadBytes, err := base64.RawURLEncoding.DecodeString(payload) + if err != nil { + slog.Error(fmt.Sprintf("failed to decode jwt payload: %v", err)) + return "" + } + + var payloadMap map[string]interface{} + if err := json.Unmarshal(payloadBytes, &payloadMap); err != nil { + slog.Error(fmt.Sprintf("failed to unmarshal payload JSON: %v", err)) + return "" + } + + sub, ok := payloadMap["sub"] + if !ok { + slog.Error("jwt does not contain 'sub' field") + return "" + } + + return fmt.Sprintf("%s", sub) +} func makeRequestWithRetry(ctx context.Context, method string, requestURL *url.URL, headers http.Header, body io.ReadSeeker, regOpts *registryOptions) (*http.Response, error) { + anonymous := true // access will default to anonymous if no user is found associated with the public key for i := 0; i < 2; i++ { resp, err := makeRequest(ctx, method, requestURL, headers, body, regOpts) if err != nil { @@ -1158,6 +1005,7 @@ func makeRequestWithRetry(ctx context.Context, method string, requestURL *url.UR if err != nil { return nil, err } + anonymous = getTokenSubject(token) == "anonymous" regOpts.Token = token if body != nil { _, err = body.Seek(0, io.SeekStart) @@ -1178,6 +1026,16 @@ func makeRequestWithRetry(ctx context.Context, method string, requestURL *url.UR } } + if anonymous { + // no user is associated with the public key, and the request requires non-anonymous access + pubKey, nestedErr := auth.GetPublicKey() + if nestedErr != nil { + slog.Error(fmt.Sprintf("couldn't get public key: %v", nestedErr)) + return nil, errUnauthorized + } + return nil, &errtypes.UnknownOllamaKey{Key: pubKey} + } + // user is associated with the public key, but is not authorized to make the request return nil, errUnauthorized } @@ -1255,7 +1113,7 @@ func parseRegistryChallenge(authStr string) registryChallenge { } } -var errDigestMismatch = fmt.Errorf("digest mismatch, file must be downloaded again") +var errDigestMismatch = errors.New("digest mismatch, file must be downloaded again") func verifyBlob(digest string) error { fp, err := GetBlobsPath(digest) diff --git a/server/layers.go b/server/layer.go similarity index 53% rename from server/layers.go rename to server/layer.go index 07787406..dcca3854 100644 --- a/server/layers.go +++ b/server/layer.go @@ -5,39 +5,14 @@ import ( "fmt" "io" "os" - "strings" - - "golang.org/x/exp/slices" ) -type Layers struct { - items []*Layer -} - -func (ls *Layers) Add(layer *Layer) { - if layer.Size > 0 { - ls.items = append(ls.items, layer) - } -} - -func (ls *Layers) Replace(layer *Layer) { - if layer.Size > 0 { - mediatype := layer.MediaType - layers := slices.DeleteFunc(ls.items, func(l *Layer) bool { - return l.MediaType == mediatype - }) - - ls.items = append(layers, layer) - } -} - type Layer struct { MediaType string `json:"mediaType"` Digest string `json:"digest"` Size int64 `json:"size"` From string `json:"from,omitempty"` - - tempFileName string + status string } func NewLayer(r io.Reader, mediatype string) (*Layer, error) { @@ -46,14 +21,12 @@ func NewLayer(r io.Reader, mediatype string) (*Layer, error) { return nil, err } - const delimiter = "-" - - pattern := strings.Join([]string{"sha256", "*-partial"}, delimiter) - temp, err := os.CreateTemp(blobs, pattern) + temp, err := os.CreateTemp(blobs, "sha256-") if err != nil { return nil, err } defer temp.Close() + defer os.Remove(temp.Name()) sha256sum := sha256.New() n, err := io.Copy(io.MultiWriter(temp, sha256sum), r) @@ -61,11 +34,29 @@ func NewLayer(r io.Reader, mediatype string) (*Layer, error) { return nil, err } + if err := temp.Close(); err != nil { + return nil, err + } + + digest := fmt.Sprintf("sha256:%x", sha256sum.Sum(nil)) + blob, err := GetBlobsPath(digest) + if err != nil { + return nil, err + } + + status := "using existing layer" + if _, err := os.Stat(blob); err != nil { + status = "creating new layer" + if err := os.Rename(temp.Name(), blob); err != nil { + return nil, err + } + } + return &Layer{ - MediaType: mediatype, - Digest: fmt.Sprintf("sha256:%x", sha256sum.Sum(nil)), - Size: n, - tempFileName: temp.Name(), + MediaType: mediatype, + Digest: digest, + Size: n, + status: fmt.Sprintf("%s %s", status, digest), }, nil } @@ -85,21 +76,15 @@ func NewLayerFromLayer(digest, mediatype, from string) (*Layer, error) { Digest: digest, Size: fi.Size(), From: from, + status: fmt.Sprintf("using existing layer %s", digest), }, nil } -func (l *Layer) Commit() (bool, error) { - // always remove temp - defer os.Remove(l.tempFileName) - +func (l *Layer) Open() (io.ReadCloser, error) { blob, err := GetBlobsPath(l.Digest) if err != nil { - return false, err + return nil, err } - if _, err := os.Stat(blob); err != nil { - return true, os.Rename(l.tempFileName, blob) - } - - return false, nil + return os.Open(blob) } diff --git a/server/manifest.go b/server/manifest.go new file mode 100644 index 00000000..8a17700e --- /dev/null +++ b/server/manifest.go @@ -0,0 +1,79 @@ +package server + +import ( + "bytes" + "crypto/sha256" + "encoding/json" + "fmt" + "io" + "os" + "path/filepath" + + "github.com/ollama/ollama/types/model" +) + +type Manifest struct { + ManifestV2 + Digest string `json:"-"` +} + +func (m *Manifest) Size() (size int64) { + for _, layer := range append(m.Layers, m.Config) { + size += layer.Size + } + + return +} + +func ParseNamedManifest(name model.Name) (*Manifest, error) { + if !name.IsFullyQualified() { + return nil, model.Unqualified(name) + } + + manifests, err := GetManifestPath() + if err != nil { + return nil, err + } + + var manifest ManifestV2 + manifestfile, err := os.Open(filepath.Join(manifests, name.Filepath())) + if err != nil { + return nil, err + } + + sha256sum := sha256.New() + if err := json.NewDecoder(io.TeeReader(manifestfile, sha256sum)).Decode(&manifest); err != nil { + return nil, err + } + + return &Manifest{ + ManifestV2: manifest, + Digest: fmt.Sprintf("%x", sha256sum.Sum(nil)), + }, nil +} + +func WriteManifest(name string, config *Layer, layers []*Layer) error { + manifest := ManifestV2{ + SchemaVersion: 2, + MediaType: "application/vnd.docker.distribution.manifest.v2+json", + Config: config, + Layers: layers, + } + + var b bytes.Buffer + if err := json.NewEncoder(&b).Encode(manifest); err != nil { + return err + } + + modelpath := ParseModelPath(name) + manifestPath, err := modelpath.GetManifestPath() + if err != nil { + return err + } + + if err := os.MkdirAll(filepath.Dir(manifestPath), 0o755); err != nil { + return err + } + + return os.WriteFile(manifestPath, b.Bytes(), 0o644) +} diff --git a/server/manifests.go b/server/manifests.go deleted file mode 100644 index 2b39db65..00000000 --- a/server/manifests.go +++ /dev/null @@ -1,34 +0,0 @@ -package server - -import ( - "bytes" - "encoding/json" - "os" - "path/filepath" -) - -func WriteManifest(name string, config *Layer, layers []*Layer) error { - manifest := ManifestV2{ - SchemaVersion: 2, - MediaType: "application/vnd.docker.distribution.manifest.v2+json", - Config: config, - Layers: layers, - } - - var b bytes.Buffer - if err := json.NewEncoder(&b).Encode(manifest); err != nil { - return err - } - - modelpath := ParseModelPath(name) - manifestPath, err := modelpath.GetManifestPath() - if err != nil { - return err - } - - if err := os.MkdirAll(filepath.Dir(manifestPath), 0o755); err != nil { - return err - } - - return os.WriteFile(manifestPath, b.Bytes(), 0o644) -} diff --git a/server/model.go b/server/model.go new file mode 100644 index 00000000..eea5d13a --- /dev/null +++ b/server/model.go @@ -0,0 +1,261 @@ +package server + +import ( + "archive/zip" + "bytes" + "context" + "errors" + "fmt" + "io" + "net/http" + "os" + "path/filepath" + + "github.com/ollama/ollama/api" + "github.com/ollama/ollama/convert" + "github.com/ollama/ollama/llm" + "github.com/ollama/ollama/types/model" +) + +type layerWithGGML struct { + *Layer + *llm.GGML +} + +func parseFromModel(ctx context.Context, name model.Name, fn func(api.ProgressResponse)) (layers []*layerWithGGML, err error) { + modelpath := ParseModelPath(name.String()) + manifest, _, err := GetManifest(modelpath) + switch { + case errors.Is(err, os.ErrNotExist): + if err := PullModel(ctx, name.String(), ®istryOptions{}, fn); err != nil { + return nil, err + } + + modelpath = ParseModelPath(name.String()) + manifest, _, err = GetManifest(modelpath) + if err != nil { + return nil, err + } + case err != nil: + return nil, err + } + + for _, layer := range manifest.Layers { + layer, err := NewLayerFromLayer(layer.Digest, layer.MediaType, modelpath.GetShortTagname()) + if err != nil { + return nil, err + } + + switch layer.MediaType { + case "application/vnd.ollama.image.model", + "application/vnd.ollama.image.projector", + "application/vnd.ollama.image.adapter": + blobpath, err := GetBlobsPath(layer.Digest) + if err != nil { + return nil, err + } + + blob, err := os.Open(blobpath) + if err != nil { + return nil, err + } + defer blob.Close() + + ggml, _, err := llm.DecodeGGML(blob) + if err != nil { + return nil, err + } + + layers = append(layers, &layerWithGGML{layer, ggml}) + default: + layers = append(layers, &layerWithGGML{layer, nil}) + } + + } + + return layers, nil +} + +func parseFromZipFile(_ context.Context, file *os.File, fn func(api.ProgressResponse)) (layers []*layerWithGGML, err error) { + stat, err := file.Stat() + if err != nil { + return nil, err + } + + r, err := zip.NewReader(file, stat.Size()) + if err != nil { + return nil, err + } + + tempdir, err := os.MkdirTemp(filepath.Dir(file.Name()), "") + if err != nil { + return nil, err + } + defer os.RemoveAll(tempdir) + + fn(api.ProgressResponse{Status: "unpacking model metadata"}) + for _, f := range r.File { + // TODO(mxyng): this should not write out all files to disk + outfile, err := os.Create(filepath.Join(tempdir, f.Name)) + if err != nil { + return nil, err + } + defer outfile.Close() + + infile, err := f.Open() + if err != nil { + return nil, err + } + defer infile.Close() + + if _, err = io.Copy(outfile, infile); err != nil { + return nil, err + } + + if err := outfile.Close(); err != nil { + return nil, err + } + + if err := infile.Close(); err != nil { + return nil, err + } + } + + mf, err := convert.GetModelFormat(tempdir) + if err != nil { + return nil, err + } + + params, err := mf.GetParams(tempdir) + if err != nil { + return nil, err + } + + mArch, err := mf.GetModelArch("", tempdir, params) + if err != nil { + return nil, err + } + + fn(api.ProgressResponse{Status: "processing tensors"}) + if err := mArch.GetTensors(); err != nil { + return nil, err + } + + if err := mArch.LoadVocab(); err != nil { + return nil, err + } + + fn(api.ProgressResponse{Status: "converting model"}) + + // TODO(mxyng): this should write directly into a layer + // e.g. NewLayer(arch.Reader(), "application/vnd.ollama.image.model") + temp, err := os.CreateTemp(tempdir, "fp16") + if err != nil { + return nil, err + } + defer temp.Close() + defer os.Remove(temp.Name()) + + if err = mArch.WriteGGUF(temp); err != nil { + return nil, err + } + + if _, err := temp.Seek(0, io.SeekStart); err != nil { + return nil, err + } + + layer, err := NewLayer(temp, "application/vnd.ollama.image.model") + if err != nil { + return nil, fmt.Errorf("aaa: %w", err) + } + + blobpath, err := GetBlobsPath(layer.Digest) + if err != nil { + return nil, err + } + + bin, err := os.Open(blobpath) + if err != nil { + return nil, err + } + defer bin.Close() + + ggml, _, err := llm.DecodeGGML(bin) + if err != nil { + return nil, err + } + + layer, err = NewLayerFromLayer(layer.Digest, layer.MediaType, "") + if err != nil { + return nil, err + } + + layers = append(layers, &layerWithGGML{layer, ggml}) + return layers, nil +} + +func parseFromFile(ctx context.Context, file *os.File, fn func(api.ProgressResponse)) (layers []*layerWithGGML, err error) { + sr := io.NewSectionReader(file, 0, 512) + contentType, err := detectContentType(sr) + if err != nil { + return nil, err + } + + switch contentType { + case "gguf", "ggla": + // noop + case "application/zip": + return parseFromZipFile(ctx, file, fn) + default: + return nil, fmt.Errorf("unsupported content type: %s", contentType) + } + + stat, err := file.Stat() + if err != nil { + return nil, err + } + + var offset int64 + for offset < stat.Size() { + ggml, n, err := llm.DecodeGGML(file) + if errors.Is(err, io.EOF) { + break + } else if err != nil { + return nil, err + } + + mediatype := "application/vnd.ollama.image.model" + if ggml.Name() == "ggla" { + mediatype = "application/vnd.ollama.image.adapter" + } else if ggml.KV().Architecture() == "clip" { + mediatype = "application/vnd.ollama.image.projector" + } + + layer, err := NewLayer(io.NewSectionReader(file, offset, n), mediatype) + if err != nil { + return nil, err + } + + layers = append(layers, &layerWithGGML{layer, ggml}) + offset = n + } + + return layers, nil +} + +func detectContentType(r io.Reader) (string, error) { + var b bytes.Buffer + if _, err := io.Copy(&b, r); err != nil { + return "", err + } + + if contentType := llm.DetectGGMLType(b.Bytes()); contentType != "" { + return contentType, nil + } + + if contentType := http.DetectContentType(b.Bytes()); contentType != "application/octet-stream" { + return contentType, nil + } + + return "unknown", nil +} diff --git a/server/modelpath.go b/server/modelpath.go index 7d333876..86908226 100644 --- a/server/modelpath.go +++ b/server/modelpath.go @@ -6,6 +6,7 @@ import ( "net/url" "os" "path/filepath" + "regexp" "strings" ) @@ -25,9 +26,10 @@ const ( ) var ( - ErrInvalidImageFormat = errors.New("invalid image format") - ErrInvalidProtocol = errors.New("invalid protocol scheme") - ErrInsecureProtocol = errors.New("insecure protocol http") + ErrInvalidImageFormat = errors.New("invalid image format") + ErrInvalidProtocol = errors.New("invalid protocol scheme") + ErrInsecureProtocol = errors.New("insecure protocol http") + ErrInvalidDigestFormat = errors.New("invalid digest format") ) func ParseModelPath(name string) ModelPath { @@ -149,6 +151,17 @@ func GetBlobsPath(digest string) (string, error) { return "", err } + // only accept actual sha256 digests + pattern := "^sha256[:-][0-9a-fA-F]{64}$" + re := regexp.MustCompile(pattern) + if err != nil { + return "", err + } + + if digest != "" && !re.MatchString(digest) { + return "", ErrInvalidDigestFormat + } + digest = strings.ReplaceAll(digest, ":", "-") path := filepath.Join(dir, "blobs", digest) dirPath := filepath.Dir(path) diff --git a/server/modelpath_test.go b/server/modelpath_test.go index 8b26d52c..30741d87 100644 --- a/server/modelpath_test.go +++ b/server/modelpath_test.go @@ -1,6 +1,73 @@ package server -import "testing" +import ( + "os" + "path/filepath" + "testing" + + "github.com/stretchr/testify/assert" +) + +func TestGetBlobsPath(t *testing.T) { + // GetBlobsPath expects an actual directory to exist + dir, err := os.MkdirTemp("", "ollama-test") + assert.Nil(t, err) + defer os.RemoveAll(dir) + + tests := []struct { + name string + digest string + expected string + err error + }{ + { + "empty digest", + "", + filepath.Join(dir, "blobs"), + nil, + }, + { + "valid with colon", + "sha256:456402914e838a953e0cf80caa6adbe75383d9e63584a964f504a7bbb8f7aad9", + filepath.Join(dir, "blobs", "sha256-456402914e838a953e0cf80caa6adbe75383d9e63584a964f504a7bbb8f7aad9"), + nil, + }, + { + "valid with dash", + "sha256-456402914e838a953e0cf80caa6adbe75383d9e63584a964f504a7bbb8f7aad9", + filepath.Join(dir, "blobs", "sha256-456402914e838a953e0cf80caa6adbe75383d9e63584a964f504a7bbb8f7aad9"), + nil, + }, + { + "digest too short", + "sha256-45640291", + "", + ErrInvalidDigestFormat, + }, + { + "digest too long", + "sha256-456402914e838a953e0cf80caa6adbe75383d9e63584a964f504a7bbb8f7aad9aaaaaaaaaa", + "", + ErrInvalidDigestFormat, + }, + { + "digest invalid chars", + "../sha256-456402914e838a953e0cf80caa6adbe75383d9e63584a964f504a7bbb8f7a", + "", + ErrInvalidDigestFormat, + }, + } + for _, tc := range tests { + t.Run(tc.name, func(t *testing.T) { + t.Setenv("OLLAMA_MODELS", dir) + + got, err := GetBlobsPath(tc.digest) + + assert.ErrorIs(t, tc.err, err, tc.name) + assert.Equal(t, tc.expected, got, tc.name) + }) + } +} func TestParseModelPath(t *testing.T) { tests := []struct { diff --git a/server/routes.go b/server/routes.go index b0d36b14..7dfeb513 100644 --- a/server/routes.go +++ b/server/routes.go @@ -1,6 +1,7 @@ package server import ( + "cmp" "context" "encoding/json" "errors" @@ -15,11 +16,8 @@ import ( "os" "os/signal" "path/filepath" - "reflect" - "runtime" "strconv" "strings" - "sync" "syscall" "time" @@ -31,14 +29,16 @@ import ( "github.com/ollama/ollama/gpu" "github.com/ollama/ollama/llm" "github.com/ollama/ollama/openai" - "github.com/ollama/ollama/parser" + "github.com/ollama/ollama/server/envconfig" + "github.com/ollama/ollama/types/model" "github.com/ollama/ollama/version" ) var mode string = gin.DebugMode type Server struct { - addr net.Addr + addr net.Addr + sched *Scheduler } func init() { @@ -53,88 +53,8 @@ func init() { gin.SetMode(mode) } -var loaded struct { - mu sync.Mutex - - llama *llm.LlamaServer - - expireTimer *time.Timer - - model string - adapters []string - projectors []string - *api.Options -} - var defaultSessionDuration = 5 * time.Minute -func unload() { - if loaded.llama != nil { - loaded.llama.Close() - } - - loaded.llama = nil - loaded.model = "" - loaded.adapters = nil - loaded.projectors = nil - loaded.Options = nil -} - -// load a model into memory if it is not already loaded, it is up to the caller to lock loaded.mu before calling this function -func load(c *gin.Context, model *Model, opts api.Options, sessionDuration time.Duration) error { - ctx, cancel := context.WithTimeout(c, 10*time.Second) - defer cancel() - - needLoad := loaded.llama == nil || // is there a model loaded? - loaded.model != model.ModelPath || // has the base model changed? - !reflect.DeepEqual(loaded.adapters, model.AdapterPaths) || // have the adapters changed? - !reflect.DeepEqual(loaded.projectors, model.ProjectorPaths) || // have the adapters changed? - !reflect.DeepEqual(loaded.Options.Runner, opts.Runner) || // have the runner options changed? - loaded.llama.Ping(ctx) != nil - - if needLoad { - if loaded.llama != nil { - slog.Info("changing loaded model") - unload() - } - - llama, err := llm.NewLlamaServer(model.ModelPath, model.AdapterPaths, model.ProjectorPaths, opts) - if err != nil { - // some older models are not compatible with newer versions of llama.cpp - // show a generalized compatibility error until there is a better way to - // check for model compatibility - if errors.Is(llm.ErrUnsupportedFormat, err) || strings.Contains(err.Error(), "failed to load model") { - err = fmt.Errorf("%v: this model may be incompatible with your version of Ollama. If you previously pulled this model, try updating it by running `ollama pull %s`", err, model.ShortName) - } - - return err - } - - loaded.model = model.ModelPath - loaded.adapters = model.AdapterPaths - loaded.projectors = model.ProjectorPaths - loaded.llama = llama - loaded.Options = &opts - - if err = llama.WaitUntilRunning(); err != nil { - slog.Error("error loading llama server", "error", err) - unload() - return err - } - } - - if loaded.expireTimer == nil { - loaded.expireTimer = time.AfterFunc(sessionDuration, func() { - loaded.mu.Lock() - defer loaded.mu.Unlock() - unload() - }) - } - - loaded.expireTimer.Reset(sessionDuration) - return nil -} - func modelOptions(model *Model, requestOpts map[string]interface{}) (api.Options, error) { opts := api.DefaultOptions() if err := opts.FromMap(model.Options); err != nil { @@ -154,9 +74,7 @@ func isSupportedImageType(image []byte) bool { return slices.Contains(allowedTypes, contentType) } -func GenerateHandler(c *gin.Context) { - loaded.mu.Lock() - defer loaded.mu.Unlock() +func (s *Server) GenerateHandler(c *gin.Context) { checkpointStart := time.Now() var req api.GenerateRequest @@ -224,8 +142,12 @@ func GenerateHandler(c *gin.Context) { sessionDuration = req.KeepAlive.Duration } - if err := load(c, model, opts, sessionDuration); err != nil { - c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()}) + rCh, eCh := s.sched.GetRunner(c.Request.Context(), model, opts, sessionDuration) + var runner *runnerRef + select { + case runner = <-rCh: + case err = <-eCh: + handleErrorResponse(c, err) return } @@ -275,7 +197,7 @@ func GenerateHandler(c *gin.Context) { sb.Reset() if req.Context != nil { - prev, err := loaded.llama.Detokenize(c.Request.Context(), req.Context) + prev, err := runner.llama.Detokenize(c.Request.Context(), req.Context) if err != nil { c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()}) return @@ -297,9 +219,6 @@ func GenerateHandler(c *gin.Context) { defer close(ch) fn := func(r llm.CompletionResponse) { - // Update model expiration - loaded.expireTimer.Reset(sessionDuration) - // Build up the full response if _, err := generated.WriteString(r.Content); err != nil { ch <- gin.H{"error": err.Error()} @@ -331,7 +250,7 @@ func GenerateHandler(c *gin.Context) { } // TODO (jmorganca): encode() should not strip special tokens - tokens, err := loaded.llama.Tokenize(c.Request.Context(), p) + tokens, err := runner.llama.Tokenize(c.Request.Context(), p) if err != nil { ch <- gin.H{"error": err.Error()} return @@ -359,7 +278,7 @@ func GenerateHandler(c *gin.Context) { Images: images, Options: opts, } - if err := loaded.llama.Completion(c.Request.Context(), req, fn); err != nil { + if err := runner.llama.Completion(c.Request.Context(), req, fn); err != nil { ch <- gin.H{"error": err.Error()} } }() @@ -421,10 +340,7 @@ func getDefaultSessionDuration() time.Duration { return defaultSessionDuration } -func EmbeddingsHandler(c *gin.Context) { - loaded.mu.Lock() - defer loaded.mu.Unlock() - +func (s *Server) EmbeddingsHandler(c *gin.Context) { var req api.EmbeddingRequest err := c.ShouldBindJSON(&req) switch { @@ -469,8 +385,12 @@ func EmbeddingsHandler(c *gin.Context) { sessionDuration = req.KeepAlive.Duration } - if err := load(c, model, opts, sessionDuration); err != nil { - c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()}) + rCh, eCh := s.sched.GetRunner(c.Request.Context(), model, opts, sessionDuration) + var runner *runnerRef + select { + case runner = <-rCh: + case err = <-eCh: + handleErrorResponse(c, err) return } @@ -480,7 +400,7 @@ func EmbeddingsHandler(c *gin.Context) { return } - embedding, err := loaded.llama.Embedding(c.Request.Context(), req.Prompt) + embedding, err := runner.llama.Embedding(c.Request.Context(), req.Prompt) if err != nil { slog.Info(fmt.Sprintf("embedding generation failed: %v", err)) c.JSON(http.StatusInternalServerError, gin.H{"error": "failed to generate embedding"}) @@ -493,7 +413,7 @@ func EmbeddingsHandler(c *gin.Context) { c.JSON(http.StatusOK, resp) } -func PullModelHandler(c *gin.Context) { +func (s *Server) PullModelHandler(c *gin.Context) { var req api.PullRequest err := c.ShouldBindJSON(&req) switch { @@ -542,7 +462,7 @@ func PullModelHandler(c *gin.Context) { streamResponse(c, ch) } -func PushModelHandler(c *gin.Context) { +func (s *Server) PushModelHandler(c *gin.Context) { var req api.PushRequest err := c.ShouldBindJSON(&req) switch { @@ -591,30 +511,19 @@ func PushModelHandler(c *gin.Context) { streamResponse(c, ch) } -func CreateModelHandler(c *gin.Context) { +func (s *Server) CreateModelHandler(c *gin.Context) { var req api.CreateRequest - err := c.ShouldBindJSON(&req) - switch { - case errors.Is(err, io.EOF): + if err := c.ShouldBindJSON(&req); errors.Is(err, io.EOF) { c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": "missing request body"}) return - case err != nil: + } else if err != nil { c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": err.Error()}) return } - var model string - if req.Model != "" { - model = req.Model - } else if req.Name != "" { - model = req.Name - } else { - c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": "model is required"}) - return - } - - if err := ParseModelPath(model).Validate(); err != nil { - c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": err.Error()}) + name := model.ParseName(cmp.Or(req.Model, req.Name)) + if !name.IsValid() { + c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": "invalid model name"}) return } @@ -623,19 +532,19 @@ func CreateModelHandler(c *gin.Context) { return } - var modelfile io.Reader = strings.NewReader(req.Modelfile) + var r io.Reader = strings.NewReader(req.Modelfile) if req.Path != "" && req.Modelfile == "" { - mf, err := os.Open(req.Path) + f, err := os.Open(req.Path) if err != nil { c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": fmt.Sprintf("error reading modelfile: %s", err)}) return } - defer mf.Close() + defer f.Close() - modelfile = mf + r = f } - commands, err := parser.Parse(modelfile) + modelfile, err := model.ParseFile(r) if err != nil { c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": err.Error()}) return @@ -651,7 +560,7 @@ func CreateModelHandler(c *gin.Context) { ctx, cancel := context.WithCancel(c.Request.Context()) defer cancel() - if err := CreateModel(ctx, model, filepath.Dir(req.Path), req.Quantization, commands, fn); err != nil { + if err := CreateModel(ctx, name.String(), filepath.Dir(req.Path), strings.ToUpper(req.Quantization), modelfile, fn); err != nil { ch <- gin.H{"error": err.Error()} } }() @@ -664,7 +573,7 @@ func CreateModelHandler(c *gin.Context) { streamResponse(c, ch) } -func DeleteModelHandler(c *gin.Context) { +func (s *Server) DeleteModelHandler(c *gin.Context) { var req api.DeleteRequest err := c.ShouldBindJSON(&req) switch { @@ -709,7 +618,7 @@ func DeleteModelHandler(c *gin.Context) { c.JSON(http.StatusOK, nil) } -func ShowModelHandler(c *gin.Context) { +func (s *Server) ShowModelHandler(c *gin.Context) { var req api.ShowRequest err := c.ShouldBindJSON(&req) switch { @@ -799,109 +708,115 @@ func GetModelInfo(req api.ShowRequest) (*api.ShowResponse, error) { } } - mf, err := ShowModelfile(model) - if err != nil { - return nil, err - } - - resp.Modelfile = mf + var sb strings.Builder + fmt.Fprintln(&sb, "# Modelfile generate by \"ollama show\"") + fmt.Fprintln(&sb, "# To build a new Modelfile based on this, replace FROM with:") + fmt.Fprintf(&sb, "# FROM %s\n\n", model.ShortName) + fmt.Fprint(&sb, model.String()) + resp.Modelfile = sb.String() return resp, nil } -func ListModelsHandler(c *gin.Context) { - models := make([]api.ModelResponse, 0) - manifestsPath, err := GetManifestPath() +func (s *Server) ListModelsHandler(c *gin.Context) { + manifests, err := GetManifestPath() if err != nil { c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()}) return } - modelResponse := func(modelName string) (api.ModelResponse, error) { - model, err := GetModel(modelName) - if err != nil { - return api.ModelResponse{}, err - } - - modelDetails := api.ModelDetails{ - Format: model.Config.ModelFormat, - Family: model.Config.ModelFamily, - Families: model.Config.ModelFamilies, - ParameterSize: model.Config.ModelType, - QuantizationLevel: model.Config.FileType, - } - - return api.ModelResponse{ - Model: model.ShortName, - Name: model.ShortName, - Size: model.Size, - Digest: model.Digest, - Details: modelDetails, - }, nil - } - - walkFunc := func(path string, info os.FileInfo, _ error) error { + var models []api.ModelResponse + if err := filepath.Walk(manifests, func(path string, info os.FileInfo, _ error) error { if !info.IsDir() { - path, tag := filepath.Split(path) - model := strings.Trim(strings.TrimPrefix(path, manifestsPath), string(os.PathSeparator)) - modelPath := strings.Join([]string{model, tag}, ":") - canonicalModelPath := strings.ReplaceAll(modelPath, string(os.PathSeparator), "/") - - resp, err := modelResponse(canonicalModelPath) + rel, err := filepath.Rel(manifests, path) if err != nil { - slog.Info(fmt.Sprintf("skipping file: %s", canonicalModelPath)) - // nolint: nilerr + return err + } + + if hidden, err := filepath.Match(".*", filepath.Base(rel)); err != nil { + return err + } else if hidden { return nil } - resp.ModifiedAt = info.ModTime() - models = append(models, resp) + n := model.ParseNameFromFilepath(rel) + m, err := ParseNamedManifest(n) + if err != nil { + return err + } + + f, err := m.Config.Open() + if err != nil { + return err + } + defer f.Close() + + var c ConfigV2 + if err := json.NewDecoder(f).Decode(&c); err != nil { + return err + } + + // tag should never be masked + models = append(models, api.ModelResponse{ + Model: n.DisplayShortest(), + Name: n.DisplayShortest(), + Size: m.Size(), + Digest: m.Digest, + ModifiedAt: info.ModTime(), + Details: api.ModelDetails{ + Format: c.ModelFormat, + Family: c.ModelFamily, + Families: c.ModelFamilies, + ParameterSize: c.ModelType, + QuantizationLevel: c.FileType, + }, + }) } return nil - } - - if err := filepath.Walk(manifestsPath, walkFunc); err != nil { + }); err != nil { c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()}) return } + slices.SortStableFunc(models, func(i, j api.ModelResponse) int { + // most recently modified first + return cmp.Compare(j.ModifiedAt.Unix(), i.ModifiedAt.Unix()) + }) + c.JSON(http.StatusOK, api.ListResponse{Models: models}) } -func CopyModelHandler(c *gin.Context) { - var req api.CopyRequest - err := c.ShouldBindJSON(&req) - switch { - case errors.Is(err, io.EOF): +func (s *Server) CopyModelHandler(c *gin.Context) { + var r api.CopyRequest + if err := c.ShouldBindJSON(&r); errors.Is(err, io.EOF) { c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": "missing request body"}) return - case err != nil: + } else if err != nil { c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": err.Error()}) return } - if req.Source == "" || req.Destination == "" { - c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": "source add destination are required"}) + src := model.ParseName(r.Source) + if !src.IsValid() { + c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": fmt.Sprintf("source %q is invalid", r.Source)}) return } - if err := ParseModelPath(req.Destination).Validate(); err != nil { - c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": err.Error()}) + dst := model.ParseName(r.Destination) + if !dst.IsValid() { + c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": fmt.Sprintf("destination %q is invalid", r.Destination)}) return } - if err := CopyModel(req.Source, req.Destination); err != nil { - if os.IsNotExist(err) { - c.JSON(http.StatusNotFound, gin.H{"error": fmt.Sprintf("model '%s' not found", req.Source)}) - } else { - c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()}) - } - return + if err := CopyModel(src, dst); errors.Is(err, os.ErrNotExist) { + c.JSON(http.StatusNotFound, gin.H{"error": fmt.Sprintf("model %q not found", r.Source)}) + } else if err != nil { + c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()}) } } -func HeadBlobHandler(c *gin.Context) { +func (s *Server) HeadBlobHandler(c *gin.Context) { path, err := GetBlobsPath(c.Param("digest")) if err != nil { c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": err.Error()}) @@ -916,7 +831,7 @@ func HeadBlobHandler(c *gin.Context) { c.Status(http.StatusOK) } -func CreateBlobHandler(c *gin.Context) { +func (s *Server) CreateBlobHandler(c *gin.Context) { path, err := GetBlobsPath(c.Param("digest")) if err != nil { c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": err.Error()}) @@ -946,20 +861,9 @@ func CreateBlobHandler(c *gin.Context) { return } - if _, err := layer.Commit(); err != nil { - c.AbortWithStatusJSON(http.StatusInternalServerError, gin.H{"error": err.Error()}) - return - } - c.Status(http.StatusCreated) } -var defaultAllowOrigins = []string{ - "localhost", - "127.0.0.1", - "0.0.0.0", -} - func isLocalIP(ip netip.Addr) bool { if interfaces, err := net.Interfaces(); err == nil { for _, iface := range interfaces { @@ -1031,6 +935,11 @@ func allowedHostsMiddleware(addr net.Addr) gin.HandlerFunc { } if allowedHost(host) { + if c.Request.Method == "OPTIONS" { + c.AbortWithStatus(http.StatusNoContent) + return + } + c.Next() return } @@ -1043,19 +952,8 @@ func (s *Server) GenerateRoutes() http.Handler { config := cors.DefaultConfig() config.AllowWildcard = true config.AllowBrowserExtensions = true - - if allowedOrigins := strings.Trim(os.Getenv("OLLAMA_ORIGINS"), "\"'"); allowedOrigins != "" { - config.AllowOrigins = strings.Split(allowedOrigins, ",") - } - - for _, allowOrigin := range defaultAllowOrigins { - config.AllowOrigins = append(config.AllowOrigins, - fmt.Sprintf("http://%s", allowOrigin), - fmt.Sprintf("https://%s", allowOrigin), - fmt.Sprintf("http://%s:*", allowOrigin), - fmt.Sprintf("https://%s:*", allowOrigin), - ) - } + config.AllowHeaders = []string{"Authorization", "Content-Type", "User-Agent", "Accept", "X-Requested-With"} + config.AllowOrigins = envconfig.AllowOrigins r := gin.Default() r.Use( @@ -1063,27 +961,27 @@ func (s *Server) GenerateRoutes() http.Handler { allowedHostsMiddleware(s.addr), ) - r.POST("/api/pull", PullModelHandler) - r.POST("/api/generate", GenerateHandler) - r.POST("/api/chat", ChatHandler) - r.POST("/api/embeddings", EmbeddingsHandler) - r.POST("/api/create", CreateModelHandler) - r.POST("/api/push", PushModelHandler) - r.POST("/api/copy", CopyModelHandler) - r.DELETE("/api/delete", DeleteModelHandler) - r.POST("/api/show", ShowModelHandler) - r.POST("/api/blobs/:digest", CreateBlobHandler) - r.HEAD("/api/blobs/:digest", HeadBlobHandler) + r.POST("/api/pull", s.PullModelHandler) + r.POST("/api/generate", s.GenerateHandler) + r.POST("/api/chat", s.ChatHandler) + r.POST("/api/embeddings", s.EmbeddingsHandler) + r.POST("/api/create", s.CreateModelHandler) + r.POST("/api/push", s.PushModelHandler) + r.POST("/api/copy", s.CopyModelHandler) + r.DELETE("/api/delete", s.DeleteModelHandler) + r.POST("/api/show", s.ShowModelHandler) + r.POST("/api/blobs/:digest", s.CreateBlobHandler) + r.HEAD("/api/blobs/:digest", s.HeadBlobHandler) // Compatibility endpoints - r.POST("/v1/chat/completions", openai.Middleware(), ChatHandler) + r.POST("/v1/chat/completions", openai.Middleware(), s.ChatHandler) for _, method := range []string{http.MethodGet, http.MethodHead} { r.Handle(method, "/", func(c *gin.Context) { c.String(http.StatusOK, "Ollama is running") }) - r.Handle(method, "/api/tags", ListModelsHandler) + r.Handle(method, "/api/tags", s.ListModelsHandler) r.Handle(method, "/api/version", func(c *gin.Context) { c.JSON(http.StatusOK, gin.H{"version": version.Version}) }) @@ -1094,10 +992,11 @@ func (s *Server) GenerateRoutes() http.Handler { func Serve(ln net.Listener) error { level := slog.LevelInfo - if debug := os.Getenv("OLLAMA_DEBUG"); debug != "" { + if envconfig.Debug { level = slog.LevelDebug } + slog.Info("server config", "env", envconfig.AsMap()) handler := slog.NewTextHandler(os.Stderr, &slog.HandlerOptions{ Level: level, AddSource: true, @@ -1121,7 +1020,7 @@ func Serve(ln net.Listener) error { return err } - if noprune := os.Getenv("OLLAMA_NOPRUNE"); noprune == "" { + if !envconfig.NoPrune { // clean up unused layers and manifests if err := PruneLayers(); err != nil { return err @@ -1137,7 +1036,9 @@ func Serve(ln net.Listener) error { } } - s := &Server{addr: ln.Addr()} + ctx, done := context.WithCancel(context.Background()) + sched := InitScheduler(ctx) + s := &Server{addr: ln.Addr(), sched: sched} r := s.GenerateRoutes() slog.Info(fmt.Sprintf("Listening on %s (version %s)", ln.Addr(), version.Version)) @@ -1150,7 +1051,9 @@ func Serve(ln net.Listener) error { signal.Notify(signals, syscall.SIGINT, syscall.SIGTERM) go func() { <-signals - unload() + srvr.Close() + done() + sched.unloadAllRunners() gpu.Cleanup() os.Exit(0) }() @@ -1158,12 +1061,12 @@ func Serve(ln net.Listener) error { if err := llm.Init(); err != nil { return fmt.Errorf("unable to initialize llm library %w", err) } - if runtime.GOOS == "linux" { // TODO - windows too - // check compatibility to log warnings - if _, err := gpu.CheckVRAM(); err != nil { - slog.Info(err.Error()) - } - } + + s.sched.Run(ctx) + + // At startup we retrieve GPU information so we can get log messages before loading a model + // This will log warnings to the log in case we have problems with detected GPUs + _ = gpu.GetGPUInfo() return srvr.Serve(ln) } @@ -1219,9 +1122,9 @@ func streamResponse(c *gin.Context, ch chan any) { } // ChatPrompt builds up a prompt from a series of messages for the currently `loaded` model -func chatPrompt(ctx context.Context, template string, messages []api.Message, numCtx int) (string, error) { +func chatPrompt(ctx context.Context, runner *runnerRef, template string, messages []api.Message, numCtx int) (string, error) { encode := func(s string) ([]int, error) { - return loaded.llama.Tokenize(ctx, s) + return runner.llama.Tokenize(ctx, s) } prompt, err := ChatPrompt(template, messages, numCtx, encode) @@ -1232,10 +1135,7 @@ func chatPrompt(ctx context.Context, template string, messages []api.Message, nu return prompt, nil } -func ChatHandler(c *gin.Context) { - loaded.mu.Lock() - defer loaded.mu.Unlock() - +func (s *Server) ChatHandler(c *gin.Context) { checkpointStart := time.Now() var req api.ChatRequest @@ -1292,8 +1192,12 @@ func ChatHandler(c *gin.Context) { sessionDuration = req.KeepAlive.Duration } - if err := load(c, model, opts, sessionDuration); err != nil { - c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()}) + rCh, eCh := s.sched.GetRunner(c.Request.Context(), model, opts, sessionDuration) + var runner *runnerRef + select { + case runner = <-rCh: + case err = <-eCh: + handleErrorResponse(c, err) return } @@ -1309,7 +1213,7 @@ func ChatHandler(c *gin.Context) { }, req.Messages...) } - prompt, err := chatPrompt(c.Request.Context(), model.Template, req.Messages, opts.NumCtx) + prompt, err := chatPrompt(c.Request.Context(), runner, model.Template, req.Messages, opts.NumCtx) if err != nil { c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()}) return @@ -1352,8 +1256,6 @@ func ChatHandler(c *gin.Context) { defer close(ch) fn := func(r llm.CompletionResponse) { - // Update model expiration - loaded.expireTimer.Reset(sessionDuration) resp := api.ChatResponse{ Model: req.Model, @@ -1376,7 +1278,7 @@ func ChatHandler(c *gin.Context) { ch <- resp } - if err := loaded.llama.Completion(c.Request.Context(), llm.CompletionRequest{ + if err := runner.llama.Completion(c.Request.Context(), llm.CompletionRequest{ Prompt: prompt, Format: req.Format, Images: images, @@ -1416,3 +1318,15 @@ func ChatHandler(c *gin.Context) { streamResponse(c, ch) } + +func handleErrorResponse(c *gin.Context, err error) { + if errors.Is(err, context.Canceled) { + c.JSON(499, gin.H{"error": "request canceled"}) + return + } + if errors.Is(err, ErrMaxQueue) { + c.JSON(http.StatusServiceUnavailable, gin.H{"error": err.Error()}) + return + } + c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()}) +} diff --git a/server/routes_test.go b/server/routes_test.go index 4f907702..896dc27b 100644 --- a/server/routes_test.go +++ b/server/routes_test.go @@ -17,7 +17,7 @@ import ( "github.com/stretchr/testify/assert" "github.com/ollama/ollama/api" - "github.com/ollama/ollama/parser" + "github.com/ollama/ollama/types/model" "github.com/ollama/ollama/version" ) @@ -55,13 +55,13 @@ func Test_Routes(t *testing.T) { createTestModel := func(t *testing.T, name string) { fname := createTestFile(t, "ollama-model") - modelfile := strings.NewReader(fmt.Sprintf("FROM %s\nPARAMETER seed 42\nPARAMETER top_p 0.9\nPARAMETER stop foo\nPARAMETER stop bar", fname)) - commands, err := parser.Parse(modelfile) + r := strings.NewReader(fmt.Sprintf("FROM %s\nPARAMETER seed 42\nPARAMETER top_p 0.9\nPARAMETER stop foo\nPARAMETER stop bar", fname)) + modelfile, err := model.ParseFile(r) assert.Nil(t, err) fn := func(resp api.ProgressResponse) { t.Logf("Status: %s", resp.Status) } - err = CreateModel(context.TODO(), name, "", "", commands, fn) + err = CreateModel(context.TODO(), name, "", "", modelfile, fn) assert.Nil(t, err) } @@ -124,14 +124,12 @@ func Test_Routes(t *testing.T) { Method: http.MethodPost, Path: "/api/create", Setup: func(t *testing.T, req *http.Request) { - f, err := os.CreateTemp(t.TempDir(), "ollama-model") - assert.Nil(t, err) - defer f.Close() + fname := createTestFile(t, "ollama-model") stream := false createReq := api.CreateRequest{ Name: "t-bone", - Modelfile: fmt.Sprintf("FROM %s", f.Name()), + Modelfile: fmt.Sprintf("FROM %s", fname), Stream: &stream, } jsonData, err := json.Marshal(createReq) @@ -216,28 +214,25 @@ func Test_Routes(t *testing.T) { httpSrv := httptest.NewServer(router) t.Cleanup(httpSrv.Close) - workDir, err := os.MkdirTemp("", "ollama-test") - assert.Nil(t, err) - defer os.RemoveAll(workDir) - os.Setenv("OLLAMA_MODELS", workDir) + t.Setenv("OLLAMA_MODELS", t.TempDir()) for _, tc := range testCases { - t.Logf("Running Test: [%s]", tc.Name) - u := httpSrv.URL + tc.Path - req, err := http.NewRequestWithContext(context.TODO(), tc.Method, u, nil) - assert.Nil(t, err) + t.Run(tc.Name, func(t *testing.T) { + u := httpSrv.URL + tc.Path + req, err := http.NewRequestWithContext(context.TODO(), tc.Method, u, nil) + assert.Nil(t, err) - if tc.Setup != nil { - tc.Setup(t, req) - } + if tc.Setup != nil { + tc.Setup(t, req) + } - resp, err := httpSrv.Client().Do(req) - assert.Nil(t, err) - defer resp.Body.Close() - - if tc.Expected != nil { - tc.Expected(t, resp) - } + resp, err := httpSrv.Client().Do(req) + assert.Nil(t, err) + defer resp.Body.Close() + if tc.Expected != nil { + tc.Expected(t, resp) + } + }) } } diff --git a/server/sched.go b/server/sched.go new file mode 100644 index 00000000..c4a071c1 --- /dev/null +++ b/server/sched.go @@ -0,0 +1,553 @@ +package server + +import ( + "context" + "errors" + "fmt" + "log/slog" + "reflect" + "sort" + "strings" + "sync" + "time" + + "github.com/ollama/ollama/api" + "github.com/ollama/ollama/format" + "github.com/ollama/ollama/gpu" + "github.com/ollama/ollama/llm" + "github.com/ollama/ollama/server/envconfig" + "golang.org/x/exp/slices" +) + +type LlmRequest struct { + ctx context.Context //nolint:containedctx + model *Model + opts api.Options + sessionDuration time.Duration + successCh chan *runnerRef + errCh chan error +} + +type Scheduler struct { + pendingReqCh chan *LlmRequest + finishedReqCh chan *LlmRequest + expiredCh chan *runnerRef + unloadedCh chan interface{} + + loaded map[string]*runnerRef + loadedMu sync.Mutex + + loadFn func(req *LlmRequest, ggml *llm.GGML, gpus gpu.GpuInfoList) + newServerFn func(gpus gpu.GpuInfoList, model string, ggml *llm.GGML, adapters []string, projectors []string, opts api.Options) (llm.LlamaServer, error) + getGpuFn func() gpu.GpuInfoList +} + +var ErrMaxQueue = fmt.Errorf("server busy, please try again. maximum pending requests exceeded") + +func InitScheduler(ctx context.Context) *Scheduler { + sched := &Scheduler{ + pendingReqCh: make(chan *LlmRequest, envconfig.MaxQueuedRequests), + finishedReqCh: make(chan *LlmRequest, envconfig.MaxQueuedRequests), + expiredCh: make(chan *runnerRef, envconfig.MaxQueuedRequests), + unloadedCh: make(chan interface{}, envconfig.MaxQueuedRequests), + loaded: make(map[string]*runnerRef), + newServerFn: llm.NewLlamaServer, + getGpuFn: gpu.GetGPUInfo, + } + sched.loadFn = sched.load + return sched +} + +// context must be canceled to decrement ref count and release the runner +func (s *Scheduler) GetRunner(c context.Context, model *Model, opts api.Options, sessionDuration time.Duration) (chan *runnerRef, chan error) { + // allocate a large enough kv cache for all parallel requests + opts.NumCtx = opts.NumCtx * envconfig.NumParallel + + req := &LlmRequest{ + ctx: c, + model: model, + opts: opts, + sessionDuration: sessionDuration, + successCh: make(chan *runnerRef), + errCh: make(chan error, 1), + } + + select { + case s.pendingReqCh <- req: + default: + req.errCh <- ErrMaxQueue + } + return req.successCh, req.errCh +} + +// Returns immediately, spawns go routines for the scheduler which will shutdown when ctx is done +func (s *Scheduler) Run(ctx context.Context) { + slog.Debug("starting llm scheduler") + go func() { + s.processPending(ctx) + }() + + go func() { + s.processCompleted(ctx) + }() +} + +func (s *Scheduler) processPending(ctx context.Context) { + for { + select { + case <-ctx.Done(): + slog.Debug("shutting down scheduler pending loop") + return + case pending := <-s.pendingReqCh: + // Block other requests until we get this pending request running + + if pending.ctx.Err() != nil { + slog.Debug("pending request cancelled or timed out, skipping scheduling") + continue + } + + for { + var runnerToExpire *runnerRef + s.loadedMu.Lock() + runner := s.loaded[pending.model.ModelPath] + loadedCount := len(s.loaded) + s.loadedMu.Unlock() + if runner != nil { + if runner.needsReload(ctx, pending) { + runnerToExpire = runner + } else { + // Runner is usable, return it + pending.useLoadedRunner(runner, s.finishedReqCh) + break + } + } else if envconfig.MaxRunners > 0 && loadedCount >= envconfig.MaxRunners { + slog.Debug("max runners achieved, unloading one to make room", "runner_count", loadedCount) + runnerToExpire = s.findRunnerToUnload() + } else { + // Either no models are loaded or below envconfig.MaxRunners + // Get a refreshed GPU list + gpus := s.getGpuFn() + + // Load model for fitting + ggml, err := llm.LoadModel(pending.model.ModelPath) + if err != nil { + pending.errCh <- err + break + } + + // If we're CPU only mode, just limit by envconfig.MaxRunners above + // TODO handle system memory exhaustion + if (len(gpus) == 1 && gpus[0].Library == "cpu") || pending.opts.NumGPU == 0 { + slog.Debug("cpu mode with existing models, loading") + s.loadFn(pending, ggml, gpus) + break + } + + // No models loaded. Load the model but prefer the best fit. + if loadedCount == 0 { + slog.Debug("loading first model", "model", pending.model.ModelPath) + g := pickBestFitGPUs(pending, ggml, gpus) + if g != nil { + gpus = g + } + s.loadFn(pending, ggml, gpus) + break + } + + // More than one loaded model, so we have to see if the new one fits + // Update free memory from currently loaded models + s.updateFreeSpace(gpus) + gpus = pickBestFitGPUs(pending, ggml, gpus) + if gpus != nil { + slog.Debug("new model fits with existing models, loading") + s.loadFn(pending, ggml, gpus) + break + } + runnerToExpire = s.findRunnerToUnload() + } + + if runnerToExpire == nil { + // Shouildn't happen + slog.Error("runner to expire was nil!") + continue + } + // Trigger an expiration to unload once it's done + runnerToExpire.refMu.Lock() + slog.Debug("resetting model to expire immediately to make room", "model", runnerToExpire.model, "refCount", runnerToExpire.refCount) + if runnerToExpire.expireTimer != nil { + runnerToExpire.expireTimer.Stop() + runnerToExpire.expireTimer = nil + } + runnerToExpire.sessionDuration = 0 + if runnerToExpire.refCount <= 0 { + s.expiredCh <- runnerToExpire + } + runnerToExpire.refMu.Unlock() + // Wait for the unload to happen + // Note: at this point we're queueing up all incoming requests, even if they were for + // a different model that's loaded and not scheduled to be removed. + slog.Debug("waiting for pending requests to complete and unload to occur", "model", runnerToExpire.model) + select { + case <-ctx.Done(): + slog.Debug("shutting down scheduler pending loop") + return + case <-s.unloadedCh: + slog.Debug("unload completed", "model", runnerToExpire.model) + continue + } + } + case <-s.unloadedCh: + // An unload request when there are no pending request can be ignored + slog.Debug("ignoring unload event with no pending requests") + } + } +} + +func (s *Scheduler) processCompleted(ctx context.Context) { + // Process completed requests, expired timers, and unloading models + for { + select { + case <-ctx.Done(): + slog.Debug("shutting down scheduler completed loop") + return + case finished := <-s.finishedReqCh: + s.loadedMu.Lock() + runner := s.loaded[finished.model.ModelPath] + s.loadedMu.Unlock() + if runner == nil { + slog.Error("finished requeset signal received after model unloaded", "model", finished.model.ModelPath) + continue + } + runner.refMu.Lock() + runner.refCount-- + if runner.refCount <= 0 { + if runner.sessionDuration <= 0 { + slog.Debug("runner with zero duration has gone idle, expiring to unload", "model", runner.model) + if runner.expireTimer != nil { + runner.expireTimer.Stop() + runner.expireTimer = nil + } + s.expiredCh <- runner + } else if runner.expireTimer == nil { + slog.Debug("runner with non-zero duration has gone idle, adding timer", "model", runner.model, "duration", runner.sessionDuration) + runner.expireTimer = time.AfterFunc(runner.sessionDuration, func() { + slog.Debug("timer expired, expiring to unload", "model", runner.model) + runner.refMu.Lock() + defer runner.refMu.Unlock() + if runner.expireTimer != nil { + runner.expireTimer.Stop() + runner.expireTimer = nil + } + s.expiredCh <- runner + }) + } else { + slog.Debug("runner with non-zero duration has gone idle, resetting timer", "model", runner.model, "duration", runner.sessionDuration) + runner.expireTimer.Reset(runner.sessionDuration) + } + } + slog.Debug("after processing request finished event", "model", runner.model, "refCount", runner.refCount) + runner.refMu.Unlock() + case runner := <-s.expiredCh: + slog.Debug("runner expired event received", "model", runner.model) + runner.refMu.Lock() + if runner.refCount > 0 { + // Shouldn't happen, but safeguard to ensure no leaked runners + slog.Debug("expired event with positive ref count, retrying", "model", runner.model, "refCount", runner.refCount) + go func(runner *runnerRef) { + // We can't unload yet, but want to as soon as the current request completes + // So queue up another expired event + time.Sleep(10 * time.Millisecond) + s.expiredCh <- runner + }(runner) + runner.refMu.Unlock() + continue + } + + s.loadedMu.Lock() + slog.Debug("got lock to unload", "model", runner.model) + runner.unload() + delete(s.loaded, runner.model) + s.loadedMu.Unlock() + slog.Debug("runner released", "model", runner.model) + runner.refMu.Unlock() + slog.Debug("sending an unloaded event", "model", runner.model) + s.unloadedCh <- struct{}{} + } + } +} + +// Complete the pending request and send the runner back to the requester +// Wires up a finished event after the request context is completed +// Updates session duration, and resets expiration timer +func (pending *LlmRequest) useLoadedRunner(runner *runnerRef, finished chan *LlmRequest) { + runner.refMu.Lock() + defer runner.refMu.Unlock() + runner.refCount++ + if runner.expireTimer != nil { + runner.expireTimer.Stop() + runner.expireTimer = nil + } + runner.sessionDuration = pending.sessionDuration + pending.successCh <- runner + go func() { + <-pending.ctx.Done() + slog.Debug("context for request finished") + finished <- pending + }() +} + +func (s *Scheduler) load(req *LlmRequest, ggml *llm.GGML, gpus gpu.GpuInfoList) { + llama, err := s.newServerFn(gpus, req.model.ModelPath, ggml, req.model.AdapterPaths, req.model.ProjectorPaths, req.opts) + if err != nil { + // some older models are not compatible with newer versions of llama.cpp + // show a generalized compatibility error until there is a better way to + // check for model compatibility + if errors.Is(llm.ErrUnsupportedFormat, err) || strings.Contains(err.Error(), "failed to load model") { + err = fmt.Errorf("%v: this model may be incompatible with your version of Ollama. If you previously pulled this model, try updating it by running `ollama pull %s`", err, req.model.ShortName) + } + slog.Info("NewLlamaServer failed", "model", req.model.ModelPath, "error", err) + req.errCh <- err + return + } + runner := &runnerRef{} + runner.model = req.model.ModelPath + runner.adapters = req.model.AdapterPaths + runner.projectors = req.model.ProjectorPaths + runner.llama = llama + runner.Options = &req.opts + runner.sessionDuration = req.sessionDuration + runner.gpus = gpus + runner.estimatedVRAM = llama.EstimatedVRAM() + runner.loading = true + runner.refCount = 1 + runner.refMu.Lock() + s.loadedMu.Lock() + s.loaded[req.model.ModelPath] = runner + slog.Info("loaded runners", "count", len(s.loaded)) + s.loadedMu.Unlock() + + go func() { + defer runner.refMu.Unlock() + if err = llama.WaitUntilRunning(req.ctx); err != nil { + slog.Error("error loading llama server", "error", err) + runner.refCount-- + req.errCh <- err + slog.Debug("triggering expiration for failed load", "model", runner.model) + s.expiredCh <- runner + return + } + slog.Debug("finished setting up runner", "model", req.model.ModelPath) + runner.loading = false + go func() { + <-req.ctx.Done() + slog.Debug("context for request finished") + s.finishedReqCh <- req + }() + req.successCh <- runner + }() +} + +func (s *Scheduler) updateFreeSpace(allGpus gpu.GpuInfoList) { + type predKey struct { + Library string + ID string + } + predMap := map[predKey]uint64{} // Sum up the total predicted usage per GPU for all runners + s.loadedMu.Lock() + for _, r := range s.loaded { + r.refMu.Lock() + gpuIDs := make([]string, 0, len(r.gpus)) + if r.llama != nil { + + // TODO this should be broken down by GPU instead of assuming uniform spread + estimatedVRAMPerGPU := r.llama.EstimatedVRAM() / uint64(len(r.gpus)) + for _, gpu := range r.gpus { + gpuIDs = append(gpuIDs, gpu.ID) + } + for _, gpu := range allGpus { + if slices.Contains(gpuIDs, gpu.ID) { + predMap[predKey{gpu.Library, gpu.ID}] += estimatedVRAMPerGPU + } + } + } else { + slog.Warn("unexpected nil runner reference, memory prediction may be incorrect") + } + r.refMu.Unlock() + } + s.loadedMu.Unlock() + + // Now that we've summed up all the GPU usage predictions across all the loaded runners, update the gpu list + for i := range allGpus { + if p, ok := predMap[predKey{allGpus[i].Library, allGpus[i].ID}]; ok { + slog.Debug("gpu reported", "gpu", allGpus[i].ID, "library", allGpus[i].Library, "available", format.HumanBytes2(allGpus[i].FreeMemory)) + if p > allGpus[i].TotalMemory { + // Shouldn't happen + slog.Warn("predicted usage exceeds VRAM", "gpu", allGpus[i].ID, "totalMemory", allGpus[i].TotalMemory, "predicted", p) + allGpus[i].FreeMemory = 0 + } else if (allGpus[i].TotalMemory - p) < allGpus[i].FreeMemory { // predicted free is smaller than reported free, use it + // TODO maybe we should just always trust our numbers, since cuda's free memory reporting is laggy + // and we might unload models we didn't actually need to. The risk is if some other GPU intensive app is loaded + // after we start our first runner, then we'll never acount for that, so picking the smallest free value seems prudent. + allGpus[i].FreeMemory = allGpus[i].TotalMemory - p + } + slog.Info("updated VRAM", "gpu", allGpus[i].ID, "library", allGpus[i].Library, "total", format.HumanBytes2(allGpus[i].TotalMemory), "available", format.HumanBytes2(allGpus[i].FreeMemory)) + } + } +} + +type runnerRef struct { + refMu sync.Mutex + // refCond sync.Cond // Signaled on transition from 1 -> 0 refCount + refCount uint // prevent unloading if > 0 + // unloading bool // set to true when we are trying to unload the runner + + llama llm.LlamaServer + loading bool // True only during initial load, then false forever + gpus gpu.GpuInfoList // Recorded at time of provisioning + estimatedVRAM uint64 + + sessionDuration time.Duration + expireTimer *time.Timer + + model string + adapters []string + projectors []string + *api.Options +} + +// The refMu must already be held when calling unload +func (runner *runnerRef) unload() { + if runner.expireTimer != nil { + runner.expireTimer.Stop() + runner.expireTimer = nil + } + if runner.llama != nil { + runner.llama.Close() + } + runner.llama = nil + runner.adapters = nil + runner.projectors = nil + runner.Options = nil + runner.gpus = nil +} + +func (runner *runnerRef) needsReload(ctx context.Context, req *LlmRequest) bool { + slog.Debug("evaluating already loaded", "model", req.model.ModelPath) + runner.refMu.Lock() + defer runner.refMu.Unlock() + + timeout := 10 * time.Second + if runner.loading { + timeout = 2 * time.Minute // Initial load can take a long time for big models on slow systems... + } + + if runner.Options == nil { + return true + } + + // Don't reload runner if num_gpu=-1 was provided + optsExisting := runner.Options.Runner + optsNew := req.opts.Runner + if optsNew.NumGPU < 0 { + optsExisting.NumGPU = -1 + optsNew.NumGPU = -1 + } + + ctx, cancel := context.WithTimeout(ctx, timeout) + defer cancel() + if !reflect.DeepEqual(runner.adapters, req.model.AdapterPaths) || // have the adapters changed? + !reflect.DeepEqual(runner.projectors, req.model.ProjectorPaths) || // have the projectors changed? + !reflect.DeepEqual(optsExisting, optsNew) || // have the runner options changed? + runner.llama.Ping(ctx) != nil { + return true + } + + return false +} + +type ByDuration []*runnerRef + +func (a ByDuration) Len() int { return len(a) } +func (a ByDuration) Swap(i, j int) { a[i], a[j] = a[j], a[i] } +func (a ByDuration) Less(i, j int) bool { + // uint64 to turn negative time (never unload) to largest + return uint64(a[i].sessionDuration) < uint64(a[j].sessionDuration) +} + +// TODO - future consideration to pick runners based on size +// type BySize []*runnerRef +// func (a BySize) Len() int { return len(a) } +// func (a BySize) Swap(i, j int) { a[i], a[j] = a[j], a[i] } +// func (a BySize) Less(i, j int) bool { return a[i].estimatedVRAM < a[j].estimatedVRAM } + +// pickBestFitGPUs will try to find the optimal placement of the model in the available GPUs where the model fully fits +// If the model can not be fit fully within the available GPU(s) nil is returned +func pickBestFitGPUs(req *LlmRequest, ggml *llm.GGML, gpus gpu.GpuInfoList) gpu.GpuInfoList { + var estimatedVRAM uint64 + for _, gl := range gpus.ByLibrary() { + var ok bool + sgl := append(make(gpu.GpuInfoList, 0, len(gl)), gl...) + + // TODO - potentially sort by performance capability, existing models loaded, etc. + // Note: at present, this will favor more VRAM over faster GPU speed in mixed setups + sort.Sort(sort.Reverse(gpu.ByFreeMemory(sgl))) + + // First attempt to fit the model into a single GPU + for _, g := range sgl { + if ok, estimatedVRAM = llm.PredictServerFit([]gpu.GpuInfo{g}, ggml, req.model.AdapterPaths, req.model.ProjectorPaths, req.opts); ok { + slog.Debug("new model will fit in available VRAM in single GPU, loading", "model", req.model.ModelPath, "gpu", g.ID, "available", g.FreeMemory, "required", format.HumanBytes2(estimatedVRAM)) + return []gpu.GpuInfo{g} + } + } + + // TODO future refinements + // - if multiple Libraries, see if any single GPU in any Library will fit + // - try subsets of GPUs instead of just falling back to 1 or all in a family + + // Now try all the GPUs + if ok, estimatedVRAM = llm.PredictServerFit(gl, ggml, req.model.AdapterPaths, req.model.ProjectorPaths, req.opts); ok { + slog.Debug("new model will fit in available VRAM, loading", "model", req.model.ModelPath, "library", gl[0].Library, "required", format.HumanBytes2(estimatedVRAM)) + return gl + } + } + return nil +} + +// findRunnerToUnload finds a runner to unload to make room for a new model +func (s *Scheduler) findRunnerToUnload() *runnerRef { + s.loadedMu.Lock() + runnerList := make([]*runnerRef, 0, len(s.loaded)) + for _, r := range s.loaded { + runnerList = append(runnerList, r) + } + s.loadedMu.Unlock() + + // In the future we can enhance the algorithm to be smarter about picking the optimal runner to unload + // e.g., if we have multiple options, will one make room for the request? + sort.Sort(ByDuration(runnerList)) + + // First try to find a runner that's already idle + for _, runner := range runnerList { + runner.refMu.Lock() + rc := runner.refCount + runner.refMu.Unlock() + if rc == 0 { + slog.Debug("found an idle runner to unload") + return runner + } + } + // None appear idle, just wait for the one with the shortest duration + slog.Debug("no idle runners, picking the shortest duration", "count", len(runnerList)) + return runnerList[0] +} + +func (s *Scheduler) unloadAllRunners() { + s.loadedMu.Lock() + defer s.loadedMu.Unlock() + for model, runner := range s.loaded { + if runner.llama != nil { + slog.Debug("shutting down runner", "model", model) + runner.llama.Close() + } + } +} diff --git a/server/sched_test.go b/server/sched_test.go new file mode 100644 index 00000000..7e4faa61 --- /dev/null +++ b/server/sched_test.go @@ -0,0 +1,601 @@ +package server + +import ( + "bytes" + "context" + "encoding/binary" + "fmt" + "log/slog" + "os" + "testing" + "time" + + "github.com/ollama/ollama/api" + "github.com/ollama/ollama/app/lifecycle" + "github.com/ollama/ollama/format" + "github.com/ollama/ollama/gpu" + "github.com/ollama/ollama/llm" + "github.com/ollama/ollama/server/envconfig" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +func init() { + os.Setenv("OLLAMA_DEBUG", "1") + lifecycle.InitLogging() +} + +func TestInitScheduler(t *testing.T) { + ctx, done := context.WithCancel(context.Background()) + defer done() + s := InitScheduler(ctx) + s.loadedMu.Lock() + require.NotNil(t, s.loaded) + s.loadedMu.Unlock() +} + +func TestLoad(t *testing.T) { + ctx, done := context.WithTimeout(context.Background(), 20*time.Millisecond) + defer done() + s := InitScheduler(ctx) + var ggml *llm.GGML // value not used in tests + req := &LlmRequest{ + ctx: ctx, + model: &Model{ModelPath: "foo"}, + opts: api.DefaultOptions(), + successCh: make(chan *runnerRef, 1), + errCh: make(chan error, 1), + sessionDuration: 2, + } + // Fail to load model first + s.newServerFn = func(gpus gpu.GpuInfoList, model string, ggml *llm.GGML, adapters []string, projectors []string, opts api.Options) (llm.LlamaServer, error) { + return nil, fmt.Errorf("something failed to load model blah") + } + gpus := gpu.GpuInfoList{} + s.load(req, ggml, gpus) + require.Len(t, req.successCh, 0) + require.Len(t, req.errCh, 1) + s.loadedMu.Lock() + require.Len(t, s.loaded, 0) + s.loadedMu.Unlock() + err := <-req.errCh + require.Contains(t, err.Error(), "this model may be incompatible") + + server := &mockLlm{estimatedVRAM: 10} + s.newServerFn = func(gpus gpu.GpuInfoList, model string, ggml *llm.GGML, adapters []string, projectors []string, opts api.Options) (llm.LlamaServer, error) { + return server, nil + } + s.load(req, ggml, gpus) + select { + case err := <-req.errCh: + require.NoError(t, err) + case resp := <-req.successCh: + require.Equal(t, uint64(10), resp.estimatedVRAM) + require.Equal(t, uint(1), resp.refCount) + s.loadedMu.Lock() + require.Len(t, s.loaded, 1) + s.loadedMu.Unlock() + } + + req.model.ModelPath = "dummy_model_path" + server.waitResp = fmt.Errorf("wait failure") + s.load(req, ggml, gpus) + select { + case err := <-req.errCh: + require.Contains(t, err.Error(), "wait failure") + case resp := <-req.successCh: + t.Errorf("unexpected success %v", resp) + } + s.loadedMu.Lock() + runner := s.loaded["dummy_model_path"] + s.loadedMu.Unlock() + require.NotNil(t, runner) + require.Equal(t, uint(0), runner.refCount) + time.Sleep(1 * time.Millisecond) + require.Len(t, s.expiredCh, 1) +} + +type bundle struct { + ctx context.Context //nolint:containedctx + ctxDone func() + srv *mockLlm + req *LlmRequest + ggml *llm.GGML +} + +func (scenario *bundle) newServer(gpus gpu.GpuInfoList, model string, ggml *llm.GGML, adapters []string, projectors []string, opts api.Options) (llm.LlamaServer, error) { + return scenario.srv, nil +} + +func newScenario(t *testing.T, ctx context.Context, modelName string, estimatedVRAM uint64) *bundle { + scenario := &bundle{} + scenario.ctx, scenario.ctxDone = context.WithCancel(ctx) + t.Helper() + + f, err := os.CreateTemp(t.TempDir(), modelName) + assert.Nil(t, err) + defer f.Close() + + gguf := llm.NewGGUFV3(binary.LittleEndian) + err = gguf.Encode(f, llm.KV{ + "general.architecture": "llama", + "general.name": "name", + "llama.context_length": uint32(32), + "llama.embedding_length": uint32(4096), + "llama.block_count": uint32(1), + "llama.attention.head_count": uint32(32), + "llama.attention.head_count_kv": uint32(32), + "tokenizer.ggml.tokens": []string{" "}, + "tokenizer.ggml.scores": []float32{0}, + "tokenizer.ggml.token_type": []int32{0}, + }, []llm.Tensor{ + {Name: "blk.0.attn.weight", Kind: uint32(0), Offset: uint64(0), Shape: []uint64{1, 1, 1, 1}, WriterTo: &bytes.Reader{}}, + }) + assert.Nil(t, err) + + fname := f.Name() + model := &Model{Name: modelName, ModelPath: fname} + scenario.ggml, err = llm.LoadModel(model.ModelPath) + require.NoError(t, err) + + scenario.req = &LlmRequest{ + ctx: scenario.ctx, + model: model, + opts: api.DefaultOptions(), + sessionDuration: 5 * time.Millisecond, + successCh: make(chan *runnerRef, 1), + errCh: make(chan error, 1), + } + scenario.srv = &mockLlm{estimatedVRAM: estimatedVRAM} + return scenario +} + +func TestRequests(t *testing.T) { + ctx, done := context.WithTimeout(context.Background(), 500*time.Millisecond) + defer done() + + // Same model, same request + scenario1a := newScenario(t, ctx, "ollama-model-1", 10) + scenario1a.req.sessionDuration = 0 + scenario1b := newScenario(t, ctx, "ollama-model-1", 11) + scenario1b.req.model = scenario1a.req.model + scenario1b.ggml = scenario1a.ggml + scenario1b.req.sessionDuration = 0 + + // simple reload of same model + scenario2a := newScenario(t, ctx, "ollama-model-1", 20) + scenario2a.req.model = scenario1a.req.model + scenario2a.ggml = scenario1a.ggml + + // Multiple loaded models + scenario3a := newScenario(t, ctx, "ollama-model-3a", 1*format.GigaByte) + scenario3b := newScenario(t, ctx, "ollama-model-3b", 24*format.GigaByte) + scenario3c := newScenario(t, ctx, "ollama-model-4a", 30) + scenario3c.req.opts.NumGPU = 0 // CPU load, will be allowed + scenario3d := newScenario(t, ctx, "ollama-model-3c", 30) // Needs prior unloaded + + s := InitScheduler(ctx) + s.getGpuFn = func() gpu.GpuInfoList { + g := gpu.GpuInfo{Library: "metal"} + g.TotalMemory = 24 * format.GigaByte + g.FreeMemory = 12 * format.GigaByte + return []gpu.GpuInfo{g} + } + s.newServerFn = scenario1a.newServer + slog.Info("scenario1a") + s.pendingReqCh <- scenario1a.req + require.Len(t, s.pendingReqCh, 1) + s.Run(ctx) + select { + case resp := <-scenario1a.req.successCh: + require.Equal(t, resp.llama, scenario1a.srv) + require.Len(t, s.pendingReqCh, 0) + require.Len(t, scenario1a.req.errCh, 0) + case <-ctx.Done(): + t.Errorf("timeout") + } + + // Same runner as first request due to not needing a reload + s.newServerFn = scenario1b.newServer + slog.Info("scenario1b") + s.pendingReqCh <- scenario1b.req + select { + case resp := <-scenario1b.req.successCh: + require.Equal(t, resp.llama, scenario1a.srv) + require.Len(t, s.pendingReqCh, 0) + require.Len(t, scenario1b.req.errCh, 0) + case <-ctx.Done(): + t.Errorf("timeout") + } + + // Trigger a reload + s.newServerFn = scenario2a.newServer + scenario2a.req.model.AdapterPaths = []string{"new"} + slog.Info("scenario2a") + s.pendingReqCh <- scenario2a.req + // finish first two requests, so model can reload + time.Sleep(1 * time.Millisecond) + scenario1a.ctxDone() + scenario1b.ctxDone() + select { + case resp := <-scenario2a.req.successCh: + require.Equal(t, resp.llama, scenario2a.srv) + require.Len(t, s.pendingReqCh, 0) + require.Len(t, scenario2a.req.errCh, 0) + case <-ctx.Done(): + t.Errorf("timeout") + } + + envconfig.MaxRunners = 1 + s.newServerFn = scenario3a.newServer + slog.Info("scenario3a") + s.pendingReqCh <- scenario3a.req + // finish prior request, so new model can load + time.Sleep(1 * time.Millisecond) + scenario2a.ctxDone() + select { + case resp := <-scenario3a.req.successCh: + require.Equal(t, resp.llama, scenario3a.srv) + require.Len(t, s.pendingReqCh, 0) + require.Len(t, scenario3a.req.errCh, 0) + case <-ctx.Done(): + t.Errorf("timeout") + } + s.loadedMu.Lock() + require.Len(t, s.loaded, 1) + s.loadedMu.Unlock() + + envconfig.MaxRunners = 0 + s.newServerFn = scenario3b.newServer + slog.Info("scenario3b") + s.pendingReqCh <- scenario3b.req + select { + case resp := <-scenario3b.req.successCh: + require.Equal(t, resp.llama, scenario3b.srv) + require.Len(t, s.pendingReqCh, 0) + require.Len(t, scenario3b.req.errCh, 0) + case <-ctx.Done(): + t.Errorf("timeout") + } + s.loadedMu.Lock() + require.Len(t, s.loaded, 2) + s.loadedMu.Unlock() + + // This is a CPU load with NumGPU = 0 so it should load + s.newServerFn = scenario3c.newServer + slog.Info("scenario3c") + s.pendingReqCh <- scenario3c.req + select { + case resp := <-scenario3c.req.successCh: + require.Equal(t, resp.llama, scenario3c.srv) + require.Len(t, s.pendingReqCh, 0) + require.Len(t, scenario3c.req.errCh, 0) + case <-ctx.Done(): + t.Errorf("timeout") + } + s.loadedMu.Lock() + require.Len(t, s.loaded, 3) + s.loadedMu.Unlock() + + // Try to load a model that wont fit + s.newServerFn = scenario3d.newServer + slog.Info("scenario3d") + s.loadedMu.Lock() + require.Len(t, s.loaded, 3) + s.loadedMu.Unlock() + scenario3a.ctxDone() // Won't help since this one isn't big enough to make room + time.Sleep(2 * time.Millisecond) + s.pendingReqCh <- scenario3d.req + // finish prior request, so new model can load + time.Sleep(6 * time.Millisecond) + s.loadedMu.Lock() + require.Len(t, s.loaded, 2) + s.loadedMu.Unlock() + scenario3b.ctxDone() + select { + case resp := <-scenario3d.req.successCh: + require.Equal(t, resp.llama, scenario3d.srv) + require.Len(t, s.pendingReqCh, 0) + require.Len(t, scenario3d.req.errCh, 0) + case <-ctx.Done(): + t.Errorf("timeout") + } + s.loadedMu.Lock() + require.Len(t, s.loaded, 2) + s.loadedMu.Unlock() +} + +func TestGetRunner(t *testing.T) { + ctx, done := context.WithTimeout(context.Background(), 100*time.Millisecond) + defer done() + + // Same model, same request + scenario1a := newScenario(t, ctx, "ollama-model-1a", 10) + scenario1a.req.sessionDuration = 0 + scenario1b := newScenario(t, ctx, "ollama-model-1b", 10) + scenario1b.req.sessionDuration = 0 + scenario1c := newScenario(t, ctx, "ollama-model-1c", 10) + scenario1c.req.sessionDuration = 0 + envconfig.MaxQueuedRequests = 1 + s := InitScheduler(ctx) + s.getGpuFn = func() gpu.GpuInfoList { + g := gpu.GpuInfo{Library: "metal"} + g.TotalMemory = 24 * format.GigaByte + g.FreeMemory = 12 * format.GigaByte + return []gpu.GpuInfo{g} + } + s.newServerFn = scenario1a.newServer + slog.Info("scenario1a") + successCh1a, errCh1a := s.GetRunner(scenario1a.ctx, scenario1a.req.model, scenario1a.req.opts, scenario1a.req.sessionDuration) + require.Len(t, s.pendingReqCh, 1) + slog.Info("scenario1b") + successCh1b, errCh1b := s.GetRunner(scenario1b.ctx, scenario1b.req.model, scenario1b.req.opts, scenario1b.req.sessionDuration) + require.Len(t, s.pendingReqCh, 1) + require.Len(t, successCh1b, 0) + require.Len(t, errCh1b, 1) + err := <-errCh1b + require.Contains(t, err.Error(), "server busy") + s.Run(ctx) + select { + case resp := <-successCh1a: + require.Equal(t, resp.llama, scenario1a.srv) + require.Len(t, s.pendingReqCh, 0) + require.Len(t, errCh1a, 0) + case <-ctx.Done(): + t.Errorf("timeout") + } + scenario1a.ctxDone() + s.loadedMu.Lock() + require.Len(t, s.loaded, 1) + s.loadedMu.Unlock() + + scenario1c.req.model.ModelPath = "bad path" + slog.Info("scenario1c") + successCh1c, errCh1c := s.GetRunner(scenario1c.ctx, scenario1c.req.model, scenario1c.req.opts, scenario1c.req.sessionDuration) + // Starts in pending channel, then should be quickly processsed to return an error + time.Sleep(5 * time.Millisecond) + require.Len(t, successCh1c, 0) + s.loadedMu.Lock() + require.Len(t, s.loaded, 0) + s.loadedMu.Unlock() + require.Len(t, errCh1c, 1) + err = <-errCh1c + require.Contains(t, err.Error(), "bad path") + scenario1b.ctxDone() +} + +// TODO - add one scenario that triggers the bogus finished event with positive ref count +func TestPrematureExpired(t *testing.T) { + ctx, done := context.WithTimeout(context.Background(), 500*time.Millisecond) + defer done() + + // Same model, same request + scenario1a := newScenario(t, ctx, "ollama-model-1a", 10) + s := InitScheduler(ctx) + s.getGpuFn = func() gpu.GpuInfoList { + g := gpu.GpuInfo{Library: "metal"} + g.TotalMemory = 24 * format.GigaByte + g.FreeMemory = 12 * format.GigaByte + return []gpu.GpuInfo{g} + } + s.newServerFn = scenario1a.newServer + successCh1a, errCh1a := s.GetRunner(scenario1a.ctx, scenario1a.req.model, scenario1a.req.opts, scenario1a.req.sessionDuration) + require.Len(t, s.pendingReqCh, 1) + s.Run(ctx) + select { + case resp := <-successCh1a: + require.Equal(t, resp.llama, scenario1a.srv) + require.Len(t, s.pendingReqCh, 0) + require.Len(t, errCh1a, 0) + s.loadedMu.Lock() + require.Len(t, s.loaded, 1) + s.loadedMu.Unlock() + slog.Info("sending premature expired event now") + s.expiredCh <- resp // Shouldn't happen in real life, but make sure its safe + case <-ctx.Done(): + t.Errorf("timeout") + } + time.Sleep(scenario1a.req.sessionDuration) + scenario1a.ctxDone() + time.Sleep(20 * time.Millisecond) + require.LessOrEqual(t, len(s.finishedReqCh), 1) + time.Sleep(10 * time.Millisecond) + require.Len(t, s.finishedReqCh, 0) + s.loadedMu.Lock() + require.Len(t, s.loaded, 0) + s.loadedMu.Unlock() + + // also shouldn't happen in real life + s.finishedReqCh <- scenario1a.req + time.Sleep(5 * time.Millisecond) +} + +func TestUseLoadedRunner(t *testing.T) { + ctx, done := context.WithTimeout(context.Background(), 100*time.Millisecond) + req := &LlmRequest{ + ctx: ctx, + opts: api.DefaultOptions(), + successCh: make(chan *runnerRef, 1), + sessionDuration: 2, + } + finished := make(chan *LlmRequest) + llm1 := &mockLlm{} + r1 := &runnerRef{llama: llm1, sessionDuration: 1} + req.useLoadedRunner(r1, finished) + require.Equal(t, uint(1), r1.refCount) + require.Equal(t, time.Duration(2), r1.sessionDuration) + select { + case success := <-req.successCh: + require.Equal(t, r1, success) + case <-ctx.Done(): + t.Errorf("timeout") + } + done() + fin := <-finished + require.Equal(t, req, fin) +} + +func TestUpdateFreeSpace(t *testing.T) { + ctx, done := context.WithTimeout(context.Background(), 100*time.Millisecond) + defer done() + gpus := gpu.GpuInfoList{ + { + Library: "a", + ID: "1", + }, + { + Library: "a", + ID: "2", + }, + } + gpus[0].TotalMemory = 1000 + gpus[0].FreeMemory = 900 + gpus[1].TotalMemory = 2000 + gpus[1].FreeMemory = 1900 + llm1 := &mockLlm{estimatedVRAM: 100} + llm2 := &mockLlm{estimatedVRAM: 200} + r1 := &runnerRef{llama: llm1, gpus: gpus} + r2 := &runnerRef{llama: llm2, gpus: gpus} + + s := InitScheduler(ctx) + s.loadedMu.Lock() + s.loaded["a"] = r1 + s.loaded["b"] = r2 + s.loadedMu.Unlock() + + s.updateFreeSpace(gpus) + require.Equal(t, uint64(850), gpus[0].FreeMemory) + require.Equal(t, uint64(1850), gpus[1].FreeMemory) +} + +func TestFindRunnerToUnload(t *testing.T) { + ctx, done := context.WithTimeout(context.Background(), 100*time.Millisecond) + defer done() + + r1 := &runnerRef{refCount: 1, sessionDuration: 1} + r2 := &runnerRef{sessionDuration: 2} + + s := InitScheduler(ctx) + s.loadedMu.Lock() + s.loaded["a"] = r1 + s.loaded["b"] = r2 + s.loadedMu.Unlock() + + resp := s.findRunnerToUnload() + require.Equal(t, r2, resp) + r2.refCount = 1 + resp = s.findRunnerToUnload() + require.Equal(t, r1, resp) + +} + +func TestNeedsReload(t *testing.T) { + ctx, done := context.WithTimeout(context.Background(), 100*time.Millisecond) + defer done() + + llm := &mockLlm{} + do := api.DefaultOptions() + runner := &runnerRef{ + adapters: []string{"adapter1"}, + projectors: []string{"projector1"}, + Options: &do, + llama: llm, + } + req := &LlmRequest{ + model: &Model{ + AdapterPaths: []string{"adapter2"}, + ProjectorPaths: []string{"projector2"}, + }, + opts: api.DefaultOptions(), + } + resp := runner.needsReload(ctx, req) + require.True(t, resp) + req.model.AdapterPaths = runner.adapters + resp = runner.needsReload(ctx, req) + require.True(t, resp) + req.model.ProjectorPaths = runner.projectors + runner.loading = true + req.opts.NumBatch = 1234 + resp = runner.needsReload(ctx, req) + require.True(t, resp) + req.opts.NumBatch = runner.Options.NumBatch + llm.pingResp = fmt.Errorf("foo") + resp = runner.needsReload(ctx, req) + require.True(t, resp) + llm.pingResp = nil + resp = runner.needsReload(ctx, req) + require.False(t, resp) + req.opts.NumGPU = 99 + resp = runner.needsReload(ctx, req) + require.True(t, resp) + req.opts.NumGPU = -1 + resp = runner.needsReload(ctx, req) + require.False(t, resp) +} + +func TestUnloadAllRunners(t *testing.T) { + ctx, done := context.WithTimeout(context.Background(), 100*time.Millisecond) + defer done() + + llm1 := &mockLlm{} + llm2 := &mockLlm{} + s := InitScheduler(ctx) + s.unloadAllRunners() + + r1 := &runnerRef{llama: llm1} + r2 := &runnerRef{llama: llm2} + + s.loadedMu.Lock() + s.loaded["a"] = r1 + s.loaded["b"] = r2 + s.loadedMu.Unlock() + s.unloadAllRunners() + + require.True(t, llm1.closeCalled) + require.True(t, llm2.closeCalled) +} + +func TestUnload(t *testing.T) { + llm1 := &mockLlm{} + r1 := &runnerRef{llama: llm1} + r2 := &runnerRef{adapters: []string{"A"}} + r1.unload() + require.True(t, llm1.closeCalled) + r2.unload() + require.Nil(t, r2.adapters) +} + +type mockLlm struct { + pingResp error + waitResp error + completionResp error + embeddingResp []float64 + embeddingRespErr error + tokenizeResp []int + tokenizeRespErr error + detokenizeResp string + detonekizeRespErr error + closeResp error + closeCalled bool + estimatedVRAM uint64 +} + +func (s *mockLlm) Ping(ctx context.Context) error { return s.pingResp } +func (s *mockLlm) WaitUntilRunning(ctx context.Context) error { return s.waitResp } +func (s *mockLlm) Completion(ctx context.Context, req llm.CompletionRequest, fn func(llm.CompletionResponse)) error { + return s.completionResp +} +func (s *mockLlm) Embedding(ctx context.Context, prompt string) ([]float64, error) { + return s.embeddingResp, s.embeddingRespErr +} +func (s *mockLlm) Tokenize(ctx context.Context, content string) ([]int, error) { + return s.tokenizeResp, s.tokenizeRespErr +} +func (s *mockLlm) Detokenize(ctx context.Context, tokens []int) (string, error) { + return s.detokenizeResp, s.detonekizeRespErr +} +func (s *mockLlm) Close() error { + s.closeCalled = true + return s.closeResp +} +func (s *mockLlm) EstimatedVRAM() uint64 { return s.estimatedVRAM } diff --git a/types/errtypes/errtypes.go b/types/errtypes/errtypes.go new file mode 100644 index 00000000..e3a18d0b --- /dev/null +++ b/types/errtypes/errtypes.go @@ -0,0 +1,18 @@ +// Package errtypes contains custom error types +package errtypes + +import ( + "fmt" + "strings" +) + +const UnknownOllamaKeyErrMsg = "unknown ollama key" + +// TODO: This should have a structured response from the API +type UnknownOllamaKey struct { + Key string +} + +func (e *UnknownOllamaKey) Error() string { + return fmt.Sprintf("unauthorized: %s %q", UnknownOllamaKeyErrMsg, strings.TrimSpace(e.Key)) +} diff --git a/types/model/digest.go b/types/model/digest.go deleted file mode 100644 index d5a7a155..00000000 --- a/types/model/digest.go +++ /dev/null @@ -1,79 +0,0 @@ -package model - -import ( - "log/slog" - "strings" - "unicode" -) - -// Digest represents a digest of a model Manifest. It is a comparable value -// type and is immutable. -// -// The zero Digest is not a valid digest. -type Digest struct { - s string -} - -// Type returns the digest type of the digest. -// -// Example: -// -// ParseDigest("sha256-1234").Type() // returns "sha256" -func (d Digest) Type() string { - typ, _, _ := strings.Cut(d.s, "-") - return typ -} - -// String returns the digest in the form of "-", or the -// empty string if the digest is invalid. -func (d Digest) String() string { return d.s } - -// IsValid returns true if the digest is valid (not zero). -// -// A valid digest may be created only by ParseDigest, or -// ParseName(name).Digest(). -func (d Digest) IsValid() bool { return d.s != "" } - -// LogValue implements slog.Value. -func (d Digest) LogValue() slog.Value { - return slog.StringValue(d.String()) -} - -var ( - _ slog.LogValuer = Digest{} -) - -// ParseDigest parses a string in the form of "-" into a -// Digest. -func ParseDigest(s string) Digest { - typ, digest, ok := strings.Cut(s, "-") - if ok && isValidDigestType(typ) && isValidHex(digest) { - return Digest{s: s} - } - return Digest{} -} - -func isValidDigestType(s string) bool { - if len(s) == 0 { - return false - } - for _, r := range s { - if !unicode.IsLower(r) && !unicode.IsDigit(r) { - return false - } - } - return true -} - -func isValidHex(s string) bool { - if len(s) == 0 { - return false - } - for i := range s { - c := s[i] - if c < '0' || c > '9' && c < 'a' || c > 'f' { - return false - } - } - return true -} diff --git a/types/model/digest_test.go b/types/model/digest_test.go deleted file mode 100644 index 5096a28a..00000000 --- a/types/model/digest_test.go +++ /dev/null @@ -1,46 +0,0 @@ -package model - -import "testing" - -var testDigests = map[string]Digest{ - "": {}, - "sha256-1234": {s: "sha256-1234"}, - "sha256-5678": {s: "sha256-5678"}, - "blake2-9abc": {s: "blake2-9abc"}, - "-1234": {}, - "sha256-": {}, - "sha256-1234-5678": {}, - "sha256-P": {}, // invalid hex - "sha256-1234P": {}, - "---": {}, -} - -func TestDigestParse(t *testing.T) { - // Test cases. - for s, want := range testDigests { - got := ParseDigest(s) - t.Logf("ParseDigest(%q) = %#v", s, got) - if got != want { - t.Errorf("ParseDigest(%q) = %q; want %q", s, got, want) - } - } -} - -func TestDigestString(t *testing.T) { - // Test cases. - for s, d := range testDigests { - want := s - if !d.IsValid() { - want = "" - } - got := d.String() - if got != want { - t.Errorf("ParseDigest(%q).String() = %q; want %q", s, got, want) - } - - got = ParseDigest(s).String() - if got != want { - t.Errorf("roundtrip ParseDigest(%q).String() = %q; want %q", s, got, want) - } - } -} diff --git a/types/model/file.go b/types/model/file.go new file mode 100644 index 00000000..ee398309 --- /dev/null +++ b/types/model/file.go @@ -0,0 +1,299 @@ +package model + +import ( + "bufio" + "bytes" + "errors" + "fmt" + "io" + "strconv" + "strings" +) + +type File struct { + Commands []Command +} + +func (f File) String() string { + var sb strings.Builder + for _, cmd := range f.Commands { + fmt.Fprintln(&sb, cmd.String()) + } + + return sb.String() +} + +type Command struct { + Name string + Args string +} + +func (c Command) String() string { + var sb strings.Builder + switch c.Name { + case "model": + fmt.Fprintf(&sb, "FROM %s", c.Args) + case "license", "template", "system", "adapter": + fmt.Fprintf(&sb, "%s %s", strings.ToUpper(c.Name), quote(c.Args)) + case "message": + role, message, _ := strings.Cut(c.Args, ": ") + fmt.Fprintf(&sb, "MESSAGE %s %s", role, quote(message)) + default: + fmt.Fprintf(&sb, "PARAMETER %s %s", c.Name, quote(c.Args)) + } + + return sb.String() +} + +type state int + +const ( + stateNil state = iota + stateName + stateValue + stateParameter + stateMessage + stateComment +) + +var ( + errMissingFrom = errors.New("no FROM line") + errInvalidMessageRole = errors.New("message role must be one of \"system\", \"user\", or \"assistant\"") + errInvalidCommand = errors.New("command must be one of \"from\", \"license\", \"template\", \"system\", \"adapter\", \"parameter\", or \"message\"") +) + +func ParseFile(r io.Reader) (*File, error) { + var cmd Command + var curr state + var b bytes.Buffer + var role string + + var f File + + br := bufio.NewReader(r) + for { + r, _, err := br.ReadRune() + if errors.Is(err, io.EOF) { + break + } else if err != nil { + return nil, err + } + + next, r, err := parseRuneForState(r, curr) + if errors.Is(err, io.ErrUnexpectedEOF) { + return nil, fmt.Errorf("%w: %s", err, b.String()) + } else if err != nil { + return nil, err + } + + // process the state transition, some transitions need to be intercepted and redirected + if next != curr { + switch curr { + case stateName: + if !isValidCommand(b.String()) { + return nil, errInvalidCommand + } + + // next state sometimes depends on the current buffer value + switch s := strings.ToLower(b.String()); s { + case "from": + cmd.Name = "model" + case "parameter": + // transition to stateParameter which sets command name + next = stateParameter + case "message": + // transition to stateMessage which validates the message role + next = stateMessage + fallthrough + default: + cmd.Name = s + } + case stateParameter: + cmd.Name = b.String() + case stateMessage: + if !isValidMessageRole(b.String()) { + return nil, errInvalidMessageRole + } + + role = b.String() + case stateComment, stateNil: + // pass + case stateValue: + s, ok := unquote(b.String()) + if !ok || isSpace(r) { + if _, err := b.WriteRune(r); err != nil { + return nil, err + } + + continue + } + + if role != "" { + s = role + ": " + s + role = "" + } + + cmd.Args = s + f.Commands = append(f.Commands, cmd) + } + + b.Reset() + curr = next + } + + if strconv.IsPrint(r) { + if _, err := b.WriteRune(r); err != nil { + return nil, err + } + } + } + + // flush the buffer + switch curr { + case stateComment, stateNil: + // pass; nothing to flush + case stateValue: + s, ok := unquote(b.String()) + if !ok { + return nil, io.ErrUnexpectedEOF + } + + if role != "" { + s = role + ": " + s + } + + cmd.Args = s + f.Commands = append(f.Commands, cmd) + default: + return nil, io.ErrUnexpectedEOF + } + + for _, cmd := range f.Commands { + if cmd.Name == "model" { + return &f, nil + } + } + + return nil, errMissingFrom +} + +func parseRuneForState(r rune, cs state) (state, rune, error) { + switch cs { + case stateNil: + switch { + case r == '#': + return stateComment, 0, nil + case isSpace(r), isNewline(r): + return stateNil, 0, nil + default: + return stateName, r, nil + } + case stateName: + switch { + case isAlpha(r): + return stateName, r, nil + case isSpace(r): + return stateValue, 0, nil + default: + return stateNil, 0, errInvalidCommand + } + case stateValue: + switch { + case isNewline(r): + return stateNil, r, nil + case isSpace(r): + return stateNil, r, nil + default: + return stateValue, r, nil + } + case stateParameter: + switch { + case isAlpha(r), isNumber(r), r == '_': + return stateParameter, r, nil + case isSpace(r): + return stateValue, 0, nil + default: + return stateNil, 0, io.ErrUnexpectedEOF + } + case stateMessage: + switch { + case isAlpha(r): + return stateMessage, r, nil + case isSpace(r): + return stateValue, 0, nil + default: + return stateNil, 0, io.ErrUnexpectedEOF + } + case stateComment: + switch { + case isNewline(r): + return stateNil, 0, nil + default: + return stateComment, 0, nil + } + default: + return stateNil, 0, errors.New("") + } +} + +func quote(s string) string { + if strings.Contains(s, "\n") || strings.HasPrefix(s, " ") || strings.HasSuffix(s, " ") { + if strings.Contains(s, "\"") { + return `"""` + s + `"""` + } + + return `"` + s + `"` + } + + return s +} + +func unquote(s string) (string, bool) { + // TODO: single quotes + if len(s) >= 3 && s[:3] == `"""` { + if len(s) >= 6 && s[len(s)-3:] == `"""` { + return s[3 : len(s)-3], true + } + + return "", false + } + + if len(s) >= 1 && s[0] == '"' { + if len(s) >= 2 && s[len(s)-1] == '"' { + return s[1 : len(s)-1], true + } + + return "", false + } + + return s, true +} + +func isAlpha(r rune) bool { + return r >= 'a' && r <= 'z' || r >= 'A' && r <= 'Z' +} + +func isNumber(r rune) bool { + return r >= '0' && r <= '9' +} + +func isSpace(r rune) bool { + return r == ' ' || r == '\t' +} + +func isNewline(r rune) bool { + return r == '\r' || r == '\n' +} + +func isValidMessageRole(role string) bool { + return role == "system" || role == "user" || role == "assistant" +} + +func isValidCommand(cmd string) bool { + switch strings.ToLower(cmd) { + case "from", "license", "template", "system", "adapter", "parameter", "message": + return true + default: + return false + } +} diff --git a/types/model/file_test.go b/types/model/file_test.go new file mode 100644 index 00000000..8e71760c --- /dev/null +++ b/types/model/file_test.go @@ -0,0 +1,511 @@ +package model + +import ( + "bytes" + "fmt" + "io" + "strings" + "testing" + + "github.com/stretchr/testify/assert" +) + +func TestParseFileFile(t *testing.T) { + input := ` +FROM model1 +ADAPTER adapter1 +LICENSE MIT +PARAMETER param1 value1 +PARAMETER param2 value2 +TEMPLATE template1 +` + + reader := strings.NewReader(input) + + modelfile, err := ParseFile(reader) + assert.NoError(t, err) + + expectedCommands := []Command{ + {Name: "model", Args: "model1"}, + {Name: "adapter", Args: "adapter1"}, + {Name: "license", Args: "MIT"}, + {Name: "param1", Args: "value1"}, + {Name: "param2", Args: "value2"}, + {Name: "template", Args: "template1"}, + } + + assert.Equal(t, expectedCommands, modelfile.Commands) +} + +func TestParseFileFrom(t *testing.T) { + var cases = []struct { + input string + expected []Command + err error + }{ + { + "FROM foo", + []Command{{Name: "model", Args: "foo"}}, + nil, + }, + { + "FROM /path/to/model", + []Command{{Name: "model", Args: "/path/to/model"}}, + nil, + }, + { + "FROM /path/to/model/fp16.bin", + []Command{{Name: "model", Args: "/path/to/model/fp16.bin"}}, + nil, + }, + { + "FROM llama3:latest", + []Command{{Name: "model", Args: "llama3:latest"}}, + nil, + }, + { + "FROM llama3:7b-instruct-q4_K_M", + []Command{{Name: "model", Args: "llama3:7b-instruct-q4_K_M"}}, + nil, + }, + { + "", nil, errMissingFrom, + }, + { + "PARAMETER param1 value1", + nil, + errMissingFrom, + }, + { + "PARAMETER param1 value1\nFROM foo", + []Command{{Name: "param1", Args: "value1"}, {Name: "model", Args: "foo"}}, + nil, + }, + } + + for _, c := range cases { + t.Run("", func(t *testing.T) { + modelfile, err := ParseFile(strings.NewReader(c.input)) + assert.ErrorIs(t, err, c.err) + if modelfile != nil { + assert.Equal(t, c.expected, modelfile.Commands) + } + }) + } +} + +func TestParseFileParametersMissingValue(t *testing.T) { + input := ` +FROM foo +PARAMETER param1 +` + + reader := strings.NewReader(input) + + _, err := ParseFile(reader) + assert.ErrorIs(t, err, io.ErrUnexpectedEOF) +} + +func TestParseFileBadCommand(t *testing.T) { + input := ` +FROM foo +BADCOMMAND param1 value1 +` + _, err := ParseFile(strings.NewReader(input)) + assert.ErrorIs(t, err, errInvalidCommand) + +} + +func TestParseFileMessages(t *testing.T) { + var cases = []struct { + input string + expected []Command + err error + }{ + { + ` +FROM foo +MESSAGE system You are a file parser. Always parse things. +`, + []Command{ + {Name: "model", Args: "foo"}, + {Name: "message", Args: "system: You are a file parser. Always parse things."}, + }, + nil, + }, + { + ` +FROM foo +MESSAGE system You are a file parser. Always parse things.`, + []Command{ + {Name: "model", Args: "foo"}, + {Name: "message", Args: "system: You are a file parser. Always parse things."}, + }, + nil, + }, + { + ` +FROM foo +MESSAGE system You are a file parser. Always parse things. +MESSAGE user Hey there! +MESSAGE assistant Hello, I want to parse all the things! +`, + []Command{ + {Name: "model", Args: "foo"}, + {Name: "message", Args: "system: You are a file parser. Always parse things."}, + {Name: "message", Args: "user: Hey there!"}, + {Name: "message", Args: "assistant: Hello, I want to parse all the things!"}, + }, + nil, + }, + { + ` +FROM foo +MESSAGE system """ +You are a multiline file parser. Always parse things. +""" + `, + []Command{ + {Name: "model", Args: "foo"}, + {Name: "message", Args: "system: \nYou are a multiline file parser. Always parse things.\n"}, + }, + nil, + }, + { + ` +FROM foo +MESSAGE badguy I'm a bad guy! +`, + nil, + errInvalidMessageRole, + }, + { + ` +FROM foo +MESSAGE system +`, + nil, + io.ErrUnexpectedEOF, + }, + { + ` +FROM foo +MESSAGE system`, + nil, + io.ErrUnexpectedEOF, + }, + } + + for _, c := range cases { + t.Run("", func(t *testing.T) { + modelfile, err := ParseFile(strings.NewReader(c.input)) + assert.ErrorIs(t, err, c.err) + if modelfile != nil { + assert.Equal(t, c.expected, modelfile.Commands) + } + }) + } +} + +func TestParseFileQuoted(t *testing.T) { + var cases = []struct { + multiline string + expected []Command + err error + }{ + { + ` +FROM foo +SYSTEM """ +This is a +multiline system. +""" + `, + []Command{ + {Name: "model", Args: "foo"}, + {Name: "system", Args: "\nThis is a\nmultiline system.\n"}, + }, + nil, + }, + { + ` +FROM foo +SYSTEM """ +This is a +multiline system.""" + `, + []Command{ + {Name: "model", Args: "foo"}, + {Name: "system", Args: "\nThis is a\nmultiline system."}, + }, + nil, + }, + { + ` +FROM foo +SYSTEM """This is a +multiline system.""" + `, + []Command{ + {Name: "model", Args: "foo"}, + {Name: "system", Args: "This is a\nmultiline system."}, + }, + nil, + }, + { + ` +FROM foo +SYSTEM """This is a multiline system.""" + `, + []Command{ + {Name: "model", Args: "foo"}, + {Name: "system", Args: "This is a multiline system."}, + }, + nil, + }, + { + ` +FROM foo +SYSTEM """This is a multiline system."" + `, + nil, + io.ErrUnexpectedEOF, + }, + { + ` +FROM foo +SYSTEM " + `, + nil, + io.ErrUnexpectedEOF, + }, + { + ` +FROM foo +SYSTEM """ +This is a multiline system with "quotes". +""" +`, + []Command{ + {Name: "model", Args: "foo"}, + {Name: "system", Args: "\nThis is a multiline system with \"quotes\".\n"}, + }, + nil, + }, + { + ` +FROM foo +SYSTEM """""" +`, + []Command{ + {Name: "model", Args: "foo"}, + {Name: "system", Args: ""}, + }, + nil, + }, + { + ` +FROM foo +SYSTEM "" +`, + []Command{ + {Name: "model", Args: "foo"}, + {Name: "system", Args: ""}, + }, + nil, + }, + { + ` +FROM foo +SYSTEM "'" +`, + []Command{ + {Name: "model", Args: "foo"}, + {Name: "system", Args: "'"}, + }, + nil, + }, + { + ` +FROM foo +SYSTEM """''"'""'""'"'''''""'""'""" +`, + []Command{ + {Name: "model", Args: "foo"}, + {Name: "system", Args: `''"'""'""'"'''''""'""'`}, + }, + nil, + }, + { + ` +FROM foo +TEMPLATE """ +{{ .Prompt }} +"""`, + []Command{ + {Name: "model", Args: "foo"}, + {Name: "template", Args: "\n{{ .Prompt }}\n"}, + }, + nil, + }, + } + + for _, c := range cases { + t.Run("", func(t *testing.T) { + modelfile, err := ParseFile(strings.NewReader(c.multiline)) + assert.ErrorIs(t, err, c.err) + if modelfile != nil { + assert.Equal(t, c.expected, modelfile.Commands) + } + }) + } +} + +func TestParseFileParameters(t *testing.T) { + var cases = map[string]struct { + name, value string + }{ + "numa true": {"numa", "true"}, + "num_ctx 1": {"num_ctx", "1"}, + "num_batch 1": {"num_batch", "1"}, + "num_gqa 1": {"num_gqa", "1"}, + "num_gpu 1": {"num_gpu", "1"}, + "main_gpu 1": {"main_gpu", "1"}, + "low_vram true": {"low_vram", "true"}, + "f16_kv true": {"f16_kv", "true"}, + "logits_all true": {"logits_all", "true"}, + "vocab_only true": {"vocab_only", "true"}, + "use_mmap true": {"use_mmap", "true"}, + "use_mlock true": {"use_mlock", "true"}, + "num_thread 1": {"num_thread", "1"}, + "num_keep 1": {"num_keep", "1"}, + "seed 1": {"seed", "1"}, + "num_predict 1": {"num_predict", "1"}, + "top_k 1": {"top_k", "1"}, + "top_p 1.0": {"top_p", "1.0"}, + "tfs_z 1.0": {"tfs_z", "1.0"}, + "typical_p 1.0": {"typical_p", "1.0"}, + "repeat_last_n 1": {"repeat_last_n", "1"}, + "temperature 1.0": {"temperature", "1.0"}, + "repeat_penalty 1.0": {"repeat_penalty", "1.0"}, + "presence_penalty 1.0": {"presence_penalty", "1.0"}, + "frequency_penalty 1.0": {"frequency_penalty", "1.0"}, + "mirostat 1": {"mirostat", "1"}, + "mirostat_tau 1.0": {"mirostat_tau", "1.0"}, + "mirostat_eta 1.0": {"mirostat_eta", "1.0"}, + "penalize_newline true": {"penalize_newline", "true"}, + "stop ### User:": {"stop", "### User:"}, + "stop ### User: ": {"stop", "### User: "}, + "stop \"### User:\"": {"stop", "### User:"}, + "stop \"### User: \"": {"stop", "### User: "}, + "stop \"\"\"### User:\"\"\"": {"stop", "### User:"}, + "stop \"\"\"### User:\n\"\"\"": {"stop", "### User:\n"}, + "stop <|endoftext|>": {"stop", "<|endoftext|>"}, + "stop <|eot_id|>": {"stop", "<|eot_id|>"}, + "stop ": {"stop", ""}, + } + + for k, v := range cases { + t.Run(k, func(t *testing.T) { + var b bytes.Buffer + fmt.Fprintln(&b, "FROM foo") + fmt.Fprintln(&b, "PARAMETER", k) + modelfile, err := ParseFile(&b) + assert.NoError(t, err) + + assert.Equal(t, []Command{ + {Name: "model", Args: "foo"}, + {Name: v.name, Args: v.value}, + }, modelfile.Commands) + }) + } +} + +func TestParseFileComments(t *testing.T) { + var cases = []struct { + input string + expected []Command + }{ + { + ` +# comment +FROM foo + `, + []Command{ + {Name: "model", Args: "foo"}, + }, + }, + } + + for _, c := range cases { + t.Run("", func(t *testing.T) { + modelfile, err := ParseFile(strings.NewReader(c.input)) + assert.NoError(t, err) + assert.Equal(t, c.expected, modelfile.Commands) + }) + } +} + +func TestParseFileFormatParseFile(t *testing.T) { + var cases = []string{ + ` +FROM foo +ADAPTER adapter1 +LICENSE MIT +PARAMETER param1 value1 +PARAMETER param2 value2 +TEMPLATE template1 +MESSAGE system You are a file parser. Always parse things. +MESSAGE user Hey there! +MESSAGE assistant Hello, I want to parse all the things! +`, + ` +FROM foo +ADAPTER adapter1 +LICENSE MIT +PARAMETER param1 value1 +PARAMETER param2 value2 +TEMPLATE template1 +MESSAGE system """ +You are a store greeter. Always responsed with "Hello!". +""" +MESSAGE user Hey there! +MESSAGE assistant Hello, I want to parse all the things! +`, + ` +FROM foo +ADAPTER adapter1 +LICENSE """ +Very long and boring legal text. +Blah blah blah. +"Oh look, a quote!" +""" + +PARAMETER param1 value1 +PARAMETER param2 value2 +TEMPLATE template1 +MESSAGE system """ +You are a store greeter. Always responsed with "Hello!". +""" +MESSAGE user Hey there! +MESSAGE assistant Hello, I want to parse all the things! +`, + ` +FROM foo +SYSTEM "" +`, + } + + for _, c := range cases { + t.Run("", func(t *testing.T) { + modelfile, err := ParseFile(strings.NewReader(c)) + assert.NoError(t, err) + + modelfile2, err := ParseFile(strings.NewReader(modelfile.String())) + assert.NoError(t, err) + + assert.Equal(t, modelfile, modelfile2) + }) + } + +} diff --git a/types/model/name.go b/types/model/name.go index 9c56c49a..b79374c3 100644 --- a/types/model/name.go +++ b/types/model/name.go @@ -1,693 +1,425 @@ +// Package model contains types and utilities for parsing, validating, and +// working with model names and digests. package model import ( "cmp" + "encoding/hex" "errors" "fmt" - "hash/maphash" - "io" "log/slog" "path/filepath" - "slices" "strings" - "sync" - - "github.com/ollama/ollama/types/structs" ) // Errors var ( - // ErrInvalidName, ErrIncompleteName, and ErrInvalidDigest are not - // used by this package, but are exported so that other packages can - // use them, instead of defining their own errors for them. - ErrInvalidName = errors.New("invalid model name") - ErrIncompleteName = errors.New("incomplete model name") - ErrInvalidDigest = errors.New("invalid digest") + // ErrUnqualifiedName represents an error where a name is not fully + // qualified. It is not used directly in this package, but is here + // to avoid other packages inventing their own error type. + // Additionally, it can be conveniently used via [Unqualified]. + ErrUnqualifiedName = errors.New("unqualified name") ) -// Defaults -const ( - // MaskDefault is the default mask used by [Name.DisplayShortest]. - MaskDefault = "registry.ollama.ai/library/?:latest" - - // MaskNothing is a mask that masks nothing. - MaskNothing = "?/?/?:?" - - // DefaultFill is the default fill used by [ParseName]. - FillDefault = "registry.ollama.ai/library/?:latest+Q4_0" - - // FillNothing is a fill that fills nothing. - FillNothing = "?/?/?:?+?" -) - -const MaxNamePartLen = 128 - -type PartKind int - -// Levels of concreteness -const ( - // Each value aligns with its index in the Name.parts array. - - PartHost PartKind = iota - PartNamespace - PartModel - PartTag - PartBuild - PartDigest - - // NumParts is the number of parts in a Name. In this list, it must - // follow the final part. - NumParts - - PartExtraneous = -1 -) - -var kindNames = map[PartKind]string{ - PartHost: "Host", - PartNamespace: "Namespace", - PartModel: "Name", - PartTag: "Tag", - PartBuild: "Build", - PartDigest: "Digest", +// Unqualified is a helper function that returns an error with +// ErrUnqualifiedName as the cause and the name as the message. +func Unqualified(n Name) error { + return fmt.Errorf("%w: %s", ErrUnqualifiedName, n) } -func (k PartKind) String() string { - return cmp.Or(kindNames[k], "Unknown") +// MissingPart is used to indicate any part of a name that was "promised" by +// the presence of a separator, but is missing. +// +// The value was chosen because it is deemed unlikely to be set by a user, +// not a valid part name valid when checked by [Name.IsValid], and easy to +// spot in logs. +const MissingPart = "!MISSING!" + +const ( + defaultHost = "registry.ollama.ai" + defaultNamespace = "library" + defaultTag = "latest" +) + +// DefaultName returns a name with the default values for the host, namespace, +// and tag parts. The model and digest parts are empty. +// +// - The default host is ("registry.ollama.ai") +// - The default namespace is ("library") +// - The default tag is ("latest") +func DefaultName() Name { + return Name{ + Host: defaultHost, + Namespace: defaultNamespace, + Tag: defaultTag, + } } -// Name is an opaque reference to a model. It holds the parts of a model -// with the case preserved, but is not directly comparable with other Names -// since model names can be represented with different casing depending on -// the use case. For instance, "Mistral" and "mistral" are the same model -// but each version may have come from different sources (e.g. copied from a -// Web page, or from a file path). +type partKind int + +const ( + kindHost partKind = iota + kindNamespace + kindModel + kindTag + kindDigest +) + +func (k partKind) String() string { + switch k { + case kindHost: + return "host" + case kindNamespace: + return "namespace" + case kindModel: + return "model" + case kindTag: + return "tag" + case kindDigest: + return "digest" + default: + return "unknown" + } +} + +// Name is a structured representation of a model name string, as defined by +// [ParseNameNoDefaults]. // -// Valid Names can ONLY be constructed by calling [ParseName]. -// -// A Name is valid if and only if is have a valid Model part. The other parts -// are optional. -// -// A Name is considered "complete" if it has all parts present. To check if a -// Name is complete, use [Name.IsComplete]. -// -// To compare two names in a case-insensitive manner, use [Name.EqualFold]. -// -// The parts of a Name are: -// -// - Host: the domain of the model (optional) -// - Namespace: the namespace of the model (optional) -// - Model: the name of the model (required) -// - Tag: the tag of the model (optional) -// - Build: the build of the model; usually the quantization or "file type" (optional) -// -// The parts can be obtained in their original form by calling [Name.Parts]. -// -// To check if a Name has at minimum a valid model part, use [Name.IsValid]. +// It is not guaranteed to be valid. Use [Name.IsValid] to check if the name +// is valid. type Name struct { - _ structs.Incomparable - parts [NumParts]string // host, namespace, model, tag, build, digest - - // TODO(bmizerany): track offsets and hold s (raw string) here? We - // could pack the offsets all into a single uint64 since the first - // parts take less bits since their max offset is less than the max - // offset of the next part. This would save a ton of bytes per Name - // and mean zero allocations for String. + Host string + Namespace string + Model string + Tag string + RawDigest string } -// ParseName parses s into a Name, and returns the result of filling it with -// defaults. The input string must be a valid string -// representation of a model name in the form: +// ParseName parses and assembles a Name from a name string. The +// format of a valid name string is: // -// [host/][namespace/][:tag][+build][@-] +// s: +// { host } "/" { namespace } "/" { model } ":" { tag } "@" { digest } +// { host } "/" { namespace } "/" { model } ":" { tag } +// { host } "/" { namespace } "/" { model } "@" { digest } +// { host } "/" { namespace } "/" { model } +// { namespace } "/" { model } ":" { tag } "@" { digest } +// { namespace } "/" { model } ":" { tag } +// { namespace } "/" { model } "@" { digest } +// { namespace } "/" { model } +// { model } ":" { tag } "@" { digest } +// { model } ":" { tag } +// { model } "@" { digest } +// { model } +// "@" { digest } +// host: +// pattern: { alphanum | "_" } { alphanum | "-" | "_" | "." | ":" }* +// length: [1, 350] +// namespace: +// pattern: { alphanum | "_" } { alphanum | "-" | "_" }* +// length: [1, 80] +// model: +// pattern: { alphanum | "_" } { alphanum | "-" | "_" | "." }* +// length: [1, 80] +// tag: +// pattern: { alphanum | "_" } { alphanum | "-" | "_" | "." }* +// length: [1, 80] +// digest: +// pattern: { alphanum | "_" } { alphanum | "-" | ":" }* +// length: [1, 80] // -// The name part is required, all others are optional. If a part is missing, -// it is left empty in the returned Name. If a part is invalid, the zero Ref -// value is returned. +// Most users should use [ParseName] instead, unless need to support +// different defaults than DefaultName. // -// The build part is normalized to uppercase. -// -// Examples of valid paths: -// -// "example.com/library/mistral:7b+x" -// "example.com/eva/mistral:7b+Q4_0" -// "mistral:7b+x" -// "example.com/mike/mistral:latest+Q4_0" -// "example.com/bruce/mistral:latest" -// "example.com/pdevine/thisisfine:7b+Q4_0@sha256-1234567890abcdef" -// -// Examples of invalid paths: -// -// "example.com/mistral:7b+" -// "example.com/mistral:7b+Q4_0+" -// "x/y/z/z:8n+I" -// "" -// -// It returns the zero value if any part is invalid. -// -// # Fills -// -// For any valid s, the fill string is used to fill in missing parts of the -// Name. The fill string must be a valid Name with the exception that any part -// may be the string ("?"), which will not be considered for filling. -func ParseName(s, fill string) Name { - var r Name - parts(s)(func(kind PartKind, part string) bool { - if kind == PartDigest && !ParseDigest(part).IsValid() { - r = Name{} - return false - } - if kind == PartExtraneous || !isValidPart(kind, part) { - r = Name{} - return false - } - r.parts[kind] = part - return true - }) - if r.IsValid() || r.IsResolved() { - return fillName(r, fill) +// The name returned is not guaranteed to be valid. If it is not valid, the +// field values are left in an undefined state. Use [Name.IsValid] to check +// if the name is valid. +func ParseName(s string) Name { + return Merge(ParseNameBare(s), DefaultName()) +} + +// ParseNameBare parses s as a name string and returns a Name. No merge with +// [DefaultName] is performed. +func ParseNameBare(s string) Name { + var n Name + var promised bool + + s, n.RawDigest, promised = cutLast(s, "@") + if promised && n.RawDigest == "" { + n.RawDigest = MissingPart } - return Name{} -} -func parseMask(s string) Name { - var r Name - parts(s)(func(kind PartKind, part string) bool { - if part == "?" { - // mask part; treat as empty but valid - return true - } - if !isValidPart(kind, part) { - panic(fmt.Errorf("invalid mask part %s: %q", kind, part)) - } - r.parts[kind] = part - return true - }) - return r -} - -func MustParseName(s, fill string) Name { - r := ParseName(s, fill) - if !r.IsValid() { - panic("invalid Name: " + s) + // "/" is an illegal tag character, so we can use it to split the host + if strings.LastIndex(s, ":") > strings.LastIndex(s, "/") { + s, n.Tag, _ = cutPromised(s, ":") } - return r -} -// fillName fills in the missing parts of dst with the parts of src. -// -// The returned Name will only be valid if dst is valid. -// -// It skipps fill parts that are "?". -func fillName(r Name, fill string) Name { - fill = cmp.Or(fill, FillDefault) - f := parseMask(fill) - if fill != FillNothing && f.IsZero() { - panic("invalid fill") + s, n.Model, promised = cutPromised(s, "/") + if !promised { + n.Model = s + return n } - for i := range r.parts { - if f.parts[i] == "?" { - continue - } - r.parts[i] = cmp.Or(r.parts[i], f.parts[i]) + + s, n.Namespace, promised = cutPromised(s, "/") + if !promised { + n.Namespace = s + return n } - return r -} -// WithBuild returns a copy of r with the build set to the given string. -func (r Name) WithBuild(build string) Name { - r.parts[PartBuild] = build - return r -} - -func (r Name) WithDigest(digest Digest) Name { - r.parts[PartDigest] = digest.String() - return r -} - -var mapHashSeed = maphash.MakeSeed() - -// MapHash returns a case insensitive hash for use in maps and equality -// checks. For a convenient way to compare names, use [Name.EqualFold]. -// -//nolint:errcheck -func (r Name) MapHash() uint64 { - // correctly hash the parts with case insensitive comparison - var h maphash.Hash - h.SetSeed(mapHashSeed) - for _, part := range r.parts { - // downcase the part for hashing - for i := range part { - c := part[i] - if c >= 'A' && c <= 'Z' { - c = c - 'A' + 'a' - } - h.WriteByte(c) - } + scheme, host, ok := strings.Cut(s, "://") + if !ok { + host = scheme } - return h.Sum64() + n.Host = host + + return n } -func (r Name) slice(from, to PartKind) Name { - var v Name - copy(v.parts[from:to+1], r.parts[from:to+1]) - return v -} - -// DisplayShortest returns the shortest possible, masked display string in form: +// ParseNameFromFilepath parses a 4-part filepath as a Name. The parts are +// expected to be in the form: // -// [host/][/][:] -// -// # Masks -// -// The mask is a string that specifies which parts of the name to omit based -// on case-insensitive comparison. [Name.DisplayShortest] omits parts of the name -// that are the same as the mask, moving from left to right until the first -// unequal part is found. It then moves right to left until the first unequal -// part is found. The result is the shortest possible display string. -// -// Unlike a [Name] the mask can contain "?" characters which are treated as -// wildcards. A "?" will never match a part of the name, since a valid name -// can never contain a "?" character. -// -// For example: Given a Name ("registry.ollama.ai/library/mistral:latest") masked -// with ("registry.ollama.ai/library/?:latest") will produce the display string -// ("mistral"). -// -// If mask is the empty string, then [MaskDefault] is used. -// -// DisplayShortest panics if the mask is not the empty string, MaskNothing, and -// invalid. -// -// # Builds -// -// For now, DisplayShortest does consider the build or return one in the -// result. We can lift this restriction when needed. -func (r Name) DisplayShortest(mask string) string { - mask = cmp.Or(mask, MaskDefault) - d := parseMask(mask) - if mask != MaskNothing && r.IsZero() { - panic("invalid Name") +// { host } "/" { namespace } "/" { model } "/" { tag } +func ParseNameFromFilepath(s string) (n Name) { + parts := strings.Split(s, string(filepath.Separator)) + if len(parts) != 4 { + return Name{} } - for i := range PartTag { - if !strings.EqualFold(r.parts[i], d.parts[i]) { - break - } - r.parts[i] = "" + + n.Host = parts[0] + n.Namespace = parts[1] + n.Model = parts[2] + n.Tag = parts[3] + if !n.IsFullyQualified() { + return Name{} } - for i := PartTag; i >= 0; i-- { - if !strings.EqualFold(r.parts[i], d.parts[i]) { - break - } - r.parts[i] = "" + + return n +} + +// Merge merges the host, namespace, and tag parts of the two names, +// preferring the non-empty parts of a. +func Merge(a, b Name) Name { + a.Host = cmp.Or(a.Host, b.Host) + a.Namespace = cmp.Or(a.Namespace, b.Namespace) + a.Tag = cmp.Or(a.Tag, b.Tag) + return a +} + +// String returns the name string, in the format that [ParseNameNoDefaults] +// accepts as valid, if [Name.IsValid] reports true; otherwise the empty +// string is returned. +func (n Name) String() string { + var b strings.Builder + if n.Host != "" { + b.WriteString(n.Host) + b.WriteByte('/') } - return r.slice(PartHost, PartTag).DisplayLong() -} - -// DisplayLongest returns the result of r.DisplayShortest(MaskNothing). -func (r Name) DisplayLongest() string { - return r.DisplayShortest(MaskNothing) -} - -var seps = [...]string{ - PartHost: "/", - PartNamespace: "/", - PartModel: ":", - PartTag: "+", - PartBuild: "@", - PartDigest: "", -} - -// WriteTo implements io.WriterTo. It writes the fullest possible display -// string in form: -// -// //:+@- -// -// Missing parts and their separators are not written. -// -// The full digest is always prefixed with "@". That is if [Name.IsValid] -// reports false and [Name.IsResolved] reports true, then the string is -// returned as "@-". -func (r Name) writeTo(w io.StringWriter) error { - var partsWritten int - for i := range r.parts { - if r.parts[i] == "" { - continue - } - if partsWritten > 0 || i == int(PartDigest) { - if _, err := w.WriteString(seps[i-1]); err != nil { - return err - } - } - if _, err := w.WriteString(r.parts[i]); err != nil { - return err - } - partsWritten++ + if n.Namespace != "" { + b.WriteString(n.Namespace) + b.WriteByte('/') + } + b.WriteString(n.Model) + if n.Tag != "" { + b.WriteByte(':') + b.WriteString(n.Tag) + } + if n.RawDigest != "" { + b.WriteByte('@') + b.WriteString(n.RawDigest) } - return nil -} - -var builderPool = sync.Pool{ - New: func() interface{} { - return &strings.Builder{} - }, -} - -// DisplayLong returns the fullest possible display string in form: -// -// //:+ -// -// If any part is missing, it is omitted from the display string. -func (r Name) DisplayLong() string { - b := builderPool.Get().(*strings.Builder) - defer builderPool.Put(b) - b.Reset() - b.Grow(50) // arbitrarily long enough for most names - _ = r.writeTo(b) return b.String() } -// GoString implements fmt.GoStringer. It returns a string suitable for -// debugging and logging. It is similar to [Name.DisplayLong] but it always -// returns a string that includes all parts of the Name, with missing parts -// replaced with a ("?"). -func (r Name) GoString() string { - for i := range r.parts { - r.parts[i] = cmp.Or(r.parts[i], "?") +// DisplayShort returns a short string version of the name. +func (n Name) DisplayShortest() string { + var sb strings.Builder + + if n.Host != defaultHost { + sb.WriteString(n.Host) + sb.WriteByte('/') + sb.WriteString(n.Namespace) + sb.WriteByte('/') + } else if n.Namespace != defaultNamespace { + sb.WriteString(n.Namespace) + sb.WriteByte('/') } - return r.DisplayLong() + + // always include model and tag + sb.WriteString(n.Model) + sb.WriteString(":") + sb.WriteString(n.Tag) + return sb.String() } -// LogValue implements slog.Valuer. -func (r Name) LogValue() slog.Value { - return slog.StringValue(r.GoString()) -} - -// IsComplete reports whether the Name is fully qualified. That is it has a -// domain, namespace, name, tag, and build. -func (r Name) IsComplete() bool { - return !slices.Contains(r.parts[:PartDigest], "") -} - -// IsCompleteNoBuild is like [Name.IsComplete] but it does not require the -// build part to be present. -func (r Name) IsCompleteNoBuild() bool { - return !slices.Contains(r.parts[:PartBuild], "") -} - -// IsResolved reports true if the Name has a valid digest. -// -// It is possible to have a valid Name, or a complete Name that is not -// resolved. -func (r Name) IsResolved() bool { - return r.Digest().IsValid() -} - -// Digest returns the digest part of the Name, if any. -// -// If Digest returns a non-empty string, then [Name.IsResolved] will return -// true, and digest is considered valid. -func (r Name) Digest() Digest { - // This was already validated by ParseName, so we can just return it. - return Digest{r.parts[PartDigest]} -} - -// EqualFold reports whether r and o are equivalent model names, ignoring -// case. -func (r Name) EqualFold(o Name) bool { - return r.CompareFold(o) == 0 -} - -// CompareFold performs a case-insensitive cmp.Compare on r and o. -// -// This can be used with [slices.SortFunc]. -// -// For simple equality checks, use [Name.EqualFold]. -func (r Name) CompareFold(o Name) int { - return slices.CompareFunc(r.parts[:], o.parts[:], compareFold) -} - -func compareFold(a, b string) int { - return slices.CompareFunc([]rune(a), []rune(b), func(a, b rune) int { - return cmp.Compare(downcase(a), downcase(b)) - }) -} - -func downcase(r rune) rune { - if r >= 'A' && r <= 'Z' { - return r - 'A' + 'a' - } - return r -} - -func (r Name) Host() string { return r.parts[PartHost] } -func (r Name) Namespace() string { return r.parts[PartNamespace] } -func (r Name) Model() string { return r.parts[PartModel] } -func (r Name) Build() string { return r.parts[PartBuild] } -func (r Name) Tag() string { return r.parts[PartTag] } - -// iter_Seq2 is a iter.Seq2 defined here to avoid the current build -// restrictions in the go1.22 iter package requiring the -// goexperiment.rangefunc tag to be set via the GOEXPERIMENT=rangefunc flag, -// which we are not yet ready to support. -// -// Once we are ready to support rangefunc, this can be removed and replaced -// with the iter.Seq2 type. -type iter_Seq2[A, B any] func(func(A, B) bool) - -// Parts returns a sequence of the parts of a Name string from most specific -// to least specific. -// -// It normalizes the input string by removing "http://" and "https://" only. -// No other normalizations are performed. -func parts(s string) iter_Seq2[PartKind, string] { - return func(yield func(PartKind, string) bool) { - if strings.HasPrefix(s, "http://") { - s = strings.TrimPrefix(s, "http://") - } else { - s = strings.TrimPrefix(s, "https://") - } - - if len(s) > MaxNamePartLen || len(s) == 0 { - return - } - - numConsecutiveDots := 0 - partLen := 0 - state, j := PartDigest, len(s) - for i := len(s) - 1; i >= 0; i-- { - if partLen++; partLen > MaxNamePartLen { - // catch a part that is too long early, so - // we don't keep spinning on it, waiting for - // an isInValidPart check which would scan - // over it again. - yield(state, s[i+1:j]) - return - } - - switch s[i] { - case '@': - switch state { - case PartDigest: - if !yield(PartDigest, s[i+1:j]) { - return - } - if i == 0 { - // This is the form - // "@" which is valid. - // - // We're done. - return - } - state, j, partLen = PartBuild, i, 0 - default: - yield(PartExtraneous, s[i+1:j]) - return - } - case '+': - switch state { - case PartBuild, PartDigest: - if !yield(PartBuild, s[i+1:j]) { - return - } - state, j, partLen = PartTag, i, 0 - default: - yield(PartExtraneous, s[i+1:j]) - return - } - case ':': - switch state { - case PartTag, PartBuild, PartDigest: - if !yield(PartTag, s[i+1:j]) { - return - } - state, j, partLen = PartModel, i, 0 - case PartHost: - // noop: support for host:port - default: - yield(PartExtraneous, s[i+1:j]) - return - } - case '/': - switch state { - case PartModel, PartTag, PartBuild, PartDigest: - if !yield(PartModel, s[i+1:j]) { - return - } - state, j = PartNamespace, i - case PartNamespace: - if !yield(PartNamespace, s[i+1:j]) { - return - } - state, j, partLen = PartHost, i, 0 - default: - yield(PartExtraneous, s[i+1:j]) - return - } - default: - if s[i] == '.' { - if numConsecutiveDots++; numConsecutiveDots > 1 { - yield(state, "") - return - } - } else { - numConsecutiveDots = 0 - } - } - } - - if state <= PartNamespace { - yield(state, s[:j]) - } else { - yield(PartModel, s[:j]) - } - } -} - -func (r Name) IsZero() bool { - return r.parts == [NumParts]string{} -} - -// IsValid reports if a model has at minimum a valid model part. -func (r Name) IsValid() bool { - // Parts ensures we only have valid parts, so no need to validate - // them here, only check if we have a name or not. - return r.parts[PartModel] != "" -} - -// ParseNameFromURLPath parses forms of a URL path into a Name. Specifically, -// it trims any leading "/" and then calls [ParseName] with fill. -func ParseNameFromURLPath(s, fill string) Name { - s = strings.TrimPrefix(s, "/") - return ParseName(s, fill) -} - -// URLPath returns a complete, canonicalized, relative URL path using the parts of a -// complete Name. -// -// The parts maintain their original case. -// -// Example: -// -// ParseName("example.com/namespace/model:tag+build").URLPath() // returns "/example.com/namespace/model:tag" -func (r Name) URLPath() string { - return r.DisplayShortest(MaskNothing) -} - -// ParseNameFromFilepath parses a file path into a Name. The input string must be a -// valid file path representation of a model name in the form: -// -// host/namespace/model/tag/build -// -// The zero valid is returned if s does not contain all path elements -// leading up to the model part, or if any path element is an invalid part -// for the its corresponding part kind. -// -// The fill string is used to fill in missing parts of any constructed Name. -// See [ParseName] for more information on the fill string. -func ParseNameFromFilepath(s, fill string) Name { - var r Name - for i := range PartBuild + 1 { - part, rest, _ := strings.Cut(s, string(filepath.Separator)) - if !isValidPart(i, part) { - return Name{} - } - r.parts[i] = part - s = rest - if s == "" { - break - } - } - if s != "" { - return Name{} - } - if !r.IsValid() { - return Name{} - } - return fillName(r, fill) -} - -// Filepath returns a complete, canonicalized, relative file path using the -// parts of a complete Name. -// -// Each parts is downcased, except for the build part which is upcased. -// -// Example: -// -// ParseName("example.com/namespace/model:tag+build").Filepath() // returns "example.com/namespace/model/tag/BUILD" -func (r Name) Filepath() string { - for i := range r.parts { - if PartKind(i) == PartBuild { - r.parts[i] = strings.ToUpper(r.parts[i]) - } else { - r.parts[i] = strings.ToLower(r.parts[i]) - } - } - return filepath.Join(r.parts[:]...) -} - -// FilepathNoBuild returns a complete, canonicalized, relative file path using -// the parts of a complete Name, but without the build part. -func (r Name) FilepathNoBuild() string { - for i := range PartBuild { - r.parts[i] = strings.ToLower(r.parts[i]) - } - return filepath.Join(r.parts[:PartBuild]...) -} - -// isValidPart reports if s contains all valid characters for the given -// part kind. -func isValidPart(kind PartKind, s string) bool { - if s == "" { +// IsValid reports whether all parts of the name are present and valid. The +// digest is a special case, and is checked for validity only if present. +func (n Name) IsValid() bool { + if n.RawDigest != "" && !isValidPart(kindDigest, n.RawDigest) { return false } - var consecutiveDots int - for _, c := range []byte(s) { - if c == '.' { - if consecutiveDots++; consecutiveDots >= 2 { - return false - } - } else { - consecutiveDots = 0 - } - if !isValidByteFor(kind, c) { + return n.IsFullyQualified() +} + +// IsFullyQualified returns true if all parts of the name are present and +// valid without the digest. +func (n Name) IsFullyQualified() bool { + var parts = []string{ + n.Host, + n.Namespace, + n.Model, + n.Tag, + } + for i, part := range parts { + if !isValidPart(partKind(i), part) { return false } } return true } -func isValidByteFor(kind PartKind, c byte) bool { - if kind == PartNamespace && c == '.' { +// Filepath returns a canonical filepath that represents the name with each part from +// host to tag as a directory in the form: +// +// {host}/{namespace}/{model}/{tag} +// +// It uses the system's filepath separator and ensures the path is clean. +// +// It panics if the name is not fully qualified. Use [Name.IsFullyQualified] +// to check if the name is fully qualified. +func (n Name) Filepath() string { + if !n.IsFullyQualified() { + panic("illegal attempt to get filepath of invalid name") + } + return filepath.Join( + strings.ToLower(filepath.Join( + n.Host, + n.Namespace, + n.Model, + )), + n.Tag, + ) +} + +// LogValue returns a slog.Value that represents the name as a string. +func (n Name) LogValue() slog.Value { + return slog.StringValue(n.String()) +} + +func isValidLen(kind partKind, s string) bool { + switch kind { + case kindHost: + return len(s) >= 1 && len(s) <= 350 + case kindTag: + return len(s) >= 1 && len(s) <= 80 + default: + return len(s) >= 1 && len(s) <= 80 + } +} + +func isValidPart(kind partKind, s string) bool { + if !isValidLen(kind, s) { return false } - if kind == PartHost && c == ':' { - return true + for i := range s { + if i == 0 { + if !isAlphanumericOrUnderscore(s[i]) { + return false + } + continue + } + switch s[i] { + case '_', '-': + case '.': + if kind == kindNamespace { + return false + } + case ':': + if kind != kindHost && kind != kindDigest { + return false + } + default: + if !isAlphanumericOrUnderscore(s[i]) { + return false + } + } } - if c == '.' || c == '-' { - return true - } - if c >= 'a' && c <= 'z' || c >= 'A' && c <= 'Z' || c >= '0' && c <= '9' || c == '_' { - return true - } - return false + return true +} + +func isAlphanumericOrUnderscore(c byte) bool { + return c >= 'A' && c <= 'Z' || c >= 'a' && c <= 'z' || c >= '0' && c <= '9' || c == '_' +} + +func cutLast(s, sep string) (before, after string, ok bool) { + i := strings.LastIndex(s, sep) + if i >= 0 { + return s[:i], s[i+len(sep):], true + } + return s, "", false +} + +// cutPromised cuts the last part of s at the last occurrence of sep. If sep is +// found, the part before and after sep are returned as-is unless empty, in +// which case they are returned as MissingPart, which will cause +// [Name.IsValid] to return false. +func cutPromised(s, sep string) (before, after string, ok bool) { + before, after, ok = cutLast(s, sep) + if !ok { + return before, after, false + } + return cmp.Or(before, MissingPart), cmp.Or(after, MissingPart), true +} + +type DigestType byte + +const ( + DigestTypeInvalid DigestType = iota + DigestTypeSHA256 +) + +func (t DigestType) String() string { + switch t { + case DigestTypeSHA256: + return "sha256" + default: + return "invalid" + } +} + +type Digest struct { + Type DigestType + Sum [32]byte +} + +func ParseDigest(s string) (Digest, error) { + i := strings.IndexAny(s, "-:") + if i < 0 { + return Digest{}, fmt.Errorf("invalid digest %q", s) + } + typ, encSum := s[:i], s[i+1:] + if typ != "sha256" { + return Digest{}, fmt.Errorf("unsupported digest type %q", typ) + } + d := Digest{ + Type: DigestTypeSHA256, + } + n, err := hex.Decode(d.Sum[:], []byte(encSum)) + if err != nil { + return Digest{}, err + } + if n != 32 { + return Digest{}, fmt.Errorf("digest %q decoded to %d bytes; want 32", encSum, n) + } + return d, nil +} + +func (d Digest) String() string { + if d.Type == DigestTypeInvalid { + return "" + } + return fmt.Sprintf("sha256-%x", d.Sum) +} + +func (d Digest) IsValid() bool { + return d.Type != DigestTypeInvalid } diff --git a/types/model/name_test.go b/types/model/name_test.go index 8749477a..fb584291 100644 --- a/types/model/name_test.go +++ b/types/model/name_test.go @@ -1,709 +1,387 @@ package model import ( - "bytes" - "cmp" - "fmt" - "log/slog" "path/filepath" - "slices" - "strings" + "reflect" + "runtime" "testing" ) -type fields struct { - host, namespace, model, tag, build string - digest string -} +const ( + part80 = "88888888888888888888888888888888888888888888888888888888888888888888888888888888" + part350 = "33333333333333333333333333333333333333333333333333333333333333333333333333333333333333333333333333333333333333333333333333333333333333333333333333333333333333333333333333333333333333333333333333333333333333333333333333333333333333333333333333333333333333333333333333333333333333333333333333333333333333333333333333333333333333333333333333333333333333" +) -func fieldsFromName(p Name) fields { - return fields{ - host: p.parts[PartHost], - namespace: p.parts[PartNamespace], - model: p.parts[PartModel], - tag: p.parts[PartTag], - build: p.parts[PartBuild], - digest: p.parts[PartDigest], - } -} - -var testNames = map[string]fields{ - "mistral:latest": {model: "mistral", tag: "latest"}, - "mistral": {model: "mistral"}, - "mistral:30B": {model: "mistral", tag: "30B"}, - "mistral:7b": {model: "mistral", tag: "7b"}, - "mistral:7b+Q4_0": {model: "mistral", tag: "7b", build: "Q4_0"}, - "mistral+KQED": {model: "mistral", build: "KQED"}, - "mistral.x-3:7b+Q4_0": {model: "mistral.x-3", tag: "7b", build: "Q4_0"}, - "mistral:7b+q4_0": {model: "mistral", tag: "7b", build: "q4_0"}, - "llama2": {model: "llama2"}, - "user/model": {namespace: "user", model: "model"}, - "example.com/ns/mistral:7b+Q4_0": {host: "example.com", namespace: "ns", model: "mistral", tag: "7b", build: "Q4_0"}, - "example.com/ns/mistral:7b+X": {host: "example.com", namespace: "ns", model: "mistral", tag: "7b", build: "X"}, - "localhost:5000/ns/mistral": {host: "localhost:5000", namespace: "ns", model: "mistral"}, - - // invalid digest - "mistral:latest@invalid256-": {}, - "mistral:latest@-123": {}, - "mistral:latest@!-123": {}, - "mistral:latest@1-!": {}, - "mistral:latest@": {}, - - // resolved - "x@sha123-1": {model: "x", digest: "sha123-1"}, - "@sha456-2": {digest: "sha456-2"}, - - "@@sha123-1": {}, - - // preserves case for build - "x+b": {model: "x", build: "b"}, - - // invalid (includes fuzzing trophies) - " / / : + ": {}, - " / : + ": {}, - " : + ": {}, - " + ": {}, - " : ": {}, - " / ": {}, - " /": {}, - "/ ": {}, - "/": {}, - ":": {}, - "+": {}, - - // (".") in namepsace is not allowed - "invalid.com/7b+x": {}, - - "invalid:7b+Q4_0:latest": {}, - "in valid": {}, - "invalid/y/z/foo": {}, - "/0": {}, - "0 /0": {}, - "0 /": {}, - "0/": {}, - ":/0": {}, - "+0/00000": {}, - "0+.\xf2\x80\xf6\x9d00000\xe5\x99\xe6\xd900\xd90\xa60\x91\xdc0\xff\xbf\x99\xe800\xb9\xdc\xd6\xc300\x970\xfb\xfd0\xe0\x8a\xe1\xad\xd40\x9700\xa80\x980\xdd0000\xb00\x91000\xfe0\x89\x9b\x90\x93\x9f0\xe60\xf7\x84\xb0\x87\xa5\xff0\xa000\x9a\x85\xf6\x85\xfe\xa9\xf9\xe9\xde00\xf4\xe0\x8f\x81\xad\xde00\xd700\xaa\xe000000\xb1\xee0\x91": {}, - "0//0": {}, - "m+^^^": {}, - "file:///etc/passwd": {}, - "file:///etc/passwd:latest": {}, - "file:///etc/passwd:latest+u": {}, - - ":x": {}, - "+x": {}, - "x+": {}, - - // Disallow ("\.+") in any part to prevent path traversal anywhere - // we convert the name to a path. - "../etc/passwd": {}, - ".../etc/passwd": {}, - "./../passwd": {}, - "./0+..": {}, - - strings.Repeat("a", MaxNamePartLen): {model: strings.Repeat("a", MaxNamePartLen)}, - strings.Repeat("a", MaxNamePartLen+1): {}, -} - -// TestConsecutiveDots tests that consecutive dots are not allowed in any -// part, to avoid path traversal. There also are some tests in testNames, but -// this test is more exhaustive and exists to emphasize the importance of -// preventing path traversal. -func TestNameConsecutiveDots(t *testing.T) { - for i := 1; i < 10; i++ { - s := strings.Repeat(".", i) - if i > 1 { - if g := ParseName(s, FillNothing).DisplayLong(); g != "" { - t.Errorf("ParseName(%q) = %q; want empty string", s, g) - } - } else { - if g := ParseName(s, FillNothing).DisplayLong(); g != s { - t.Errorf("ParseName(%q) = %q; want %q", s, g, s) - } - } - } -} - -func TestNameParts(t *testing.T) { - var p Name - if w, g := int(NumParts), len(p.parts); w != g { - t.Errorf("Parts() = %d; want %d", g, w) - } -} - -func TestNamePartString(t *testing.T) { - if g := PartKind(-2).String(); g != "Unknown" { - t.Errorf("Unknown part = %q; want %q", g, "Unknown") - } - for kind, name := range kindNames { - if g := kind.String(); g != name { - t.Errorf("%s = %q; want %q", kind, g, name) - } - } -} - -func TestParseName(t *testing.T) { - for baseName, want := range testNames { - for _, prefix := range []string{"", "https://", "http://"} { - // We should get the same results with or without the - // http(s) prefixes - s := prefix + baseName - - t.Run(s, func(t *testing.T) { - name := ParseName(s, FillNothing) - got := fieldsFromName(name) - if got != want { - t.Errorf("ParseName(%q) = %q; want %q", s, got, want) - } - - // test round-trip - if !ParseName(name.DisplayLong(), FillNothing).EqualFold(name) { - t.Errorf("ParseName(%q).String() = %s; want %s", s, name.DisplayLong(), baseName) - } - }) - } - } -} - -func TestParseNameFill(t *testing.T) { - cases := []struct { - in string - fill string - want string - }{ - {"mistral", "example.com/library/?:latest+Q4_0", "example.com/library/mistral:latest+Q4_0"}, - {"mistral", "example.com/library/?:latest", "example.com/library/mistral:latest"}, - {"llama2:x", "example.com/library/?:latest+Q4_0", "example.com/library/llama2:x+Q4_0"}, - - // Invalid - {"", "example.com/library/?:latest+Q4_0", ""}, - {"llama2:?", "example.com/library/?:latest+Q4_0", ""}, - } - - for _, tt := range cases { - t.Run(tt.in, func(t *testing.T) { - name := ParseName(tt.in, tt.fill) - if g := name.DisplayLong(); g != tt.want { - t.Errorf("ParseName(%q, %q) = %q; want %q", tt.in, tt.fill, g, tt.want) - } - }) - } - - t.Run("invalid fill", func(t *testing.T) { - defer func() { - if recover() == nil { - t.Fatal("expected panic") - } - }() - ParseName("x", "^") - }) -} - -func TestParseNameHTTPDoublePrefixStrip(t *testing.T) { - cases := []string{ - "http://https://valid.com/valid/valid:latest", - "https://http://valid.com/valid/valid:latest", - } - for _, s := range cases { - t.Run(s, func(t *testing.T) { - name := ParseName(s, FillNothing) - if name.IsValid() { - t.Errorf("expected invalid path; got %#v", name) - } - }) - } - -} - -func TestCompleteWithAndWithoutBuild(t *testing.T) { +func TestParseNameParts(t *testing.T) { cases := []struct { in string - complete bool - completeNoBuild bool + want Name + wantFilepath string + wantValidDigest bool }{ - {"", false, false}, - {"incomplete/mistral:7b+x", false, false}, - {"incomplete/mistral:7b+Q4_0", false, false}, - {"incomplete:7b+x", false, false}, - {"complete.com/x/mistral:latest+Q4_0", true, true}, - {"complete.com/x/mistral:latest", false, true}, + { + in: "registry.ollama.ai/library/dolphin-mistral:7b-v2.6-dpo-laser-q6_K", + want: Name{ + Host: "registry.ollama.ai", + Namespace: "library", + Model: "dolphin-mistral", + Tag: "7b-v2.6-dpo-laser-q6_K", + }, + wantFilepath: filepath.Join("registry.ollama.ai", "library", "dolphin-mistral", "7b-v2.6-dpo-laser-q6_K"), + }, + { + in: "scheme://host:port/namespace/model:tag", + want: Name{ + Host: "host:port", + Namespace: "namespace", + Model: "model", + Tag: "tag", + }, + wantFilepath: filepath.Join("host:port", "namespace", "model", "tag"), + }, + { + in: "host/namespace/model:tag", + want: Name{ + Host: "host", + Namespace: "namespace", + Model: "model", + Tag: "tag", + }, + wantFilepath: filepath.Join("host", "namespace", "model", "tag"), + }, + { + in: "host:port/namespace/model:tag", + want: Name{ + Host: "host:port", + Namespace: "namespace", + Model: "model", + Tag: "tag", + }, + wantFilepath: filepath.Join("host:port", "namespace", "model", "tag"), + }, + { + in: "host/namespace/model", + want: Name{ + Host: "host", + Namespace: "namespace", + Model: "model", + }, + wantFilepath: filepath.Join("host", "namespace", "model", "latest"), + }, + { + in: "host:port/namespace/model", + want: Name{ + Host: "host:port", + Namespace: "namespace", + Model: "model", + }, + wantFilepath: filepath.Join("host:port", "namespace", "model", "latest"), + }, + { + in: "namespace/model", + want: Name{ + Namespace: "namespace", + Model: "model", + }, + wantFilepath: filepath.Join("registry.ollama.ai", "namespace", "model", "latest"), + }, + { + in: "model", + want: Name{ + Model: "model", + }, + wantFilepath: filepath.Join("registry.ollama.ai", "library", "model", "latest"), + }, + { + in: "h/nn/mm:t", + want: Name{ + Host: "h", + Namespace: "nn", + Model: "mm", + Tag: "t", + }, + wantFilepath: filepath.Join("h", "nn", "mm", "t"), + }, + { + in: part80 + "/" + part80 + "/" + part80 + ":" + part80, + want: Name{ + Host: part80, + Namespace: part80, + Model: part80, + Tag: part80, + }, + wantFilepath: filepath.Join(part80, part80, part80, part80), + }, + { + in: part350 + "/" + part80 + "/" + part80 + ":" + part80, + want: Name{ + Host: part350, + Namespace: part80, + Model: part80, + Tag: part80, + }, + wantFilepath: filepath.Join(part350, part80, part80, part80), + }, + { + in: "@digest", + want: Name{ + RawDigest: "digest", + }, + wantValidDigest: false, + }, + { + in: "model@sha256:123", + want: Name{ + Model: "model", + RawDigest: "sha256:123", + }, + wantValidDigest: true, + }, } for _, tt := range cases { t.Run(tt.in, func(t *testing.T) { - p := ParseName(tt.in, FillNothing) - t.Logf("ParseName(%q) = %#v", tt.in, p) - if g := p.IsComplete(); g != tt.complete { - t.Errorf("Complete(%q) = %v; want %v", tt.in, g, tt.complete) + got := ParseNameBare(tt.in) + if !reflect.DeepEqual(got, tt.want) { + t.Errorf("parseName(%q) = %v; want %v", tt.in, got, tt.want) } - if g := p.IsCompleteNoBuild(); g != tt.completeNoBuild { - t.Errorf("CompleteNoBuild(%q) = %v; want %v", tt.in, g, tt.completeNoBuild) - } - }) - } - // Complete uses Parts which returns a slice, but it should be - // inlined when used in Complete, preventing any allocations or - // escaping to the heap. - allocs := testing.AllocsPerRun(1000, func() { - keep(ParseName("complete.com/x/mistral:latest+Q4_0", FillNothing).IsComplete()) - }) - if allocs > 0 { - t.Errorf("Complete allocs = %v; want 0", allocs) - } -} - -func TestNameLogValue(t *testing.T) { - cases := []string{ - "example.com/library/mistral:latest+Q4_0", - "mistral:latest", - "mistral:7b+Q4_0", - } - for _, s := range cases { - t.Run(s, func(t *testing.T) { - var b bytes.Buffer - log := slog.New(slog.NewTextHandler(&b, nil)) - name := ParseName(s, FillNothing) - log.Info("", "name", name) - want := fmt.Sprintf("name=%s", name.GoString()) - got := b.String() - if !strings.Contains(got, want) { - t.Errorf("expected log output to contain %q; got %q", want, got) + got = ParseName(tt.in) + if tt.wantFilepath != "" && got.Filepath() != tt.wantFilepath { + t.Errorf("parseName(%q).Filepath() = %q; want %q", tt.in, got.Filepath(), tt.wantFilepath) } }) } } -func TestNameGoString(t *testing.T) { +var testCases = map[string]bool{ // name -> valid + "": false, + + "_why/_the/_lucky:_stiff": true, + + // minimal + "h/n/m:t@d": true, + + "host/namespace/model:tag": true, + "host/namespace/model": false, + "namespace/model": false, + "model": false, + "@sha256-1000000000000000000000000000000000000000000000000000000000000000": false, + "model@sha256-1000000000000000000000000000000000000000000000000000000000000000": false, + "model@sha256:1000000000000000000000000000000000000000000000000000000000000000": false, + + // long (but valid) + part80 + "/" + part80 + "/" + part80 + ":" + part80: true, + part350 + "/" + part80 + "/" + part80 + ":" + part80: true, + + "h/nn/mm:t@sha256-1000000000000000000000000000000000000000000000000000000000000000": true, // bare minimum part sizes + "h/nn/mm:t@sha256:1000000000000000000000000000000000000000000000000000000000000000": true, // bare minimum part sizes + + // unqualified + "m": false, + "n/m:": false, + "h/n/m": false, + "@t": false, + "m@d": false, + + // invalids + "^": false, + "mm:": false, + "/nn/mm": false, + "//": false, + "//mm": false, + "hh//": false, + "//mm:@": false, + "00@": false, + "@": false, + + // not starting with alphanum + "-hh/nn/mm:tt@dd": false, + "hh/-nn/mm:tt@dd": false, + "hh/nn/-mm:tt@dd": false, + "hh/nn/mm:-tt@dd": false, + "hh/nn/mm:tt@-dd": false, + + // hosts + "host:https/namespace/model:tag": true, + + // colon in non-host part before tag + "host/name:space/model:tag": false, +} + +func TestNameparseNameDefault(t *testing.T) { + const name = "xx" + n := ParseName(name) + got := n.String() + want := "registry.ollama.ai/library/xx:latest" + if got != want { + t.Errorf("parseName(%q).String() = %q; want %q", name, got, want) + } +} + +func TestNameIsValid(t *testing.T) { + var numStringTests int + for s, want := range testCases { + n := ParseNameBare(s) + got := n.IsValid() + if got != want { + t.Errorf("parseName(%q).IsValid() = %v; want %v", s, got, want) + } + + // Test roundtrip with String + if got { + got := ParseNameBare(s).String() + if got != s { + t.Errorf("parseName(%q).String() = %q; want %q", s, got, s) + } + numStringTests++ + } + } + + if numStringTests == 0 { + t.Errorf("no tests for Name.String") + } +} + +func TestNameIsValidPart(t *testing.T) { cases := []struct { - name string - in string - wantString string - wantGoString string // default is tt.in + kind partKind + s string + want bool }{ - { - name: "Complete Name", - in: "example.com/library/mistral:latest+Q4_0", - wantGoString: "example.com/library/mistral:latest+Q4_0@?", - }, - { - name: "Short Name", - in: "mistral:latest", - wantGoString: "?/?/mistral:latest+?@?", - }, - { - name: "Long Name", - in: "library/mistral:latest", - wantGoString: "?/library/mistral:latest+?@?", - }, - { - name: "Case Preserved", - in: "Library/Mistral:Latest", - wantGoString: "?/Library/Mistral:Latest+?@?", - }, - { - name: "With digest", - in: "Library/Mistral:Latest@sha256-123456", - wantGoString: "?/Library/Mistral:Latest+?@sha256-123456", - }, + {kind: kindHost, s: "", want: false}, + {kind: kindHost, s: "a", want: true}, + {kind: kindHost, s: "a.", want: true}, + {kind: kindHost, s: "a.b", want: true}, + {kind: kindHost, s: "a:123", want: true}, + {kind: kindHost, s: "a:123/aa/bb", want: false}, + {kind: kindNamespace, s: "bb", want: true}, + {kind: kindNamespace, s: "a.", want: false}, + {kind: kindModel, s: "-h", want: false}, + {kind: kindDigest, s: "sha256-1000000000000000000000000000000000000000000000000000000000000000", want: true}, + } + for _, tt := range cases { + t.Run(tt.s, func(t *testing.T) { + got := isValidPart(tt.kind, tt.s) + if got != tt.want { + t.Errorf("isValidPart(%s, %q) = %v; want %v", tt.kind, tt.s, got, tt.want) + } + }) } +} + +func TestFilepathAllocs(t *testing.T) { + n := ParseNameBare("HOST/NAMESPACE/MODEL:TAG") + allocs := testing.AllocsPerRun(1000, func() { + n.Filepath() + }) + var allowedAllocs float64 = 3 + if runtime.GOOS == "windows" { + allowedAllocs = 5 + } + if allocs > allowedAllocs { + t.Errorf("allocs = %v; allowed %v", allocs, allowedAllocs) + } +} + +const ( + validSha256 = "sha256-1000000000000000000000000000000000000000000000000000000000000000" + validSha256Old = "sha256:1000000000000000000000000000000000000000000000000000000000000000" +) + +func TestParseDigest(t *testing.T) { + cases := []struct { + in string + want string + }{ + {"", ""}, // empty + {"sha123-12", ""}, // invalid type + {"sha256-", ""}, // invalid sum + {"sha256-123", ""}, // invalid odd length sum + + {validSha256, validSha256}, + {validSha256Old, validSha256}, + } for _, tt := range cases { - t.Run(tt.name, func(t *testing.T) { - p := ParseName(tt.in, FillNothing) - tt.wantGoString = cmp.Or(tt.wantGoString, tt.in) - if g := fmt.Sprintf("%#v", p); g != tt.wantGoString { - t.Errorf("GoString() = %q; want %q", g, tt.wantGoString) + t.Run(tt.in, func(t *testing.T) { + got, err := ParseDigest(tt.in) + if err != nil { + if tt.want != "" { + t.Errorf("parseDigest(%q) = %v; want %v", tt.in, err, tt.want) + } + return + } + if got.String() != tt.want { + t.Errorf("parseDigest(%q).String() = %q; want %q", tt.in, got, tt.want) } }) } } -func TestDisplayLongest(t *testing.T) { - g := ParseName("example.com/library/mistral:latest+Q4_0", FillNothing).DisplayLongest() - if g != "example.com/library/mistral:latest" { - t.Errorf("got = %q; want %q", g, "example.com/library/mistral:latest") +func TestParseNameFromFilepath(t *testing.T) { + cases := map[string]Name{ + filepath.Join("host", "namespace", "model", "tag"): {Host: "host", Namespace: "namespace", Model: "model", Tag: "tag"}, + filepath.Join("host:port", "namespace", "model", "tag"): {Host: "host:port", Namespace: "namespace", Model: "model", Tag: "tag"}, + filepath.Join("namespace", "model", "tag"): {}, + filepath.Join("model", "tag"): {}, + filepath.Join("model"): {}, + filepath.Join("..", "..", "model", "tag"): {}, + filepath.Join("", "namespace", ".", "tag"): {}, + filepath.Join(".", ".", ".", "."): {}, + filepath.Join("/", "path", "to", "random", "file"): {}, + } + + for in, want := range cases { + t.Run(in, func(t *testing.T) { + got := ParseNameFromFilepath(in) + + if !reflect.DeepEqual(got, want) { + t.Errorf("parseNameFromFilepath(%q) = %v; want %v", in, got, want) + } + }) } } func TestDisplayShortest(t *testing.T) { - cases := []struct { - in string - mask string - want string - wantPanic bool - }{ - {"example.com/library/mistral:latest+Q4_0", "example.com/library/_:latest", "mistral", false}, - {"example.com/library/mistral:latest+Q4_0", "example.com/_/_:latest", "library/mistral", false}, - {"example.com/library/mistral:latest+Q4_0", "", "example.com/library/mistral", false}, - {"example.com/library/mistral:latest+Q4_0", "", "example.com/library/mistral", false}, - - // case-insensitive - {"Example.com/library/mistral:latest+Q4_0", "example.com/library/_:latest", "mistral", false}, - {"example.com/Library/mistral:latest+Q4_0", "example.com/library/_:latest", "mistral", false}, - {"example.com/library/Mistral:latest+Q4_0", "example.com/library/_:latest", "Mistral", false}, - {"example.com/library/mistral:Latest+Q4_0", "example.com/library/_:latest", "mistral", false}, - {"example.com/library/mistral:Latest+q4_0", "example.com/library/_:latest", "mistral", false}, - - // zero value - {"", MaskDefault, "", true}, - - // invalid mask - {"example.com/library/mistral:latest+Q4_0", "example.com/mistral", "", true}, - - // DefaultMask - {"registry.ollama.ai/library/mistral:latest+Q4_0", MaskDefault, "mistral", false}, - - // Auto-Fill - {"x", "example.com/library/_:latest", "x", false}, - {"x", "example.com/library/_:latest+Q4_0", "x", false}, - {"x/y:z", "a.com/library/_:latest+Q4_0", "x/y:z", false}, - {"x/y:z", "a.com/library/_:latest+Q4_0", "x/y:z", false}, + cases := map[string]string{ + "registry.ollama.ai/library/model:latest": "model:latest", + "registry.ollama.ai/library/model:tag": "model:tag", + "registry.ollama.ai/namespace/model:tag": "namespace/model:tag", + "host/namespace/model:tag": "host/namespace/model:tag", + "host/library/model:tag": "host/library/model:tag", } - for _, tt := range cases { - t.Run("", func(t *testing.T) { - defer func() { - if tt.wantPanic { - if recover() == nil { - t.Errorf("expected panic") - } + for in, want := range cases { + t.Run(in, func(t *testing.T) { + got := ParseNameBare(in).DisplayShortest() + if got != want { + t.Errorf("parseName(%q).DisplayShortest() = %q; want %q", in, got, want) + } + }) + } +} + +func FuzzName(f *testing.F) { + for s := range testCases { + f.Add(s) + } + f.Fuzz(func(t *testing.T, s string) { + n := ParseNameBare(s) + if n.IsValid() { + parts := [...]string{n.Host, n.Namespace, n.Model, n.Tag, n.RawDigest} + for _, part := range parts { + if part == ".." { + t.Errorf("unexpected .. as valid part") + } + if len(part) > 350 { + t.Errorf("part too long: %q", part) } - }() - - p := ParseName(tt.in, FillNothing) - t.Logf("ParseName(%q) = %#v", tt.in, p) - if g := p.DisplayShortest(tt.mask); g != tt.want { - t.Errorf("got = %q; want %q", g, tt.want) } - }) - } -} - -func TestParseNameAllocs(t *testing.T) { - allocs := testing.AllocsPerRun(1000, func() { - keep(ParseName("example.com/mistral:7b+Q4_0", FillNothing)) - }) - if allocs > 0 { - t.Errorf("ParseName allocs = %v; want 0", allocs) - } -} - -func BenchmarkParseName(b *testing.B) { - b.ReportAllocs() - - for range b.N { - keep(ParseName("example.com/mistral:7b+Q4_0", FillNothing)) - } -} - -func FuzzParseNameFromFilepath(f *testing.F) { - f.Add("example.com/library/mistral/7b/Q4_0") - f.Add("example.com/../mistral/7b/Q4_0") - f.Add("example.com/x/../7b/Q4_0") - f.Add("example.com/x/../7b") - f.Fuzz(func(t *testing.T, s string) { - name := ParseNameFromFilepath(s, FillNothing) - if strings.Contains(s, "..") && !name.IsZero() { - t.Fatalf("non-zero value for path with '..': %q", s) - } - if name.IsValid() == name.IsZero() { - t.Errorf("expected valid path to be non-zero value; got %#v", name) + if n.String() != s { + t.Errorf("String() = %q; want %q", n.String(), s) + } } + }) } - -func FuzzParseName(f *testing.F) { - f.Add("example.com/mistral:7b+Q4_0") - f.Add("example.com/mistral:7b+q4_0") - f.Add("example.com/mistral:7b+x") - f.Add("x/y/z:8n+I") - f.Add(":x") - f.Add("@sha256-123456") - f.Add("example.com/mistral:latest+Q4_0@sha256-123456") - f.Add(":@!@") - f.Add("...") - f.Fuzz(func(t *testing.T, s string) { - r0 := ParseName(s, FillNothing) - - if strings.Contains(s, "..") && !r0.IsZero() { - t.Fatalf("non-zero value for path with '..': %q", s) - } - - if !r0.IsValid() && !r0.IsResolved() { - if !r0.EqualFold(Name{}) { - t.Errorf("expected invalid path to be zero value; got %#v", r0) - } - t.Skipf("invalid path: %q", s) - } - - for _, p := range r0.parts { - if len(p) > MaxNamePartLen { - t.Errorf("part too long: %q", p) - } - } - - if !strings.EqualFold(r0.DisplayLong(), s) { - t.Errorf("String() did not round-trip with case insensitivity: %q\ngot = %q\nwant = %q", s, r0.DisplayLong(), s) - } - - r1 := ParseName(r0.DisplayLong(), FillNothing) - if !r0.EqualFold(r1) { - t.Errorf("round-trip mismatch: %+v != %+v", r0, r1) - } - }) -} - -func TestNameStringAllocs(t *testing.T) { - name := ParseName("example.com/ns/mistral:latest+Q4_0", FillNothing) - allocs := testing.AllocsPerRun(1000, func() { - keep(name.DisplayLong()) - }) - if allocs > 1 { - t.Errorf("String allocs = %v; want 0", allocs) - } -} - -func TestNamePath(t *testing.T) { - cases := []struct { - in string - want string - }{ - {"example.com/library/mistral:latest+Q4_0", "example.com/library/mistral:latest"}, - - // incomplete - {"example.com/library/mistral:latest", "example.com/library/mistral:latest"}, - {"", ""}, - } - for _, tt := range cases { - t.Run(tt.in, func(t *testing.T) { - p := ParseName(tt.in, FillNothing) - t.Logf("ParseName(%q) = %#v", tt.in, p) - if g := p.URLPath(); g != tt.want { - t.Errorf("got = %q; want %q", g, tt.want) - } - }) - } -} - -func TestNameFilepath(t *testing.T) { - cases := []struct { - in string - want string - wantNoBuild string - }{ - { - in: "example.com/library/mistral:latest+Q4_0", - want: "example.com/library/mistral/latest/Q4_0", - wantNoBuild: "example.com/library/mistral/latest", - }, - { - in: "Example.Com/Library/Mistral:Latest+Q4_0", - want: "example.com/library/mistral/latest/Q4_0", - wantNoBuild: "example.com/library/mistral/latest", - }, - { - in: "Example.Com/Library/Mistral:Latest+Q4_0", - want: "example.com/library/mistral/latest/Q4_0", - wantNoBuild: "example.com/library/mistral/latest", - }, - { - in: "example.com/library/mistral:latest", - want: "example.com/library/mistral/latest", - wantNoBuild: "example.com/library/mistral/latest", - }, - { - in: "", - want: "", - wantNoBuild: "", - }, - } - for _, tt := range cases { - t.Run(tt.in, func(t *testing.T) { - p := ParseName(tt.in, FillNothing) - t.Logf("ParseName(%q) = %#v", tt.in, p) - g := p.Filepath() - g = filepath.ToSlash(g) - if g != tt.want { - t.Errorf("got = %q; want %q", g, tt.want) - } - g = p.FilepathNoBuild() - g = filepath.ToSlash(g) - if g != tt.wantNoBuild { - t.Errorf("got = %q; want %q", g, tt.wantNoBuild) - } - }) - } -} - -func TestParseNameFilepath(t *testing.T) { - cases := []struct { - in string - fill string // default is FillNothing - want string - }{ - { - in: "example.com/library/mistral/latest/Q4_0", - want: "example.com/library/mistral:latest+Q4_0", - }, - { - in: "example.com/library/mistral/latest", - fill: "?/?/?:latest+Q4_0", - want: "example.com/library/mistral:latest+Q4_0", - }, - { - in: "example.com/library/mistral", - fill: "?/?/?:latest+Q4_0", - want: "example.com/library/mistral:latest+Q4_0", - }, - { - in: "example.com/library", - want: "", - }, - { - in: "example.com/", - want: "", - }, - { - in: "example.com/^/mistral/latest/Q4_0", - want: "", - }, - { - in: "example.com/library/mistral/../Q4_0", - want: "", - }, - { - in: "example.com/library/mistral/latest/Q4_0/extra", - want: "", - }, - } - for _, tt := range cases { - t.Run(tt.in, func(t *testing.T) { - in := strings.ReplaceAll(tt.in, "/", string(filepath.Separator)) - fill := cmp.Or(tt.fill, FillNothing) - want := ParseName(tt.want, fill) - if g := ParseNameFromFilepath(in, fill); !g.EqualFold(want) { - t.Errorf("got = %q; want %q", g.DisplayLong(), tt.want) - } - }) - } -} - -func TestParseNameFromPath(t *testing.T) { - cases := []struct { - in string - want string - fill string // default is FillNothing - }{ - { - in: "example.com/library/mistral:latest+Q4_0", - want: "example.com/library/mistral:latest+Q4_0", - }, - { - in: "/example.com/library/mistral:latest+Q4_0", - want: "example.com/library/mistral:latest+Q4_0", - }, - { - in: "/example.com/library/mistral", - want: "example.com/library/mistral", - }, - { - in: "/example.com/library/mistral", - fill: "?/?/?:latest+Q4_0", - want: "example.com/library/mistral:latest+Q4_0", - }, - { - in: "/example.com/library", - want: "", - }, - { - in: "/example.com/", - want: "", - }, - { - in: "/example.com/^/mistral/latest", - want: "", - }, - } - for _, tt := range cases { - t.Run(tt.in, func(t *testing.T) { - fill := cmp.Or(tt.fill, FillNothing) - if g := ParseNameFromURLPath(tt.in, fill); g.DisplayLong() != tt.want { - t.Errorf("got = %q; want %q", g.DisplayLong(), tt.want) - } - }) - } -} - -func ExampleName_MapHash() { - m := map[uint64]bool{} - - // key 1 - m[ParseName("mistral:latest+q4", FillNothing).MapHash()] = true - m[ParseName("miSTRal:latest+Q4", FillNothing).MapHash()] = true - m[ParseName("mistral:LATest+Q4", FillNothing).MapHash()] = true - - // key 2 - m[ParseName("mistral:LATest", FillNothing).MapHash()] = true - - fmt.Println(len(m)) - // Output: - // 2 -} - -func ExampleName_CompareFold_sort() { - names := []Name{ - ParseName("mistral:latest", FillNothing), - ParseName("mistRal:7b+q4", FillNothing), - ParseName("MIstral:7b", FillNothing), - } - - slices.SortFunc(names, Name.CompareFold) - - for _, n := range names { - fmt.Println(n.DisplayLong()) - } - - // Output: - // MIstral:7b - // mistRal:7b+q4 - // mistral:latest -} - -func ExampleName_completeAndResolved() { - for _, s := range []string{ - "x/y/z:latest+q4_0@sha123-1", - "x/y/z:latest+q4_0", - "@sha123-1", - } { - name := ParseName(s, FillNothing) - fmt.Printf("complete:%v resolved:%v digest:%s\n", name.IsComplete(), name.IsResolved(), name.Digest()) - } - - // Output: - // complete:true resolved:true digest:sha123-1 - // complete:true resolved:false digest: - // complete:false resolved:true digest:sha123-1 -} - -func ExampleName_DisplayShortest() { - name := ParseName("example.com/jmorganca/mistral:latest+Q4_0", FillNothing) - - fmt.Println(name.DisplayShortest("example.com/jmorganca/_:latest")) - fmt.Println(name.DisplayShortest("example.com/_/_:latest")) - fmt.Println(name.DisplayShortest("example.com/_/_:_")) - fmt.Println(name.DisplayShortest("_/_/_:_")) - - // Default - name = ParseName("registry.ollama.ai/library/mistral:latest+Q4_0", FillNothing) - fmt.Println(name.DisplayShortest("")) - - // Output: - // mistral - // jmorganca/mistral - // jmorganca/mistral:latest - // example.com/jmorganca/mistral:latest - // mistral -} - -func keep[T any](v T) T { return v } diff --git a/types/model/testdata/fuzz/FuzzParseRef/1d43ee52085cb4aa b/types/model/testdata/fuzz/FuzzName/d37463aa416f6bab similarity index 53% rename from types/model/testdata/fuzz/FuzzParseRef/1d43ee52085cb4aa rename to types/model/testdata/fuzz/FuzzName/d37463aa416f6bab index 0cdf1eac..0034d9f5 100644 --- a/types/model/testdata/fuzz/FuzzParseRef/1d43ee52085cb4aa +++ b/types/model/testdata/fuzz/FuzzName/d37463aa416f6bab @@ -1,2 +1,2 @@ go test fuzz v1 -string("/0") +string("00@") diff --git a/types/model/testdata/fuzz/FuzzParseRef/27fd759314f0e6d6 b/types/model/testdata/fuzz/FuzzParseRef/27fd759314f0e6d6 deleted file mode 100644 index c5d09a4c..00000000 --- a/types/model/testdata/fuzz/FuzzParseRef/27fd759314f0e6d6 +++ /dev/null @@ -1,2 +0,0 @@ -go test fuzz v1 -string("0//0") diff --git a/types/model/testdata/fuzz/FuzzParseRef/3e3b70dba384074d b/types/model/testdata/fuzz/FuzzParseRef/3e3b70dba384074d deleted file mode 100644 index 880ce7a3..00000000 --- a/types/model/testdata/fuzz/FuzzParseRef/3e3b70dba384074d +++ /dev/null @@ -1,2 +0,0 @@ -go test fuzz v1 -string("0 /0") diff --git a/types/model/testdata/fuzz/FuzzParseRef/71f1fdff711b6dab b/types/model/testdata/fuzz/FuzzParseRef/71f1fdff711b6dab deleted file mode 100644 index fa981c52..00000000 --- a/types/model/testdata/fuzz/FuzzParseRef/71f1fdff711b6dab +++ /dev/null @@ -1,2 +0,0 @@ -go test fuzz v1 -string("+0/00000") diff --git a/types/model/testdata/fuzz/FuzzParseRef/82c2975c430ac608 b/types/model/testdata/fuzz/FuzzParseRef/82c2975c430ac608 deleted file mode 100644 index 0a66beb8..00000000 --- a/types/model/testdata/fuzz/FuzzParseRef/82c2975c430ac608 +++ /dev/null @@ -1,2 +0,0 @@ -go test fuzz v1 -string(":") diff --git a/types/model/testdata/fuzz/FuzzParseRef/b51b1c875e61a948 b/types/model/testdata/fuzz/FuzzParseRef/b51b1c875e61a948 deleted file mode 100644 index db07727d..00000000 --- a/types/model/testdata/fuzz/FuzzParseRef/b51b1c875e61a948 +++ /dev/null @@ -1,2 +0,0 @@ -go test fuzz v1 -string("0+.\xf2\x80\xf6\x9d00000\xe5\x99\xe6\xd900\xd90\xa60\x91\xdc0\xff\xbf\x99\xe800\xb9\xdc\xd6\xc300\x970\xfb\xfd0\xe0\x8a\xe1\xad\xd40\x9700\xa80\x980\xdd0000\xb00\x91000\xfe0\x89\x9b\x90\x93\x9f0\xe60\xf7\x84\xb0\x87\xa5\xff0\xa000\x9a\x85\xf6\x85\xfe\xa9\xf9\xe9\xde00\xf4\xe0\x8f\x81\xad\xde00\xd700\xaa\xe000000\xb1\xee0\x91") diff --git a/types/structs/structs.go b/types/structs/structs.go deleted file mode 100644 index 52929ebf..00000000 --- a/types/structs/structs.go +++ /dev/null @@ -1,15 +0,0 @@ -// Copyright (c) Tailscale Inc & AUTHORS -// SPDX-License-Identifier: BSD-3-Clause - -// Package structs contains the Incomparable type. -package structs - -// Incomparable is a zero-width incomparable type. If added as the -// first field in a struct, it marks that struct as not comparable -// (can't do == or be a map key) and usually doesn't add any width to -// the struct (unless the struct has only small fields). -// -// By making a struct incomparable, you can prevent misuse (prevent -// people from using ==), but also you can shrink generated binaries, -// as the compiler can omit equality funcs from the binary. -type Incomparable [0]func()