diff --git a/.github/workflows/test.yaml b/.github/workflows/test.yaml index ad178cab..f5174c33 100644 --- a/.github/workflows/test.yaml +++ b/.github/workflows/test.yaml @@ -23,29 +23,72 @@ jobs: with: go-version: '1.21' cache: true - - if: ${{ startsWith(matrix.os, 'windows-') }} - shell: pwsh - run: | - $path = vswhere -latest -products * -requires Microsoft.VisualStudio.Component.VC.Tools.x86.x64 -property installationPath - if ($path) { - $path = join-path $path 'Common7\Tools\vsdevcmd.bat' - if (test-path $path) { - cmd /s /c """$path"" $args && set" | where { $_ -match '(\w+)=(.*)' } | foreach { - echo "$($Matches[1])=$($Matches[2])" | Out-File -FilePath $Env:GITHUB_ENV -Encoding utf8 -Append - } - } - } - - echo "C:\Program Files\Git\usr\bin" | Out-File -FilePath $Env:GITHUB_PATH -Encoding utf8 -Append - run: go get ./... - run: go generate -x ./... - uses: actions/upload-artifact@v4 with: name: ${{ matrix.os }}-${{ matrix.arch }}-libraries - path: | - llm/llama.cpp/build/**/lib/* + path: llm/llama.cpp/build/**/lib/* + generate-cuda: + strategy: + matrix: + cuda-version: + - '11.8.0' + runs-on: ubuntu-latest + container: nvidia/cuda:${{ matrix.cuda-version }}-devel-ubuntu20.04 + steps: + - run: | + apt-get update && apt-get install -y git build-essential curl + curl -fsSL https://github.com/Kitware/CMake/releases/download/v3.28.1/cmake-3.28.1-linux-x86_64.tar.gz \ + | tar -zx -C /usr --strip-components 1 + env: + DEBIAN_FRONTEND: noninteractive + - uses: actions/checkout@v4 + - uses: actions/setup-go@v4 + with: + go-version: '1.21' + cache: true + - run: go get ./... + - run: | + git config --global --add safe.directory /__w/ollama/ollama + go generate -x ./... + env: + OLLAMA_SKIP_CPU_GENERATE: '1' + - uses: actions/upload-artifact@v4 + with: + name: cuda-${{ matrix.cuda-version }}-libraries + path: llm/llama.cpp/build/**/lib/* + generate-rocm: + strategy: + matrix: + rocm-version: + - '5.7.1' + - '6.0' + runs-on: ubuntu-latest + container: rocm/dev-ubuntu-20.04:${{ matrix.rocm-version }} + steps: + - run: | + apt-get update && apt-get install -y git build-essential curl rocm-libs + curl -fsSL https://github.com/Kitware/CMake/releases/download/v3.28.1/cmake-3.28.1-linux-x86_64.tar.gz \ + | tar -zx -C /usr --strip-components 1 + env: + DEBIAN_FRONTEND: noninteractive + - uses: actions/checkout@v4 + - uses: actions/setup-go@v4 + with: + go-version: '1.21' + cache: true + - run: go get ./... + - run: | + git config --global --add safe.directory /__w/ollama/ollama + go generate -x ./... + env: + OLLAMA_SKIP_CPU_GENERATE: '1' + - uses: actions/upload-artifact@v4 + with: + name: rocm-${{ matrix.rocm-version }}-libraries + path: llm/llama.cpp/build/**/lib/* lint: - needs: generate strategy: matrix: os: [ubuntu-latest, macos-latest, windows-latest] @@ -69,10 +112,19 @@ jobs: with: go-version: '1.21' cache: false - - uses: actions/download-artifact@v4 - with: - name: ${{ matrix.os }}-${{ matrix.arch }}-libraries - path: llm/llama.cpp/build + - run: | + mkdir -p llm/llama.cpp/build/linux/${{ matrix.arch }}/stub/lib/ + touch llm/llama.cpp/build/linux/${{ matrix.arch }}/stub/lib/stub.so + if: ${{ startsWith(matrix.os, 'ubuntu-') }} + - run: | + mkdir -p llm/llama.cpp/build/darwin/${{ matrix.arch }}/stub/lib/ + touch llm/llama.cpp/build/darwin/${{ matrix.arch }}/stub/lib/stub.dylib + touch llm/llama.cpp/ggml-metal.metal + if: ${{ startsWith(matrix.os, 'macos-') }} + - run: | + mkdir -p llm/llama.cpp/build/windows/${{ matrix.arch }}/stub/lib/ + touch llm/llama.cpp/build/windows/${{ matrix.arch }}/stub/lib/stub.dll + if: ${{ startsWith(matrix.os, 'windows-') }} - uses: golangci/golangci-lint-action@v3 test: needs: generate @@ -104,3 +156,7 @@ jobs: path: llm/llama.cpp/build - run: go build - run: go test -v ./... + - uses: actions/upload-artifact@v4 + with: + name: ${{ matrix.os }}-binaries + path: ollama diff --git a/Dockerfile b/Dockerfile index 9767faa3..7c921df8 100644 --- a/Dockerfile +++ b/Dockerfile @@ -109,17 +109,28 @@ ARG CGO_CFLAGS RUN go build . # Runtime stages -FROM --platform=linux/amd64 rocm/dev-centos-7:6.0-complete as runtime-amd64 +FROM --platform=linux/amd64 ubuntu:22.04 as runtime-amd64 +RUN apt-get update && apt-get install -y ca-certificates COPY --from=build-amd64 /go/src/github.com/jmorganca/ollama/ollama /bin/ollama FROM --platform=linux/arm64 ubuntu:22.04 as runtime-arm64 RUN apt-get update && apt-get install -y ca-certificates COPY --from=build-arm64 /go/src/github.com/jmorganca/ollama/ollama /bin/ollama +# Radeon images are much larger so we keep it distinct from the CPU/CUDA image +FROM --platform=linux/amd64 rocm/dev-centos-7:5.7.1-complete as runtime-rocm +RUN update-pciids +COPY --from=build-amd64 /go/src/github.com/jmorganca/ollama/ollama /bin/ollama +EXPOSE 11434 +ENV OLLAMA_HOST 0.0.0.0 + +ENTRYPOINT ["/bin/ollama"] +CMD ["serve"] + FROM runtime-$TARGETARCH EXPOSE 11434 ENV OLLAMA_HOST 0.0.0.0 ENV PATH=/usr/local/sbin:/usr/local/bin:/usr/sbin:/usr/bin:/sbin:/bin -ENV LD_LIBRARY_PATH=/usr/local/nvidia/lib:/usr/local/nvidia/lib64:/opt/rocm/lib: +ENV LD_LIBRARY_PATH=/usr/local/nvidia/lib:/usr/local/nvidia/lib64 ENV NVIDIA_DRIVER_CAPABILITIES=compute,utility ENTRYPOINT ["/bin/ollama"] diff --git a/api/types.go b/api/types.go index d4e385bf..609c4a8a 100644 --- a/api/types.go +++ b/api/types.go @@ -34,24 +34,26 @@ func (e StatusError) Error() string { type ImageData []byte type GenerateRequest struct { - Model string `json:"model"` - Prompt string `json:"prompt"` - System string `json:"system"` - Template string `json:"template"` - Context []int `json:"context,omitempty"` - Stream *bool `json:"stream,omitempty"` - Raw bool `json:"raw,omitempty"` - Format string `json:"format"` - Images []ImageData `json:"images,omitempty"` + Model string `json:"model"` + Prompt string `json:"prompt"` + System string `json:"system"` + Template string `json:"template"` + Context []int `json:"context,omitempty"` + Stream *bool `json:"stream,omitempty"` + Raw bool `json:"raw,omitempty"` + Format string `json:"format"` + KeepAlive *Duration `json:"keep_alive,omitempty"` + Images []ImageData `json:"images,omitempty"` Options map[string]interface{} `json:"options"` } type ChatRequest struct { - Model string `json:"model"` - Messages []Message `json:"messages"` - Stream *bool `json:"stream,omitempty"` - Format string `json:"format"` + Model string `json:"model"` + Messages []Message `json:"messages"` + Stream *bool `json:"stream,omitempty"` + Format string `json:"format"` + KeepAlive *Duration `json:"keep_alive,omitempty"` Options map[string]interface{} `json:"options"` } @@ -126,8 +128,9 @@ type Runner struct { } type EmbeddingRequest struct { - Model string `json:"model"` - Prompt string `json:"prompt"` + Model string `json:"model"` + Prompt string `json:"prompt"` + KeepAlive *Duration `json:"keep_alive,omitempty"` Options map[string]interface{} `json:"options"` } @@ -171,6 +174,7 @@ type ShowResponse struct { Template string `json:"template,omitempty"` System string `json:"system,omitempty"` Details ModelDetails `json:"details,omitempty"` + Messages []Message `json:"messages,omitempty"` } type CopyRequest struct { @@ -236,6 +240,7 @@ type GenerateResponse struct { } type ModelDetails struct { + ParentModel string `json:"parent_model"` Format string `json:"format"` Family string `json:"family"` Families []string `json:"families"` @@ -411,14 +416,19 @@ func (d *Duration) UnmarshalJSON(b []byte) (err error) { case float64: if t < 0 { t = math.MaxFloat64 + d.Duration = time.Duration(t) + } else { + d.Duration = time.Duration(t * float64(time.Second)) } - - d.Duration = time.Duration(t) case string: d.Duration, err = time.ParseDuration(t) if err != nil { return err } + if d.Duration < 0 { + mf := math.MaxFloat64 + d.Duration = time.Duration(mf) + } } return nil diff --git a/cmd/cmd.go b/cmd/cmd.go index 76e3c7a9..915fa993 100644 --- a/cmd/cmd.go +++ b/cmd/cmd.go @@ -458,15 +458,17 @@ func RunGenerate(cmd *cobra.Command, args []string) error { type generateContextKey string type runOptions struct { - Model string - Prompt string - Messages []api.Message - WordWrap bool - Format string - System string - Template string - Images []api.ImageData - Options map[string]interface{} + Model string + ParentModel string + Prompt string + Messages []api.Message + WordWrap bool + Format string + System string + Template string + Images []api.ImageData + Options map[string]interface{} + MultiModal bool } type displayResponseState struct { diff --git a/cmd/interactive.go b/cmd/interactive.go index da3c5b72..d337e555 100644 --- a/cmd/interactive.go +++ b/cmd/interactive.go @@ -7,12 +7,14 @@ import ( "net/http" "os" "regexp" + "sort" "strings" "github.com/spf13/cobra" "golang.org/x/exp/slices" "github.com/jmorganca/ollama/api" + "github.com/jmorganca/ollama/progress" "github.com/jmorganca/ollama/readline" ) @@ -25,43 +27,75 @@ const ( MultilineTemplate ) -func modelIsMultiModal(cmd *cobra.Command, name string) bool { - // get model details +func loadModel(cmd *cobra.Command, opts *runOptions) error { client, err := api.ClientFromEnvironment() if err != nil { - fmt.Println("error: couldn't connect to ollama server") - return false + return err } - req := api.ShowRequest{Name: name} - resp, err := client.Show(cmd.Context(), &req) + p := progress.NewProgress(os.Stderr) + defer p.StopAndClear() + + spinner := progress.NewSpinner("") + p.Add("", spinner) + + showReq := api.ShowRequest{Name: opts.Model} + showResp, err := client.Show(cmd.Context(), &showReq) if err != nil { - return false + return err + } + opts.MultiModal = slices.Contains(showResp.Details.Families, "clip") + opts.ParentModel = showResp.Details.ParentModel + + if len(showResp.Messages) > 0 { + opts.Messages = append(opts.Messages, showResp.Messages...) } - return slices.Contains(resp.Details.Families, "clip") + chatReq := &api.ChatRequest{ + Model: opts.Model, + Messages: []api.Message{}, + } + err = client.Chat(cmd.Context(), chatReq, func(resp api.ChatResponse) error { + p.StopAndClear() + if len(opts.Messages) > 0 { + for _, msg := range opts.Messages { + switch msg.Role { + case "user": + fmt.Printf(">>> %s\n", msg.Content) + case "assistant": + state := &displayResponseState{} + displayResponse(msg.Content, opts.WordWrap, state) + fmt.Println() + fmt.Println() + } + } + } + return nil + }) + if err != nil { + return err + } + + return nil } func generateInteractive(cmd *cobra.Command, opts runOptions) error { - multiModal := modelIsMultiModal(cmd, opts.Model) + opts.Messages = make([]api.Message, 0) - // load the model - loadOpts := runOptions{ - Model: opts.Model, - Prompt: "", - Messages: []api.Message{}, - } - if _, err := chat(cmd, loadOpts); err != nil { + err := loadModel(cmd, &opts) + if err != nil { return err } usage := func() { fmt.Fprintln(os.Stderr, "Available Commands:") - fmt.Fprintln(os.Stderr, " /set Set session variables") - fmt.Fprintln(os.Stderr, " /show Show model information") - fmt.Fprintln(os.Stderr, " /bye Exit") - fmt.Fprintln(os.Stderr, " /?, /help Help for a command") - fmt.Fprintln(os.Stderr, " /? shortcuts Help for keyboard shortcuts") + fmt.Fprintln(os.Stderr, " /set Set session variables") + fmt.Fprintln(os.Stderr, " /show Show model information") + fmt.Fprintln(os.Stderr, " /load Load a session or model") + fmt.Fprintln(os.Stderr, " /save Save your current session") + fmt.Fprintln(os.Stderr, " /bye Exit") + fmt.Fprintln(os.Stderr, " /?, /help Help for a command") + fmt.Fprintln(os.Stderr, " /? shortcuts Help for keyboard shortcuts") fmt.Fprintln(os.Stderr, "") fmt.Fprintln(os.Stderr, "Use \"\"\" to begin a multi-line message.") fmt.Fprintln(os.Stderr, "") @@ -140,7 +174,6 @@ func generateInteractive(cmd *cobra.Command, opts runOptions) error { var sb strings.Builder var multiline MultilineState - opts.Messages = make([]api.Message, 0) for { line, err := scanner.Readline() @@ -203,6 +236,44 @@ func generateInteractive(cmd *cobra.Command, opts runOptions) error { if err := ListHandler(cmd, args[1:]); err != nil { return err } + case strings.HasPrefix(line, "/load"): + args := strings.Fields(line) + if len(args) != 2 { + fmt.Println("Usage:\n /load ") + continue + } + opts.Model = args[1] + opts.Messages = []api.Message{} + fmt.Printf("Loading model '%s'\n", opts.Model) + if err := loadModel(cmd, &opts); err != nil { + return err + } + continue + case strings.HasPrefix(line, "/save"): + args := strings.Fields(line) + if len(args) != 2 { + fmt.Println("Usage:\n /save ") + continue + } + + client, err := api.ClientFromEnvironment() + if err != nil { + fmt.Println("error: couldn't connect to ollama server") + return err + } + + req := &api.CreateRequest{ + Name: args[1], + Modelfile: buildModelfile(opts), + } + fn := func(resp api.ProgressResponse) error { return nil } + err = client.Create(cmd.Context(), req, fn) + if err != nil { + fmt.Println("error: couldn't save model") + return err + } + fmt.Printf("Created new model '%s'\n", args[1]) + continue case strings.HasPrefix(line, "/set"): args := strings.Fields(line) if len(args) > 1 { @@ -389,7 +460,7 @@ func generateInteractive(cmd *cobra.Command, opts runOptions) error { args := strings.Fields(line) isFile := false - if multiModal { + if opts.MultiModal { for _, f := range extractFileNames(line) { if strings.HasPrefix(f, args[0]) { isFile = true @@ -411,7 +482,7 @@ func generateInteractive(cmd *cobra.Command, opts runOptions) error { if sb.Len() > 0 && multiline == MultilineNone { newMessage := api.Message{Role: "user", Content: sb.String()} - if multiModal { + if opts.MultiModal { msg, images, err := extractFileData(sb.String()) if err != nil { return err @@ -454,6 +525,38 @@ func generateInteractive(cmd *cobra.Command, opts runOptions) error { } } +func buildModelfile(opts runOptions) string { + var mf strings.Builder + model := opts.ParentModel + if model == "" { + model = opts.Model + } + fmt.Fprintf(&mf, "FROM %s\n", model) + if opts.System != "" { + fmt.Fprintf(&mf, "SYSTEM \"\"\"%s\"\"\"\n", opts.System) + } + + if opts.Template != "" { + fmt.Fprintf(&mf, "TEMPLATE \"\"\"%s\"\"\"\n", opts.Template) + } + + keys := make([]string, 0) + for k := range opts.Options { + keys = append(keys, k) + } + sort.Strings(keys) + for _, k := range keys { + fmt.Fprintf(&mf, "PARAMETER %s %v\n", k, opts.Options[k]) + } + fmt.Fprintln(&mf) + + for _, msg := range opts.Messages { + fmt.Fprintf(&mf, "MESSAGE %s \"\"\"%s\"\"\"\n", msg.Role, msg.Content) + } + + return mf.String() +} + func normalizeFilePath(fp string) string { // Define a map of escaped characters and their replacements replacements := map[string]string{ diff --git a/cmd/interactive_test.go b/cmd/interactive_test.go index 1bd5058a..19e43287 100644 --- a/cmd/interactive_test.go +++ b/cmd/interactive_test.go @@ -1,9 +1,13 @@ package cmd import ( + "bytes" "testing" + "text/template" "github.com/stretchr/testify/assert" + + "github.com/jmorganca/ollama/api" ) func TestExtractFilenames(t *testing.T) { @@ -49,3 +53,64 @@ d:\path with\spaces\seven.svg inbetween7 c:\users\jdoe\eight.png inbetween8 assert.Contains(t, res[9], "ten.svg") assert.Contains(t, res[9], "E:") } + +func TestModelfileBuilder(t *testing.T) { + opts := runOptions{ + Model: "hork", + System: "You are part horse and part shark, but all hork. Do horklike things", + Template: "This is a template.", + Messages: []api.Message{ + {Role: "user", Content: "Hey there hork!"}, + {Role: "assistant", Content: "Yes it is true, I am half horse, half shark."}, + }, + Options: map[string]interface{}{}, + } + + opts.Options["temperature"] = 0.9 + opts.Options["seed"] = 42 + opts.Options["penalize_newline"] = false + opts.Options["stop"] = []string{"hi", "there"} + + mf := buildModelfile(opts) + expectedModelfile := `FROM {{.Model}} +SYSTEM """{{.System}}""" +TEMPLATE """{{.Template}}""" +PARAMETER penalize_newline false +PARAMETER seed 42 +PARAMETER stop [hi there] +PARAMETER temperature 0.9 + +MESSAGE user """Hey there hork!""" +MESSAGE assistant """Yes it is true, I am half horse, half shark.""" +` + + tmpl, err := template.New("").Parse(expectedModelfile) + assert.Nil(t, err) + + var buf bytes.Buffer + err = tmpl.Execute(&buf, opts) + assert.Nil(t, err) + assert.Equal(t, buf.String(), mf) + + opts.ParentModel = "horseshark" + mf = buildModelfile(opts) + expectedModelfile = `FROM {{.ParentModel}} +SYSTEM """{{.System}}""" +TEMPLATE """{{.Template}}""" +PARAMETER penalize_newline false +PARAMETER seed 42 +PARAMETER stop [hi there] +PARAMETER temperature 0.9 + +MESSAGE user """Hey there hork!""" +MESSAGE assistant """Yes it is true, I am half horse, half shark.""" +` + + tmpl, err = template.New("").Parse(expectedModelfile) + assert.Nil(t, err) + + var parentBuf bytes.Buffer + err = tmpl.Execute(&parentBuf, opts) + assert.Nil(t, err) + assert.Equal(t, parentBuf.String(), mf) +} diff --git a/docs/development.md b/docs/development.md index ac45a3e0..59651b1f 100644 --- a/docs/development.md +++ b/docs/development.md @@ -50,7 +50,8 @@ development and runtime packages. Typically the build scripts will auto-detect CUDA, however, if your Linux distro or installation approach uses unusual paths, you can specify the location by specifying an environment variable `CUDA_LIB_DIR` to the location of the shared -libraries, and `CUDACXX` to the location of the nvcc compiler. +libraries, and `CUDACXX` to the location of the nvcc compiler. You can customize +set set of target CUDA architectues by setting `CMAKE_CUDA_ARCHITECTURES` (e.g. "50;60;70") Then generate dependencies: diff --git a/docs/modelfile.md b/docs/modelfile.md index 6134bf9c..6d6ac152 100644 --- a/docs/modelfile.md +++ b/docs/modelfile.md @@ -19,6 +19,7 @@ A model file is the blueprint to create and share models with Ollama. - [SYSTEM](#system) - [ADAPTER](#adapter) - [LICENSE](#license) + - [MESSAGE](#message) - [Notes](#notes) ## Format @@ -38,6 +39,7 @@ INSTRUCTION arguments | [`SYSTEM`](#system) | Specifies the system message that will be set in the template. | | [`ADAPTER`](#adapter) | Defines the (Q)LoRA adapters to apply to the model. | | [`LICENSE`](#license) | Specifies the legal license. | +| [`MESSAGE`](#message) | Specify message history. | ## Examples @@ -205,6 +207,19 @@ LICENSE """ """ ``` +### MESSAGE + +The `MESSAGE` instruction allows you to specify a message history for the model to use when responding: + +```modelfile +MESSAGE user Is Toronto in Canada? +MESSAGE assistant yes +MESSAGE user Is Sacramento in Canada? +MESSAGE assistant no +MESSAGE user Is Ontario in Canada? +MESSAGE assistant yes +``` + ## Notes - the **`Modelfile` is not case sensitive**. In the examples, uppercase instructions are used to make it easier to distinguish it from arguments. diff --git a/gpu/gpu.go b/gpu/gpu.go index fb120ea5..550467a3 100644 --- a/gpu/gpu.go +++ b/gpu/gpu.go @@ -16,6 +16,7 @@ import ( "os" "path/filepath" "runtime" + "strconv" "strings" "sync" "unsafe" @@ -29,8 +30,8 @@ type handles struct { var gpuMutex sync.Mutex var gpuHandles *handles = nil -// With our current CUDA compile flags, 5.2 and older will not work properly -const CudaComputeMajorMin = 6 +// With our current CUDA compile flags, older than 5.0 will not work properly +var CudaComputeMin = [2]C.int{5, 0} // Possible locations for the nvidia-ml library var CudaLinuxGlobs = []string{ @@ -121,9 +122,15 @@ func GetGPUInfo() GpuInfo { initGPUHandles() } + // All our GPU builds have AVX enabled, so fallback to CPU if we don't detect at least AVX + cpuVariant := GetCPUVariant() + if cpuVariant == "" { + slog.Warn("CPU does not have AVX or AVX2, disabling GPU support.") + } + var memInfo C.mem_info_t resp := GpuInfo{} - if gpuHandles.cuda != nil { + if gpuHandles.cuda != nil && cpuVariant != "" { C.cuda_check_vram(*gpuHandles.cuda, &memInfo) if memInfo.err != nil { slog.Info(fmt.Sprintf("error looking up CUDA GPU memory: %s", C.GoString(memInfo.err))) @@ -135,19 +142,40 @@ func GetGPUInfo() GpuInfo { if cc.err != nil { slog.Info(fmt.Sprintf("error looking up CUDA GPU compute capability: %s", C.GoString(cc.err))) C.free(unsafe.Pointer(cc.err)) - } else if cc.major >= CudaComputeMajorMin { + } else if cc.major > CudaComputeMin[0] || (cc.major == CudaComputeMin[0] && cc.minor >= CudaComputeMin[1]) { slog.Info(fmt.Sprintf("CUDA Compute Capability detected: %d.%d", cc.major, cc.minor)) resp.Library = "cuda" } else { slog.Info(fmt.Sprintf("CUDA GPU is too old. Falling back to CPU mode. Compute Capability detected: %d.%d", cc.major, cc.minor)) } } - } else if gpuHandles.rocm != nil { + } else if gpuHandles.rocm != nil && cpuVariant != "" { C.rocm_check_vram(*gpuHandles.rocm, &memInfo) if memInfo.err != nil { slog.Info(fmt.Sprintf("error looking up ROCm GPU memory: %s", C.GoString(memInfo.err))) C.free(unsafe.Pointer(memInfo.err)) + } else if memInfo.igpu_index >= 0 && memInfo.count == 1 { + // Only one GPU detected and it appears to be an integrated GPU - skip it + slog.Info("ROCm unsupported integrated GPU detected") } else { + if memInfo.igpu_index >= 0 { + // We have multiple GPUs reported, and one of them is an integrated GPU + // so we have to set the env var to bypass it + // If the user has specified their own ROCR_VISIBLE_DEVICES, don't clobber it + val := os.Getenv("ROCR_VISIBLE_DEVICES") + if val == "" { + devices := []string{} + for i := 0; i < int(memInfo.count); i++ { + if i == int(memInfo.igpu_index) { + continue + } + devices = append(devices, strconv.Itoa(i)) + } + val = strings.Join(devices, ",") + os.Setenv("ROCR_VISIBLE_DEVICES", val) + } + slog.Info(fmt.Sprintf("ROCm integrated GPU detected - ROCR_VISIBLE_DEVICES=%s", val)) + } resp.Library = "rocm" var version C.rocm_version_resp_t C.rocm_get_version(*gpuHandles.rocm, &version) @@ -163,7 +191,7 @@ func GetGPUInfo() GpuInfo { if resp.Library == "" { C.cpu_check_ram(&memInfo) resp.Library = "cpu" - resp.Variant = GetCPUVariant() + resp.Variant = cpuVariant } if memInfo.err != nil { slog.Info(fmt.Sprintf("error looking up CPU memory: %s", C.GoString(memInfo.err))) @@ -199,7 +227,9 @@ func CheckVRAM() (int64, error) { if overhead < gpus*1024*1024*1024 { overhead = gpus * 1024 * 1024 * 1024 } - return int64(gpuInfo.FreeMemory - overhead), nil + avail := int64(gpuInfo.FreeMemory - overhead) + slog.Debug(fmt.Sprintf("%s detected %d devices with %dM available memory", gpuInfo.Library, gpuInfo.DeviceCount, avail/1024/1024)) + return avail, nil } return 0, fmt.Errorf("no GPU detected") // TODO - better handling of CPU based memory determiniation diff --git a/gpu/gpu_info.h b/gpu/gpu_info.h index f32efa8e..e52d2066 100644 --- a/gpu/gpu_info.h +++ b/gpu/gpu_info.h @@ -42,6 +42,7 @@ typedef struct mem_info { uint64_t total; uint64_t free; unsigned int count; + int igpu_index; // If >= 0, we detected an integrated GPU to ignore char *err; // If non-nill, caller responsible for freeing } mem_info_t; diff --git a/gpu/gpu_info_cuda.c b/gpu/gpu_info_cuda.c index 9299b22c..d877ff0c 100644 --- a/gpu/gpu_info_cuda.c +++ b/gpu/gpu_info_cuda.c @@ -70,6 +70,7 @@ void cuda_init(char *cuda_lib_path, cuda_init_resp_t *resp) { resp->ch.handle = NULL; snprintf(buf, buflen, "nvml vram init failure: %d", ret); resp->err = strdup(buf); + return; } // Report driver version if we're in verbose mode, ignore errors diff --git a/gpu/gpu_info_rocm.c b/gpu/gpu_info_rocm.c index 59ab0817..7ac88611 100644 --- a/gpu/gpu_info_rocm.c +++ b/gpu/gpu_info_rocm.c @@ -77,6 +77,7 @@ void rocm_init(char *rocm_lib_path, rocm_init_resp_t *resp) { void rocm_check_vram(rocm_handle_t h, mem_info_t *resp) { resp->err = NULL; + resp->igpu_index = -1; uint64_t totalMem = 0; uint64_t usedMem = 0; rsmi_status_t ret; @@ -162,8 +163,14 @@ void rocm_check_vram(rocm_handle_t h, mem_info_t *resp) { } LOG(h.verbose, "[%d] ROCm totalMem %ld\n", i, totalMem); LOG(h.verbose, "[%d] ROCm usedMem %ld\n", i, usedMem); - resp->total += totalMem; - resp->free += totalMem - usedMem; + if (totalMem < 1024 * 1024 * 1024) { + // Do not add up integrated GPU memory capacity, it's a bogus 512M, and actually uses system memory + LOG(h.verbose, "[%d] ROCm integrated GPU\n", i); + resp->igpu_index = i; + } else { + resp->total += totalMem; + resp->free += totalMem - usedMem; + } } } @@ -171,7 +178,7 @@ void rocm_get_version(rocm_handle_t h, rocm_version_resp_t *resp) { const int buflen = 256; char buf[buflen + 1]; if (h.handle == NULL) { - resp->str = strdup("nvml handle not initialized"); + resp->str = strdup("rocm handle not initialized"); resp->status = 1; return; } @@ -188,4 +195,4 @@ void rocm_get_version(rocm_handle_t h, rocm_version_resp_t *resp) { resp->str = strdup(buf); } -#endif // __APPLE__ \ No newline at end of file +#endif // __APPLE__ diff --git a/llm/dyn_ext_server.go b/llm/dyn_ext_server.go index 45e2dc72..8674a514 100644 --- a/llm/dyn_ext_server.go +++ b/llm/dyn_ext_server.go @@ -190,6 +190,7 @@ func (llm *dynExtServer) Predict(ctx context.Context, predict PredictOpts, fn fu "seed": predict.Options.Seed, "stop": predict.Options.Stop, "image_data": imageData, + "cache_prompt": true, } if predict.Format == "json" { diff --git a/llm/generate/gen_common.sh b/llm/generate/gen_common.sh index d1e64d7d..43d3dce5 100644 --- a/llm/generate/gen_common.sh +++ b/llm/generate/gen_common.sh @@ -39,6 +39,9 @@ init_vars() { *) ;; esac + if [ -z "${CMAKE_CUDA_ARCHITECTURES}" ] ; then + CMAKE_CUDA_ARCHITECTURES="50;52;61;70;75;80" + fi } git_module_setup() { @@ -61,6 +64,17 @@ apply_patches() { if ! grep ollama ${LLAMACPP_DIR}/examples/server/CMakeLists.txt; then echo 'include (../../../ext_server/CMakeLists.txt) # ollama' >>${LLAMACPP_DIR}/examples/server/CMakeLists.txt fi + + # apply temporary patches until fix is upstream + for patch in ../patches/*.diff; do + for file in $(grep "^+++ " ${patch} | cut -f2 -d' ' | cut -f2- -d/); do + (cd ${LLAMACPP_DIR}; git checkout ${file}) + done + done + for patch in ../patches/*.diff; do + (cd ${LLAMACPP_DIR} && git apply ${patch}) + done + # Avoid duplicate main symbols when we link into the cgo binary sed -e 's/int main(/int __main(/g' <${LLAMACPP_DIR}/examples/server/server.cpp >${LLAMACPP_DIR}/examples/server/server.cpp.tmp && mv ${LLAMACPP_DIR}/examples/server/server.cpp.tmp ${LLAMACPP_DIR}/examples/server/server.cpp diff --git a/llm/generate/gen_linux.sh b/llm/generate/gen_linux.sh index b5190cfa..65ca602e 100755 --- a/llm/generate/gen_linux.sh +++ b/llm/generate/gen_linux.sh @@ -140,7 +140,7 @@ if [ -d "${CUDA_LIB_DIR}" ]; then if [ -n "${CUDA_MAJOR}" ]; then CUDA_VARIANT=_v${CUDA_MAJOR} fi - CMAKE_DEFS="-DLLAMA_CUBLAS=on ${COMMON_CMAKE_DEFS} ${CMAKE_DEFS}" + CMAKE_DEFS="-DLLAMA_CUBLAS=on -DLLAMA_CUDA_FORCE_MMQ=on -DCMAKE_CUDA_ARCHITECTURES=${CMAKE_CUDA_ARCHITECTURES} ${COMMON_CMAKE_DEFS} ${CMAKE_DEFS}" BUILD_DIR="${LLAMACPP_DIR}/build/linux/${ARCH}/cuda${CUDA_VARIANT}" EXTRA_LIBS="-L${CUDA_LIB_DIR} -lcudart -lcublas -lcublasLt -lcuda" build diff --git a/llm/generate/gen_windows.ps1 b/llm/generate/gen_windows.ps1 index 109b8602..f7a241cc 100644 --- a/llm/generate/gen_windows.ps1 +++ b/llm/generate/gen_windows.ps1 @@ -25,6 +25,11 @@ function init_vars { } $script:GZIP=(get-command -ea 'silentlycontinue' gzip).path $script:DUMPBIN=(get-command -ea 'silentlycontinue' dumpbin).path + if ($null -eq $env:CMAKE_CUDA_ARCHITECTURES) { + $script:CMAKE_CUDA_ARCHITECTURES="50;52;61;70;75;80" + } else { + $script:CMAKE_CUDA_ARCHITECTURES=$env:CMAKE_CUDA_ARCHITECTURES + } } function git_module_setup { @@ -40,6 +45,29 @@ function apply_patches { if (!(Select-String -Path "${script:llamacppDir}/examples/server/CMakeLists.txt" -Pattern 'ollama')) { Add-Content -Path "${script:llamacppDir}/examples/server/CMakeLists.txt" -Value 'include (../../../ext_server/CMakeLists.txt) # ollama' } + + # Apply temporary patches until fix is upstream + $patches = Get-ChildItem "../patches/*.diff" + foreach ($patch in $patches) { + # Extract file paths from the patch file + $filePaths = Get-Content $patch.FullName | Where-Object { $_ -match '^\+\+\+ ' } | ForEach-Object { + $parts = $_ -split ' ' + ($parts[1] -split '/', 2)[1] + } + + # Checkout each file + foreach ($file in $filePaths) { + Set-Location -Path ${script:llamacppDir} + git checkout $file + } + } + + # Apply each patch + foreach ($patch in $patches) { + Set-Location -Path ${script:llamacppDir} + git apply $patch.FullName + } + # Avoid duplicate main symbols when we link into the cgo binary $content = Get-Content -Path "${script:llamacppDir}/examples/server/server.cpp" $content = $content -replace 'int main\(', 'int __main(' @@ -128,7 +156,7 @@ if ($null -ne $script:CUDA_LIB_DIR) { } init_vars $script:buildDir="${script:llamacppDir}/build/windows/${script:ARCH}/cuda$script:CUDA_VARIANT" - $script:cmakeDefs += @("-DLLAMA_CUBLAS=ON", "-DLLAMA_AVX=on") + $script:cmakeDefs += @("-DLLAMA_CUBLAS=ON", "-DLLAMA_AVX=on", "-DCMAKE_CUDA_ARCHITECTURES=${script:CMAKE_CUDA_ARCHITECTURES}") build install cp "${script:CUDA_LIB_DIR}/cudart64_*.dll" "${script:buildDir}/lib" diff --git a/llm/gguf.go b/llm/gguf.go index cfcab758..436be42c 100644 --- a/llm/gguf.go +++ b/llm/gguf.go @@ -69,12 +69,65 @@ type tensor struct { name string kind uint32 offset uint64 - size uint64 // shape is the number of elements in each dimension shape [4]uint64 } +func (t tensor) blockSize() uint64 { + switch { + case t.kind < 2: + return 1 + case t.kind < 10: + return 32 + default: + return 256 + } +} + +func (t tensor) typeSize() uint64 { + blockSize := t.blockSize() + + switch t.kind { + case 0: // FP32 + return 4 + case 1: // FP16 + return 2 + case 2: // Q4_0 + return 2 + blockSize/2 + case 3: // Q4_1 + return 2 + 2 + blockSize/2 + case 6: // Q5_0 + return 2 + 4 + blockSize/2 + case 7: // Q5_1 + return 2 + 2 + 4 + blockSize/2 + case 8: // Q8_0 + return 2 + blockSize + case 9: // Q8_1 + return 4 + 4 + blockSize + case 10: // Q2_K + return blockSize/16 + blockSize/4 + 2 + 2 + case 11: // Q3_K + return blockSize/8 + blockSize/4 + 12 + 2 + case 12: // Q4_K + return 2 + 2 + 12 + blockSize/2 + case 13: // Q5_K + return 2 + 2 + 12 + blockSize/8 + blockSize/2 + case 14: // Q6_K + return blockSize/2 + blockSize/4 + blockSize/16 + 2 + default: + return 0 + } +} + +func (t tensor) parameters() uint64 { + return t.shape[0] * t.shape[1] * t.shape[2] * t.shape[3] +} + +func (t tensor) size() uint64 { + return t.parameters() * t.typeSize() / t.blockSize() +} + type ggufModel struct { *containerGGUF @@ -201,61 +254,15 @@ func (llm *ggufModel) Decode(rso *readSeekOffset) error { shape[i] = llm.readU64(rso) } - kind := llm.readU32(rso) - offset := llm.readU64(rso) - - var blockSize uint64 - switch { - case kind < 2: - blockSize = 1 - case kind < 10: - blockSize = 32 - default: - blockSize = 256 - } - - var typeSize uint64 - switch kind { - case 0: // FP32 - typeSize = 4 - case 1: // FP16 - typeSize = 2 - case 2: // Q4_0 - typeSize = 2 + blockSize/2 - case 3: // Q4_1 - typeSize = 2 + 2 + blockSize/2 - case 6: // Q5_0 - typeSize = 2 + 4 + blockSize/2 - case 7: // Q5_1 - typeSize = 2 + 2 + 4 + blockSize/2 - case 8: // Q8_0 - typeSize = 2 + blockSize - case 9: // Q8_1 - typeSize = 4 + 4 + blockSize - case 10: // Q2_K - typeSize = blockSize/16 + blockSize/4 + 2 + 2 - case 11: // Q3_K - typeSize = blockSize/8 + blockSize/4 + 12 + 2 - case 12: // Q4_K - typeSize = 2 + 2 + 12 + blockSize/2 - case 13: // Q5_K - typeSize = 2 + 2 + 12 + blockSize/8 + blockSize/2 - case 14: // Q6_K - typeSize = blockSize/2 + blockSize/4 + blockSize/16 + 2 - } - - parameters := shape[0] * shape[1] * shape[2] * shape[3] - size := parameters * typeSize / blockSize - - llm.tensors = append(llm.tensors, tensor{ + tensor := tensor{ name: name, - kind: kind, - offset: offset, - size: size, + kind: llm.readU32(rso), + offset: llm.readU64(rso), shape: shape, - }) + } - llm.parameters += parameters + llm.tensors = append(llm.tensors, tensor) + llm.parameters += tensor.parameters() } alignment, ok := llm.kv["general.alignment"].(uint32) @@ -265,7 +272,7 @@ func (llm *ggufModel) Decode(rso *readSeekOffset) error { rso.Seek(int64(alignment)-rso.offset%int64(alignment), io.SeekCurrent) for _, tensor := range llm.tensors { - padded := (int64(tensor.size) + int64(alignment) - 1) & ^(int64(alignment) - 1) + padded := (int64(tensor.size()) + int64(alignment) - 1) & ^(int64(alignment) - 1) rso.Seek(padded, io.SeekCurrent) } diff --git a/llm/llama.cpp b/llm/llama.cpp index 011e8ec5..cd4fddb2 160000 --- a/llm/llama.cpp +++ b/llm/llama.cpp @@ -1 +1 @@ -Subproject commit 011e8ec577fd135cbc02993d3ea9840c516d6a1c +Subproject commit cd4fddb29f81d6a1f6d51a0c016bc6b486d68def diff --git a/llm/patches/01-cache.diff b/llm/patches/01-cache.diff new file mode 100644 index 00000000..f8392495 --- /dev/null +++ b/llm/patches/01-cache.diff @@ -0,0 +1,30 @@ +diff --git a/examples/server/server.cpp b/examples/server/server.cpp +index 0462fbd2..4fa7b57f 100644 +--- a/examples/server/server.cpp ++++ b/examples/server/server.cpp +@@ -1857,12 +1857,6 @@ struct llama_server_context + LOG_TEE("slot %d : in cache: %i tokens | to process: %i tokens\n", slot.id, slot.n_past, slot.num_prompt_tokens_processed); + } + +- LOG_TEE("slot %d : kv cache rm - [%d, end)\n", slot.id, (int) system_tokens.size() + slot.n_past); +- +- llama_kv_cache_seq_rm(ctx, slot.id, system_tokens.size() + slot.n_past, -1); +- +- slot.cache_tokens = prompt_tokens; +- + if (slot.n_past == slot.num_prompt_tokens && slot.n_past > 0) + { + // we have to evaluate at least 1 token to generate logits. +@@ -1870,6 +1864,12 @@ struct llama_server_context + slot.n_past--; + } + ++ LOG_TEE("slot %d : kv cache rm - [%d, end)\n", slot.id, (int) system_tokens.size() + slot.n_past); ++ ++ llama_kv_cache_seq_rm(ctx, slot.id, system_tokens.size() + slot.n_past, -1); ++ ++ slot.cache_tokens = prompt_tokens; ++ + LOG_VERBOSE("prompt ingested", { + {"n_past", slot.n_past}, + {"cached", tokens_to_str(ctx, slot.cache_tokens.cbegin(), slot.cache_tokens.cbegin() + slot.n_past)}, diff --git a/parser/parser.go b/parser/parser.go index 2fbd3cc5..947848b2 100644 --- a/parser/parser.go +++ b/parser/parser.go @@ -7,6 +7,7 @@ import ( "fmt" "io" "log/slog" + "slices" ) type Command struct { @@ -56,6 +57,16 @@ func Parse(reader io.Reader) ([]Command, error) { command.Args = string(bytes.TrimSpace(fields[1])) case "EMBED": return nil, fmt.Errorf("deprecated command: EMBED is no longer supported, use the /embed API endpoint instead") + case "MESSAGE": + command.Name = string(bytes.ToLower(fields[0])) + fields = bytes.SplitN(fields[1], []byte(" "), 2) + if len(fields) < 2 { + return nil, fmt.Errorf("should be in the format ") + } + if !slices.Contains([]string{"system", "user", "assistant"}, string(bytes.ToLower(fields[0]))) { + return nil, fmt.Errorf("role must be one of \"system\", \"user\", or \"assistant\"") + } + command.Args = fmt.Sprintf("%s: %s", string(bytes.ToLower(fields[0])), string(fields[1])) default: if !bytes.HasPrefix(fields[0], []byte("#")) { // log a warning for unknown commands diff --git a/parser/parser_test.go b/parser/parser_test.go index 53555ad1..25e849b5 100644 --- a/parser/parser_test.go +++ b/parser/parser_test.go @@ -61,3 +61,38 @@ PARAMETER param1 assert.ErrorContains(t, err, "missing value for [param1]") } + +func Test_Parser_Messages(t *testing.T) { + + input := ` +FROM foo +MESSAGE system You are a Parser. Always Parse things. +MESSAGE user Hey there! +MESSAGE assistant Hello, I want to parse all the things! +` + + reader := strings.NewReader(input) + commands, err := Parse(reader) + assert.Nil(t, err) + + expectedCommands := []Command{ + {Name: "model", Args: "foo"}, + {Name: "message", Args: "system: You are a Parser. Always Parse things."}, + {Name: "message", Args: "user: Hey there!"}, + {Name: "message", Args: "assistant: Hello, I want to parse all the things!"}, + } + + assert.Equal(t, expectedCommands, commands) +} + +func Test_Parser_Messages_BadRole(t *testing.T) { + + input := ` +FROM foo +MESSAGE badguy I'm a bad guy! +` + + reader := strings.NewReader(input) + _, err := Parse(reader) + assert.ErrorContains(t, err, "role must be one of \"system\", \"user\", or \"assistant\"") +} diff --git a/scripts/build_docker.sh b/scripts/build_docker.sh index ef02a144..40054ca6 100755 --- a/scripts/build_docker.sh +++ b/scripts/build_docker.sh @@ -13,3 +13,13 @@ docker build \ -f Dockerfile \ -t ollama/ollama:$VERSION \ . + +docker build \ + --load \ + --platform=linux/amd64 \ + --build-arg=VERSION \ + --build-arg=GOFLAGS \ + --target runtime-rocm \ + -f Dockerfile \ + -t ollama/ollama:$VERSION-rocm \ + . diff --git a/server/download.go b/server/download.go index b5858487..f089bd41 100644 --- a/server/download.go +++ b/server/download.go @@ -25,6 +25,11 @@ import ( "github.com/jmorganca/ollama/format" ) +const maxRetries = 6 + +var errMaxRetriesExceeded = errors.New("max retries exceeded") +var errPartStalled = errors.New("part stalled") + var blobDownloadManager sync.Map type blobDownload struct { @@ -44,10 +49,11 @@ type blobDownload struct { } type blobDownloadPart struct { - N int - Offset int64 - Size int64 - Completed int64 + N int + Offset int64 + Size int64 + Completed int64 + lastUpdated time.Time *blobDownload `json:"-"` } @@ -72,6 +78,13 @@ func (p *blobDownloadPart) StopsAt() int64 { return p.Offset + p.Size } +func (p *blobDownloadPart) Write(b []byte) (n int, err error) { + n = len(b) + p.blobDownload.Completed.Add(int64(n)) + p.lastUpdated = time.Now() + return n, nil +} + func (b *blobDownload) Prepare(ctx context.Context, requestURL *url.URL, opts *RegistryOptions) error { partFilePaths, err := filepath.Glob(b.Name + "-partial-*") if err != nil { @@ -157,6 +170,9 @@ func (b *blobDownload) run(ctx context.Context, requestURL *url.URL, opts *Regis case errors.Is(err, context.Canceled), errors.Is(err, syscall.ENOSPC): // return immediately if the context is canceled or the device is out of space return err + case errors.Is(err, errPartStalled): + try-- + continue case err != nil: sleep := time.Second * time.Duration(math.Pow(2, float64(try))) slog.Info(fmt.Sprintf("%s part %d attempt %d failed: %v, retrying in %s", b.Digest[7:19], part.N, try, err, sleep)) @@ -195,28 +211,54 @@ func (b *blobDownload) run(ctx context.Context, requestURL *url.URL, opts *Regis } func (b *blobDownload) downloadChunk(ctx context.Context, requestURL *url.URL, w io.Writer, part *blobDownloadPart, opts *RegistryOptions) error { - headers := make(http.Header) - headers.Set("Range", fmt.Sprintf("bytes=%d-%d", part.StartsAt(), part.StopsAt()-1)) - resp, err := makeRequestWithRetry(ctx, http.MethodGet, requestURL, headers, nil, opts) - if err != nil { - return err - } - defer resp.Body.Close() + g, ctx := errgroup.WithContext(ctx) + g.Go(func() error { + headers := make(http.Header) + headers.Set("Range", fmt.Sprintf("bytes=%d-%d", part.StartsAt(), part.StopsAt()-1)) + resp, err := makeRequestWithRetry(ctx, http.MethodGet, requestURL, headers, nil, opts) + if err != nil { + return err + } + defer resp.Body.Close() - n, err := io.Copy(w, io.TeeReader(resp.Body, b)) - if err != nil && !errors.Is(err, context.Canceled) && !errors.Is(err, io.ErrUnexpectedEOF) { - // rollback progress - b.Completed.Add(-n) - return err - } + n, err := io.Copy(w, io.TeeReader(resp.Body, part)) + if err != nil && !errors.Is(err, context.Canceled) && !errors.Is(err, io.ErrUnexpectedEOF) { + // rollback progress + b.Completed.Add(-n) + return err + } - part.Completed += n - if err := b.writePart(part.Name(), part); err != nil { - return err - } + part.Completed += n + if err := b.writePart(part.Name(), part); err != nil { + return err + } - // return nil or context.Canceled or UnexpectedEOF (resumable) - return err + // return nil or context.Canceled or UnexpectedEOF (resumable) + return err + }) + + g.Go(func() error { + ticker := time.NewTicker(time.Second) + for { + select { + case <-ticker.C: + if part.Completed >= part.Size { + return nil + } + + if !part.lastUpdated.IsZero() && time.Since(part.lastUpdated) > 5*time.Second { + slog.Info(fmt.Sprintf("%s part %d stalled; retrying", b.Digest[7:19], part.N)) + // reset last updated + part.lastUpdated = time.Time{} + return errPartStalled + } + case <-ctx.Done(): + return ctx.Err() + } + } + }) + + return g.Wait() } func (b *blobDownload) newPart(offset, size int64) error { @@ -255,12 +297,6 @@ func (b *blobDownload) writePart(partName string, part *blobDownloadPart) error return json.NewEncoder(partFile).Encode(part) } -func (b *blobDownload) Write(p []byte) (n int, err error) { - n = len(p) - b.Completed.Add(int64(n)) - return n, nil -} - func (b *blobDownload) acquire() { b.references.Add(1) } @@ -279,20 +315,19 @@ func (b *blobDownload) Wait(ctx context.Context, fn func(api.ProgressResponse)) for { select { case <-ticker.C: + fn(api.ProgressResponse{ + Status: fmt.Sprintf("pulling %s", b.Digest[7:19]), + Digest: b.Digest, + Total: b.Total, + Completed: b.Completed.Load(), + }) + + if b.done || b.err != nil { + return b.err + } case <-ctx.Done(): return ctx.Err() } - - fn(api.ProgressResponse{ - Status: fmt.Sprintf("pulling %s", b.Digest[7:19]), - Digest: b.Digest, - Total: b.Total, - Completed: b.Completed.Load(), - }) - - if b.done || b.err != nil { - return b.err - } } } @@ -303,10 +338,6 @@ type downloadOpts struct { fn func(api.ProgressResponse) } -const maxRetries = 6 - -var errMaxRetriesExceeded = errors.New("max retries exceeded") - // downloadBlob downloads a blob from the registry and stores it in the blobs directory func downloadBlob(ctx context.Context, opts downloadOpts) error { fp, err := GetBlobsPath(opts.digest) diff --git a/server/images.go b/server/images.go index a20f6bd7..ab3b4faa 100644 --- a/server/images.go +++ b/server/images.go @@ -41,7 +41,7 @@ type Model struct { Config ConfigV2 ShortName string ModelPath string - OriginalModel string + ParentModel string AdapterPaths []string ProjectorPaths []string Template string @@ -50,6 +50,12 @@ type Model struct { Digest string Size int64 Options map[string]interface{} + Messages []Message +} + +type Message struct { + Role string `json:"role"` + Content string `json:"content"` } type PromptVars struct { @@ -333,7 +339,7 @@ func GetModel(name string) (*Model, error) { switch layer.MediaType { case "application/vnd.ollama.image.model": model.ModelPath = filename - model.OriginalModel = layer.From + model.ParentModel = layer.From case "application/vnd.ollama.image.embed": // Deprecated in versions > 0.1.2 // TODO: remove this warning in a future version @@ -374,6 +380,16 @@ func GetModel(name string) (*Model, error) { if err = json.NewDecoder(params).Decode(&model.Options); err != nil { return nil, err } + case "application/vnd.ollama.image.messages": + msgs, err := os.Open(filename) + if err != nil { + return nil, err + } + defer msgs.Close() + + if err = json.NewDecoder(msgs).Decode(&model.Messages); err != nil { + return nil, err + } case "application/vnd.ollama.image.license": bts, err := os.ReadFile(filename) if err != nil { @@ -428,12 +444,12 @@ func CreateModel(ctx context.Context, name, modelFileDir string, commands []pars } var layers Layers + messages := []string{} params := make(map[string][]string) fromParams := make(map[string]any) for _, c := range commands { - slog.Info(fmt.Sprintf("[%s] - %s", c.Name, c.Args)) mediatype := fmt.Sprintf("application/vnd.ollama.image.%s", c.Name) switch c.Name { @@ -607,11 +623,37 @@ func CreateModel(ctx context.Context, name, modelFileDir string, commands []pars } layers.Replace(layer) + case "message": + messages = append(messages, c.Args) default: params[c.Name] = append(params[c.Name], c.Args) } } + if len(messages) > 0 { + fn(api.ProgressResponse{Status: "creating parameters layer"}) + + msgs := make([]api.Message, 0) + + for _, m := range messages { + // todo: handle images + msg := strings.SplitN(m, ": ", 2) + msgs = append(msgs, api.Message{Role: msg[0], Content: msg[1]}) + } + + var b bytes.Buffer + if err := json.NewEncoder(&b).Encode(msgs); err != nil { + return err + } + + layer, err := NewLayer(&b, "application/vnd.ollama.image.messages") + if err != nil { + return err + } + + layers.Replace(layer) + } + if len(params) > 0 { fn(api.ProgressResponse{Status: "creating parameters layer"}) @@ -908,8 +950,8 @@ func ShowModelfile(model *Model) (string, error) { mt.Model = model mt.From = model.ModelPath - if model.OriginalModel != "" { - mt.From = model.OriginalModel + if model.ParentModel != "" { + mt.From = model.ParentModel } modelFile := `# Modelfile generated by "ollama show" diff --git a/server/routes.go b/server/routes.go index 0c145ae6..56c275c9 100644 --- a/server/routes.go +++ b/server/routes.go @@ -186,7 +186,13 @@ func GenerateHandler(c *gin.Context) { return } - sessionDuration := defaultSessionDuration + var sessionDuration time.Duration + if req.KeepAlive == nil { + sessionDuration = defaultSessionDuration + } else { + sessionDuration = req.KeepAlive.Duration + } + if err := load(c, model, opts, sessionDuration); err != nil { c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()}) return @@ -378,7 +384,14 @@ func EmbeddingHandler(c *gin.Context) { c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()}) return } - sessionDuration := defaultSessionDuration + + var sessionDuration time.Duration + if req.KeepAlive == nil { + sessionDuration = defaultSessionDuration + } else { + sessionDuration = req.KeepAlive.Duration + } + if err := load(c, model, opts, sessionDuration); err != nil { c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()}) return @@ -659,6 +672,7 @@ func GetModelInfo(req api.ShowRequest) (*api.ShowResponse, error) { } modelDetails := api.ModelDetails{ + ParentModel: model.ParentModel, Format: model.Config.ModelFormat, Family: model.Config.ModelFamily, Families: model.Config.ModelFamilies, @@ -674,11 +688,17 @@ func GetModelInfo(req api.ShowRequest) (*api.ShowResponse, error) { model.Template = req.Template } + msgs := make([]api.Message, 0) + for _, msg := range model.Messages { + msgs = append(msgs, api.Message{Role: msg.Role, Content: msg.Content}) + } + resp := &api.ShowResponse{ License: strings.Join(model.License, "\n"), System: model.System, Template: model.Template, Details: modelDetails, + Messages: msgs, } var params []string @@ -1067,7 +1087,14 @@ func ChatHandler(c *gin.Context) { c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()}) return } - sessionDuration := defaultSessionDuration + + var sessionDuration time.Duration + if req.KeepAlive == nil { + sessionDuration = defaultSessionDuration + } else { + sessionDuration = req.KeepAlive.Duration + } + if err := load(c, model, opts, sessionDuration); err != nil { c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()}) return @@ -1075,7 +1102,13 @@ func ChatHandler(c *gin.Context) { // an empty request loads the model if len(req.Messages) == 0 { - c.JSON(http.StatusOK, api.ChatResponse{CreatedAt: time.Now().UTC(), Model: req.Model, Done: true, Message: api.Message{Role: "assistant"}}) + resp := api.ChatResponse{ + CreatedAt: time.Now().UTC(), + Model: req.Model, + Done: true, + Message: api.Message{Role: "assistant"}, + } + c.JSON(http.StatusOK, resp) return }