diff --git a/api/types.go b/api/types.go index ff643026..e33002d2 100644 --- a/api/types.go +++ b/api/types.go @@ -88,8 +88,8 @@ type PullRequest struct { type ProgressResponse struct { Status string `json:"status"` Digest string `json:"digest,omitempty"` - Total int `json:"total,omitempty"` - Completed int `json:"completed,omitempty"` + Total int64 `json:"total,omitempty"` + Completed int64 `json:"completed,omitempty"` } type PushRequest struct { @@ -106,7 +106,7 @@ type ListResponse struct { type ModelResponse struct { Name string `json:"name"` ModifiedAt time.Time `json:"modified_at"` - Size int `json:"size"` + Size int64 `json:"size"` Digest string `json:"digest"` } diff --git a/cmd/cmd.go b/cmd/cmd.go index ee5b753e..d5d1a06a 100644 --- a/cmd/cmd.go +++ b/cmd/cmd.go @@ -78,18 +78,18 @@ func CreateHandler(cmd *cobra.Command, args []string) error { currentDigest = resp.Digest switch { case strings.Contains(resp.Status, "embeddings"): - bar = progressbar.Default(int64(resp.Total), resp.Status) - bar.Set(resp.Completed) + bar = progressbar.Default(resp.Total, resp.Status) + bar.Set64(resp.Completed) default: // pulling bar = progressbar.DefaultBytes( - int64(resp.Total), + resp.Total, resp.Status, ) - bar.Set(resp.Completed) + bar.Set64(resp.Completed) } } else if resp.Digest == currentDigest && resp.Digest != "" { - bar.Set(resp.Completed) + bar.Set64(resp.Completed) } else { currentDigest = "" if spinner != nil { @@ -160,13 +160,13 @@ func PushHandler(cmd *cobra.Command, args []string) error { if resp.Digest != currentDigest && resp.Digest != "" { currentDigest = resp.Digest bar = progressbar.DefaultBytes( - int64(resp.Total), + resp.Total, fmt.Sprintf("pushing %s...", resp.Digest[7:19]), ) - bar.Set(resp.Completed) + bar.Set64(resp.Completed) } else if resp.Digest == currentDigest && resp.Digest != "" { - bar.Set(resp.Completed) + bar.Set64(resp.Completed) } else { currentDigest = "" fmt.Println(resp.Status) @@ -349,13 +349,13 @@ func pull(model string, insecure bool) error { if resp.Digest != currentDigest && resp.Digest != "" { currentDigest = resp.Digest bar = progressbar.DefaultBytes( - int64(resp.Total), + resp.Total, fmt.Sprintf("pulling %s...", resp.Digest[7:19]), ) - bar.Set(resp.Completed) + bar.Set64(resp.Completed) } else if resp.Digest == currentDigest && resp.Digest != "" { - bar.Set(resp.Completed) + bar.Set64(resp.Completed) } else { currentDigest = "" fmt.Println(resp.Status) diff --git a/llm/llama.go b/llm/llama.go index 6e748302..3d1f2bff 100644 --- a/llm/llama.go +++ b/llm/llama.go @@ -187,7 +187,7 @@ type llama struct { var errNoGPU = errors.New("nvidia-smi command failed") // CheckVRAM returns the available VRAM in MiB on Linux machines with NVIDIA GPUs -func CheckVRAM() (int, error) { +func CheckVRAM() (int64, error) { cmd := exec.Command("nvidia-smi", "--query-gpu=memory.total", "--format=csv,noheader,nounits") var stdout bytes.Buffer cmd.Stdout = &stdout @@ -196,11 +196,11 @@ func CheckVRAM() (int, error) { return 0, errNoGPU } - var total int + var total int64 scanner := bufio.NewScanner(&stdout) for scanner.Scan() { line := scanner.Text() - vram, err := strconv.Atoi(line) + vram, err := strconv.ParseInt(strings.TrimSpace(line), 10, 64) if err != nil { return 0, fmt.Errorf("failed to parse available VRAM: %v", err) } diff --git a/server/download.go b/server/download.go index 90b1d14d..cde9214f 100644 --- a/server/download.go +++ b/server/download.go @@ -46,8 +46,8 @@ func downloadBlob(ctx context.Context, opts downloadOpts) error { // we already have the file, so return opts.fn(api.ProgressResponse{ Digest: opts.digest, - Total: int(fi.Size()), - Completed: int(fi.Size()), + Total: fi.Size(), + Completed: fi.Size(), }) return nil @@ -93,8 +93,8 @@ func monitorDownload(ctx context.Context, opts downloadOpts, f *FileDownload) er // successful download while monitoring opts.fn(api.ProgressResponse{ Digest: f.Digest, - Total: int(fi.Size()), - Completed: int(fi.Size()), + Total: fi.Size(), + Completed: fi.Size(), }) return true, false, nil } @@ -109,8 +109,8 @@ func monitorDownload(ctx context.Context, opts downloadOpts, f *FileDownload) er opts.fn(api.ProgressResponse{ Status: fmt.Sprintf("downloading %s", f.Digest), Digest: f.Digest, - Total: int(f.Total), - Completed: int(f.Completed), + Total: f.Total, + Completed: f.Completed, }) return false, false, nil }() @@ -129,8 +129,8 @@ func monitorDownload(ctx context.Context, opts downloadOpts, f *FileDownload) er } var ( - chunkSize = 1024 * 1024 // 1 MiB in bytes - errDownload = fmt.Errorf("download failed") + chunkSize int64 = 1024 * 1024 // 1 MiB in bytes + errDownload = fmt.Errorf("download failed") ) // doDownload downloads a blob from the registry and stores it in the blobs directory @@ -147,7 +147,7 @@ func doDownload(ctx context.Context, opts downloadOpts, f *FileDownload) error { default: size = fi.Size() // Ensure the size is divisible by the chunk size by removing excess bytes - size -= size % int64(chunkSize) + size -= size % chunkSize err := os.Truncate(f.FilePath+"-partial", size) if err != nil { @@ -200,8 +200,8 @@ outerLoop: opts.fn(api.ProgressResponse{ Status: fmt.Sprintf("downloading %s", f.Digest), Digest: f.Digest, - Total: int(f.Total), - Completed: int(f.Completed), + Total: f.Total, + Completed: f.Completed, }) if f.Completed >= f.Total { @@ -213,8 +213,8 @@ outerLoop: opts.fn(api.ProgressResponse{ Status: fmt.Sprintf("error renaming file: %v", err), Digest: f.Digest, - Total: int(f.Total), - Completed: int(f.Completed), + Total: f.Total, + Completed: f.Completed, }) return err } @@ -223,7 +223,7 @@ outerLoop: } } - n, err := io.CopyN(out, resp.Body, int64(chunkSize)) + n, err := io.CopyN(out, resp.Body, chunkSize) if err != nil && !errors.Is(err, io.EOF) { return fmt.Errorf("%w: %w", errDownload, err) } diff --git a/server/images.go b/server/images.go index 57e78631..a6b0a5ea 100644 --- a/server/images.go +++ b/server/images.go @@ -103,7 +103,7 @@ type ManifestV2 struct { type Layer struct { MediaType string `json:"mediaType"` Digest string `json:"digest"` - Size int `json:"size"` + Size int64 `json:"size"` From string `json:"from,omitempty"` } @@ -129,11 +129,11 @@ type RootFS struct { DiffIDs []string `json:"diff_ids"` } -func (m *ManifestV2) GetTotalSize() int { - var total int +func (m *ManifestV2) GetTotalSize() (total int64) { for _, layer := range m.Layers { total += layer.Size } + total += m.Config.Size return total } @@ -649,8 +649,8 @@ func embeddingLayers(workDir string, e EmbeddingParams) ([]*LayerReader, error) e.fn(api.ProgressResponse{ Status: fmt.Sprintf("creating embeddings for file %s", filePath), Digest: fileDigest, - Total: len(data) - 1, - Completed: i, + Total: int64(len(data) - 1), + Completed: int64(i), }) if len(existing[d]) > 0 { // already have an embedding for this line @@ -675,7 +675,7 @@ func embeddingLayers(workDir string, e EmbeddingParams) ([]*LayerReader, error) Layer: Layer{ MediaType: "application/vnd.ollama.image.embed", Digest: digest, - Size: r.Len(), + Size: r.Size(), }, Reader: r, } @@ -1356,14 +1356,14 @@ func createConfigLayer(config ConfigV2, layers []string) (*LayerReader, error) { } // GetSHA256Digest returns the SHA256 hash of a given buffer and returns it, and the size of buffer -func GetSHA256Digest(r io.Reader) (string, int) { +func GetSHA256Digest(r io.Reader) (string, int64) { h := sha256.New() n, err := io.Copy(h, r) if err != nil { log.Fatal(err) } - return fmt.Sprintf("sha256:%x", h.Sum(nil)), int(n) + return fmt.Sprintf("sha256:%x", h.Sum(nil)), n } // Function to check if a blob already exists in the Docker registry diff --git a/server/modelpath_test.go b/server/modelpath_test.go index c52c689c..8b26d52c 100644 --- a/server/modelpath_test.go +++ b/server/modelpath_test.go @@ -4,9 +4,9 @@ import "testing" func TestParseModelPath(t *testing.T) { tests := []struct { - name string - arg string - want ModelPath + name string + arg string + want ModelPath }{ { "full path https", diff --git a/server/upload.go b/server/upload.go index 618195f7..f5e100e0 100644 --- a/server/upload.go +++ b/server/upload.go @@ -15,8 +15,8 @@ import ( ) const ( - redirectChunkSize = 1024 * 1024 * 1024 - regularChunkSize = 95 * 1024 * 1024 + redirectChunkSize int64 = 1024 * 1024 * 1024 + regularChunkSize int64 = 95 * 1024 * 1024 ) func startUpload(ctx context.Context, mp ModelPath, layer *Layer, regOpts *RegistryOptions) (*url.URL, int64, error) { @@ -48,7 +48,7 @@ func startUpload(ctx context.Context, mp ModelPath, layer *Layer, regOpts *Regis return nil, 0, err } - return locationURL, int64(chunkSize), nil + return locationURL, chunkSize, nil } func uploadBlob(ctx context.Context, requestURL *url.URL, layer *Layer, chunkSize int64, regOpts *RegistryOptions, fn func(api.ProgressResponse)) error { @@ -73,10 +73,10 @@ func uploadBlob(ctx context.Context, requestURL *url.URL, layer *Layer, chunkSiz fn: fn, } - for offset := int64(0); offset < int64(layer.Size); { - chunk := int64(layer.Size) - offset - if chunk > int64(chunkSize) { - chunk = int64(chunkSize) + for offset := int64(0); offset < layer.Size; { + chunk := layer.Size - offset + if chunk > chunkSize { + chunk = chunkSize } resp, err := uploadBlobChunk(ctx, http.MethodPatch, requestURL, f, offset, chunk, regOpts, &pw) @@ -85,7 +85,7 @@ func uploadBlob(ctx context.Context, requestURL *url.URL, layer *Layer, chunkSiz Status: fmt.Sprintf("error uploading chunk: %v", err), Digest: layer.Digest, Total: layer.Size, - Completed: int(offset), + Completed: offset, }) return err @@ -127,7 +127,7 @@ func uploadBlob(ctx context.Context, requestURL *url.URL, layer *Layer, chunkSiz } func uploadBlobChunk(ctx context.Context, method string, requestURL *url.URL, r io.ReaderAt, offset, limit int64, opts *RegistryOptions, pw *ProgressWriter) (*http.Response, error) { - sectionReader := io.NewSectionReader(r, int64(offset), limit) + sectionReader := io.NewSectionReader(r, offset, limit) headers := make(http.Header) headers.Set("Content-Type", "application/octet-stream") @@ -152,7 +152,7 @@ func uploadBlobChunk(ctx context.Context, method string, requestURL *url.URL, r return nil, err } - pw.completed = int(offset) + pw.completed = offset if _, err := uploadBlobChunk(ctx, http.MethodPut, location, r, offset, limit, nil, pw); err != nil { // retry log.Printf("retrying redirected upload: %v", err) @@ -170,7 +170,7 @@ func uploadBlobChunk(ctx context.Context, method string, requestURL *url.URL, r opts.Token = token - pw.completed = int(offset) + pw.completed = offset sectionReader = io.NewSectionReader(r, offset, limit) continue case resp.StatusCode >= http.StatusBadRequest: @@ -187,19 +187,19 @@ func uploadBlobChunk(ctx context.Context, method string, requestURL *url.URL, r type ProgressWriter struct { status string digest string - bucket int - completed int - total int + bucket int64 + completed int64 + total int64 fn func(api.ProgressResponse) } func (pw *ProgressWriter) Write(b []byte) (int, error) { n := len(b) - pw.bucket += n - pw.completed += n + pw.bucket += int64(n) // throttle status updates to not spam the client - if pw.bucket >= 1024*1024 || pw.completed >= pw.total { + if pw.bucket >= 1024*1024 || pw.completed+pw.bucket >= pw.total { + pw.completed += pw.bucket pw.fn(api.ProgressResponse{ Status: pw.status, Digest: pw.digest,