From f0f49435771352c4d1e432351675d10d9e23c099 Mon Sep 17 00:00:00 2001 From: Michael Yang Date: Thu, 7 Sep 2023 11:49:36 -0700 Subject: [PATCH 1/2] fix get auth token --- server/auth.go | 2 +- server/images.go | 10 ++++++---- 2 files changed, 7 insertions(+), 5 deletions(-) diff --git a/server/auth.go b/server/auth.go index 3e35178f..4238b252 100644 --- a/server/auth.go +++ b/server/auth.go @@ -103,7 +103,7 @@ func getAuthToken(ctx context.Context, redirData AuthRedirect, regOpts *Registry headers := make(http.Header) headers.Set("Authorization", sig) - resp, err := makeRequest(ctx, "GET", redirectURL, headers, nil, regOpts) + resp, err := makeRequest(ctx, "GET", redirectURL, headers, nil, nil) if err != nil { log.Printf("couldn't get token: %q", err) } diff --git a/server/images.go b/server/images.go index 1356c9e9..91819910 100644 --- a/server/images.go +++ b/server/images.go @@ -1313,10 +1313,12 @@ func makeRequest(ctx context.Context, method string, requestURL *url.URL, header req.Header = headers } - if regOpts.Token != "" { - req.Header.Set("Authorization", "Bearer "+regOpts.Token) - } else if regOpts.Username != "" && regOpts.Password != "" { - req.SetBasicAuth(regOpts.Username, regOpts.Password) + if regOpts != nil { + if regOpts.Token != "" { + req.Header.Set("Authorization", "Bearer "+regOpts.Token) + } else if regOpts.Username != "" && regOpts.Password != "" { + req.SetBasicAuth(regOpts.Username, regOpts.Password) + } } req.Header.Set("User-Agent", fmt.Sprintf("ollama/%s (%s %s) Go/%s", version.Version, runtime.GOARCH, runtime.GOOS, runtime.Version())) From bf146fb072b8dbf49efa2f874959ad978c48bf29 Mon Sep 17 00:00:00 2001 From: Michael Yang Date: Thu, 7 Sep 2023 12:01:50 -0700 Subject: [PATCH 2/2] fix retry on unauthorized chunk --- server/upload.go | 43 +++++++++++++++++++++++++------------------ 1 file changed, 25 insertions(+), 18 deletions(-) diff --git a/server/upload.go b/server/upload.go index cf51eea6..a8c62827 100644 --- a/server/upload.go +++ b/server/upload.go @@ -66,31 +66,39 @@ func uploadBlobChunked(ctx context.Context, requestURL *url.URL, layer *Layer, r sectionReader := io.NewSectionReader(f, int64(offset), chunk) for try := 0; try < MaxRetries; try++ { + ch := make(chan error, 1) + r, w := io.Pipe() defer r.Close() go func() { defer w.Close() for chunked := int64(0); chunked < chunk; { - n, err := io.CopyN(w, sectionReader, 1024*1024) - if err != nil && !errors.Is(err, io.EOF) { + select { + case err := <-ch: + log.Printf("chunk interrupted: %v", err) + return + default: + n, err := io.CopyN(w, sectionReader, 1024*1024) + if err != nil && !errors.Is(err, io.EOF) { + fn(api.ProgressResponse{ + Status: fmt.Sprintf("error reading chunk: %v", err), + Digest: layer.Digest, + Total: layer.Size, + Completed: int(offset), + }) + + return + } + + chunked += n fn(api.ProgressResponse{ - Status: fmt.Sprintf("error reading chunk: %v", err), + Status: fmt.Sprintf("uploading %s", layer.Digest), Digest: layer.Digest, Total: layer.Size, - Completed: int(offset), + Completed: int(offset) + int(chunked), }) - - return } - - chunked += n - fn(api.ProgressResponse{ - Status: fmt.Sprintf("uploading %s", layer.Digest), - Digest: layer.Digest, - Total: layer.Size, - Completed: int(offset) + int(chunked), - }) } }() @@ -113,6 +121,8 @@ func uploadBlobChunked(ctx context.Context, requestURL *url.URL, layer *Layer, r switch { case resp.StatusCode == http.StatusUnauthorized: + ch <- errors.New("unauthorized") + auth := resp.Header.Get("www-authenticate") authRedir := ParseAuthRedirectString(auth) token, err := getAuthToken(ctx, authRedir, regOpts) @@ -121,10 +131,7 @@ func uploadBlobChunked(ctx context.Context, requestURL *url.URL, layer *Layer, r } regOpts.Token = token - if _, err := sectionReader.Seek(0, io.SeekStart); err != nil { - return err - } - + sectionReader = io.NewSectionReader(f, int64(offset), chunk) continue case resp.StatusCode >= http.StatusBadRequest: body, _ := io.ReadAll(resp.Body)