use int64 consistently

This commit is contained in:
Michael Yang 2023-09-28 10:00:34 -07:00
parent 5f4008c296
commit f40b3de758
7 changed files with 59 additions and 59 deletions

View file

@ -88,8 +88,8 @@ type PullRequest struct {
type ProgressResponse struct { type ProgressResponse struct {
Status string `json:"status"` Status string `json:"status"`
Digest string `json:"digest,omitempty"` Digest string `json:"digest,omitempty"`
Total int `json:"total,omitempty"` Total int64 `json:"total,omitempty"`
Completed int `json:"completed,omitempty"` Completed int64 `json:"completed,omitempty"`
} }
type PushRequest struct { type PushRequest struct {
@ -106,7 +106,7 @@ type ListResponse struct {
type ModelResponse struct { type ModelResponse struct {
Name string `json:"name"` Name string `json:"name"`
ModifiedAt time.Time `json:"modified_at"` ModifiedAt time.Time `json:"modified_at"`
Size int `json:"size"` Size int64 `json:"size"`
Digest string `json:"digest"` Digest string `json:"digest"`
} }

View file

@ -78,18 +78,18 @@ func CreateHandler(cmd *cobra.Command, args []string) error {
currentDigest = resp.Digest currentDigest = resp.Digest
switch { switch {
case strings.Contains(resp.Status, "embeddings"): case strings.Contains(resp.Status, "embeddings"):
bar = progressbar.Default(int64(resp.Total), resp.Status) bar = progressbar.Default(resp.Total, resp.Status)
bar.Set(resp.Completed) bar.Set64(resp.Completed)
default: default:
// pulling // pulling
bar = progressbar.DefaultBytes( bar = progressbar.DefaultBytes(
int64(resp.Total), resp.Total,
resp.Status, resp.Status,
) )
bar.Set(resp.Completed) bar.Set64(resp.Completed)
} }
} else if resp.Digest == currentDigest && resp.Digest != "" { } else if resp.Digest == currentDigest && resp.Digest != "" {
bar.Set(resp.Completed) bar.Set64(resp.Completed)
} else { } else {
currentDigest = "" currentDigest = ""
if spinner != nil { if spinner != nil {
@ -160,13 +160,13 @@ func PushHandler(cmd *cobra.Command, args []string) error {
if resp.Digest != currentDigest && resp.Digest != "" { if resp.Digest != currentDigest && resp.Digest != "" {
currentDigest = resp.Digest currentDigest = resp.Digest
bar = progressbar.DefaultBytes( bar = progressbar.DefaultBytes(
int64(resp.Total), resp.Total,
fmt.Sprintf("pushing %s...", resp.Digest[7:19]), fmt.Sprintf("pushing %s...", resp.Digest[7:19]),
) )
bar.Set(resp.Completed) bar.Set64(resp.Completed)
} else if resp.Digest == currentDigest && resp.Digest != "" { } else if resp.Digest == currentDigest && resp.Digest != "" {
bar.Set(resp.Completed) bar.Set64(resp.Completed)
} else { } else {
currentDigest = "" currentDigest = ""
fmt.Println(resp.Status) fmt.Println(resp.Status)
@ -349,13 +349,13 @@ func pull(model string, insecure bool) error {
if resp.Digest != currentDigest && resp.Digest != "" { if resp.Digest != currentDigest && resp.Digest != "" {
currentDigest = resp.Digest currentDigest = resp.Digest
bar = progressbar.DefaultBytes( bar = progressbar.DefaultBytes(
int64(resp.Total), resp.Total,
fmt.Sprintf("pulling %s...", resp.Digest[7:19]), fmt.Sprintf("pulling %s...", resp.Digest[7:19]),
) )
bar.Set(resp.Completed) bar.Set64(resp.Completed)
} else if resp.Digest == currentDigest && resp.Digest != "" { } else if resp.Digest == currentDigest && resp.Digest != "" {
bar.Set(resp.Completed) bar.Set64(resp.Completed)
} else { } else {
currentDigest = "" currentDigest = ""
fmt.Println(resp.Status) fmt.Println(resp.Status)

View file

@ -187,7 +187,7 @@ type llama struct {
var errNoGPU = errors.New("nvidia-smi command failed") var errNoGPU = errors.New("nvidia-smi command failed")
// CheckVRAM returns the available VRAM in MiB on Linux machines with NVIDIA GPUs // CheckVRAM returns the available VRAM in MiB on Linux machines with NVIDIA GPUs
func CheckVRAM() (int, error) { func CheckVRAM() (int64, error) {
cmd := exec.Command("nvidia-smi", "--query-gpu=memory.total", "--format=csv,noheader,nounits") cmd := exec.Command("nvidia-smi", "--query-gpu=memory.total", "--format=csv,noheader,nounits")
var stdout bytes.Buffer var stdout bytes.Buffer
cmd.Stdout = &stdout cmd.Stdout = &stdout
@ -196,11 +196,11 @@ func CheckVRAM() (int, error) {
return 0, errNoGPU return 0, errNoGPU
} }
var total int var total int64
scanner := bufio.NewScanner(&stdout) scanner := bufio.NewScanner(&stdout)
for scanner.Scan() { for scanner.Scan() {
line := scanner.Text() line := scanner.Text()
vram, err := strconv.Atoi(line) vram, err := strconv.ParseInt(strings.TrimSpace(line), 10, 64)
if err != nil { if err != nil {
return 0, fmt.Errorf("failed to parse available VRAM: %v", err) return 0, fmt.Errorf("failed to parse available VRAM: %v", err)
} }

View file

@ -46,8 +46,8 @@ func downloadBlob(ctx context.Context, opts downloadOpts) error {
// we already have the file, so return // we already have the file, so return
opts.fn(api.ProgressResponse{ opts.fn(api.ProgressResponse{
Digest: opts.digest, Digest: opts.digest,
Total: int(fi.Size()), Total: fi.Size(),
Completed: int(fi.Size()), Completed: fi.Size(),
}) })
return nil return nil
@ -93,8 +93,8 @@ func monitorDownload(ctx context.Context, opts downloadOpts, f *FileDownload) er
// successful download while monitoring // successful download while monitoring
opts.fn(api.ProgressResponse{ opts.fn(api.ProgressResponse{
Digest: f.Digest, Digest: f.Digest,
Total: int(fi.Size()), Total: fi.Size(),
Completed: int(fi.Size()), Completed: fi.Size(),
}) })
return true, false, nil return true, false, nil
} }
@ -109,8 +109,8 @@ func monitorDownload(ctx context.Context, opts downloadOpts, f *FileDownload) er
opts.fn(api.ProgressResponse{ opts.fn(api.ProgressResponse{
Status: fmt.Sprintf("downloading %s", f.Digest), Status: fmt.Sprintf("downloading %s", f.Digest),
Digest: f.Digest, Digest: f.Digest,
Total: int(f.Total), Total: f.Total,
Completed: int(f.Completed), Completed: f.Completed,
}) })
return false, false, nil return false, false, nil
}() }()
@ -129,7 +129,7 @@ func monitorDownload(ctx context.Context, opts downloadOpts, f *FileDownload) er
} }
var ( var (
chunkSize = 1024 * 1024 // 1 MiB in bytes chunkSize int64 = 1024 * 1024 // 1 MiB in bytes
errDownload = fmt.Errorf("download failed") errDownload = fmt.Errorf("download failed")
) )
@ -147,7 +147,7 @@ func doDownload(ctx context.Context, opts downloadOpts, f *FileDownload) error {
default: default:
size = fi.Size() size = fi.Size()
// Ensure the size is divisible by the chunk size by removing excess bytes // Ensure the size is divisible by the chunk size by removing excess bytes
size -= size % int64(chunkSize) size -= size % chunkSize
err := os.Truncate(f.FilePath+"-partial", size) err := os.Truncate(f.FilePath+"-partial", size)
if err != nil { if err != nil {
@ -200,8 +200,8 @@ outerLoop:
opts.fn(api.ProgressResponse{ opts.fn(api.ProgressResponse{
Status: fmt.Sprintf("downloading %s", f.Digest), Status: fmt.Sprintf("downloading %s", f.Digest),
Digest: f.Digest, Digest: f.Digest,
Total: int(f.Total), Total: f.Total,
Completed: int(f.Completed), Completed: f.Completed,
}) })
if f.Completed >= f.Total { if f.Completed >= f.Total {
@ -213,8 +213,8 @@ outerLoop:
opts.fn(api.ProgressResponse{ opts.fn(api.ProgressResponse{
Status: fmt.Sprintf("error renaming file: %v", err), Status: fmt.Sprintf("error renaming file: %v", err),
Digest: f.Digest, Digest: f.Digest,
Total: int(f.Total), Total: f.Total,
Completed: int(f.Completed), Completed: f.Completed,
}) })
return err return err
} }
@ -223,7 +223,7 @@ outerLoop:
} }
} }
n, err := io.CopyN(out, resp.Body, int64(chunkSize)) n, err := io.CopyN(out, resp.Body, chunkSize)
if err != nil && !errors.Is(err, io.EOF) { if err != nil && !errors.Is(err, io.EOF) {
return fmt.Errorf("%w: %w", errDownload, err) return fmt.Errorf("%w: %w", errDownload, err)
} }

View file

@ -103,7 +103,7 @@ type ManifestV2 struct {
type Layer struct { type Layer struct {
MediaType string `json:"mediaType"` MediaType string `json:"mediaType"`
Digest string `json:"digest"` Digest string `json:"digest"`
Size int `json:"size"` Size int64 `json:"size"`
From string `json:"from,omitempty"` From string `json:"from,omitempty"`
} }
@ -129,11 +129,11 @@ type RootFS struct {
DiffIDs []string `json:"diff_ids"` DiffIDs []string `json:"diff_ids"`
} }
func (m *ManifestV2) GetTotalSize() int { func (m *ManifestV2) GetTotalSize() (total int64) {
var total int
for _, layer := range m.Layers { for _, layer := range m.Layers {
total += layer.Size total += layer.Size
} }
total += m.Config.Size total += m.Config.Size
return total return total
} }
@ -649,8 +649,8 @@ func embeddingLayers(workDir string, e EmbeddingParams) ([]*LayerReader, error)
e.fn(api.ProgressResponse{ e.fn(api.ProgressResponse{
Status: fmt.Sprintf("creating embeddings for file %s", filePath), Status: fmt.Sprintf("creating embeddings for file %s", filePath),
Digest: fileDigest, Digest: fileDigest,
Total: len(data) - 1, Total: int64(len(data) - 1),
Completed: i, Completed: int64(i),
}) })
if len(existing[d]) > 0 { if len(existing[d]) > 0 {
// already have an embedding for this line // already have an embedding for this line
@ -675,7 +675,7 @@ func embeddingLayers(workDir string, e EmbeddingParams) ([]*LayerReader, error)
Layer: Layer{ Layer: Layer{
MediaType: "application/vnd.ollama.image.embed", MediaType: "application/vnd.ollama.image.embed",
Digest: digest, Digest: digest,
Size: r.Len(), Size: r.Size(),
}, },
Reader: r, Reader: r,
} }
@ -1356,14 +1356,14 @@ func createConfigLayer(config ConfigV2, layers []string) (*LayerReader, error) {
} }
// GetSHA256Digest returns the SHA256 hash of a given buffer and returns it, and the size of buffer // GetSHA256Digest returns the SHA256 hash of a given buffer and returns it, and the size of buffer
func GetSHA256Digest(r io.Reader) (string, int) { func GetSHA256Digest(r io.Reader) (string, int64) {
h := sha256.New() h := sha256.New()
n, err := io.Copy(h, r) n, err := io.Copy(h, r)
if err != nil { if err != nil {
log.Fatal(err) log.Fatal(err)
} }
return fmt.Sprintf("sha256:%x", h.Sum(nil)), int(n) return fmt.Sprintf("sha256:%x", h.Sum(nil)), n
} }
// Function to check if a blob already exists in the Docker registry // Function to check if a blob already exists in the Docker registry

View file

@ -15,8 +15,8 @@ import (
) )
const ( const (
redirectChunkSize = 1024 * 1024 * 1024 redirectChunkSize int64 = 1024 * 1024 * 1024
regularChunkSize = 95 * 1024 * 1024 regularChunkSize int64 = 95 * 1024 * 1024
) )
func startUpload(ctx context.Context, mp ModelPath, layer *Layer, regOpts *RegistryOptions) (*url.URL, int64, error) { func startUpload(ctx context.Context, mp ModelPath, layer *Layer, regOpts *RegistryOptions) (*url.URL, int64, error) {
@ -48,7 +48,7 @@ func startUpload(ctx context.Context, mp ModelPath, layer *Layer, regOpts *Regis
return nil, 0, err return nil, 0, err
} }
return locationURL, int64(chunkSize), nil return locationURL, chunkSize, nil
} }
func uploadBlob(ctx context.Context, requestURL *url.URL, layer *Layer, chunkSize int64, regOpts *RegistryOptions, fn func(api.ProgressResponse)) error { func uploadBlob(ctx context.Context, requestURL *url.URL, layer *Layer, chunkSize int64, regOpts *RegistryOptions, fn func(api.ProgressResponse)) error {
@ -73,10 +73,10 @@ func uploadBlob(ctx context.Context, requestURL *url.URL, layer *Layer, chunkSiz
fn: fn, fn: fn,
} }
for offset := int64(0); offset < int64(layer.Size); { for offset := int64(0); offset < layer.Size; {
chunk := int64(layer.Size) - offset chunk := layer.Size - offset
if chunk > int64(chunkSize) { if chunk > chunkSize {
chunk = int64(chunkSize) chunk = chunkSize
} }
resp, err := uploadBlobChunk(ctx, http.MethodPatch, requestURL, f, offset, chunk, regOpts, &pw) resp, err := uploadBlobChunk(ctx, http.MethodPatch, requestURL, f, offset, chunk, regOpts, &pw)
@ -85,7 +85,7 @@ func uploadBlob(ctx context.Context, requestURL *url.URL, layer *Layer, chunkSiz
Status: fmt.Sprintf("error uploading chunk: %v", err), Status: fmt.Sprintf("error uploading chunk: %v", err),
Digest: layer.Digest, Digest: layer.Digest,
Total: layer.Size, Total: layer.Size,
Completed: int(offset), Completed: offset,
}) })
return err return err
@ -127,7 +127,7 @@ func uploadBlob(ctx context.Context, requestURL *url.URL, layer *Layer, chunkSiz
} }
func uploadBlobChunk(ctx context.Context, method string, requestURL *url.URL, r io.ReaderAt, offset, limit int64, opts *RegistryOptions, pw *ProgressWriter) (*http.Response, error) { func uploadBlobChunk(ctx context.Context, method string, requestURL *url.URL, r io.ReaderAt, offset, limit int64, opts *RegistryOptions, pw *ProgressWriter) (*http.Response, error) {
sectionReader := io.NewSectionReader(r, int64(offset), limit) sectionReader := io.NewSectionReader(r, offset, limit)
headers := make(http.Header) headers := make(http.Header)
headers.Set("Content-Type", "application/octet-stream") headers.Set("Content-Type", "application/octet-stream")
@ -152,7 +152,7 @@ func uploadBlobChunk(ctx context.Context, method string, requestURL *url.URL, r
return nil, err return nil, err
} }
pw.completed = int(offset) pw.completed = offset
if _, err := uploadBlobChunk(ctx, http.MethodPut, location, r, offset, limit, nil, pw); err != nil { if _, err := uploadBlobChunk(ctx, http.MethodPut, location, r, offset, limit, nil, pw); err != nil {
// retry // retry
log.Printf("retrying redirected upload: %v", err) log.Printf("retrying redirected upload: %v", err)
@ -170,7 +170,7 @@ func uploadBlobChunk(ctx context.Context, method string, requestURL *url.URL, r
opts.Token = token opts.Token = token
pw.completed = int(offset) pw.completed = offset
sectionReader = io.NewSectionReader(r, offset, limit) sectionReader = io.NewSectionReader(r, offset, limit)
continue continue
case resp.StatusCode >= http.StatusBadRequest: case resp.StatusCode >= http.StatusBadRequest:
@ -187,19 +187,19 @@ func uploadBlobChunk(ctx context.Context, method string, requestURL *url.URL, r
type ProgressWriter struct { type ProgressWriter struct {
status string status string
digest string digest string
bucket int bucket int64
completed int completed int64
total int total int64
fn func(api.ProgressResponse) fn func(api.ProgressResponse)
} }
func (pw *ProgressWriter) Write(b []byte) (int, error) { func (pw *ProgressWriter) Write(b []byte) (int, error) {
n := len(b) n := len(b)
pw.bucket += n pw.bucket += int64(n)
pw.completed += n
// throttle status updates to not spam the client // throttle status updates to not spam the client
if pw.bucket >= 1024*1024 || pw.completed >= pw.total { if pw.bucket >= 1024*1024 || pw.completed+pw.bucket >= pw.total {
pw.completed += pw.bucket
pw.fn(api.ProgressResponse{ pw.fn(api.ProgressResponse{
Status: pw.status, Status: pw.status,
Digest: pw.digest, Digest: pw.digest,