Merge branch 'ollama:main' into main

This commit is contained in:
mraiser 2024-01-27 21:56:11 -05:00 committed by GitHub
commit 4c4c730a0a
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
25 changed files with 739 additions and 195 deletions

View file

@ -23,29 +23,72 @@ jobs:
with: with:
go-version: '1.21' go-version: '1.21'
cache: true cache: true
- if: ${{ startsWith(matrix.os, 'windows-') }}
shell: pwsh
run: |
$path = vswhere -latest -products * -requires Microsoft.VisualStudio.Component.VC.Tools.x86.x64 -property installationPath
if ($path) {
$path = join-path $path 'Common7\Tools\vsdevcmd.bat'
if (test-path $path) {
cmd /s /c """$path"" $args && set" | where { $_ -match '(\w+)=(.*)' } | foreach {
echo "$($Matches[1])=$($Matches[2])" | Out-File -FilePath $Env:GITHUB_ENV -Encoding utf8 -Append
}
}
}
echo "C:\Program Files\Git\usr\bin" | Out-File -FilePath $Env:GITHUB_PATH -Encoding utf8 -Append
- run: go get ./... - run: go get ./...
- run: go generate -x ./... - run: go generate -x ./...
- uses: actions/upload-artifact@v4 - uses: actions/upload-artifact@v4
with: with:
name: ${{ matrix.os }}-${{ matrix.arch }}-libraries name: ${{ matrix.os }}-${{ matrix.arch }}-libraries
path: | path: llm/llama.cpp/build/**/lib/*
llm/llama.cpp/build/**/lib/* generate-cuda:
strategy:
matrix:
cuda-version:
- '11.8.0'
runs-on: ubuntu-latest
container: nvidia/cuda:${{ matrix.cuda-version }}-devel-ubuntu20.04
steps:
- run: |
apt-get update && apt-get install -y git build-essential curl
curl -fsSL https://github.com/Kitware/CMake/releases/download/v3.28.1/cmake-3.28.1-linux-x86_64.tar.gz \
| tar -zx -C /usr --strip-components 1
env:
DEBIAN_FRONTEND: noninteractive
- uses: actions/checkout@v4
- uses: actions/setup-go@v4
with:
go-version: '1.21'
cache: true
- run: go get ./...
- run: |
git config --global --add safe.directory /__w/ollama/ollama
go generate -x ./...
env:
OLLAMA_SKIP_CPU_GENERATE: '1'
- uses: actions/upload-artifact@v4
with:
name: cuda-${{ matrix.cuda-version }}-libraries
path: llm/llama.cpp/build/**/lib/*
generate-rocm:
strategy:
matrix:
rocm-version:
- '5.7.1'
- '6.0'
runs-on: ubuntu-latest
container: rocm/dev-ubuntu-20.04:${{ matrix.rocm-version }}
steps:
- run: |
apt-get update && apt-get install -y git build-essential curl rocm-libs
curl -fsSL https://github.com/Kitware/CMake/releases/download/v3.28.1/cmake-3.28.1-linux-x86_64.tar.gz \
| tar -zx -C /usr --strip-components 1
env:
DEBIAN_FRONTEND: noninteractive
- uses: actions/checkout@v4
- uses: actions/setup-go@v4
with:
go-version: '1.21'
cache: true
- run: go get ./...
- run: |
git config --global --add safe.directory /__w/ollama/ollama
go generate -x ./...
env:
OLLAMA_SKIP_CPU_GENERATE: '1'
- uses: actions/upload-artifact@v4
with:
name: rocm-${{ matrix.rocm-version }}-libraries
path: llm/llama.cpp/build/**/lib/*
lint: lint:
needs: generate
strategy: strategy:
matrix: matrix:
os: [ubuntu-latest, macos-latest, windows-latest] os: [ubuntu-latest, macos-latest, windows-latest]
@ -69,10 +112,19 @@ jobs:
with: with:
go-version: '1.21' go-version: '1.21'
cache: false cache: false
- uses: actions/download-artifact@v4 - run: |
with: mkdir -p llm/llama.cpp/build/linux/${{ matrix.arch }}/stub/lib/
name: ${{ matrix.os }}-${{ matrix.arch }}-libraries touch llm/llama.cpp/build/linux/${{ matrix.arch }}/stub/lib/stub.so
path: llm/llama.cpp/build if: ${{ startsWith(matrix.os, 'ubuntu-') }}
- run: |
mkdir -p llm/llama.cpp/build/darwin/${{ matrix.arch }}/stub/lib/
touch llm/llama.cpp/build/darwin/${{ matrix.arch }}/stub/lib/stub.dylib
touch llm/llama.cpp/ggml-metal.metal
if: ${{ startsWith(matrix.os, 'macos-') }}
- run: |
mkdir -p llm/llama.cpp/build/windows/${{ matrix.arch }}/stub/lib/
touch llm/llama.cpp/build/windows/${{ matrix.arch }}/stub/lib/stub.dll
if: ${{ startsWith(matrix.os, 'windows-') }}
- uses: golangci/golangci-lint-action@v3 - uses: golangci/golangci-lint-action@v3
test: test:
needs: generate needs: generate
@ -104,3 +156,7 @@ jobs:
path: llm/llama.cpp/build path: llm/llama.cpp/build
- run: go build - run: go build
- run: go test -v ./... - run: go test -v ./...
- uses: actions/upload-artifact@v4
with:
name: ${{ matrix.os }}-binaries
path: ollama

View file

@ -109,17 +109,28 @@ ARG CGO_CFLAGS
RUN go build . RUN go build .
# Runtime stages # Runtime stages
FROM --platform=linux/amd64 rocm/dev-centos-7:6.0-complete as runtime-amd64 FROM --platform=linux/amd64 ubuntu:22.04 as runtime-amd64
RUN apt-get update && apt-get install -y ca-certificates
COPY --from=build-amd64 /go/src/github.com/jmorganca/ollama/ollama /bin/ollama COPY --from=build-amd64 /go/src/github.com/jmorganca/ollama/ollama /bin/ollama
FROM --platform=linux/arm64 ubuntu:22.04 as runtime-arm64 FROM --platform=linux/arm64 ubuntu:22.04 as runtime-arm64
RUN apt-get update && apt-get install -y ca-certificates RUN apt-get update && apt-get install -y ca-certificates
COPY --from=build-arm64 /go/src/github.com/jmorganca/ollama/ollama /bin/ollama COPY --from=build-arm64 /go/src/github.com/jmorganca/ollama/ollama /bin/ollama
# Radeon images are much larger so we keep it distinct from the CPU/CUDA image
FROM --platform=linux/amd64 rocm/dev-centos-7:5.7.1-complete as runtime-rocm
RUN update-pciids
COPY --from=build-amd64 /go/src/github.com/jmorganca/ollama/ollama /bin/ollama
EXPOSE 11434
ENV OLLAMA_HOST 0.0.0.0
ENTRYPOINT ["/bin/ollama"]
CMD ["serve"]
FROM runtime-$TARGETARCH FROM runtime-$TARGETARCH
EXPOSE 11434 EXPOSE 11434
ENV OLLAMA_HOST 0.0.0.0 ENV OLLAMA_HOST 0.0.0.0
ENV PATH=/usr/local/sbin:/usr/local/bin:/usr/sbin:/usr/bin:/sbin:/bin ENV PATH=/usr/local/sbin:/usr/local/bin:/usr/sbin:/usr/bin:/sbin:/bin
ENV LD_LIBRARY_PATH=/usr/local/nvidia/lib:/usr/local/nvidia/lib64:/opt/rocm/lib: ENV LD_LIBRARY_PATH=/usr/local/nvidia/lib:/usr/local/nvidia/lib64
ENV NVIDIA_DRIVER_CAPABILITIES=compute,utility ENV NVIDIA_DRIVER_CAPABILITIES=compute,utility
ENTRYPOINT ["/bin/ollama"] ENTRYPOINT ["/bin/ollama"]

View file

@ -42,6 +42,7 @@ type GenerateRequest struct {
Stream *bool `json:"stream,omitempty"` Stream *bool `json:"stream,omitempty"`
Raw bool `json:"raw,omitempty"` Raw bool `json:"raw,omitempty"`
Format string `json:"format"` Format string `json:"format"`
KeepAlive *Duration `json:"keep_alive,omitempty"`
Images []ImageData `json:"images,omitempty"` Images []ImageData `json:"images,omitempty"`
Options map[string]interface{} `json:"options"` Options map[string]interface{} `json:"options"`
@ -52,6 +53,7 @@ type ChatRequest struct {
Messages []Message `json:"messages"` Messages []Message `json:"messages"`
Stream *bool `json:"stream,omitempty"` Stream *bool `json:"stream,omitempty"`
Format string `json:"format"` Format string `json:"format"`
KeepAlive *Duration `json:"keep_alive,omitempty"`
Options map[string]interface{} `json:"options"` Options map[string]interface{} `json:"options"`
} }
@ -128,6 +130,7 @@ type Runner struct {
type EmbeddingRequest struct { type EmbeddingRequest struct {
Model string `json:"model"` Model string `json:"model"`
Prompt string `json:"prompt"` Prompt string `json:"prompt"`
KeepAlive *Duration `json:"keep_alive,omitempty"`
Options map[string]interface{} `json:"options"` Options map[string]interface{} `json:"options"`
} }
@ -171,6 +174,7 @@ type ShowResponse struct {
Template string `json:"template,omitempty"` Template string `json:"template,omitempty"`
System string `json:"system,omitempty"` System string `json:"system,omitempty"`
Details ModelDetails `json:"details,omitempty"` Details ModelDetails `json:"details,omitempty"`
Messages []Message `json:"messages,omitempty"`
} }
type CopyRequest struct { type CopyRequest struct {
@ -236,6 +240,7 @@ type GenerateResponse struct {
} }
type ModelDetails struct { type ModelDetails struct {
ParentModel string `json:"parent_model"`
Format string `json:"format"` Format string `json:"format"`
Family string `json:"family"` Family string `json:"family"`
Families []string `json:"families"` Families []string `json:"families"`
@ -411,14 +416,19 @@ func (d *Duration) UnmarshalJSON(b []byte) (err error) {
case float64: case float64:
if t < 0 { if t < 0 {
t = math.MaxFloat64 t = math.MaxFloat64
}
d.Duration = time.Duration(t) d.Duration = time.Duration(t)
} else {
d.Duration = time.Duration(t * float64(time.Second))
}
case string: case string:
d.Duration, err = time.ParseDuration(t) d.Duration, err = time.ParseDuration(t)
if err != nil { if err != nil {
return err return err
} }
if d.Duration < 0 {
mf := math.MaxFloat64
d.Duration = time.Duration(mf)
}
} }
return nil return nil

View file

@ -459,6 +459,7 @@ type generateContextKey string
type runOptions struct { type runOptions struct {
Model string Model string
ParentModel string
Prompt string Prompt string
Messages []api.Message Messages []api.Message
WordWrap bool WordWrap bool
@ -467,6 +468,7 @@ type runOptions struct {
Template string Template string
Images []api.ImageData Images []api.ImageData
Options map[string]interface{} Options map[string]interface{}
MultiModal bool
} }
type displayResponseState struct { type displayResponseState struct {

View file

@ -7,12 +7,14 @@ import (
"net/http" "net/http"
"os" "os"
"regexp" "regexp"
"sort"
"strings" "strings"
"github.com/spf13/cobra" "github.com/spf13/cobra"
"golang.org/x/exp/slices" "golang.org/x/exp/slices"
"github.com/jmorganca/ollama/api" "github.com/jmorganca/ollama/api"
"github.com/jmorganca/ollama/progress"
"github.com/jmorganca/ollama/readline" "github.com/jmorganca/ollama/readline"
) )
@ -25,33 +27,63 @@ const (
MultilineTemplate MultilineTemplate
) )
func modelIsMultiModal(cmd *cobra.Command, name string) bool { func loadModel(cmd *cobra.Command, opts *runOptions) error {
// get model details
client, err := api.ClientFromEnvironment() client, err := api.ClientFromEnvironment()
if err != nil { if err != nil {
fmt.Println("error: couldn't connect to ollama server") return err
return false
} }
req := api.ShowRequest{Name: name} p := progress.NewProgress(os.Stderr)
resp, err := client.Show(cmd.Context(), &req) defer p.StopAndClear()
spinner := progress.NewSpinner("")
p.Add("", spinner)
showReq := api.ShowRequest{Name: opts.Model}
showResp, err := client.Show(cmd.Context(), &showReq)
if err != nil { if err != nil {
return false return err
}
opts.MultiModal = slices.Contains(showResp.Details.Families, "clip")
opts.ParentModel = showResp.Details.ParentModel
if len(showResp.Messages) > 0 {
opts.Messages = append(opts.Messages, showResp.Messages...)
} }
return slices.Contains(resp.Details.Families, "clip") chatReq := &api.ChatRequest{
Model: opts.Model,
Messages: []api.Message{},
}
err = client.Chat(cmd.Context(), chatReq, func(resp api.ChatResponse) error {
p.StopAndClear()
if len(opts.Messages) > 0 {
for _, msg := range opts.Messages {
switch msg.Role {
case "user":
fmt.Printf(">>> %s\n", msg.Content)
case "assistant":
state := &displayResponseState{}
displayResponse(msg.Content, opts.WordWrap, state)
fmt.Println()
fmt.Println()
}
}
}
return nil
})
if err != nil {
return err
}
return nil
} }
func generateInteractive(cmd *cobra.Command, opts runOptions) error { func generateInteractive(cmd *cobra.Command, opts runOptions) error {
multiModal := modelIsMultiModal(cmd, opts.Model) opts.Messages = make([]api.Message, 0)
// load the model err := loadModel(cmd, &opts)
loadOpts := runOptions{ if err != nil {
Model: opts.Model,
Prompt: "",
Messages: []api.Message{},
}
if _, err := chat(cmd, loadOpts); err != nil {
return err return err
} }
@ -59,6 +91,8 @@ func generateInteractive(cmd *cobra.Command, opts runOptions) error {
fmt.Fprintln(os.Stderr, "Available Commands:") fmt.Fprintln(os.Stderr, "Available Commands:")
fmt.Fprintln(os.Stderr, " /set Set session variables") fmt.Fprintln(os.Stderr, " /set Set session variables")
fmt.Fprintln(os.Stderr, " /show Show model information") fmt.Fprintln(os.Stderr, " /show Show model information")
fmt.Fprintln(os.Stderr, " /load <model> Load a session or model")
fmt.Fprintln(os.Stderr, " /save <model> Save your current session")
fmt.Fprintln(os.Stderr, " /bye Exit") fmt.Fprintln(os.Stderr, " /bye Exit")
fmt.Fprintln(os.Stderr, " /?, /help Help for a command") fmt.Fprintln(os.Stderr, " /?, /help Help for a command")
fmt.Fprintln(os.Stderr, " /? shortcuts Help for keyboard shortcuts") fmt.Fprintln(os.Stderr, " /? shortcuts Help for keyboard shortcuts")
@ -140,7 +174,6 @@ func generateInteractive(cmd *cobra.Command, opts runOptions) error {
var sb strings.Builder var sb strings.Builder
var multiline MultilineState var multiline MultilineState
opts.Messages = make([]api.Message, 0)
for { for {
line, err := scanner.Readline() line, err := scanner.Readline()
@ -203,6 +236,44 @@ func generateInteractive(cmd *cobra.Command, opts runOptions) error {
if err := ListHandler(cmd, args[1:]); err != nil { if err := ListHandler(cmd, args[1:]); err != nil {
return err return err
} }
case strings.HasPrefix(line, "/load"):
args := strings.Fields(line)
if len(args) != 2 {
fmt.Println("Usage:\n /load <modelname>")
continue
}
opts.Model = args[1]
opts.Messages = []api.Message{}
fmt.Printf("Loading model '%s'\n", opts.Model)
if err := loadModel(cmd, &opts); err != nil {
return err
}
continue
case strings.HasPrefix(line, "/save"):
args := strings.Fields(line)
if len(args) != 2 {
fmt.Println("Usage:\n /save <modelname>")
continue
}
client, err := api.ClientFromEnvironment()
if err != nil {
fmt.Println("error: couldn't connect to ollama server")
return err
}
req := &api.CreateRequest{
Name: args[1],
Modelfile: buildModelfile(opts),
}
fn := func(resp api.ProgressResponse) error { return nil }
err = client.Create(cmd.Context(), req, fn)
if err != nil {
fmt.Println("error: couldn't save model")
return err
}
fmt.Printf("Created new model '%s'\n", args[1])
continue
case strings.HasPrefix(line, "/set"): case strings.HasPrefix(line, "/set"):
args := strings.Fields(line) args := strings.Fields(line)
if len(args) > 1 { if len(args) > 1 {
@ -389,7 +460,7 @@ func generateInteractive(cmd *cobra.Command, opts runOptions) error {
args := strings.Fields(line) args := strings.Fields(line)
isFile := false isFile := false
if multiModal { if opts.MultiModal {
for _, f := range extractFileNames(line) { for _, f := range extractFileNames(line) {
if strings.HasPrefix(f, args[0]) { if strings.HasPrefix(f, args[0]) {
isFile = true isFile = true
@ -411,7 +482,7 @@ func generateInteractive(cmd *cobra.Command, opts runOptions) error {
if sb.Len() > 0 && multiline == MultilineNone { if sb.Len() > 0 && multiline == MultilineNone {
newMessage := api.Message{Role: "user", Content: sb.String()} newMessage := api.Message{Role: "user", Content: sb.String()}
if multiModal { if opts.MultiModal {
msg, images, err := extractFileData(sb.String()) msg, images, err := extractFileData(sb.String())
if err != nil { if err != nil {
return err return err
@ -454,6 +525,38 @@ func generateInteractive(cmd *cobra.Command, opts runOptions) error {
} }
} }
func buildModelfile(opts runOptions) string {
var mf strings.Builder
model := opts.ParentModel
if model == "" {
model = opts.Model
}
fmt.Fprintf(&mf, "FROM %s\n", model)
if opts.System != "" {
fmt.Fprintf(&mf, "SYSTEM \"\"\"%s\"\"\"\n", opts.System)
}
if opts.Template != "" {
fmt.Fprintf(&mf, "TEMPLATE \"\"\"%s\"\"\"\n", opts.Template)
}
keys := make([]string, 0)
for k := range opts.Options {
keys = append(keys, k)
}
sort.Strings(keys)
for _, k := range keys {
fmt.Fprintf(&mf, "PARAMETER %s %v\n", k, opts.Options[k])
}
fmt.Fprintln(&mf)
for _, msg := range opts.Messages {
fmt.Fprintf(&mf, "MESSAGE %s \"\"\"%s\"\"\"\n", msg.Role, msg.Content)
}
return mf.String()
}
func normalizeFilePath(fp string) string { func normalizeFilePath(fp string) string {
// Define a map of escaped characters and their replacements // Define a map of escaped characters and their replacements
replacements := map[string]string{ replacements := map[string]string{

View file

@ -1,9 +1,13 @@
package cmd package cmd
import ( import (
"bytes"
"testing" "testing"
"text/template"
"github.com/stretchr/testify/assert" "github.com/stretchr/testify/assert"
"github.com/jmorganca/ollama/api"
) )
func TestExtractFilenames(t *testing.T) { func TestExtractFilenames(t *testing.T) {
@ -49,3 +53,64 @@ d:\path with\spaces\seven.svg inbetween7 c:\users\jdoe\eight.png inbetween8
assert.Contains(t, res[9], "ten.svg") assert.Contains(t, res[9], "ten.svg")
assert.Contains(t, res[9], "E:") assert.Contains(t, res[9], "E:")
} }
func TestModelfileBuilder(t *testing.T) {
opts := runOptions{
Model: "hork",
System: "You are part horse and part shark, but all hork. Do horklike things",
Template: "This is a template.",
Messages: []api.Message{
{Role: "user", Content: "Hey there hork!"},
{Role: "assistant", Content: "Yes it is true, I am half horse, half shark."},
},
Options: map[string]interface{}{},
}
opts.Options["temperature"] = 0.9
opts.Options["seed"] = 42
opts.Options["penalize_newline"] = false
opts.Options["stop"] = []string{"hi", "there"}
mf := buildModelfile(opts)
expectedModelfile := `FROM {{.Model}}
SYSTEM """{{.System}}"""
TEMPLATE """{{.Template}}"""
PARAMETER penalize_newline false
PARAMETER seed 42
PARAMETER stop [hi there]
PARAMETER temperature 0.9
MESSAGE user """Hey there hork!"""
MESSAGE assistant """Yes it is true, I am half horse, half shark."""
`
tmpl, err := template.New("").Parse(expectedModelfile)
assert.Nil(t, err)
var buf bytes.Buffer
err = tmpl.Execute(&buf, opts)
assert.Nil(t, err)
assert.Equal(t, buf.String(), mf)
opts.ParentModel = "horseshark"
mf = buildModelfile(opts)
expectedModelfile = `FROM {{.ParentModel}}
SYSTEM """{{.System}}"""
TEMPLATE """{{.Template}}"""
PARAMETER penalize_newline false
PARAMETER seed 42
PARAMETER stop [hi there]
PARAMETER temperature 0.9
MESSAGE user """Hey there hork!"""
MESSAGE assistant """Yes it is true, I am half horse, half shark."""
`
tmpl, err = template.New("").Parse(expectedModelfile)
assert.Nil(t, err)
var parentBuf bytes.Buffer
err = tmpl.Execute(&parentBuf, opts)
assert.Nil(t, err)
assert.Equal(t, parentBuf.String(), mf)
}

View file

@ -50,7 +50,8 @@ development and runtime packages.
Typically the build scripts will auto-detect CUDA, however, if your Linux distro Typically the build scripts will auto-detect CUDA, however, if your Linux distro
or installation approach uses unusual paths, you can specify the location by or installation approach uses unusual paths, you can specify the location by
specifying an environment variable `CUDA_LIB_DIR` to the location of the shared specifying an environment variable `CUDA_LIB_DIR` to the location of the shared
libraries, and `CUDACXX` to the location of the nvcc compiler. libraries, and `CUDACXX` to the location of the nvcc compiler. You can customize
set set of target CUDA architectues by setting `CMAKE_CUDA_ARCHITECTURES` (e.g. "50;60;70")
Then generate dependencies: Then generate dependencies:

View file

@ -19,6 +19,7 @@ A model file is the blueprint to create and share models with Ollama.
- [SYSTEM](#system) - [SYSTEM](#system)
- [ADAPTER](#adapter) - [ADAPTER](#adapter)
- [LICENSE](#license) - [LICENSE](#license)
- [MESSAGE](#message)
- [Notes](#notes) - [Notes](#notes)
## Format ## Format
@ -38,6 +39,7 @@ INSTRUCTION arguments
| [`SYSTEM`](#system) | Specifies the system message that will be set in the template. | | [`SYSTEM`](#system) | Specifies the system message that will be set in the template. |
| [`ADAPTER`](#adapter) | Defines the (Q)LoRA adapters to apply to the model. | | [`ADAPTER`](#adapter) | Defines the (Q)LoRA adapters to apply to the model. |
| [`LICENSE`](#license) | Specifies the legal license. | | [`LICENSE`](#license) | Specifies the legal license. |
| [`MESSAGE`](#message) | Specify message history. |
## Examples ## Examples
@ -205,6 +207,19 @@ LICENSE """
""" """
``` ```
### MESSAGE
The `MESSAGE` instruction allows you to specify a message history for the model to use when responding:
```modelfile
MESSAGE user Is Toronto in Canada?
MESSAGE assistant yes
MESSAGE user Is Sacramento in Canada?
MESSAGE assistant no
MESSAGE user Is Ontario in Canada?
MESSAGE assistant yes
```
## Notes ## Notes
- the **`Modelfile` is not case sensitive**. In the examples, uppercase instructions are used to make it easier to distinguish it from arguments. - the **`Modelfile` is not case sensitive**. In the examples, uppercase instructions are used to make it easier to distinguish it from arguments.

View file

@ -16,6 +16,7 @@ import (
"os" "os"
"path/filepath" "path/filepath"
"runtime" "runtime"
"strconv"
"strings" "strings"
"sync" "sync"
"unsafe" "unsafe"
@ -29,8 +30,8 @@ type handles struct {
var gpuMutex sync.Mutex var gpuMutex sync.Mutex
var gpuHandles *handles = nil var gpuHandles *handles = nil
// With our current CUDA compile flags, 5.2 and older will not work properly // With our current CUDA compile flags, older than 5.0 will not work properly
const CudaComputeMajorMin = 6 var CudaComputeMin = [2]C.int{5, 0}
// Possible locations for the nvidia-ml library // Possible locations for the nvidia-ml library
var CudaLinuxGlobs = []string{ var CudaLinuxGlobs = []string{
@ -121,9 +122,15 @@ func GetGPUInfo() GpuInfo {
initGPUHandles() initGPUHandles()
} }
// All our GPU builds have AVX enabled, so fallback to CPU if we don't detect at least AVX
cpuVariant := GetCPUVariant()
if cpuVariant == "" {
slog.Warn("CPU does not have AVX or AVX2, disabling GPU support.")
}
var memInfo C.mem_info_t var memInfo C.mem_info_t
resp := GpuInfo{} resp := GpuInfo{}
if gpuHandles.cuda != nil { if gpuHandles.cuda != nil && cpuVariant != "" {
C.cuda_check_vram(*gpuHandles.cuda, &memInfo) C.cuda_check_vram(*gpuHandles.cuda, &memInfo)
if memInfo.err != nil { if memInfo.err != nil {
slog.Info(fmt.Sprintf("error looking up CUDA GPU memory: %s", C.GoString(memInfo.err))) slog.Info(fmt.Sprintf("error looking up CUDA GPU memory: %s", C.GoString(memInfo.err)))
@ -135,19 +142,40 @@ func GetGPUInfo() GpuInfo {
if cc.err != nil { if cc.err != nil {
slog.Info(fmt.Sprintf("error looking up CUDA GPU compute capability: %s", C.GoString(cc.err))) slog.Info(fmt.Sprintf("error looking up CUDA GPU compute capability: %s", C.GoString(cc.err)))
C.free(unsafe.Pointer(cc.err)) C.free(unsafe.Pointer(cc.err))
} else if cc.major >= CudaComputeMajorMin { } else if cc.major > CudaComputeMin[0] || (cc.major == CudaComputeMin[0] && cc.minor >= CudaComputeMin[1]) {
slog.Info(fmt.Sprintf("CUDA Compute Capability detected: %d.%d", cc.major, cc.minor)) slog.Info(fmt.Sprintf("CUDA Compute Capability detected: %d.%d", cc.major, cc.minor))
resp.Library = "cuda" resp.Library = "cuda"
} else { } else {
slog.Info(fmt.Sprintf("CUDA GPU is too old. Falling back to CPU mode. Compute Capability detected: %d.%d", cc.major, cc.minor)) slog.Info(fmt.Sprintf("CUDA GPU is too old. Falling back to CPU mode. Compute Capability detected: %d.%d", cc.major, cc.minor))
} }
} }
} else if gpuHandles.rocm != nil { } else if gpuHandles.rocm != nil && cpuVariant != "" {
C.rocm_check_vram(*gpuHandles.rocm, &memInfo) C.rocm_check_vram(*gpuHandles.rocm, &memInfo)
if memInfo.err != nil { if memInfo.err != nil {
slog.Info(fmt.Sprintf("error looking up ROCm GPU memory: %s", C.GoString(memInfo.err))) slog.Info(fmt.Sprintf("error looking up ROCm GPU memory: %s", C.GoString(memInfo.err)))
C.free(unsafe.Pointer(memInfo.err)) C.free(unsafe.Pointer(memInfo.err))
} else if memInfo.igpu_index >= 0 && memInfo.count == 1 {
// Only one GPU detected and it appears to be an integrated GPU - skip it
slog.Info("ROCm unsupported integrated GPU detected")
} else { } else {
if memInfo.igpu_index >= 0 {
// We have multiple GPUs reported, and one of them is an integrated GPU
// so we have to set the env var to bypass it
// If the user has specified their own ROCR_VISIBLE_DEVICES, don't clobber it
val := os.Getenv("ROCR_VISIBLE_DEVICES")
if val == "" {
devices := []string{}
for i := 0; i < int(memInfo.count); i++ {
if i == int(memInfo.igpu_index) {
continue
}
devices = append(devices, strconv.Itoa(i))
}
val = strings.Join(devices, ",")
os.Setenv("ROCR_VISIBLE_DEVICES", val)
}
slog.Info(fmt.Sprintf("ROCm integrated GPU detected - ROCR_VISIBLE_DEVICES=%s", val))
}
resp.Library = "rocm" resp.Library = "rocm"
var version C.rocm_version_resp_t var version C.rocm_version_resp_t
C.rocm_get_version(*gpuHandles.rocm, &version) C.rocm_get_version(*gpuHandles.rocm, &version)
@ -163,7 +191,7 @@ func GetGPUInfo() GpuInfo {
if resp.Library == "" { if resp.Library == "" {
C.cpu_check_ram(&memInfo) C.cpu_check_ram(&memInfo)
resp.Library = "cpu" resp.Library = "cpu"
resp.Variant = GetCPUVariant() resp.Variant = cpuVariant
} }
if memInfo.err != nil { if memInfo.err != nil {
slog.Info(fmt.Sprintf("error looking up CPU memory: %s", C.GoString(memInfo.err))) slog.Info(fmt.Sprintf("error looking up CPU memory: %s", C.GoString(memInfo.err)))
@ -199,7 +227,9 @@ func CheckVRAM() (int64, error) {
if overhead < gpus*1024*1024*1024 { if overhead < gpus*1024*1024*1024 {
overhead = gpus * 1024 * 1024 * 1024 overhead = gpus * 1024 * 1024 * 1024
} }
return int64(gpuInfo.FreeMemory - overhead), nil avail := int64(gpuInfo.FreeMemory - overhead)
slog.Debug(fmt.Sprintf("%s detected %d devices with %dM available memory", gpuInfo.Library, gpuInfo.DeviceCount, avail/1024/1024))
return avail, nil
} }
return 0, fmt.Errorf("no GPU detected") // TODO - better handling of CPU based memory determiniation return 0, fmt.Errorf("no GPU detected") // TODO - better handling of CPU based memory determiniation

View file

@ -42,6 +42,7 @@ typedef struct mem_info {
uint64_t total; uint64_t total;
uint64_t free; uint64_t free;
unsigned int count; unsigned int count;
int igpu_index; // If >= 0, we detected an integrated GPU to ignore
char *err; // If non-nill, caller responsible for freeing char *err; // If non-nill, caller responsible for freeing
} mem_info_t; } mem_info_t;

View file

@ -70,6 +70,7 @@ void cuda_init(char *cuda_lib_path, cuda_init_resp_t *resp) {
resp->ch.handle = NULL; resp->ch.handle = NULL;
snprintf(buf, buflen, "nvml vram init failure: %d", ret); snprintf(buf, buflen, "nvml vram init failure: %d", ret);
resp->err = strdup(buf); resp->err = strdup(buf);
return;
} }
// Report driver version if we're in verbose mode, ignore errors // Report driver version if we're in verbose mode, ignore errors

View file

@ -77,6 +77,7 @@ void rocm_init(char *rocm_lib_path, rocm_init_resp_t *resp) {
void rocm_check_vram(rocm_handle_t h, mem_info_t *resp) { void rocm_check_vram(rocm_handle_t h, mem_info_t *resp) {
resp->err = NULL; resp->err = NULL;
resp->igpu_index = -1;
uint64_t totalMem = 0; uint64_t totalMem = 0;
uint64_t usedMem = 0; uint64_t usedMem = 0;
rsmi_status_t ret; rsmi_status_t ret;
@ -162,16 +163,22 @@ void rocm_check_vram(rocm_handle_t h, mem_info_t *resp) {
} }
LOG(h.verbose, "[%d] ROCm totalMem %ld\n", i, totalMem); LOG(h.verbose, "[%d] ROCm totalMem %ld\n", i, totalMem);
LOG(h.verbose, "[%d] ROCm usedMem %ld\n", i, usedMem); LOG(h.verbose, "[%d] ROCm usedMem %ld\n", i, usedMem);
if (totalMem < 1024 * 1024 * 1024) {
// Do not add up integrated GPU memory capacity, it's a bogus 512M, and actually uses system memory
LOG(h.verbose, "[%d] ROCm integrated GPU\n", i);
resp->igpu_index = i;
} else {
resp->total += totalMem; resp->total += totalMem;
resp->free += totalMem - usedMem; resp->free += totalMem - usedMem;
} }
} }
}
void rocm_get_version(rocm_handle_t h, rocm_version_resp_t *resp) { void rocm_get_version(rocm_handle_t h, rocm_version_resp_t *resp) {
const int buflen = 256; const int buflen = 256;
char buf[buflen + 1]; char buf[buflen + 1];
if (h.handle == NULL) { if (h.handle == NULL) {
resp->str = strdup("nvml handle not initialized"); resp->str = strdup("rocm handle not initialized");
resp->status = 1; resp->status = 1;
return; return;
} }

View file

@ -190,6 +190,7 @@ func (llm *dynExtServer) Predict(ctx context.Context, predict PredictOpts, fn fu
"seed": predict.Options.Seed, "seed": predict.Options.Seed,
"stop": predict.Options.Stop, "stop": predict.Options.Stop,
"image_data": imageData, "image_data": imageData,
"cache_prompt": true,
} }
if predict.Format == "json" { if predict.Format == "json" {

View file

@ -39,6 +39,9 @@ init_vars() {
*) *)
;; ;;
esac esac
if [ -z "${CMAKE_CUDA_ARCHITECTURES}" ] ; then
CMAKE_CUDA_ARCHITECTURES="50;52;61;70;75;80"
fi
} }
git_module_setup() { git_module_setup() {
@ -61,6 +64,17 @@ apply_patches() {
if ! grep ollama ${LLAMACPP_DIR}/examples/server/CMakeLists.txt; then if ! grep ollama ${LLAMACPP_DIR}/examples/server/CMakeLists.txt; then
echo 'include (../../../ext_server/CMakeLists.txt) # ollama' >>${LLAMACPP_DIR}/examples/server/CMakeLists.txt echo 'include (../../../ext_server/CMakeLists.txt) # ollama' >>${LLAMACPP_DIR}/examples/server/CMakeLists.txt
fi fi
# apply temporary patches until fix is upstream
for patch in ../patches/*.diff; do
for file in $(grep "^+++ " ${patch} | cut -f2 -d' ' | cut -f2- -d/); do
(cd ${LLAMACPP_DIR}; git checkout ${file})
done
done
for patch in ../patches/*.diff; do
(cd ${LLAMACPP_DIR} && git apply ${patch})
done
# Avoid duplicate main symbols when we link into the cgo binary # Avoid duplicate main symbols when we link into the cgo binary
sed -e 's/int main(/int __main(/g' <${LLAMACPP_DIR}/examples/server/server.cpp >${LLAMACPP_DIR}/examples/server/server.cpp.tmp && sed -e 's/int main(/int __main(/g' <${LLAMACPP_DIR}/examples/server/server.cpp >${LLAMACPP_DIR}/examples/server/server.cpp.tmp &&
mv ${LLAMACPP_DIR}/examples/server/server.cpp.tmp ${LLAMACPP_DIR}/examples/server/server.cpp mv ${LLAMACPP_DIR}/examples/server/server.cpp.tmp ${LLAMACPP_DIR}/examples/server/server.cpp

View file

@ -140,7 +140,7 @@ if [ -d "${CUDA_LIB_DIR}" ]; then
if [ -n "${CUDA_MAJOR}" ]; then if [ -n "${CUDA_MAJOR}" ]; then
CUDA_VARIANT=_v${CUDA_MAJOR} CUDA_VARIANT=_v${CUDA_MAJOR}
fi fi
CMAKE_DEFS="-DLLAMA_CUBLAS=on ${COMMON_CMAKE_DEFS} ${CMAKE_DEFS}" CMAKE_DEFS="-DLLAMA_CUBLAS=on -DLLAMA_CUDA_FORCE_MMQ=on -DCMAKE_CUDA_ARCHITECTURES=${CMAKE_CUDA_ARCHITECTURES} ${COMMON_CMAKE_DEFS} ${CMAKE_DEFS}"
BUILD_DIR="${LLAMACPP_DIR}/build/linux/${ARCH}/cuda${CUDA_VARIANT}" BUILD_DIR="${LLAMACPP_DIR}/build/linux/${ARCH}/cuda${CUDA_VARIANT}"
EXTRA_LIBS="-L${CUDA_LIB_DIR} -lcudart -lcublas -lcublasLt -lcuda" EXTRA_LIBS="-L${CUDA_LIB_DIR} -lcudart -lcublas -lcublasLt -lcuda"
build build

View file

@ -25,6 +25,11 @@ function init_vars {
} }
$script:GZIP=(get-command -ea 'silentlycontinue' gzip).path $script:GZIP=(get-command -ea 'silentlycontinue' gzip).path
$script:DUMPBIN=(get-command -ea 'silentlycontinue' dumpbin).path $script:DUMPBIN=(get-command -ea 'silentlycontinue' dumpbin).path
if ($null -eq $env:CMAKE_CUDA_ARCHITECTURES) {
$script:CMAKE_CUDA_ARCHITECTURES="50;52;61;70;75;80"
} else {
$script:CMAKE_CUDA_ARCHITECTURES=$env:CMAKE_CUDA_ARCHITECTURES
}
} }
function git_module_setup { function git_module_setup {
@ -40,6 +45,29 @@ function apply_patches {
if (!(Select-String -Path "${script:llamacppDir}/examples/server/CMakeLists.txt" -Pattern 'ollama')) { if (!(Select-String -Path "${script:llamacppDir}/examples/server/CMakeLists.txt" -Pattern 'ollama')) {
Add-Content -Path "${script:llamacppDir}/examples/server/CMakeLists.txt" -Value 'include (../../../ext_server/CMakeLists.txt) # ollama' Add-Content -Path "${script:llamacppDir}/examples/server/CMakeLists.txt" -Value 'include (../../../ext_server/CMakeLists.txt) # ollama'
} }
# Apply temporary patches until fix is upstream
$patches = Get-ChildItem "../patches/*.diff"
foreach ($patch in $patches) {
# Extract file paths from the patch file
$filePaths = Get-Content $patch.FullName | Where-Object { $_ -match '^\+\+\+ ' } | ForEach-Object {
$parts = $_ -split ' '
($parts[1] -split '/', 2)[1]
}
# Checkout each file
foreach ($file in $filePaths) {
Set-Location -Path ${script:llamacppDir}
git checkout $file
}
}
# Apply each patch
foreach ($patch in $patches) {
Set-Location -Path ${script:llamacppDir}
git apply $patch.FullName
}
# Avoid duplicate main symbols when we link into the cgo binary # Avoid duplicate main symbols when we link into the cgo binary
$content = Get-Content -Path "${script:llamacppDir}/examples/server/server.cpp" $content = Get-Content -Path "${script:llamacppDir}/examples/server/server.cpp"
$content = $content -replace 'int main\(', 'int __main(' $content = $content -replace 'int main\(', 'int __main('
@ -128,7 +156,7 @@ if ($null -ne $script:CUDA_LIB_DIR) {
} }
init_vars init_vars
$script:buildDir="${script:llamacppDir}/build/windows/${script:ARCH}/cuda$script:CUDA_VARIANT" $script:buildDir="${script:llamacppDir}/build/windows/${script:ARCH}/cuda$script:CUDA_VARIANT"
$script:cmakeDefs += @("-DLLAMA_CUBLAS=ON", "-DLLAMA_AVX=on") $script:cmakeDefs += @("-DLLAMA_CUBLAS=ON", "-DLLAMA_AVX=on", "-DCMAKE_CUDA_ARCHITECTURES=${script:CMAKE_CUDA_ARCHITECTURES}")
build build
install install
cp "${script:CUDA_LIB_DIR}/cudart64_*.dll" "${script:buildDir}/lib" cp "${script:CUDA_LIB_DIR}/cudart64_*.dll" "${script:buildDir}/lib"

View file

@ -69,12 +69,65 @@ type tensor struct {
name string name string
kind uint32 kind uint32
offset uint64 offset uint64
size uint64
// shape is the number of elements in each dimension // shape is the number of elements in each dimension
shape [4]uint64 shape [4]uint64
} }
func (t tensor) blockSize() uint64 {
switch {
case t.kind < 2:
return 1
case t.kind < 10:
return 32
default:
return 256
}
}
func (t tensor) typeSize() uint64 {
blockSize := t.blockSize()
switch t.kind {
case 0: // FP32
return 4
case 1: // FP16
return 2
case 2: // Q4_0
return 2 + blockSize/2
case 3: // Q4_1
return 2 + 2 + blockSize/2
case 6: // Q5_0
return 2 + 4 + blockSize/2
case 7: // Q5_1
return 2 + 2 + 4 + blockSize/2
case 8: // Q8_0
return 2 + blockSize
case 9: // Q8_1
return 4 + 4 + blockSize
case 10: // Q2_K
return blockSize/16 + blockSize/4 + 2 + 2
case 11: // Q3_K
return blockSize/8 + blockSize/4 + 12 + 2
case 12: // Q4_K
return 2 + 2 + 12 + blockSize/2
case 13: // Q5_K
return 2 + 2 + 12 + blockSize/8 + blockSize/2
case 14: // Q6_K
return blockSize/2 + blockSize/4 + blockSize/16 + 2
default:
return 0
}
}
func (t tensor) parameters() uint64 {
return t.shape[0] * t.shape[1] * t.shape[2] * t.shape[3]
}
func (t tensor) size() uint64 {
return t.parameters() * t.typeSize() / t.blockSize()
}
type ggufModel struct { type ggufModel struct {
*containerGGUF *containerGGUF
@ -201,61 +254,15 @@ func (llm *ggufModel) Decode(rso *readSeekOffset) error {
shape[i] = llm.readU64(rso) shape[i] = llm.readU64(rso)
} }
kind := llm.readU32(rso) tensor := tensor{
offset := llm.readU64(rso)
var blockSize uint64
switch {
case kind < 2:
blockSize = 1
case kind < 10:
blockSize = 32
default:
blockSize = 256
}
var typeSize uint64
switch kind {
case 0: // FP32
typeSize = 4
case 1: // FP16
typeSize = 2
case 2: // Q4_0
typeSize = 2 + blockSize/2
case 3: // Q4_1
typeSize = 2 + 2 + blockSize/2
case 6: // Q5_0
typeSize = 2 + 4 + blockSize/2
case 7: // Q5_1
typeSize = 2 + 2 + 4 + blockSize/2
case 8: // Q8_0
typeSize = 2 + blockSize
case 9: // Q8_1
typeSize = 4 + 4 + blockSize
case 10: // Q2_K
typeSize = blockSize/16 + blockSize/4 + 2 + 2
case 11: // Q3_K
typeSize = blockSize/8 + blockSize/4 + 12 + 2
case 12: // Q4_K
typeSize = 2 + 2 + 12 + blockSize/2
case 13: // Q5_K
typeSize = 2 + 2 + 12 + blockSize/8 + blockSize/2
case 14: // Q6_K
typeSize = blockSize/2 + blockSize/4 + blockSize/16 + 2
}
parameters := shape[0] * shape[1] * shape[2] * shape[3]
size := parameters * typeSize / blockSize
llm.tensors = append(llm.tensors, tensor{
name: name, name: name,
kind: kind, kind: llm.readU32(rso),
offset: offset, offset: llm.readU64(rso),
size: size,
shape: shape, shape: shape,
}) }
llm.parameters += parameters llm.tensors = append(llm.tensors, tensor)
llm.parameters += tensor.parameters()
} }
alignment, ok := llm.kv["general.alignment"].(uint32) alignment, ok := llm.kv["general.alignment"].(uint32)
@ -265,7 +272,7 @@ func (llm *ggufModel) Decode(rso *readSeekOffset) error {
rso.Seek(int64(alignment)-rso.offset%int64(alignment), io.SeekCurrent) rso.Seek(int64(alignment)-rso.offset%int64(alignment), io.SeekCurrent)
for _, tensor := range llm.tensors { for _, tensor := range llm.tensors {
padded := (int64(tensor.size) + int64(alignment) - 1) & ^(int64(alignment) - 1) padded := (int64(tensor.size()) + int64(alignment) - 1) & ^(int64(alignment) - 1)
rso.Seek(padded, io.SeekCurrent) rso.Seek(padded, io.SeekCurrent)
} }

@ -1 +1 @@
Subproject commit 011e8ec577fd135cbc02993d3ea9840c516d6a1c Subproject commit cd4fddb29f81d6a1f6d51a0c016bc6b486d68def

30
llm/patches/01-cache.diff Normal file
View file

@ -0,0 +1,30 @@
diff --git a/examples/server/server.cpp b/examples/server/server.cpp
index 0462fbd2..4fa7b57f 100644
--- a/examples/server/server.cpp
+++ b/examples/server/server.cpp
@@ -1857,12 +1857,6 @@ struct llama_server_context
LOG_TEE("slot %d : in cache: %i tokens | to process: %i tokens\n", slot.id, slot.n_past, slot.num_prompt_tokens_processed);
}
- LOG_TEE("slot %d : kv cache rm - [%d, end)\n", slot.id, (int) system_tokens.size() + slot.n_past);
-
- llama_kv_cache_seq_rm(ctx, slot.id, system_tokens.size() + slot.n_past, -1);
-
- slot.cache_tokens = prompt_tokens;
-
if (slot.n_past == slot.num_prompt_tokens && slot.n_past > 0)
{
// we have to evaluate at least 1 token to generate logits.
@@ -1870,6 +1864,12 @@ struct llama_server_context
slot.n_past--;
}
+ LOG_TEE("slot %d : kv cache rm - [%d, end)\n", slot.id, (int) system_tokens.size() + slot.n_past);
+
+ llama_kv_cache_seq_rm(ctx, slot.id, system_tokens.size() + slot.n_past, -1);
+
+ slot.cache_tokens = prompt_tokens;
+
LOG_VERBOSE("prompt ingested", {
{"n_past", slot.n_past},
{"cached", tokens_to_str(ctx, slot.cache_tokens.cbegin(), slot.cache_tokens.cbegin() + slot.n_past)},

View file

@ -7,6 +7,7 @@ import (
"fmt" "fmt"
"io" "io"
"log/slog" "log/slog"
"slices"
) )
type Command struct { type Command struct {
@ -56,6 +57,16 @@ func Parse(reader io.Reader) ([]Command, error) {
command.Args = string(bytes.TrimSpace(fields[1])) command.Args = string(bytes.TrimSpace(fields[1]))
case "EMBED": case "EMBED":
return nil, fmt.Errorf("deprecated command: EMBED is no longer supported, use the /embed API endpoint instead") return nil, fmt.Errorf("deprecated command: EMBED is no longer supported, use the /embed API endpoint instead")
case "MESSAGE":
command.Name = string(bytes.ToLower(fields[0]))
fields = bytes.SplitN(fields[1], []byte(" "), 2)
if len(fields) < 2 {
return nil, fmt.Errorf("should be in the format <role> <message>")
}
if !slices.Contains([]string{"system", "user", "assistant"}, string(bytes.ToLower(fields[0]))) {
return nil, fmt.Errorf("role must be one of \"system\", \"user\", or \"assistant\"")
}
command.Args = fmt.Sprintf("%s: %s", string(bytes.ToLower(fields[0])), string(fields[1]))
default: default:
if !bytes.HasPrefix(fields[0], []byte("#")) { if !bytes.HasPrefix(fields[0], []byte("#")) {
// log a warning for unknown commands // log a warning for unknown commands

View file

@ -61,3 +61,38 @@ PARAMETER param1
assert.ErrorContains(t, err, "missing value for [param1]") assert.ErrorContains(t, err, "missing value for [param1]")
} }
func Test_Parser_Messages(t *testing.T) {
input := `
FROM foo
MESSAGE system You are a Parser. Always Parse things.
MESSAGE user Hey there!
MESSAGE assistant Hello, I want to parse all the things!
`
reader := strings.NewReader(input)
commands, err := Parse(reader)
assert.Nil(t, err)
expectedCommands := []Command{
{Name: "model", Args: "foo"},
{Name: "message", Args: "system: You are a Parser. Always Parse things."},
{Name: "message", Args: "user: Hey there!"},
{Name: "message", Args: "assistant: Hello, I want to parse all the things!"},
}
assert.Equal(t, expectedCommands, commands)
}
func Test_Parser_Messages_BadRole(t *testing.T) {
input := `
FROM foo
MESSAGE badguy I'm a bad guy!
`
reader := strings.NewReader(input)
_, err := Parse(reader)
assert.ErrorContains(t, err, "role must be one of \"system\", \"user\", or \"assistant\"")
}

View file

@ -13,3 +13,13 @@ docker build \
-f Dockerfile \ -f Dockerfile \
-t ollama/ollama:$VERSION \ -t ollama/ollama:$VERSION \
. .
docker build \
--load \
--platform=linux/amd64 \
--build-arg=VERSION \
--build-arg=GOFLAGS \
--target runtime-rocm \
-f Dockerfile \
-t ollama/ollama:$VERSION-rocm \
.

View file

@ -25,6 +25,11 @@ import (
"github.com/jmorganca/ollama/format" "github.com/jmorganca/ollama/format"
) )
const maxRetries = 6
var errMaxRetriesExceeded = errors.New("max retries exceeded")
var errPartStalled = errors.New("part stalled")
var blobDownloadManager sync.Map var blobDownloadManager sync.Map
type blobDownload struct { type blobDownload struct {
@ -48,6 +53,7 @@ type blobDownloadPart struct {
Offset int64 Offset int64
Size int64 Size int64
Completed int64 Completed int64
lastUpdated time.Time
*blobDownload `json:"-"` *blobDownload `json:"-"`
} }
@ -72,6 +78,13 @@ func (p *blobDownloadPart) StopsAt() int64 {
return p.Offset + p.Size return p.Offset + p.Size
} }
func (p *blobDownloadPart) Write(b []byte) (n int, err error) {
n = len(b)
p.blobDownload.Completed.Add(int64(n))
p.lastUpdated = time.Now()
return n, nil
}
func (b *blobDownload) Prepare(ctx context.Context, requestURL *url.URL, opts *RegistryOptions) error { func (b *blobDownload) Prepare(ctx context.Context, requestURL *url.URL, opts *RegistryOptions) error {
partFilePaths, err := filepath.Glob(b.Name + "-partial-*") partFilePaths, err := filepath.Glob(b.Name + "-partial-*")
if err != nil { if err != nil {
@ -157,6 +170,9 @@ func (b *blobDownload) run(ctx context.Context, requestURL *url.URL, opts *Regis
case errors.Is(err, context.Canceled), errors.Is(err, syscall.ENOSPC): case errors.Is(err, context.Canceled), errors.Is(err, syscall.ENOSPC):
// return immediately if the context is canceled or the device is out of space // return immediately if the context is canceled or the device is out of space
return err return err
case errors.Is(err, errPartStalled):
try--
continue
case err != nil: case err != nil:
sleep := time.Second * time.Duration(math.Pow(2, float64(try))) sleep := time.Second * time.Duration(math.Pow(2, float64(try)))
slog.Info(fmt.Sprintf("%s part %d attempt %d failed: %v, retrying in %s", b.Digest[7:19], part.N, try, err, sleep)) slog.Info(fmt.Sprintf("%s part %d attempt %d failed: %v, retrying in %s", b.Digest[7:19], part.N, try, err, sleep))
@ -195,6 +211,8 @@ func (b *blobDownload) run(ctx context.Context, requestURL *url.URL, opts *Regis
} }
func (b *blobDownload) downloadChunk(ctx context.Context, requestURL *url.URL, w io.Writer, part *blobDownloadPart, opts *RegistryOptions) error { func (b *blobDownload) downloadChunk(ctx context.Context, requestURL *url.URL, w io.Writer, part *blobDownloadPart, opts *RegistryOptions) error {
g, ctx := errgroup.WithContext(ctx)
g.Go(func() error {
headers := make(http.Header) headers := make(http.Header)
headers.Set("Range", fmt.Sprintf("bytes=%d-%d", part.StartsAt(), part.StopsAt()-1)) headers.Set("Range", fmt.Sprintf("bytes=%d-%d", part.StartsAt(), part.StopsAt()-1))
resp, err := makeRequestWithRetry(ctx, http.MethodGet, requestURL, headers, nil, opts) resp, err := makeRequestWithRetry(ctx, http.MethodGet, requestURL, headers, nil, opts)
@ -203,7 +221,7 @@ func (b *blobDownload) downloadChunk(ctx context.Context, requestURL *url.URL, w
} }
defer resp.Body.Close() defer resp.Body.Close()
n, err := io.Copy(w, io.TeeReader(resp.Body, b)) n, err := io.Copy(w, io.TeeReader(resp.Body, part))
if err != nil && !errors.Is(err, context.Canceled) && !errors.Is(err, io.ErrUnexpectedEOF) { if err != nil && !errors.Is(err, context.Canceled) && !errors.Is(err, io.ErrUnexpectedEOF) {
// rollback progress // rollback progress
b.Completed.Add(-n) b.Completed.Add(-n)
@ -217,6 +235,30 @@ func (b *blobDownload) downloadChunk(ctx context.Context, requestURL *url.URL, w
// return nil or context.Canceled or UnexpectedEOF (resumable) // return nil or context.Canceled or UnexpectedEOF (resumable)
return err return err
})
g.Go(func() error {
ticker := time.NewTicker(time.Second)
for {
select {
case <-ticker.C:
if part.Completed >= part.Size {
return nil
}
if !part.lastUpdated.IsZero() && time.Since(part.lastUpdated) > 5*time.Second {
slog.Info(fmt.Sprintf("%s part %d stalled; retrying", b.Digest[7:19], part.N))
// reset last updated
part.lastUpdated = time.Time{}
return errPartStalled
}
case <-ctx.Done():
return ctx.Err()
}
}
})
return g.Wait()
} }
func (b *blobDownload) newPart(offset, size int64) error { func (b *blobDownload) newPart(offset, size int64) error {
@ -255,12 +297,6 @@ func (b *blobDownload) writePart(partName string, part *blobDownloadPart) error
return json.NewEncoder(partFile).Encode(part) return json.NewEncoder(partFile).Encode(part)
} }
func (b *blobDownload) Write(p []byte) (n int, err error) {
n = len(p)
b.Completed.Add(int64(n))
return n, nil
}
func (b *blobDownload) acquire() { func (b *blobDownload) acquire() {
b.references.Add(1) b.references.Add(1)
} }
@ -279,10 +315,6 @@ func (b *blobDownload) Wait(ctx context.Context, fn func(api.ProgressResponse))
for { for {
select { select {
case <-ticker.C: case <-ticker.C:
case <-ctx.Done():
return ctx.Err()
}
fn(api.ProgressResponse{ fn(api.ProgressResponse{
Status: fmt.Sprintf("pulling %s", b.Digest[7:19]), Status: fmt.Sprintf("pulling %s", b.Digest[7:19]),
Digest: b.Digest, Digest: b.Digest,
@ -293,6 +325,9 @@ func (b *blobDownload) Wait(ctx context.Context, fn func(api.ProgressResponse))
if b.done || b.err != nil { if b.done || b.err != nil {
return b.err return b.err
} }
case <-ctx.Done():
return ctx.Err()
}
} }
} }
@ -303,10 +338,6 @@ type downloadOpts struct {
fn func(api.ProgressResponse) fn func(api.ProgressResponse)
} }
const maxRetries = 6
var errMaxRetriesExceeded = errors.New("max retries exceeded")
// downloadBlob downloads a blob from the registry and stores it in the blobs directory // downloadBlob downloads a blob from the registry and stores it in the blobs directory
func downloadBlob(ctx context.Context, opts downloadOpts) error { func downloadBlob(ctx context.Context, opts downloadOpts) error {
fp, err := GetBlobsPath(opts.digest) fp, err := GetBlobsPath(opts.digest)

View file

@ -41,7 +41,7 @@ type Model struct {
Config ConfigV2 Config ConfigV2
ShortName string ShortName string
ModelPath string ModelPath string
OriginalModel string ParentModel string
AdapterPaths []string AdapterPaths []string
ProjectorPaths []string ProjectorPaths []string
Template string Template string
@ -50,6 +50,12 @@ type Model struct {
Digest string Digest string
Size int64 Size int64
Options map[string]interface{} Options map[string]interface{}
Messages []Message
}
type Message struct {
Role string `json:"role"`
Content string `json:"content"`
} }
type PromptVars struct { type PromptVars struct {
@ -333,7 +339,7 @@ func GetModel(name string) (*Model, error) {
switch layer.MediaType { switch layer.MediaType {
case "application/vnd.ollama.image.model": case "application/vnd.ollama.image.model":
model.ModelPath = filename model.ModelPath = filename
model.OriginalModel = layer.From model.ParentModel = layer.From
case "application/vnd.ollama.image.embed": case "application/vnd.ollama.image.embed":
// Deprecated in versions > 0.1.2 // Deprecated in versions > 0.1.2
// TODO: remove this warning in a future version // TODO: remove this warning in a future version
@ -374,6 +380,16 @@ func GetModel(name string) (*Model, error) {
if err = json.NewDecoder(params).Decode(&model.Options); err != nil { if err = json.NewDecoder(params).Decode(&model.Options); err != nil {
return nil, err return nil, err
} }
case "application/vnd.ollama.image.messages":
msgs, err := os.Open(filename)
if err != nil {
return nil, err
}
defer msgs.Close()
if err = json.NewDecoder(msgs).Decode(&model.Messages); err != nil {
return nil, err
}
case "application/vnd.ollama.image.license": case "application/vnd.ollama.image.license":
bts, err := os.ReadFile(filename) bts, err := os.ReadFile(filename)
if err != nil { if err != nil {
@ -428,12 +444,12 @@ func CreateModel(ctx context.Context, name, modelFileDir string, commands []pars
} }
var layers Layers var layers Layers
messages := []string{}
params := make(map[string][]string) params := make(map[string][]string)
fromParams := make(map[string]any) fromParams := make(map[string]any)
for _, c := range commands { for _, c := range commands {
slog.Info(fmt.Sprintf("[%s] - %s", c.Name, c.Args))
mediatype := fmt.Sprintf("application/vnd.ollama.image.%s", c.Name) mediatype := fmt.Sprintf("application/vnd.ollama.image.%s", c.Name)
switch c.Name { switch c.Name {
@ -607,11 +623,37 @@ func CreateModel(ctx context.Context, name, modelFileDir string, commands []pars
} }
layers.Replace(layer) layers.Replace(layer)
case "message":
messages = append(messages, c.Args)
default: default:
params[c.Name] = append(params[c.Name], c.Args) params[c.Name] = append(params[c.Name], c.Args)
} }
} }
if len(messages) > 0 {
fn(api.ProgressResponse{Status: "creating parameters layer"})
msgs := make([]api.Message, 0)
for _, m := range messages {
// todo: handle images
msg := strings.SplitN(m, ": ", 2)
msgs = append(msgs, api.Message{Role: msg[0], Content: msg[1]})
}
var b bytes.Buffer
if err := json.NewEncoder(&b).Encode(msgs); err != nil {
return err
}
layer, err := NewLayer(&b, "application/vnd.ollama.image.messages")
if err != nil {
return err
}
layers.Replace(layer)
}
if len(params) > 0 { if len(params) > 0 {
fn(api.ProgressResponse{Status: "creating parameters layer"}) fn(api.ProgressResponse{Status: "creating parameters layer"})
@ -908,8 +950,8 @@ func ShowModelfile(model *Model) (string, error) {
mt.Model = model mt.Model = model
mt.From = model.ModelPath mt.From = model.ModelPath
if model.OriginalModel != "" { if model.ParentModel != "" {
mt.From = model.OriginalModel mt.From = model.ParentModel
} }
modelFile := `# Modelfile generated by "ollama show" modelFile := `# Modelfile generated by "ollama show"

View file

@ -186,7 +186,13 @@ func GenerateHandler(c *gin.Context) {
return return
} }
sessionDuration := defaultSessionDuration var sessionDuration time.Duration
if req.KeepAlive == nil {
sessionDuration = defaultSessionDuration
} else {
sessionDuration = req.KeepAlive.Duration
}
if err := load(c, model, opts, sessionDuration); err != nil { if err := load(c, model, opts, sessionDuration); err != nil {
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()}) c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
return return
@ -378,7 +384,14 @@ func EmbeddingHandler(c *gin.Context) {
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()}) c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
return return
} }
sessionDuration := defaultSessionDuration
var sessionDuration time.Duration
if req.KeepAlive == nil {
sessionDuration = defaultSessionDuration
} else {
sessionDuration = req.KeepAlive.Duration
}
if err := load(c, model, opts, sessionDuration); err != nil { if err := load(c, model, opts, sessionDuration); err != nil {
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()}) c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
return return
@ -659,6 +672,7 @@ func GetModelInfo(req api.ShowRequest) (*api.ShowResponse, error) {
} }
modelDetails := api.ModelDetails{ modelDetails := api.ModelDetails{
ParentModel: model.ParentModel,
Format: model.Config.ModelFormat, Format: model.Config.ModelFormat,
Family: model.Config.ModelFamily, Family: model.Config.ModelFamily,
Families: model.Config.ModelFamilies, Families: model.Config.ModelFamilies,
@ -674,11 +688,17 @@ func GetModelInfo(req api.ShowRequest) (*api.ShowResponse, error) {
model.Template = req.Template model.Template = req.Template
} }
msgs := make([]api.Message, 0)
for _, msg := range model.Messages {
msgs = append(msgs, api.Message{Role: msg.Role, Content: msg.Content})
}
resp := &api.ShowResponse{ resp := &api.ShowResponse{
License: strings.Join(model.License, "\n"), License: strings.Join(model.License, "\n"),
System: model.System, System: model.System,
Template: model.Template, Template: model.Template,
Details: modelDetails, Details: modelDetails,
Messages: msgs,
} }
var params []string var params []string
@ -1067,7 +1087,14 @@ func ChatHandler(c *gin.Context) {
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()}) c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
return return
} }
sessionDuration := defaultSessionDuration
var sessionDuration time.Duration
if req.KeepAlive == nil {
sessionDuration = defaultSessionDuration
} else {
sessionDuration = req.KeepAlive.Duration
}
if err := load(c, model, opts, sessionDuration); err != nil { if err := load(c, model, opts, sessionDuration); err != nil {
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()}) c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
return return
@ -1075,7 +1102,13 @@ func ChatHandler(c *gin.Context) {
// an empty request loads the model // an empty request loads the model
if len(req.Messages) == 0 { if len(req.Messages) == 0 {
c.JSON(http.StatusOK, api.ChatResponse{CreatedAt: time.Now().UTC(), Model: req.Model, Done: true, Message: api.Message{Role: "assistant"}}) resp := api.ChatResponse{
CreatedAt: time.Now().UTC(),
Model: req.Model,
Done: true,
Message: api.Message{Role: "assistant"},
}
c.JSON(http.StatusOK, resp)
return return
} }