Merge pull request #634 from jmorganca/mxyng/int64
use int64 consistently
This commit is contained in:
commit
c951da7096
7 changed files with 59 additions and 59 deletions
|
@ -88,8 +88,8 @@ type PullRequest struct {
|
||||||
type ProgressResponse struct {
|
type ProgressResponse struct {
|
||||||
Status string `json:"status"`
|
Status string `json:"status"`
|
||||||
Digest string `json:"digest,omitempty"`
|
Digest string `json:"digest,omitempty"`
|
||||||
Total int `json:"total,omitempty"`
|
Total int64 `json:"total,omitempty"`
|
||||||
Completed int `json:"completed,omitempty"`
|
Completed int64 `json:"completed,omitempty"`
|
||||||
}
|
}
|
||||||
|
|
||||||
type PushRequest struct {
|
type PushRequest struct {
|
||||||
|
@ -106,7 +106,7 @@ type ListResponse struct {
|
||||||
type ModelResponse struct {
|
type ModelResponse struct {
|
||||||
Name string `json:"name"`
|
Name string `json:"name"`
|
||||||
ModifiedAt time.Time `json:"modified_at"`
|
ModifiedAt time.Time `json:"modified_at"`
|
||||||
Size int `json:"size"`
|
Size int64 `json:"size"`
|
||||||
Digest string `json:"digest"`
|
Digest string `json:"digest"`
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
22
cmd/cmd.go
22
cmd/cmd.go
|
@ -78,18 +78,18 @@ func CreateHandler(cmd *cobra.Command, args []string) error {
|
||||||
currentDigest = resp.Digest
|
currentDigest = resp.Digest
|
||||||
switch {
|
switch {
|
||||||
case strings.Contains(resp.Status, "embeddings"):
|
case strings.Contains(resp.Status, "embeddings"):
|
||||||
bar = progressbar.Default(int64(resp.Total), resp.Status)
|
bar = progressbar.Default(resp.Total, resp.Status)
|
||||||
bar.Set(resp.Completed)
|
bar.Set64(resp.Completed)
|
||||||
default:
|
default:
|
||||||
// pulling
|
// pulling
|
||||||
bar = progressbar.DefaultBytes(
|
bar = progressbar.DefaultBytes(
|
||||||
int64(resp.Total),
|
resp.Total,
|
||||||
resp.Status,
|
resp.Status,
|
||||||
)
|
)
|
||||||
bar.Set(resp.Completed)
|
bar.Set64(resp.Completed)
|
||||||
}
|
}
|
||||||
} else if resp.Digest == currentDigest && resp.Digest != "" {
|
} else if resp.Digest == currentDigest && resp.Digest != "" {
|
||||||
bar.Set(resp.Completed)
|
bar.Set64(resp.Completed)
|
||||||
} else {
|
} else {
|
||||||
currentDigest = ""
|
currentDigest = ""
|
||||||
if spinner != nil {
|
if spinner != nil {
|
||||||
|
@ -160,13 +160,13 @@ func PushHandler(cmd *cobra.Command, args []string) error {
|
||||||
if resp.Digest != currentDigest && resp.Digest != "" {
|
if resp.Digest != currentDigest && resp.Digest != "" {
|
||||||
currentDigest = resp.Digest
|
currentDigest = resp.Digest
|
||||||
bar = progressbar.DefaultBytes(
|
bar = progressbar.DefaultBytes(
|
||||||
int64(resp.Total),
|
resp.Total,
|
||||||
fmt.Sprintf("pushing %s...", resp.Digest[7:19]),
|
fmt.Sprintf("pushing %s...", resp.Digest[7:19]),
|
||||||
)
|
)
|
||||||
|
|
||||||
bar.Set(resp.Completed)
|
bar.Set64(resp.Completed)
|
||||||
} else if resp.Digest == currentDigest && resp.Digest != "" {
|
} else if resp.Digest == currentDigest && resp.Digest != "" {
|
||||||
bar.Set(resp.Completed)
|
bar.Set64(resp.Completed)
|
||||||
} else {
|
} else {
|
||||||
currentDigest = ""
|
currentDigest = ""
|
||||||
fmt.Println(resp.Status)
|
fmt.Println(resp.Status)
|
||||||
|
@ -349,13 +349,13 @@ func pull(model string, insecure bool) error {
|
||||||
if resp.Digest != currentDigest && resp.Digest != "" {
|
if resp.Digest != currentDigest && resp.Digest != "" {
|
||||||
currentDigest = resp.Digest
|
currentDigest = resp.Digest
|
||||||
bar = progressbar.DefaultBytes(
|
bar = progressbar.DefaultBytes(
|
||||||
int64(resp.Total),
|
resp.Total,
|
||||||
fmt.Sprintf("pulling %s...", resp.Digest[7:19]),
|
fmt.Sprintf("pulling %s...", resp.Digest[7:19]),
|
||||||
)
|
)
|
||||||
|
|
||||||
bar.Set(resp.Completed)
|
bar.Set64(resp.Completed)
|
||||||
} else if resp.Digest == currentDigest && resp.Digest != "" {
|
} else if resp.Digest == currentDigest && resp.Digest != "" {
|
||||||
bar.Set(resp.Completed)
|
bar.Set64(resp.Completed)
|
||||||
} else {
|
} else {
|
||||||
currentDigest = ""
|
currentDigest = ""
|
||||||
fmt.Println(resp.Status)
|
fmt.Println(resp.Status)
|
||||||
|
|
|
@ -187,7 +187,7 @@ type llama struct {
|
||||||
var errNoGPU = errors.New("nvidia-smi command failed")
|
var errNoGPU = errors.New("nvidia-smi command failed")
|
||||||
|
|
||||||
// CheckVRAM returns the available VRAM in MiB on Linux machines with NVIDIA GPUs
|
// CheckVRAM returns the available VRAM in MiB on Linux machines with NVIDIA GPUs
|
||||||
func CheckVRAM() (int, error) {
|
func CheckVRAM() (int64, error) {
|
||||||
cmd := exec.Command("nvidia-smi", "--query-gpu=memory.total", "--format=csv,noheader,nounits")
|
cmd := exec.Command("nvidia-smi", "--query-gpu=memory.total", "--format=csv,noheader,nounits")
|
||||||
var stdout bytes.Buffer
|
var stdout bytes.Buffer
|
||||||
cmd.Stdout = &stdout
|
cmd.Stdout = &stdout
|
||||||
|
@ -196,11 +196,11 @@ func CheckVRAM() (int, error) {
|
||||||
return 0, errNoGPU
|
return 0, errNoGPU
|
||||||
}
|
}
|
||||||
|
|
||||||
var total int
|
var total int64
|
||||||
scanner := bufio.NewScanner(&stdout)
|
scanner := bufio.NewScanner(&stdout)
|
||||||
for scanner.Scan() {
|
for scanner.Scan() {
|
||||||
line := scanner.Text()
|
line := scanner.Text()
|
||||||
vram, err := strconv.Atoi(line)
|
vram, err := strconv.ParseInt(strings.TrimSpace(line), 10, 64)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return 0, fmt.Errorf("failed to parse available VRAM: %v", err)
|
return 0, fmt.Errorf("failed to parse available VRAM: %v", err)
|
||||||
}
|
}
|
||||||
|
|
|
@ -46,8 +46,8 @@ func downloadBlob(ctx context.Context, opts downloadOpts) error {
|
||||||
// we already have the file, so return
|
// we already have the file, so return
|
||||||
opts.fn(api.ProgressResponse{
|
opts.fn(api.ProgressResponse{
|
||||||
Digest: opts.digest,
|
Digest: opts.digest,
|
||||||
Total: int(fi.Size()),
|
Total: fi.Size(),
|
||||||
Completed: int(fi.Size()),
|
Completed: fi.Size(),
|
||||||
})
|
})
|
||||||
|
|
||||||
return nil
|
return nil
|
||||||
|
@ -93,8 +93,8 @@ func monitorDownload(ctx context.Context, opts downloadOpts, f *FileDownload) er
|
||||||
// successful download while monitoring
|
// successful download while monitoring
|
||||||
opts.fn(api.ProgressResponse{
|
opts.fn(api.ProgressResponse{
|
||||||
Digest: f.Digest,
|
Digest: f.Digest,
|
||||||
Total: int(fi.Size()),
|
Total: fi.Size(),
|
||||||
Completed: int(fi.Size()),
|
Completed: fi.Size(),
|
||||||
})
|
})
|
||||||
return true, false, nil
|
return true, false, nil
|
||||||
}
|
}
|
||||||
|
@ -109,8 +109,8 @@ func monitorDownload(ctx context.Context, opts downloadOpts, f *FileDownload) er
|
||||||
opts.fn(api.ProgressResponse{
|
opts.fn(api.ProgressResponse{
|
||||||
Status: fmt.Sprintf("downloading %s", f.Digest),
|
Status: fmt.Sprintf("downloading %s", f.Digest),
|
||||||
Digest: f.Digest,
|
Digest: f.Digest,
|
||||||
Total: int(f.Total),
|
Total: f.Total,
|
||||||
Completed: int(f.Completed),
|
Completed: f.Completed,
|
||||||
})
|
})
|
||||||
return false, false, nil
|
return false, false, nil
|
||||||
}()
|
}()
|
||||||
|
@ -129,8 +129,8 @@ func monitorDownload(ctx context.Context, opts downloadOpts, f *FileDownload) er
|
||||||
}
|
}
|
||||||
|
|
||||||
var (
|
var (
|
||||||
chunkSize = 1024 * 1024 // 1 MiB in bytes
|
chunkSize int64 = 1024 * 1024 // 1 MiB in bytes
|
||||||
errDownload = fmt.Errorf("download failed")
|
errDownload = fmt.Errorf("download failed")
|
||||||
)
|
)
|
||||||
|
|
||||||
// doDownload downloads a blob from the registry and stores it in the blobs directory
|
// doDownload downloads a blob from the registry and stores it in the blobs directory
|
||||||
|
@ -147,7 +147,7 @@ func doDownload(ctx context.Context, opts downloadOpts, f *FileDownload) error {
|
||||||
default:
|
default:
|
||||||
size = fi.Size()
|
size = fi.Size()
|
||||||
// Ensure the size is divisible by the chunk size by removing excess bytes
|
// Ensure the size is divisible by the chunk size by removing excess bytes
|
||||||
size -= size % int64(chunkSize)
|
size -= size % chunkSize
|
||||||
|
|
||||||
err := os.Truncate(f.FilePath+"-partial", size)
|
err := os.Truncate(f.FilePath+"-partial", size)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
|
@ -200,8 +200,8 @@ outerLoop:
|
||||||
opts.fn(api.ProgressResponse{
|
opts.fn(api.ProgressResponse{
|
||||||
Status: fmt.Sprintf("downloading %s", f.Digest),
|
Status: fmt.Sprintf("downloading %s", f.Digest),
|
||||||
Digest: f.Digest,
|
Digest: f.Digest,
|
||||||
Total: int(f.Total),
|
Total: f.Total,
|
||||||
Completed: int(f.Completed),
|
Completed: f.Completed,
|
||||||
})
|
})
|
||||||
|
|
||||||
if f.Completed >= f.Total {
|
if f.Completed >= f.Total {
|
||||||
|
@ -213,8 +213,8 @@ outerLoop:
|
||||||
opts.fn(api.ProgressResponse{
|
opts.fn(api.ProgressResponse{
|
||||||
Status: fmt.Sprintf("error renaming file: %v", err),
|
Status: fmt.Sprintf("error renaming file: %v", err),
|
||||||
Digest: f.Digest,
|
Digest: f.Digest,
|
||||||
Total: int(f.Total),
|
Total: f.Total,
|
||||||
Completed: int(f.Completed),
|
Completed: f.Completed,
|
||||||
})
|
})
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
@ -223,7 +223,7 @@ outerLoop:
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
n, err := io.CopyN(out, resp.Body, int64(chunkSize))
|
n, err := io.CopyN(out, resp.Body, chunkSize)
|
||||||
if err != nil && !errors.Is(err, io.EOF) {
|
if err != nil && !errors.Is(err, io.EOF) {
|
||||||
return fmt.Errorf("%w: %w", errDownload, err)
|
return fmt.Errorf("%w: %w", errDownload, err)
|
||||||
}
|
}
|
||||||
|
|
|
@ -103,7 +103,7 @@ type ManifestV2 struct {
|
||||||
type Layer struct {
|
type Layer struct {
|
||||||
MediaType string `json:"mediaType"`
|
MediaType string `json:"mediaType"`
|
||||||
Digest string `json:"digest"`
|
Digest string `json:"digest"`
|
||||||
Size int `json:"size"`
|
Size int64 `json:"size"`
|
||||||
From string `json:"from,omitempty"`
|
From string `json:"from,omitempty"`
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -129,11 +129,11 @@ type RootFS struct {
|
||||||
DiffIDs []string `json:"diff_ids"`
|
DiffIDs []string `json:"diff_ids"`
|
||||||
}
|
}
|
||||||
|
|
||||||
func (m *ManifestV2) GetTotalSize() int {
|
func (m *ManifestV2) GetTotalSize() (total int64) {
|
||||||
var total int
|
|
||||||
for _, layer := range m.Layers {
|
for _, layer := range m.Layers {
|
||||||
total += layer.Size
|
total += layer.Size
|
||||||
}
|
}
|
||||||
|
|
||||||
total += m.Config.Size
|
total += m.Config.Size
|
||||||
return total
|
return total
|
||||||
}
|
}
|
||||||
|
@ -649,8 +649,8 @@ func embeddingLayers(workDir string, e EmbeddingParams) ([]*LayerReader, error)
|
||||||
e.fn(api.ProgressResponse{
|
e.fn(api.ProgressResponse{
|
||||||
Status: fmt.Sprintf("creating embeddings for file %s", filePath),
|
Status: fmt.Sprintf("creating embeddings for file %s", filePath),
|
||||||
Digest: fileDigest,
|
Digest: fileDigest,
|
||||||
Total: len(data) - 1,
|
Total: int64(len(data) - 1),
|
||||||
Completed: i,
|
Completed: int64(i),
|
||||||
})
|
})
|
||||||
if len(existing[d]) > 0 {
|
if len(existing[d]) > 0 {
|
||||||
// already have an embedding for this line
|
// already have an embedding for this line
|
||||||
|
@ -675,7 +675,7 @@ func embeddingLayers(workDir string, e EmbeddingParams) ([]*LayerReader, error)
|
||||||
Layer: Layer{
|
Layer: Layer{
|
||||||
MediaType: "application/vnd.ollama.image.embed",
|
MediaType: "application/vnd.ollama.image.embed",
|
||||||
Digest: digest,
|
Digest: digest,
|
||||||
Size: r.Len(),
|
Size: r.Size(),
|
||||||
},
|
},
|
||||||
Reader: r,
|
Reader: r,
|
||||||
}
|
}
|
||||||
|
@ -1356,14 +1356,14 @@ func createConfigLayer(config ConfigV2, layers []string) (*LayerReader, error) {
|
||||||
}
|
}
|
||||||
|
|
||||||
// GetSHA256Digest returns the SHA256 hash of a given buffer and returns it, and the size of buffer
|
// GetSHA256Digest returns the SHA256 hash of a given buffer and returns it, and the size of buffer
|
||||||
func GetSHA256Digest(r io.Reader) (string, int) {
|
func GetSHA256Digest(r io.Reader) (string, int64) {
|
||||||
h := sha256.New()
|
h := sha256.New()
|
||||||
n, err := io.Copy(h, r)
|
n, err := io.Copy(h, r)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.Fatal(err)
|
log.Fatal(err)
|
||||||
}
|
}
|
||||||
|
|
||||||
return fmt.Sprintf("sha256:%x", h.Sum(nil)), int(n)
|
return fmt.Sprintf("sha256:%x", h.Sum(nil)), n
|
||||||
}
|
}
|
||||||
|
|
||||||
// Function to check if a blob already exists in the Docker registry
|
// Function to check if a blob already exists in the Docker registry
|
||||||
|
|
|
@ -4,9 +4,9 @@ import "testing"
|
||||||
|
|
||||||
func TestParseModelPath(t *testing.T) {
|
func TestParseModelPath(t *testing.T) {
|
||||||
tests := []struct {
|
tests := []struct {
|
||||||
name string
|
name string
|
||||||
arg string
|
arg string
|
||||||
want ModelPath
|
want ModelPath
|
||||||
}{
|
}{
|
||||||
{
|
{
|
||||||
"full path https",
|
"full path https",
|
||||||
|
|
|
@ -15,8 +15,8 @@ import (
|
||||||
)
|
)
|
||||||
|
|
||||||
const (
|
const (
|
||||||
redirectChunkSize = 1024 * 1024 * 1024
|
redirectChunkSize int64 = 1024 * 1024 * 1024
|
||||||
regularChunkSize = 95 * 1024 * 1024
|
regularChunkSize int64 = 95 * 1024 * 1024
|
||||||
)
|
)
|
||||||
|
|
||||||
func startUpload(ctx context.Context, mp ModelPath, layer *Layer, regOpts *RegistryOptions) (*url.URL, int64, error) {
|
func startUpload(ctx context.Context, mp ModelPath, layer *Layer, regOpts *RegistryOptions) (*url.URL, int64, error) {
|
||||||
|
@ -48,7 +48,7 @@ func startUpload(ctx context.Context, mp ModelPath, layer *Layer, regOpts *Regis
|
||||||
return nil, 0, err
|
return nil, 0, err
|
||||||
}
|
}
|
||||||
|
|
||||||
return locationURL, int64(chunkSize), nil
|
return locationURL, chunkSize, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func uploadBlob(ctx context.Context, requestURL *url.URL, layer *Layer, chunkSize int64, regOpts *RegistryOptions, fn func(api.ProgressResponse)) error {
|
func uploadBlob(ctx context.Context, requestURL *url.URL, layer *Layer, chunkSize int64, regOpts *RegistryOptions, fn func(api.ProgressResponse)) error {
|
||||||
|
@ -73,10 +73,10 @@ func uploadBlob(ctx context.Context, requestURL *url.URL, layer *Layer, chunkSiz
|
||||||
fn: fn,
|
fn: fn,
|
||||||
}
|
}
|
||||||
|
|
||||||
for offset := int64(0); offset < int64(layer.Size); {
|
for offset := int64(0); offset < layer.Size; {
|
||||||
chunk := int64(layer.Size) - offset
|
chunk := layer.Size - offset
|
||||||
if chunk > int64(chunkSize) {
|
if chunk > chunkSize {
|
||||||
chunk = int64(chunkSize)
|
chunk = chunkSize
|
||||||
}
|
}
|
||||||
|
|
||||||
resp, err := uploadBlobChunk(ctx, http.MethodPatch, requestURL, f, offset, chunk, regOpts, &pw)
|
resp, err := uploadBlobChunk(ctx, http.MethodPatch, requestURL, f, offset, chunk, regOpts, &pw)
|
||||||
|
@ -85,7 +85,7 @@ func uploadBlob(ctx context.Context, requestURL *url.URL, layer *Layer, chunkSiz
|
||||||
Status: fmt.Sprintf("error uploading chunk: %v", err),
|
Status: fmt.Sprintf("error uploading chunk: %v", err),
|
||||||
Digest: layer.Digest,
|
Digest: layer.Digest,
|
||||||
Total: layer.Size,
|
Total: layer.Size,
|
||||||
Completed: int(offset),
|
Completed: offset,
|
||||||
})
|
})
|
||||||
|
|
||||||
return err
|
return err
|
||||||
|
@ -127,7 +127,7 @@ func uploadBlob(ctx context.Context, requestURL *url.URL, layer *Layer, chunkSiz
|
||||||
}
|
}
|
||||||
|
|
||||||
func uploadBlobChunk(ctx context.Context, method string, requestURL *url.URL, r io.ReaderAt, offset, limit int64, opts *RegistryOptions, pw *ProgressWriter) (*http.Response, error) {
|
func uploadBlobChunk(ctx context.Context, method string, requestURL *url.URL, r io.ReaderAt, offset, limit int64, opts *RegistryOptions, pw *ProgressWriter) (*http.Response, error) {
|
||||||
sectionReader := io.NewSectionReader(r, int64(offset), limit)
|
sectionReader := io.NewSectionReader(r, offset, limit)
|
||||||
|
|
||||||
headers := make(http.Header)
|
headers := make(http.Header)
|
||||||
headers.Set("Content-Type", "application/octet-stream")
|
headers.Set("Content-Type", "application/octet-stream")
|
||||||
|
@ -152,7 +152,7 @@ func uploadBlobChunk(ctx context.Context, method string, requestURL *url.URL, r
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
pw.completed = int(offset)
|
pw.completed = offset
|
||||||
if _, err := uploadBlobChunk(ctx, http.MethodPut, location, r, offset, limit, nil, pw); err != nil {
|
if _, err := uploadBlobChunk(ctx, http.MethodPut, location, r, offset, limit, nil, pw); err != nil {
|
||||||
// retry
|
// retry
|
||||||
log.Printf("retrying redirected upload: %v", err)
|
log.Printf("retrying redirected upload: %v", err)
|
||||||
|
@ -170,7 +170,7 @@ func uploadBlobChunk(ctx context.Context, method string, requestURL *url.URL, r
|
||||||
|
|
||||||
opts.Token = token
|
opts.Token = token
|
||||||
|
|
||||||
pw.completed = int(offset)
|
pw.completed = offset
|
||||||
sectionReader = io.NewSectionReader(r, offset, limit)
|
sectionReader = io.NewSectionReader(r, offset, limit)
|
||||||
continue
|
continue
|
||||||
case resp.StatusCode >= http.StatusBadRequest:
|
case resp.StatusCode >= http.StatusBadRequest:
|
||||||
|
@ -187,19 +187,19 @@ func uploadBlobChunk(ctx context.Context, method string, requestURL *url.URL, r
|
||||||
type ProgressWriter struct {
|
type ProgressWriter struct {
|
||||||
status string
|
status string
|
||||||
digest string
|
digest string
|
||||||
bucket int
|
bucket int64
|
||||||
completed int
|
completed int64
|
||||||
total int
|
total int64
|
||||||
fn func(api.ProgressResponse)
|
fn func(api.ProgressResponse)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (pw *ProgressWriter) Write(b []byte) (int, error) {
|
func (pw *ProgressWriter) Write(b []byte) (int, error) {
|
||||||
n := len(b)
|
n := len(b)
|
||||||
pw.bucket += n
|
pw.bucket += int64(n)
|
||||||
pw.completed += n
|
|
||||||
|
|
||||||
// throttle status updates to not spam the client
|
// throttle status updates to not spam the client
|
||||||
if pw.bucket >= 1024*1024 || pw.completed >= pw.total {
|
if pw.bucket >= 1024*1024 || pw.completed+pw.bucket >= pw.total {
|
||||||
|
pw.completed += pw.bucket
|
||||||
pw.fn(api.ProgressResponse{
|
pw.fn(api.ProgressResponse{
|
||||||
Status: pw.status,
|
Status: pw.status,
|
||||||
Digest: pw.digest,
|
Digest: pw.digest,
|
||||||
|
|
Loading…
Reference in a new issue