From f397e0e988272ffd14bdfb6c4070bb3ab5328df2 Mon Sep 17 00:00:00 2001 From: Daniel Hiltgen Date: Mon, 5 Feb 2024 12:59:52 -0800 Subject: [PATCH] Move hub auth out to new package --- {server => auth}/auth.go | 68 ++++++++++++++++++-------------- auth/request.go | 72 ++++++++++++++++++++++++++++++++++ server/download.go | 11 +++--- server/images.go | 84 ++++++---------------------------------- server/routes.go | 5 ++- server/upload.go | 17 ++++---- 6 files changed, 142 insertions(+), 115 deletions(-) rename {server => auth}/auth.go (87%) create mode 100644 auth/request.go diff --git a/server/auth.go b/auth/auth.go similarity index 87% rename from server/auth.go rename to auth/auth.go index 0d09668d..c0ce0a52 100644 --- a/server/auth.go +++ b/auth/auth.go @@ -1,4 +1,4 @@ -package server +package auth import ( "bytes" @@ -24,6 +24,10 @@ import ( "github.com/jmorganca/ollama/api" ) +const ( + KeyType = "id_ed25519" +) + type AuthRedirect struct { Realm string Service string @@ -71,39 +75,47 @@ func (r AuthRedirect) URL() (*url.URL, error) { return redirectURL, nil } -func getAuthToken(ctx context.Context, redirData AuthRedirect) (string, error) { +func SignRequest(method, url string, data []byte, headers http.Header) error { + home, err := os.UserHomeDir() + if err != nil { + return err + } + + keyPath := filepath.Join(home, ".ollama", KeyType) + + rawKey, err := os.ReadFile(keyPath) + if err != nil { + slog.Info(fmt.Sprintf("Failed to load private key: %v", err)) + return err + } + + s := SignatureData{ + Method: method, + Path: url, + Data: data, + } + + sig, err := s.Sign(rawKey) + if err != nil { + return err + } + + headers.Set("Authorization", sig) + return nil +} + +func GetAuthToken(ctx context.Context, redirData AuthRedirect) (string, error) { redirectURL, err := redirData.URL() if err != nil { return "", err } - home, err := os.UserHomeDir() - if err != nil { - return "", err - } - - keyPath := filepath.Join(home, ".ollama", "id_ed25519") - - rawKey, err := os.ReadFile(keyPath) - if err != nil { - slog.Info(fmt.Sprintf("Failed to load private key: %v", err)) - return "", err - } - - s := SignatureData{ - Method: http.MethodGet, - Path: redirectURL.String(), - Data: nil, - } - - sig, err := s.Sign(rawKey) - if err != nil { - return "", err - } - headers := make(http.Header) - headers.Set("Authorization", sig) - resp, err := makeRequest(ctx, http.MethodGet, redirectURL, headers, nil, nil) + err = SignRequest(http.MethodGet, redirectURL.String(), nil, headers) + if err != nil { + return "", err + } + resp, err := MakeRequest(ctx, http.MethodGet, redirectURL, headers, nil, nil) if err != nil { slog.Info(fmt.Sprintf("couldn't get token: %q", err)) return "", err diff --git a/auth/request.go b/auth/request.go new file mode 100644 index 00000000..ab863fe3 --- /dev/null +++ b/auth/request.go @@ -0,0 +1,72 @@ +package auth + +import ( + "context" + "fmt" + "io" + "net/http" + "net/url" + "runtime" + "strconv" + + "github.com/jmorganca/ollama/version" +) + +type RegistryOptions struct { + Insecure bool + Username string + Password string + Token string +} + +func MakeRequest(ctx context.Context, method string, requestURL *url.URL, headers http.Header, body io.Reader, regOpts *RegistryOptions) (*http.Response, error) { + if requestURL.Scheme != "http" && regOpts != nil && regOpts.Insecure { + requestURL.Scheme = "http" + } + + req, err := http.NewRequestWithContext(ctx, method, requestURL.String(), body) + if err != nil { + return nil, err + } + + if headers != nil { + req.Header = headers + } + + if regOpts != nil { + if regOpts.Token != "" { + req.Header.Set("Authorization", "Bearer "+regOpts.Token) + } else if regOpts.Username != "" && regOpts.Password != "" { + req.SetBasicAuth(regOpts.Username, regOpts.Password) + } + } + + req.Header.Set("User-Agent", fmt.Sprintf("ollama/%s (%s %s) Go/%s", version.Version, runtime.GOARCH, runtime.GOOS, runtime.Version())) + + if s := req.Header.Get("Content-Length"); s != "" { + contentLength, err := strconv.ParseInt(s, 10, 64) + if err != nil { + return nil, err + } + + req.ContentLength = contentLength + } + + proxyURL, err := http.ProxyFromEnvironment(req) + if err != nil { + return nil, err + } + + client := http.Client{ + Transport: &http.Transport{ + Proxy: http.ProxyURL(proxyURL), + }, + } + + resp, err := client.Do(req) + if err != nil { + return nil, err + } + + return resp, nil +} diff --git a/server/download.go b/server/download.go index f089bd41..dbfba2dd 100644 --- a/server/download.go +++ b/server/download.go @@ -22,6 +22,7 @@ import ( "golang.org/x/sync/errgroup" "github.com/jmorganca/ollama/api" + "github.com/jmorganca/ollama/auth" "github.com/jmorganca/ollama/format" ) @@ -85,7 +86,7 @@ func (p *blobDownloadPart) Write(b []byte) (n int, err error) { return n, nil } -func (b *blobDownload) Prepare(ctx context.Context, requestURL *url.URL, opts *RegistryOptions) error { +func (b *blobDownload) Prepare(ctx context.Context, requestURL *url.URL, opts *auth.RegistryOptions) error { partFilePaths, err := filepath.Glob(b.Name + "-partial-*") if err != nil { return err @@ -137,11 +138,11 @@ func (b *blobDownload) Prepare(ctx context.Context, requestURL *url.URL, opts *R return nil } -func (b *blobDownload) Run(ctx context.Context, requestURL *url.URL, opts *RegistryOptions) { +func (b *blobDownload) Run(ctx context.Context, requestURL *url.URL, opts *auth.RegistryOptions) { b.err = b.run(ctx, requestURL, opts) } -func (b *blobDownload) run(ctx context.Context, requestURL *url.URL, opts *RegistryOptions) error { +func (b *blobDownload) run(ctx context.Context, requestURL *url.URL, opts *auth.RegistryOptions) error { defer blobDownloadManager.Delete(b.Digest) ctx, b.CancelFunc = context.WithCancel(ctx) @@ -210,7 +211,7 @@ func (b *blobDownload) run(ctx context.Context, requestURL *url.URL, opts *Regis return nil } -func (b *blobDownload) downloadChunk(ctx context.Context, requestURL *url.URL, w io.Writer, part *blobDownloadPart, opts *RegistryOptions) error { +func (b *blobDownload) downloadChunk(ctx context.Context, requestURL *url.URL, w io.Writer, part *blobDownloadPart, opts *auth.RegistryOptions) error { g, ctx := errgroup.WithContext(ctx) g.Go(func() error { headers := make(http.Header) @@ -334,7 +335,7 @@ func (b *blobDownload) Wait(ctx context.Context, fn func(api.ProgressResponse)) type downloadOpts struct { mp ModelPath digest string - regOpts *RegistryOptions + regOpts *auth.RegistryOptions fn func(api.ProgressResponse) } diff --git a/server/images.go b/server/images.go index 55b68456..8a70cdd5 100644 --- a/server/images.go +++ b/server/images.go @@ -16,25 +16,17 @@ import ( "os" "path/filepath" "runtime" - "strconv" "strings" "text/template" "golang.org/x/exp/slices" "github.com/jmorganca/ollama/api" + "github.com/jmorganca/ollama/auth" "github.com/jmorganca/ollama/llm" "github.com/jmorganca/ollama/parser" - "github.com/jmorganca/ollama/version" ) -type RegistryOptions struct { - Insecure bool - Username string - Password string - Token string -} - type Model struct { Name string `json:"name"` Config ConfigV2 @@ -320,7 +312,7 @@ func CreateModel(ctx context.Context, name, modelFileDir string, commands []pars switch { case errors.Is(err, os.ErrNotExist): fn(api.ProgressResponse{Status: "pulling model"}) - if err := PullModel(ctx, c.Args, &RegistryOptions{}, fn); err != nil { + if err := PullModel(ctx, c.Args, &auth.RegistryOptions{}, fn); err != nil { return err } @@ -840,7 +832,7 @@ PARAMETER {{ $k }} {{ printf "%#v" $parameter }} return buf.String(), nil } -func PushModel(ctx context.Context, name string, regOpts *RegistryOptions, fn func(api.ProgressResponse)) error { +func PushModel(ctx context.Context, name string, regOpts *auth.RegistryOptions, fn func(api.ProgressResponse)) error { mp := ParseModelPath(name) fn(api.ProgressResponse{Status: "retrieving manifest"}) @@ -890,7 +882,7 @@ func PushModel(ctx context.Context, name string, regOpts *RegistryOptions, fn fu return nil } -func PullModel(ctx context.Context, name string, regOpts *RegistryOptions, fn func(api.ProgressResponse)) error { +func PullModel(ctx context.Context, name string, regOpts *auth.RegistryOptions, fn func(api.ProgressResponse)) error { mp := ParseModelPath(name) var manifest *ManifestV2 @@ -996,7 +988,7 @@ func PullModel(ctx context.Context, name string, regOpts *RegistryOptions, fn fu return nil } -func pullModelManifest(ctx context.Context, mp ModelPath, regOpts *RegistryOptions) (*ManifestV2, error) { +func pullModelManifest(ctx context.Context, mp ModelPath, regOpts *auth.RegistryOptions) (*ManifestV2, error) { requestURL := mp.BaseURL().JoinPath("v2", mp.GetNamespaceRepository(), "manifests", mp.Tag) headers := make(http.Header) @@ -1028,9 +1020,9 @@ func GetSHA256Digest(r io.Reader) (string, int64) { var errUnauthorized = fmt.Errorf("unauthorized") -func makeRequestWithRetry(ctx context.Context, method string, requestURL *url.URL, headers http.Header, body io.ReadSeeker, regOpts *RegistryOptions) (*http.Response, error) { +func makeRequestWithRetry(ctx context.Context, method string, requestURL *url.URL, headers http.Header, body io.ReadSeeker, regOpts *auth.RegistryOptions) (*http.Response, error) { for i := 0; i < 2; i++ { - resp, err := makeRequest(ctx, method, requestURL, headers, body, regOpts) + resp, err := auth.MakeRequest(ctx, method, requestURL, headers, body, regOpts) if err != nil { if !errors.Is(err, context.Canceled) { slog.Info(fmt.Sprintf("request failed: %v", err)) @@ -1042,9 +1034,9 @@ func makeRequestWithRetry(ctx context.Context, method string, requestURL *url.UR switch { case resp.StatusCode == http.StatusUnauthorized: // Handle authentication error with one retry - auth := resp.Header.Get("www-authenticate") - authRedir := ParseAuthRedirectString(auth) - token, err := getAuthToken(ctx, authRedir) + authenticate := resp.Header.Get("www-authenticate") + authRedir := ParseAuthRedirectString(authenticate) + token, err := auth.GetAuthToken(ctx, authRedir) if err != nil { return nil, err } @@ -1071,58 +1063,6 @@ func makeRequestWithRetry(ctx context.Context, method string, requestURL *url.UR return nil, errUnauthorized } -func makeRequest(ctx context.Context, method string, requestURL *url.URL, headers http.Header, body io.Reader, regOpts *RegistryOptions) (*http.Response, error) { - if requestURL.Scheme != "http" && regOpts != nil && regOpts.Insecure { - requestURL.Scheme = "http" - } - - req, err := http.NewRequestWithContext(ctx, method, requestURL.String(), body) - if err != nil { - return nil, err - } - - if headers != nil { - req.Header = headers - } - - if regOpts != nil { - if regOpts.Token != "" { - req.Header.Set("Authorization", "Bearer "+regOpts.Token) - } else if regOpts.Username != "" && regOpts.Password != "" { - req.SetBasicAuth(regOpts.Username, regOpts.Password) - } - } - - req.Header.Set("User-Agent", fmt.Sprintf("ollama/%s (%s %s) Go/%s", version.Version, runtime.GOARCH, runtime.GOOS, runtime.Version())) - - if s := req.Header.Get("Content-Length"); s != "" { - contentLength, err := strconv.ParseInt(s, 10, 64) - if err != nil { - return nil, err - } - - req.ContentLength = contentLength - } - - proxyURL, err := http.ProxyFromEnvironment(req) - if err != nil { - return nil, err - } - - client := http.Client{ - Transport: &http.Transport{ - Proxy: http.ProxyURL(proxyURL), - }, - } - - resp, err := client.Do(req) - if err != nil { - return nil, err - } - - return resp, nil -} - func getValue(header, key string) string { startIdx := strings.Index(header, key+"=") if startIdx == -1 { @@ -1146,10 +1086,10 @@ func getValue(header, key string) string { return header[startIdx:endIdx] } -func ParseAuthRedirectString(authStr string) AuthRedirect { +func ParseAuthRedirectString(authStr string) auth.AuthRedirect { authStr = strings.TrimPrefix(authStr, "Bearer ") - return AuthRedirect{ + return auth.AuthRedirect{ Realm: getValue(authStr, "realm"), Service: getValue(authStr, "service"), Scope: getValue(authStr, "scope"), diff --git a/server/routes.go b/server/routes.go index bd943ee1..ddf22e78 100644 --- a/server/routes.go +++ b/server/routes.go @@ -25,6 +25,7 @@ import ( "golang.org/x/exp/slices" "github.com/jmorganca/ollama/api" + "github.com/jmorganca/ollama/auth" "github.com/jmorganca/ollama/gpu" "github.com/jmorganca/ollama/llm" "github.com/jmorganca/ollama/openai" @@ -479,7 +480,7 @@ func PullModelHandler(c *gin.Context) { ch <- r } - regOpts := &RegistryOptions{ + regOpts := &auth.RegistryOptions{ Insecure: req.Insecure, } @@ -528,7 +529,7 @@ func PushModelHandler(c *gin.Context) { ch <- r } - regOpts := &RegistryOptions{ + regOpts := &auth.RegistryOptions{ Insecure: req.Insecure, } diff --git a/server/upload.go b/server/upload.go index 3609b308..525b27b8 100644 --- a/server/upload.go +++ b/server/upload.go @@ -18,6 +18,7 @@ import ( "time" "github.com/jmorganca/ollama/api" + "github.com/jmorganca/ollama/auth" "github.com/jmorganca/ollama/format" "golang.org/x/sync/errgroup" ) @@ -49,7 +50,7 @@ const ( maxUploadPartSize int64 = 1000 * format.MegaByte ) -func (b *blobUpload) Prepare(ctx context.Context, requestURL *url.URL, opts *RegistryOptions) error { +func (b *blobUpload) Prepare(ctx context.Context, requestURL *url.URL, opts *auth.RegistryOptions) error { p, err := GetBlobsPath(b.Digest) if err != nil { return err @@ -121,7 +122,7 @@ func (b *blobUpload) Prepare(ctx context.Context, requestURL *url.URL, opts *Reg // Run uploads blob parts to the upstream. If the upstream supports redirection, parts will be uploaded // in parallel as defined by Prepare. Otherwise, parts will be uploaded serially. Run sets b.err on error. -func (b *blobUpload) Run(ctx context.Context, opts *RegistryOptions) { +func (b *blobUpload) Run(ctx context.Context, opts *auth.RegistryOptions) { defer blobUploadManager.Delete(b.Digest) ctx, b.CancelFunc = context.WithCancel(ctx) @@ -212,7 +213,7 @@ func (b *blobUpload) Run(ctx context.Context, opts *RegistryOptions) { b.done = true } -func (b *blobUpload) uploadPart(ctx context.Context, method string, requestURL *url.URL, part *blobUploadPart, opts *RegistryOptions) error { +func (b *blobUpload) uploadPart(ctx context.Context, method string, requestURL *url.URL, part *blobUploadPart, opts *auth.RegistryOptions) error { headers := make(http.Header) headers.Set("Content-Type", "application/octet-stream") headers.Set("Content-Length", fmt.Sprintf("%d", part.Size)) @@ -227,7 +228,7 @@ func (b *blobUpload) uploadPart(ctx context.Context, method string, requestURL * md5sum := md5.New() w := &progressWriter{blobUpload: b} - resp, err := makeRequest(ctx, method, requestURL, headers, io.TeeReader(sr, io.MultiWriter(w, md5sum)), opts) + resp, err := auth.MakeRequest(ctx, method, requestURL, headers, io.TeeReader(sr, io.MultiWriter(w, md5sum)), opts) if err != nil { w.Rollback() return err @@ -277,9 +278,9 @@ func (b *blobUpload) uploadPart(ctx context.Context, method string, requestURL * case resp.StatusCode == http.StatusUnauthorized: w.Rollback() - auth := resp.Header.Get("www-authenticate") - authRedir := ParseAuthRedirectString(auth) - token, err := getAuthToken(ctx, authRedir) + authenticate := resp.Header.Get("www-authenticate") + authRedir := ParseAuthRedirectString(authenticate) + token, err := auth.GetAuthToken(ctx, authRedir) if err != nil { return err } @@ -364,7 +365,7 @@ func (p *progressWriter) Rollback() { p.written = 0 } -func uploadBlob(ctx context.Context, mp ModelPath, layer *Layer, opts *RegistryOptions, fn func(api.ProgressResponse)) error { +func uploadBlob(ctx context.Context, mp ModelPath, layer *Layer, opts *auth.RegistryOptions, fn func(api.ProgressResponse)) error { requestURL := mp.BaseURL() requestURL = requestURL.JoinPath("v2", mp.GetNamespaceRepository(), "blobs", layer.Digest)