Merge branch 'main' into patch-1

This commit is contained in:
Timothy Jaeryang Baek 2023-11-04 19:12:18 -05:00 committed by GitHub
commit 6febde7200
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
25 changed files with 580 additions and 2301 deletions

View file

@ -29,8 +29,7 @@ curl https://ollama.ai/install.sh | sh
### Docker ### Docker
The official [Ollama Docker image `ollama/ollama`](https://hub.docker.com/r/ollama/ollama) The official [Ollama Docker image](https://hub.docker.com/r/ollama/ollama) `ollama/ollama` is available on Docker Hub.
is available on Docker Hub.
## Quickstart ## Quickstart
@ -235,6 +234,7 @@ See the [API documentation](./docs/api.md) for all endpoints.
- [LlamaIndex](https://gpt-index.readthedocs.io/en/stable/examples/llm/ollama.html) - [LlamaIndex](https://gpt-index.readthedocs.io/en/stable/examples/llm/ollama.html)
- [LiteLLM](https://github.com/BerriAI/litellm) - [LiteLLM](https://github.com/BerriAI/litellm)
- [OllamaSharp for .NET](https://github.com/awaescher/OllamaSharp) - [OllamaSharp for .NET](https://github.com/awaescher/OllamaSharp)
- [Ollama-rs for Rust](https://github.com/pepperoni21/ollama-rs)
### Plugins (Extensions) ### Plugins (Extensions)
- [Raycast extension](https://github.com/MassimilianoPasquini97/raycast_ollama) - [Raycast extension](https://github.com/MassimilianoPasquini97/raycast_ollama)
@ -245,5 +245,3 @@ See the [API documentation](./docs/api.md) for all endpoints.
- [Discord AI Bot](https://github.com/mekb-turtle/discord-ai-bot) - [Discord AI Bot](https://github.com/mekb-turtle/discord-ai-bot)
- [Dumbar](https://github.com/JerrySievert/Dumbar) - [Dumbar](https://github.com/JerrySievert/Dumbar)

View file

@ -72,7 +72,7 @@ func ClientFromEnvironment() (*Client, error) {
}, },
} }
mockRequest, err := http.NewRequest("HEAD", client.base.String(), nil) mockRequest, err := http.NewRequest(http.MethodHead, client.base.String(), nil)
if err != nil { if err != nil {
return nil, err return nil, err
} }

View file

@ -293,7 +293,7 @@ func DefaultOptions() Options {
return Options{ return Options{
// options set on request to runner // options set on request to runner
NumPredict: -1, NumPredict: -1,
NumKeep: -1, NumKeep: 0,
Temperature: 0.8, Temperature: 0.8,
TopK: 40, TopK: 40,
TopP: 0.9, TopP: 0.9,

View file

@ -11,6 +11,7 @@ import (
"io" "io"
"log" "log"
"net" "net"
"net/http"
"os" "os"
"os/exec" "os/exec"
"os/signal" "os/signal"
@ -98,19 +99,16 @@ func RunHandler(cmd *cobra.Command, args []string) error {
return err return err
} }
models, err := client.List(context.Background()) name := args[0]
if err != nil { // check if the model exists on the server
return err _, err = client.Show(context.Background(), &api.ShowRequest{Name: name})
} var statusError api.StatusError
switch {
canonicalModelPath := server.ParseModelPath(args[0]) case errors.As(err, &statusError) && statusError.StatusCode == http.StatusNotFound:
for _, model := range models.Models { if err := PullHandler(cmd, args); err != nil {
if model.Name == canonicalModelPath.GetShortTagname() { return err
return RunGenerate(cmd, args)
} }
} case err != nil:
if err := PullHandler(cmd, args); err != nil {
return err return err
} }
@ -731,21 +729,6 @@ func RunServer(cmd *cobra.Command, _ []string) error {
origins = strings.Split(o, ",") origins = strings.Split(o, ",")
} }
if noprune := os.Getenv("OLLAMA_NOPRUNE"); noprune == "" {
if err := server.PruneLayers(); err != nil {
return err
}
manifestsPath, err := server.GetManifestPath()
if err != nil {
return err
}
if err := server.PruneDirectory(manifestsPath); err != nil {
return err
}
}
return server.Serve(ln, origins) return server.Serve(ln, origins)
} }

View file

@ -45,9 +45,11 @@ Advanced parameters (optional):
- `system`: system prompt to (overrides what is defined in the `Modelfile`) - `system`: system prompt to (overrides what is defined in the `Modelfile`)
- `template`: the full prompt or prompt template (overrides what is defined in the `Modelfile`) - `template`: the full prompt or prompt template (overrides what is defined in the `Modelfile`)
- `context`: the context parameter returned from a previous request to `/generate`, this can be used to keep a short conversational memory - `context`: the context parameter returned from a previous request to `/generate`, this can be used to keep a short conversational memory
- `stream`: if `false` the response will be be returned as a single response object, rather than a stream of objects - `stream`: if `false` the response will be returned as a single response object, rather than a stream of objects
### Request ### Examples
#### Request
```shell ```shell
curl -X POST http://localhost:11434/api/generate -d '{ curl -X POST http://localhost:11434/api/generate -d '{
@ -56,9 +58,9 @@ curl -X POST http://localhost:11434/api/generate -d '{
}' }'
``` ```
### Response #### Response
A stream of JSON objects: A stream of JSON objects is returned:
```json ```json
{ {
@ -102,6 +104,38 @@ To calculate how fast the response is generated in tokens per second (token/s),
} }
``` ```
#### Request
```shell
curl -X POST http://localhost:11434/api/generate -d '{
"model": "llama2:7b",
"prompt": "Why is the sky blue?",
"stream": false
}'
```
#### Response
If `stream` is set to `false`, the response will be a single JSON object:
```json
{
"model": "llama2:7b",
"created_at": "2023-08-04T19:22:45.499127Z",
"response": "The sky is blue because it is the color of the sky.",
"context": [1, 2, 3],
"done": true,
"total_duration": 5589157167,
"load_duration": 3013701500,
"sample_count": 114,
"sample_duration": 81442000,
"prompt_eval_count": 46,
"prompt_eval_duration": 1160282000,
"eval_count": 13,
"eval_duration": 1325948000
}
```
## Create a Model ## Create a Model
```shell ```shell
@ -114,9 +148,11 @@ Create a model from a [`Modelfile`](./modelfile.md)
- `name`: name of the model to create - `name`: name of the model to create
- `path`: path to the Modelfile - `path`: path to the Modelfile
- `stream`: (optional) if `false` the response will be be returned as a single response object, rather than a stream of objects - `stream`: (optional) if `false` the response will be returned as a single response object, rather than a stream of objects
### Request ### Examples
#### Request
```shell ```shell
curl -X POST http://localhost:11434/api/create -d '{ curl -X POST http://localhost:11434/api/create -d '{
@ -125,7 +161,7 @@ curl -X POST http://localhost:11434/api/create -d '{
}' }'
``` ```
### Response #### Response
A stream of JSON objects. When finished, `status` is `success`. A stream of JSON objects. When finished, `status` is `success`.
@ -143,13 +179,17 @@ GET /api/tags
List models that are available locally. List models that are available locally.
### Request ### Examples
#### Request
```shell ```shell
curl http://localhost:11434/api/tags curl http://localhost:11434/api/tags
``` ```
### Response #### Response
A single JSON object will be returned.
```json ```json
{ {
@ -180,7 +220,9 @@ Show details about a model including modelfile, template, parameters, license, a
- `name`: name of the model to show - `name`: name of the model to show
### Request ### Examples
#### Request
```shell ```shell
curl http://localhost:11434/api/show -d '{ curl http://localhost:11434/api/show -d '{
@ -188,7 +230,7 @@ curl http://localhost:11434/api/show -d '{
}' }'
``` ```
### Response #### Response
```json ```json
{ {
@ -207,7 +249,9 @@ POST /api/copy
Copy a model. Creates a model with another name from an existing model. Copy a model. Creates a model with another name from an existing model.
### Request ### Examples
#### Request
```shell ```shell
curl http://localhost:11434/api/copy -d '{ curl http://localhost:11434/api/copy -d '{
@ -216,6 +260,10 @@ curl http://localhost:11434/api/copy -d '{
}' }'
``` ```
#### Response
The only response is a 200 OK if successful.
## Delete a Model ## Delete a Model
```shell ```shell
@ -226,9 +274,11 @@ Delete a model and its data.
### Parameters ### Parameters
- `model`: model name to delete - `name`: model name to delete
### Request ### Examples
#### Request
```shell ```shell
curl -X DELETE http://localhost:11434/api/delete -d '{ curl -X DELETE http://localhost:11434/api/delete -d '{
@ -236,6 +286,10 @@ curl -X DELETE http://localhost:11434/api/delete -d '{
}' }'
``` ```
#### Response
If successful, the only response is a 200 OK.
## Pull a Model ## Pull a Model
```shell ```shell
@ -248,9 +302,11 @@ Download a model from the ollama library. Cancelled pulls are resumed from where
- `name`: name of the model to pull - `name`: name of the model to pull
- `insecure`: (optional) allow insecure connections to the library. Only use this if you are pulling from your own library during development. - `insecure`: (optional) allow insecure connections to the library. Only use this if you are pulling from your own library during development.
- `stream`: (optional) if `false` the response will be be returned as a single response object, rather than a stream of objects - `stream`: (optional) if `false` the response will be returned as a single response object, rather than a stream of objects
### Request ### Examples
#### Request
```shell ```shell
curl -X POST http://localhost:11434/api/pull -d '{ curl -X POST http://localhost:11434/api/pull -d '{
@ -258,13 +314,51 @@ curl -X POST http://localhost:11434/api/pull -d '{
}' }'
``` ```
### Response #### Response
If `stream` is not specified, or set to `true`, a stream of JSON objects is returned:
The first object is the manifest:
```json
{
"status": "pulling manifest"
}
```
Then there is a series of downloading responses. Until any of the download is completed, the `completed` key may not be included. The number of files to be downloaded depends on the number of layers specified in the manifest.
```json ```json
{ {
"status": "downloading digestname", "status": "downloading digestname",
"digest": "digestname", "digest": "digestname",
"total": 2142590208 "total": 2142590208,
"completed": 241970
}
```
After all the files are downloaded, the final responses are:
```json
{
"status": "verifying sha256 digest"
}
{
"status": "writing manifest"
}
{
"status": "removing any unused layers"
}
{
"status": "success"
}
```
if `stream` is set to false, then the response is a single JSON object:
```json
{
"status": "success"
} }
``` ```
@ -280,9 +374,11 @@ Upload a model to a model library. Requires registering for ollama.ai and adding
- `name`: name of the model to push in the form of `<namespace>/<model>:<tag>` - `name`: name of the model to push in the form of `<namespace>/<model>:<tag>`
- `insecure`: (optional) allow insecure connections to the library. Only use this if you are pushing to your library during development. - `insecure`: (optional) allow insecure connections to the library. Only use this if you are pushing to your library during development.
- `stream`: (optional) if `false` the response will be be returned as a single response object, rather than a stream of objects - `stream`: (optional) if `false` the response will be returned as a single response object, rather than a stream of objects
### Request ### Examples
#### Request
```shell ```shell
curl -X POST http://localhost:11434/api/push -d '{ curl -X POST http://localhost:11434/api/push -d '{
@ -290,9 +386,9 @@ curl -X POST http://localhost:11434/api/push -d '{
}' }'
``` ```
### Response #### Response
Streaming response that starts with: If `stream` is not specified, or set to `true`, a stream of JSON objects is returned:
```json ```json
{ "status": "retrieving manifest" } { "status": "retrieving manifest" }
@ -325,6 +421,12 @@ Finally, when the upload is complete:
{"status":"success"} {"status":"success"}
``` ```
If `stream` is set to `false`, then the response is a single JSON object:
```json
{ "status": "success" }
```
## Generate Embeddings ## Generate Embeddings
```shell ```shell
@ -342,7 +444,9 @@ Advanced parameters:
- `options`: additional model parameters listed in the documentation for the [Modelfile](./modelfile.md#valid-parameters-and-values) such as `temperature` - `options`: additional model parameters listed in the documentation for the [Modelfile](./modelfile.md#valid-parameters-and-values) such as `temperature`
### Request ### Examples
#### Request
```shell ```shell
curl -X POST http://localhost:11434/api/embeddings -d '{ curl -X POST http://localhost:11434/api/embeddings -d '{
@ -351,7 +455,7 @@ curl -X POST http://localhost:11434/api/embeddings -d '{
}' }'
``` ```
### Response #### Response
```json ```json
{ {

View file

@ -185,7 +185,7 @@ python convert.py <path to model directory>
python convert-falcon-hf-to-gguf.py <path to model directory> python convert-falcon-hf-to-gguf.py <path to model directory>
# GPTNeoXForCausalLM # GPTNeoXForCausalLM
python convert-falcon-hf-to-gguf.py <path to model directory> python convert-gptneox-hf-to-gguf.py <path to model directory>
# GPTBigCodeForCausalLM # GPTBigCodeForCausalLM
python convert-starcoder-hf-to-gguf.py <path to model directory> python convert-starcoder-hf-to-gguf.py <path to model directory>

View file

@ -6,7 +6,6 @@ PERSIST_DIRECTORY = os.environ.get('PERSIST_DIRECTORY', 'db')
# Define the Chroma settings # Define the Chroma settings
CHROMA_SETTINGS = Settings( CHROMA_SETTINGS = Settings(
chroma_db_impl='duckdb+parquet',
persist_directory=PERSIST_DIRECTORY, persist_directory=PERSIST_DIRECTORY,
anonymized_telemetry=False anonymized_telemetry=False
) )

View file

@ -150,7 +150,7 @@ def main():
print("Creating new vectorstore") print("Creating new vectorstore")
texts = process_documents() texts = process_documents()
print(f"Creating embeddings. May take some minutes...") print(f"Creating embeddings. May take some minutes...")
db = Chroma.from_documents(texts, embeddings, persist_directory=persist_directory, client_settings=CHROMA_SETTINGS) db = Chroma.from_documents(texts, embeddings, persist_directory=persist_directory)
db.persist() db.persist()
db = None db = None

View file

@ -4,6 +4,7 @@ from langchain.embeddings import HuggingFaceEmbeddings
from langchain.callbacks.streaming_stdout import StreamingStdOutCallbackHandler from langchain.callbacks.streaming_stdout import StreamingStdOutCallbackHandler
from langchain.vectorstores import Chroma from langchain.vectorstores import Chroma
from langchain.llms import Ollama from langchain.llms import Ollama
import chromadb
import os import os
import argparse import argparse
import time import time
@ -22,7 +23,9 @@ def main():
# Parse the command line arguments # Parse the command line arguments
args = parse_arguments() args = parse_arguments()
embeddings = HuggingFaceEmbeddings(model_name=embeddings_model_name) embeddings = HuggingFaceEmbeddings(model_name=embeddings_model_name)
db = Chroma(persist_directory=persist_directory, embedding_function=embeddings, client_settings=CHROMA_SETTINGS)
db = Chroma(persist_directory=persist_directory, embedding_function=embeddings)
retriever = db.as_retriever(search_kwargs={"k": target_source_chunks}) retriever = db.as_retriever(search_kwargs={"k": target_source_chunks})
# activate/deactivate the streaming StdOut callback for LLMs # activate/deactivate the streaming StdOut callback for LLMs
callbacks = [] if args.mute_stream else [StreamingStdOutCallbackHandler()] callbacks = [] if args.mute_stream else [StreamingStdOutCallbackHandler()]

File diff suppressed because it is too large Load diff

1
go.mod
View file

@ -11,7 +11,6 @@ require (
github.com/olekukonko/tablewriter v0.0.5 github.com/olekukonko/tablewriter v0.0.5
github.com/spf13/cobra v1.7.0 github.com/spf13/cobra v1.7.0
golang.org/x/sync v0.3.0 golang.org/x/sync v0.3.0
gonum.org/v1/gonum v0.14.0
) )
require github.com/rivo/uniseg v0.2.0 // indirect require github.com/rivo/uniseg v0.2.0 // indirect

2
go.sum
View file

@ -140,8 +140,6 @@ golang.org/x/text v0.13.0 h1:ablQoSUd0tRdKxZewP80B+BaqeKJuVhuRxj/dkrun3k=
golang.org/x/text v0.13.0/go.mod h1:TvPlkZtksWOMsz7fbANvkp4WM8x/WCo/om8BMLbz+aE= golang.org/x/text v0.13.0/go.mod h1:TvPlkZtksWOMsz7fbANvkp4WM8x/WCo/om8BMLbz+aE=
golang.org/x/tools v0.0.0-20180917221912-90fa682c2a6e/go.mod h1:n7NCudcB/nEzxVGmLbDWY5pfWTLqBcC2KZ6jyYvM4mQ= golang.org/x/tools v0.0.0-20180917221912-90fa682c2a6e/go.mod h1:n7NCudcB/nEzxVGmLbDWY5pfWTLqBcC2KZ6jyYvM4mQ=
golang.org/x/xerrors v0.0.0-20191204190536-9bdfabe68543/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0= golang.org/x/xerrors v0.0.0-20191204190536-9bdfabe68543/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0=
gonum.org/v1/gonum v0.14.0 h1:2NiG67LD1tEH0D7kM+ps2V+fXmsAnpUeec7n8tcr4S0=
gonum.org/v1/gonum v0.14.0/go.mod h1:AoWeoz0becf9QMWtE8iWXNXc27fK4fNeHNf/oMejGfU=
google.golang.org/protobuf v1.26.0-rc.1/go.mod h1:jlhhOSvTdKEhbULTjvd4ARK9grFBp09yW+WbY/TyQbw= google.golang.org/protobuf v1.26.0-rc.1/go.mod h1:jlhhOSvTdKEhbULTjvd4ARK9grFBp09yW+WbY/TyQbw=
google.golang.org/protobuf v1.28.0/go.mod h1:HV8QOd/L58Z+nl8r43ehVNZIU/HEI6OcFqwMG9pJV4I= google.golang.org/protobuf v1.28.0/go.mod h1:HV8QOd/L58Z+nl8r43ehVNZIU/HEI6OcFqwMG9pJV4I=
google.golang.org/protobuf v1.30.0 h1:kPPoIgf3TsEvrm0PFe15JQ+570QVxYzEvvHqChK+cng= google.golang.org/protobuf v1.30.0 h1:kPPoIgf3TsEvrm0PFe15JQ+570QVxYzEvvHqChK+cng=

View file

@ -306,13 +306,19 @@ func newLlama(model string, adapters []string, runners []ModelRunner, numLayers
params := []string{ params := []string{
"--model", model, "--model", model,
"--ctx-size", fmt.Sprintf("%d", opts.NumCtx), "--ctx-size", fmt.Sprintf("%d", opts.NumCtx),
"--rope-freq-base", fmt.Sprintf("%f", opts.RopeFrequencyBase),
"--rope-freq-scale", fmt.Sprintf("%f", opts.RopeFrequencyScale),
"--batch-size", fmt.Sprintf("%d", opts.NumBatch), "--batch-size", fmt.Sprintf("%d", opts.NumBatch),
"--n-gpu-layers", fmt.Sprintf("%d", numGPU), "--n-gpu-layers", fmt.Sprintf("%d", numGPU),
"--embedding", "--embedding",
} }
if opts.RopeFrequencyBase > 0 {
params = append(params, "--rope-freq-base", fmt.Sprintf("%f", opts.RopeFrequencyBase))
}
if opts.RopeFrequencyScale > 0 {
params = append(params, "--rope-freq-scale", fmt.Sprintf("%f", opts.RopeFrequencyScale))
}
if opts.NumGQA > 0 { if opts.NumGQA > 0 {
params = append(params, "--gqa", fmt.Sprintf("%d", opts.NumGQA)) params = append(params, "--gqa", fmt.Sprintf("%d", opts.NumGQA))
} }
@ -360,7 +366,15 @@ func newLlama(model string, adapters []string, runners []ModelRunner, numLayers
runner.Path, runner.Path,
append(params, "--port", strconv.Itoa(port))..., append(params, "--port", strconv.Itoa(port))...,
) )
cmd.Env = append(os.Environ(), fmt.Sprintf("LD_LIBRARY_PATH=%s", filepath.Dir(runner.Path)))
var libraryPaths []string
if libraryPath, ok := os.LookupEnv("LD_LIBRARY_PATH"); ok {
libraryPaths = append(libraryPaths, libraryPath)
}
libraryPaths = append(libraryPaths, filepath.Dir(runner.Path))
cmd.Env = append(os.Environ(), fmt.Sprintf("LD_LIBRARY_PATH=%s", strings.Join(libraryPaths, ":")))
cmd.Stdout = os.Stderr cmd.Stdout = os.Stderr
statusWriter := NewStatusWriter() statusWriter := NewStatusWriter()
cmd.Stderr = statusWriter cmd.Stderr = statusWriter

View file

@ -85,7 +85,10 @@ func New(workDir, model string, adapters []string, opts api.Options) (LLM, error
switch ggml.Name() { switch ggml.Name() {
case "gguf": case "gguf":
opts.NumGQA = 0 // TODO: remove this when llama.cpp runners differ enough to need separate newLlama functions // TODO: gguf will load these options automatically from the model binary
opts.NumGQA = 0
opts.RopeFrequencyBase = 0.0
opts.RopeFrequencyScale = 0.0
return newLlama(model, adapters, chooseRunners(workDir, "gguf"), ggml.NumLayers(), opts) return newLlama(model, adapters, chooseRunners(workDir, "gguf"), ggml.NumLayers(), opts)
case "ggml", "ggmf", "ggjt", "ggla": case "ggml", "ggmf", "ggjt", "ggla":
return newLlama(model, adapters, chooseRunners(workDir, "ggml"), ggml.NumLayers(), opts) return newLlama(model, adapters, chooseRunners(workDir, "ggml"), ggml.NumLayers(), opts)

View file

@ -2,6 +2,7 @@ package readline
import ( import (
"fmt" "fmt"
"os"
"github.com/emirpasic/gods/lists/arraylist" "github.com/emirpasic/gods/lists/arraylist"
"golang.org/x/term" "golang.org/x/term"
@ -17,7 +18,8 @@ type Buffer struct {
} }
func NewBuffer(prompt *Prompt) (*Buffer, error) { func NewBuffer(prompt *Prompt) (*Buffer, error) {
width, height, err := term.GetSize(0) fd := int(os.Stdout.Fd())
width, height, err := term.GetSize(fd)
if err != nil { if err != nil {
fmt.Println("Error getting size:", err) fmt.Println("Error getting size:", err)
return nil, err return nil, err

View file

@ -51,11 +51,12 @@ func (i *Instance) Readline() (string, error) {
} }
fmt.Print(prompt) fmt.Print(prompt)
termios, err := SetRawMode(syscall.Stdin) fd := int(syscall.Stdin)
termios, err := SetRawMode(fd)
if err != nil { if err != nil {
return "", err return "", err
} }
defer UnsetRawMode(syscall.Stdin, termios) defer UnsetRawMode(fd, termios)
buf, _ := NewBuffer(i.Prompt) buf, _ := NewBuffer(i.Prompt)

View file

@ -1,4 +1,5 @@
//go:build darwin || freebsd || netbsd || openbsd //go:build darwin || freebsd || netbsd || openbsd
package readline package readline
import ( import (

View file

@ -1,4 +1,5 @@
//go:build linux || solaris //go:build linux || solaris
package readline package readline
import ( import (

62
readline/term_windows.go Normal file
View file

@ -0,0 +1,62 @@
package readline
import (
"syscall"
"unsafe"
)
const (
enableLineInput = 2
enableWindowInput = 8
enableMouseInput = 16
enableInsertMode = 32
enableQuickEditMode = 64
enableExtendedFlags = 128
enableProcessedOutput = 1
enableWrapAtEolOutput = 2
enableAutoPosition = 256 // Cursor position is not affected by writing data to the console.
enableEchoInput = 4 // Characters are written to the console as they're read.
enableProcessedInput = 1 // Enables input processing (like recognizing Ctrl+C).
)
var kernel32 = syscall.NewLazyDLL("kernel32.dll")
var (
procGetConsoleMode = kernel32.NewProc("GetConsoleMode")
procSetConsoleMode = kernel32.NewProc("SetConsoleMode")
)
type State struct {
mode uint32
}
// IsTerminal checks if the given file descriptor is associated with a terminal
func IsTerminal(fd int) bool {
var st uint32
r, _, e := syscall.SyscallN(procGetConsoleMode.Addr(), uintptr(fd), uintptr(unsafe.Pointer(&st)), 0)
// if the call succeeds and doesn't produce an error, it's a terminal
return r != 0 && e == 0
}
func SetRawMode(fd int) (*State, error) {
var st uint32
// retrieve the current mode of the terminal
_, _, e := syscall.SyscallN(procGetConsoleMode.Addr(), uintptr(fd), uintptr(unsafe.Pointer(&st)), 0)
if e != 0 {
return nil, error(e)
}
// modify the mode to set it to raw
raw := st &^ (enableEchoInput | enableProcessedInput | enableLineInput | enableProcessedOutput)
// apply the new mode to the terminal
_, _, e = syscall.SyscallN(procSetConsoleMode.Addr(), uintptr(fd), uintptr(raw), 0)
if e != 0 {
return nil, error(e)
}
// return the original state so that it can be restored later
return &State{st}, nil
}
func UnsetRawMode(fd int, state *State) error {
_, _, err := syscall.SyscallN(procSetConsoleMode.Addr(), uintptr(fd), uintptr(state.mode), 0)
return err
}

View file

@ -63,7 +63,10 @@ status "Installing ollama to $BINDIR..."
$SUDO install -o0 -g0 -m755 -d $BINDIR $SUDO install -o0 -g0 -m755 -d $BINDIR
$SUDO install -o0 -g0 -m755 $TEMP_DIR/ollama $BINDIR/ollama $SUDO install -o0 -g0 -m755 $TEMP_DIR/ollama $BINDIR/ollama
install_success() { status 'Install complete. Run "ollama" from the command line.'; } install_success() {
status 'The Ollama API is now available at 0.0.0.0:11434.'
status 'Install complete. Run "ollama" from the command line.'
}
trap install_success EXIT trap install_success EXIT
# Everything from this point onwards is optional. # Everything from this point onwards is optional.
@ -130,6 +133,7 @@ if check_gpu nvidia-smi; then
fi fi
if ! check_gpu lspci && ! check_gpu lshw; then if ! check_gpu lspci && ! check_gpu lshw; then
install_success
warning "No NVIDIA GPU detected. Ollama will run in CPU-only mode." warning "No NVIDIA GPU detected. Ollama will run in CPU-only mode."
exit 0 exit 0
fi fi

View file

@ -91,7 +91,7 @@ func getAuthToken(ctx context.Context, redirData AuthRedirect) (string, error) {
} }
s := SignatureData{ s := SignatureData{
Method: "GET", Method: http.MethodGet,
Path: redirectURL.String(), Path: redirectURL.String(),
Data: nil, Data: nil,
} }
@ -103,7 +103,7 @@ func getAuthToken(ctx context.Context, redirData AuthRedirect) (string, error) {
headers := make(http.Header) headers := make(http.Header)
headers.Set("Authorization", sig) headers.Set("Authorization", sig)
resp, err := makeRequest(ctx, "GET", redirectURL, headers, nil, nil) resp, err := makeRequest(ctx, http.MethodGet, redirectURL, headers, nil, nil)
if err != nil { if err != nil {
log.Printf("couldn't get token: %q", err) log.Printf("couldn't get token: %q", err)
return "", err return "", err

View file

@ -89,17 +89,12 @@ func (b *blobDownload) Prepare(ctx context.Context, requestURL *url.URL, opts *R
} }
if len(b.Parts) == 0 { if len(b.Parts) == 0 {
resp, err := makeRequest(ctx, "HEAD", requestURL, nil, nil, opts) resp, err := makeRequestWithRetry(ctx, http.MethodHead, requestURL, nil, nil, opts)
if err != nil { if err != nil {
return err return err
} }
defer resp.Body.Close() defer resp.Body.Close()
if resp.StatusCode >= http.StatusBadRequest {
body, _ := io.ReadAll(resp.Body)
return fmt.Errorf("registry responded with code %d: %v", resp.StatusCode, string(body))
}
b.Total, _ = strconv.ParseInt(resp.Header.Get("Content-Length"), 10, 64) b.Total, _ = strconv.ParseInt(resp.Header.Get("Content-Length"), 10, 64)
var size = b.Total / numDownloadParts var size = b.Total / numDownloadParts
@ -134,7 +129,6 @@ func (b *blobDownload) Run(ctx context.Context, requestURL *url.URL, opts *Regis
func (b *blobDownload) run(ctx context.Context, requestURL *url.URL, opts *RegistryOptions) error { func (b *blobDownload) run(ctx context.Context, requestURL *url.URL, opts *RegistryOptions) error {
defer blobDownloadManager.Delete(b.Digest) defer blobDownloadManager.Delete(b.Digest)
ctx, b.CancelFunc = context.WithCancel(ctx) ctx, b.CancelFunc = context.WithCancel(ctx)
file, err := os.OpenFile(b.Name+"-partial", os.O_CREATE|os.O_RDWR, 0644) file, err := os.OpenFile(b.Name+"-partial", os.O_CREATE|os.O_RDWR, 0644)
@ -170,7 +164,7 @@ func (b *blobDownload) run(ctx context.Context, requestURL *url.URL, opts *Regis
} }
} }
return errors.New("max retries exceeded") return errMaxRetriesExceeded
}) })
} }
@ -200,7 +194,7 @@ func (b *blobDownload) run(ctx context.Context, requestURL *url.URL, opts *Regis
func (b *blobDownload) downloadChunk(ctx context.Context, requestURL *url.URL, w io.Writer, part *blobDownloadPart, opts *RegistryOptions) error { func (b *blobDownload) downloadChunk(ctx context.Context, requestURL *url.URL, w io.Writer, part *blobDownloadPart, opts *RegistryOptions) error {
headers := make(http.Header) headers := make(http.Header)
headers.Set("Range", fmt.Sprintf("bytes=%d-%d", part.StartsAt(), part.StopsAt()-1)) headers.Set("Range", fmt.Sprintf("bytes=%d-%d", part.StartsAt(), part.StopsAt()-1))
resp, err := makeRequest(ctx, "GET", requestURL, headers, nil, opts) resp, err := makeRequestWithRetry(ctx, http.MethodGet, requestURL, headers, nil, opts)
if err != nil { if err != nil {
return err return err
} }
@ -308,6 +302,8 @@ type downloadOpts struct {
const maxRetries = 3 const maxRetries = 3
var errMaxRetriesExceeded = errors.New("max retries exceeded")
// downloadBlob downloads a blob from the registry and stores it in the blobs directory // downloadBlob downloads a blob from the registry and stores it in the blobs directory
func downloadBlob(ctx context.Context, opts downloadOpts) error { func downloadBlob(ctx context.Context, opts downloadOpts) error {
fp, err := GetBlobsPath(opts.digest) fp, err := GetBlobsPath(opts.digest)

View file

@ -63,15 +63,11 @@ func (m *Model) Prompt(request api.GenerateRequest) (string, error) {
First bool First bool
System string System string
Prompt string Prompt string
// deprecated: versions <= 0.0.7 used this to omit the system prompt
Context []int
} }
vars.First = len(request.Context) == 0 vars.First = len(request.Context) == 0
vars.System = m.System vars.System = m.System
vars.Prompt = request.Prompt vars.Prompt = request.Prompt
vars.Context = request.Context
if request.System != "" { if request.System != "" {
vars.System = request.System vars.System = request.System
@ -981,46 +977,7 @@ func PushModel(ctx context.Context, name string, regOpts *RegistryOptions, fn fu
layers = append(layers, &manifest.Config) layers = append(layers, &manifest.Config)
for _, layer := range layers { for _, layer := range layers {
exists, err := checkBlobExistence(ctx, mp, layer.Digest, regOpts) if err := uploadBlob(ctx, mp, layer, regOpts, fn); err != nil {
if err != nil {
return err
}
if exists {
fn(api.ProgressResponse{
Status: "using existing layer",
Digest: layer.Digest,
Total: layer.Size,
Completed: layer.Size,
})
log.Printf("Layer %s already exists", layer.Digest)
continue
}
fn(api.ProgressResponse{
Status: "starting upload",
Digest: layer.Digest,
Total: layer.Size,
})
location, chunkSize, err := startUpload(ctx, mp, layer, regOpts)
if err != nil {
log.Printf("couldn't start upload: %v", err)
return err
}
if strings.HasPrefix(filepath.Base(location.Path), "sha256:") {
layer.Digest = filepath.Base(location.Path)
fn(api.ProgressResponse{
Status: "using existing layer",
Digest: layer.Digest,
Total: layer.Size,
Completed: layer.Size,
})
continue
}
if err := uploadBlob(ctx, location, layer, chunkSize, regOpts, fn); err != nil {
log.Printf("error uploading blob: %v", err) log.Printf("error uploading blob: %v", err)
return err return err
} }
@ -1037,7 +994,7 @@ func PushModel(ctx context.Context, name string, regOpts *RegistryOptions, fn fu
headers := make(http.Header) headers := make(http.Header)
headers.Set("Content-Type", "application/vnd.docker.distribution.manifest.v2+json") headers.Set("Content-Type", "application/vnd.docker.distribution.manifest.v2+json")
resp, err := makeRequestWithRetry(ctx, "PUT", requestURL, headers, bytes.NewReader(manifestJSON), regOpts) resp, err := makeRequestWithRetry(ctx, http.MethodPut, requestURL, headers, bytes.NewReader(manifestJSON), regOpts)
if err != nil { if err != nil {
return err return err
} }
@ -1159,22 +1116,12 @@ func pullModelManifest(ctx context.Context, mp ModelPath, regOpts *RegistryOptio
headers := make(http.Header) headers := make(http.Header)
headers.Set("Accept", "application/vnd.docker.distribution.manifest.v2+json") headers.Set("Accept", "application/vnd.docker.distribution.manifest.v2+json")
resp, err := makeRequest(ctx, "GET", requestURL, headers, nil, regOpts) resp, err := makeRequestWithRetry(ctx, http.MethodGet, requestURL, headers, nil, regOpts)
if err != nil { if err != nil {
log.Printf("couldn't get manifest: %v", err)
return nil, err return nil, err
} }
defer resp.Body.Close() defer resp.Body.Close()
if resp.StatusCode >= http.StatusBadRequest {
if resp.StatusCode == http.StatusNotFound {
return nil, fmt.Errorf("model not found")
}
body, _ := io.ReadAll(resp.Body)
return nil, fmt.Errorf("on pull registry responded with code %d: %s", resp.StatusCode, body)
}
var m *ManifestV2 var m *ManifestV2
if err := json.NewDecoder(resp.Body).Decode(&m); err != nil { if err := json.NewDecoder(resp.Body).Decode(&m); err != nil {
return nil, err return nil, err
@ -1218,24 +1165,7 @@ func GetSHA256Digest(r io.Reader) (string, int64) {
return fmt.Sprintf("sha256:%x", h.Sum(nil)), n return fmt.Sprintf("sha256:%x", h.Sum(nil)), n
} }
// Function to check if a blob already exists in the Docker registry
func checkBlobExistence(ctx context.Context, mp ModelPath, digest string, regOpts *RegistryOptions) (bool, error) {
requestURL := mp.BaseURL()
requestURL = requestURL.JoinPath("v2", mp.GetNamespaceRepository(), "blobs", digest)
resp, err := makeRequest(ctx, "HEAD", requestURL, nil, nil, regOpts)
if err != nil {
log.Printf("couldn't check for blob: %v", err)
return false, err
}
defer resp.Body.Close()
// Check for success: If the blob exists, the Docker registry will respond with a 200 OK
return resp.StatusCode < http.StatusBadRequest, nil
}
func makeRequestWithRetry(ctx context.Context, method string, requestURL *url.URL, headers http.Header, body io.ReadSeeker, regOpts *RegistryOptions) (*http.Response, error) { func makeRequestWithRetry(ctx context.Context, method string, requestURL *url.URL, headers http.Header, body io.ReadSeeker, regOpts *RegistryOptions) (*http.Response, error) {
var status string
for try := 0; try < maxRetries; try++ { for try := 0; try < maxRetries; try++ {
resp, err := makeRequest(ctx, method, requestURL, headers, body, regOpts) resp, err := makeRequest(ctx, method, requestURL, headers, body, regOpts)
if err != nil { if err != nil {
@ -1243,8 +1173,6 @@ func makeRequestWithRetry(ctx context.Context, method string, requestURL *url.UR
return nil, err return nil, err
} }
status = resp.Status
switch { switch {
case resp.StatusCode == http.StatusUnauthorized: case resp.StatusCode == http.StatusUnauthorized:
auth := resp.Header.Get("www-authenticate") auth := resp.Header.Get("www-authenticate")
@ -1256,21 +1184,25 @@ func makeRequestWithRetry(ctx context.Context, method string, requestURL *url.UR
regOpts.Token = token regOpts.Token = token
if body != nil { if body != nil {
if _, err := body.Seek(0, io.SeekStart); err != nil { body.Seek(0, io.SeekStart)
return nil, err
}
} }
continue continue
case resp.StatusCode == http.StatusNotFound:
return nil, os.ErrNotExist
case resp.StatusCode >= http.StatusBadRequest: case resp.StatusCode >= http.StatusBadRequest:
body, _ := io.ReadAll(resp.Body) body, err := io.ReadAll(resp.Body)
return nil, fmt.Errorf("on upload registry responded with code %d: %s", resp.StatusCode, body) if err != nil {
return nil, fmt.Errorf("%d: %s", resp.StatusCode, err)
}
return nil, fmt.Errorf("%d: %s", resp.StatusCode, body)
default: default:
return resp, nil return resp, nil
} }
} }
return nil, fmt.Errorf("max retry exceeded: %v", status) return nil, errMaxRetriesExceeded
} }
func makeRequest(ctx context.Context, method string, requestURL *url.URL, headers http.Header, body io.Reader, regOpts *RegistryOptions) (*http.Response, error) { func makeRequest(ctx context.Context, method string, requestURL *url.URL, headers http.Header, body io.Reader, regOpts *RegistryOptions) (*http.Response, error) {

View file

@ -365,7 +365,9 @@ func PushModelHandler(c *gin.Context) {
Insecure: req.Insecure, Insecure: req.Insecure,
} }
ctx := context.Background() ctx, cancel := context.WithCancel(c.Request.Context())
defer cancel()
if err := PushModel(ctx, req.Name, regOpts, fn); err != nil { if err := PushModel(ctx, req.Name, regOpts, fn); err != nil {
ch <- gin.H{"error": err.Error()} ch <- gin.H{"error": err.Error()}
} }
@ -614,6 +616,22 @@ var defaultAllowOrigins = []string{
} }
func Serve(ln net.Listener, allowOrigins []string) error { func Serve(ln net.Listener, allowOrigins []string) error {
if noprune := os.Getenv("OLLAMA_NOPRUNE"); noprune == "" {
// clean up unused layers and manifests
if err := PruneLayers(); err != nil {
return err
}
manifestsPath, err := GetManifestPath()
if err != nil {
return err
}
if err := PruneDirectory(manifestsPath); err != nil {
return err
}
}
config := cors.DefaultConfig() config := cors.DefaultConfig()
config.AllowWildcard = true config.AllowWildcard = true
@ -679,7 +697,7 @@ func Serve(ln net.Listener, allowOrigins []string) error {
if runtime.GOOS == "linux" { if runtime.GOOS == "linux" {
// check compatibility to log warnings // check compatibility to log warnings
if _, err := llm.CheckVRAM(); err != nil { if _, err := llm.CheckVRAM(); err != nil {
log.Printf("Warning: GPU support may not enabled, check you have installed install GPU drivers: %v", err) log.Printf("Warning: GPU support may not be enabled, check you have installed GPU drivers: %v", err)
} }
} }

View file

@ -2,218 +2,367 @@ package server
import ( import (
"context" "context"
"crypto/md5"
"errors" "errors"
"fmt" "fmt"
"hash"
"io" "io"
"log" "log"
"net/http" "net/http"
"net/url" "net/url"
"os" "os"
"strconv" "strings"
"sync" "sync"
"sync/atomic"
"time"
"github.com/jmorganca/ollama/api" "github.com/jmorganca/ollama/api"
"github.com/jmorganca/ollama/format"
"golang.org/x/sync/errgroup"
) )
var blobUploadManager sync.Map
type blobUpload struct {
*Layer
Total int64
Completed atomic.Int64
Parts []blobUploadPart
nextURL chan *url.URL
context.CancelFunc
done bool
err error
references atomic.Int32
}
type blobUploadPart struct {
// N is the part number
N int
Offset int64
Size int64
hash.Hash
}
const ( const (
redirectChunkSize int64 = 1024 * 1024 * 1024 numUploadParts = 64
regularChunkSize int64 = 95 * 1024 * 1024 minUploadPartSize int64 = 95 * 1000 * 1000
maxUploadPartSize int64 = 1000 * 1000 * 1000
) )
func startUpload(ctx context.Context, mp ModelPath, layer *Layer, regOpts *RegistryOptions) (*url.URL, int64, error) { func (b *blobUpload) Prepare(ctx context.Context, requestURL *url.URL, opts *RegistryOptions) error {
requestURL := mp.BaseURL() p, err := GetBlobsPath(b.Digest)
requestURL = requestURL.JoinPath("v2", mp.GetNamespaceRepository(), "blobs/uploads/") if err != nil {
if layer.From != "" { return err
}
if b.From != "" {
values := requestURL.Query() values := requestURL.Query()
values.Add("mount", layer.Digest) values.Add("mount", b.Digest)
values.Add("from", layer.From) values.Add("from", b.From)
requestURL.RawQuery = values.Encode() requestURL.RawQuery = values.Encode()
} }
resp, err := makeRequestWithRetry(ctx, "POST", requestURL, nil, nil, regOpts) resp, err := makeRequestWithRetry(ctx, http.MethodPost, requestURL, nil, nil, opts)
if err != nil { if err != nil {
log.Printf("couldn't start upload: %v", err) return err
return nil, 0, err
} }
defer resp.Body.Close() defer resp.Body.Close()
location := resp.Header.Get("Docker-Upload-Location") location := resp.Header.Get("Docker-Upload-Location")
chunkSize := redirectChunkSize
if location == "" { if location == "" {
location = resp.Header.Get("Location") location = resp.Header.Get("Location")
chunkSize = regularChunkSize
} }
locationURL, err := url.Parse(location) fi, err := os.Stat(p)
if err != nil { if err != nil {
return nil, 0, err return err
} }
return locationURL, chunkSize, nil b.Total = fi.Size()
var size = b.Total / numUploadParts
switch {
case size < minUploadPartSize:
size = minUploadPartSize
case size > maxUploadPartSize:
size = maxUploadPartSize
}
var offset int64
for offset < fi.Size() {
if offset+size > fi.Size() {
size = fi.Size() - offset
}
// set part.N to the current number of parts
b.Parts = append(b.Parts, blobUploadPart{N: len(b.Parts), Offset: offset, Size: size, Hash: md5.New()})
offset += size
}
log.Printf("uploading %s in %d %s part(s)", b.Digest[7:19], len(b.Parts), format.HumanBytes(b.Parts[0].Size))
requestURL, err = url.Parse(location)
if err != nil {
return err
}
b.nextURL = make(chan *url.URL, 1)
b.nextURL <- requestURL
return nil
} }
func uploadBlob(ctx context.Context, requestURL *url.URL, layer *Layer, chunkSize int64, regOpts *RegistryOptions, fn func(api.ProgressResponse)) error { // Run uploads blob parts to the upstream. If the upstream supports redirection, parts will be uploaded
// TODO allow resumability // in parallel as defined by Prepare. Otherwise, parts will be uploaded serially. Run sets b.err on error.
// TODO allow canceling uploads via DELETE func (b *blobUpload) Run(ctx context.Context, opts *RegistryOptions) {
defer blobUploadManager.Delete(b.Digest)
ctx, b.CancelFunc = context.WithCancel(ctx)
fp, err := GetBlobsPath(layer.Digest) p, err := GetBlobsPath(b.Digest)
if err != nil { if err != nil {
return err b.err = err
return
} }
f, err := os.Open(fp) f, err := os.Open(p)
if err != nil { if err != nil {
return err b.err = err
return
} }
defer f.Close() defer f.Close()
pw := ProgressWriter{ g, inner := errgroup.WithContext(ctx)
status: fmt.Sprintf("uploading %s", layer.Digest), g.SetLimit(numUploadParts)
digest: layer.Digest, for i := range b.Parts {
total: layer.Size, part := &b.Parts[i]
fn: fn, select {
} case <-inner.Done():
case requestURL := <-b.nextURL:
g.Go(func() error {
for try := 0; try < maxRetries; try++ {
r := io.NewSectionReader(f, part.Offset, part.Size)
err := b.uploadChunk(inner, http.MethodPatch, requestURL, r, part, opts)
switch {
case errors.Is(err, context.Canceled):
return err
case errors.Is(err, errMaxRetriesExceeded):
return err
case err != nil:
log.Printf("%s part %d attempt %d failed: %v, retrying", b.Digest[7:19], part.N, try, err)
continue
}
for offset := int64(0); offset < layer.Size; { return nil
chunk := layer.Size - offset }
if chunk > chunkSize {
chunk = chunkSize
}
resp, err := uploadBlobChunk(ctx, http.MethodPatch, requestURL, f, offset, chunk, regOpts, &pw) return errMaxRetriesExceeded
if err != nil {
fn(api.ProgressResponse{
Status: fmt.Sprintf("error uploading chunk: %v", err),
Digest: layer.Digest,
Total: layer.Size,
Completed: offset,
}) })
return err
}
offset += chunk
location := resp.Header.Get("Docker-Upload-Location")
if location == "" {
location = resp.Header.Get("Location")
}
requestURL, err = url.Parse(location)
if err != nil {
return err
} }
} }
if err := g.Wait(); err != nil {
b.err = err
return
}
requestURL := <-b.nextURL
var sb strings.Builder
for _, part := range b.Parts {
sb.Write(part.Sum(nil))
}
md5sum := md5.Sum([]byte(sb.String()))
values := requestURL.Query() values := requestURL.Query()
values.Add("digest", layer.Digest) values.Add("digest", b.Digest)
values.Add("etag", fmt.Sprintf("%x-%d", md5sum, len(b.Parts)))
requestURL.RawQuery = values.Encode() requestURL.RawQuery = values.Encode()
headers := make(http.Header) headers := make(http.Header)
headers.Set("Content-Type", "application/octet-stream") headers.Set("Content-Type", "application/octet-stream")
headers.Set("Content-Length", "0") headers.Set("Content-Length", "0")
// finish the upload resp, err := makeRequestWithRetry(ctx, http.MethodPut, requestURL, headers, nil, opts)
resp, err := makeRequest(ctx, "PUT", requestURL, headers, nil, regOpts) if err != nil {
b.err = err
return
}
defer resp.Body.Close()
b.done = true
}
func (b *blobUpload) uploadChunk(ctx context.Context, method string, requestURL *url.URL, rs io.ReadSeeker, part *blobUploadPart, opts *RegistryOptions) error {
headers := make(http.Header)
headers.Set("Content-Type", "application/octet-stream")
headers.Set("Content-Length", fmt.Sprintf("%d", part.Size))
headers.Set("X-Redirect-Uploads", "1")
if method == http.MethodPatch {
headers.Set("Content-Range", fmt.Sprintf("%d-%d", part.Offset, part.Offset+part.Size-1))
}
buw := blobUploadWriter{blobUpload: b}
resp, err := makeRequest(ctx, method, requestURL, headers, io.TeeReader(rs, io.MultiWriter(&buw, part.Hash)), opts)
if err != nil { if err != nil {
log.Printf("couldn't finish upload: %v", err)
return err return err
} }
defer resp.Body.Close() defer resp.Body.Close()
if resp.StatusCode >= http.StatusBadRequest { location := resp.Header.Get("Docker-Upload-Location")
body, _ := io.ReadAll(resp.Body) if location == "" {
return fmt.Errorf("on finish upload registry responded with code %d: %v", resp.StatusCode, string(body)) location = resp.Header.Get("Location")
}
return nil
}
func uploadBlobChunk(ctx context.Context, method string, requestURL *url.URL, r io.ReaderAt, offset, limit int64, opts *RegistryOptions, pw *ProgressWriter) (*http.Response, error) {
sectionReader := io.NewSectionReader(r, offset, limit)
headers := make(http.Header)
headers.Set("Content-Type", "application/octet-stream")
headers.Set("Content-Length", strconv.Itoa(int(limit)))
headers.Set("X-Redirect-Uploads", "1")
if method == http.MethodPatch {
headers.Set("Content-Range", fmt.Sprintf("%d-%d", offset, offset+sectionReader.Size()-1))
} }
for try := 0; try < maxRetries; try++ { nextURL, err := url.Parse(location)
resp, err := makeRequest(ctx, method, requestURL, headers, io.TeeReader(sectionReader, pw), opts) if err != nil {
if err != nil && !errors.Is(err, io.EOF) { return err
return nil, err }
switch {
case resp.StatusCode == http.StatusTemporaryRedirect:
b.nextURL <- nextURL
redirectURL, err := resp.Location()
if err != nil {
return err
} }
defer resp.Body.Close()
switch { for try := 0; try < maxRetries; try++ {
case resp.StatusCode == http.StatusTemporaryRedirect: rs.Seek(0, io.SeekStart)
location, err := resp.Location() b.Completed.Add(-buw.written)
if err != nil { buw.written = 0
return nil, err part.Hash = md5.New()
} err := b.uploadChunk(ctx, http.MethodPut, redirectURL, rs, part, nil)
switch {
pw.completed = offset case errors.Is(err, context.Canceled):
if _, err := uploadBlobChunk(ctx, http.MethodPut, location, r, offset, limit, nil, pw); err != nil { return err
// retry case errors.Is(err, errMaxRetriesExceeded):
log.Printf("retrying redirected upload: %v", err) return err
case err != nil:
log.Printf("%s part %d attempt %d failed: %v, retrying", b.Digest[7:19], part.N, try, err)
continue continue
} }
return resp, nil return nil
case resp.StatusCode == http.StatusUnauthorized:
auth := resp.Header.Get("www-authenticate")
authRedir := ParseAuthRedirectString(auth)
token, err := getAuthToken(ctx, authRedir)
if err != nil {
return nil, err
}
opts.Token = token
pw.completed = offset
sectionReader = io.NewSectionReader(r, offset, limit)
continue
case resp.StatusCode >= http.StatusBadRequest:
body, _ := io.ReadAll(resp.Body)
return nil, fmt.Errorf("on upload registry responded with code %d: %s", resp.StatusCode, body)
} }
return resp, nil return errMaxRetriesExceeded
case resp.StatusCode == http.StatusUnauthorized:
auth := resp.Header.Get("www-authenticate")
authRedir := ParseAuthRedirectString(auth)
token, err := getAuthToken(ctx, authRedir)
if err != nil {
return err
}
opts.Token = token
fallthrough
case resp.StatusCode >= http.StatusBadRequest:
body, err := io.ReadAll(resp.Body)
if err != nil {
return err
}
rs.Seek(0, io.SeekStart)
b.Completed.Add(-buw.written)
buw.written = 0
return fmt.Errorf("http status %d %s: %s", resp.StatusCode, resp.Status, body)
} }
return nil, fmt.Errorf("max retries exceeded") if method == http.MethodPatch {
b.nextURL <- nextURL
}
return nil
} }
type ProgressWriter struct { func (b *blobUpload) acquire() {
status string b.references.Add(1)
digest string
bucket int64
completed int64
total int64
fn func(api.ProgressResponse)
mu sync.Mutex
} }
func (pw *ProgressWriter) Write(b []byte) (int, error) { func (b *blobUpload) release() {
pw.mu.Lock() if b.references.Add(-1) == 0 {
defer pw.mu.Unlock() b.CancelFunc()
}
}
n := len(b) func (b *blobUpload) Wait(ctx context.Context, fn func(api.ProgressResponse)) error {
pw.bucket += int64(n) b.acquire()
defer b.release()
// throttle status updates to not spam the client ticker := time.NewTicker(60 * time.Millisecond)
if pw.bucket >= 1024*1024 || pw.completed+pw.bucket >= pw.total { for {
pw.completed += pw.bucket select {
pw.fn(api.ProgressResponse{ case <-ticker.C:
Status: pw.status, case <-ctx.Done():
Digest: pw.digest, return ctx.Err()
Total: pw.total, }
Completed: pw.completed,
fn(api.ProgressResponse{
Status: fmt.Sprintf("uploading %s", b.Digest),
Digest: b.Digest,
Total: b.Total,
Completed: b.Completed.Load(),
}) })
pw.bucket = 0 if b.done || b.err != nil {
return b.err
}
} }
}
type blobUploadWriter struct {
written int64
*blobUpload
}
func (b *blobUploadWriter) Write(p []byte) (n int, err error) {
n = len(p)
b.written += int64(n)
b.Completed.Add(int64(n))
return n, nil return n, nil
} }
func uploadBlob(ctx context.Context, mp ModelPath, layer *Layer, opts *RegistryOptions, fn func(api.ProgressResponse)) error {
requestURL := mp.BaseURL()
requestURL = requestURL.JoinPath("v2", mp.GetNamespaceRepository(), "blobs", layer.Digest)
resp, err := makeRequestWithRetry(ctx, http.MethodHead, requestURL, nil, nil, opts)
switch {
case errors.Is(err, os.ErrNotExist):
case err != nil:
return err
default:
defer resp.Body.Close()
fn(api.ProgressResponse{
Status: fmt.Sprintf("uploading %s", layer.Digest),
Digest: layer.Digest,
Total: layer.Size,
Completed: layer.Size,
})
return nil
}
data, ok := blobUploadManager.LoadOrStore(layer.Digest, &blobUpload{Layer: layer})
upload := data.(*blobUpload)
if !ok {
requestURL := mp.BaseURL()
requestURL = requestURL.JoinPath("v2", mp.GetNamespaceRepository(), "blobs/uploads/")
if err := upload.Prepare(ctx, requestURL, opts); err != nil {
blobUploadManager.Delete(layer.Digest)
return err
}
go upload.Run(context.Background(), opts)
}
return upload.Wait(ctx, fn)
}