rename partial file

This commit is contained in:
Michael Yang 2023-07-11 13:36:35 -07:00
parent e243329e2e
commit 948323fa78

View file

@ -2,6 +2,7 @@ package server
import (
"encoding/json"
"errors"
"fmt"
"io"
"net/http"
@ -34,6 +35,14 @@ func (m *Model) FullName() string {
return path.Join(home, ".ollama", "models", m.Name+".bin")
}
func (m *Model) TempFile() string {
fullName := m.FullName()
return path.Join(
path.Dir(fullName),
fmt.Sprintf(".%s.part", path.Base(fullName)),
)
}
func getRemote(model string) (*Model, error) {
// resolve the model download from our directory
resp, err := http.Get(directoryURL)
@ -66,37 +75,45 @@ func saveModel(model *Model, fn func(total, completed int64)) error {
if err != nil {
return fmt.Errorf("failed to download model: %w", err)
}
// check for resume
alreadyDownloaded := int64(0)
fileInfo, err := os.Stat(model.FullName())
if err != nil {
if !os.IsNotExist(err) {
return fmt.Errorf("failed to check resume model file: %w", err)
}
// file doesn't exist, create it now
} else {
alreadyDownloaded = fileInfo.Size()
req.Header.Add("Range", fmt.Sprintf("bytes=%d-", alreadyDownloaded))
// check if completed file exists
fi, err := os.Stat(model.FullName())
switch {
case errors.Is(err, os.ErrNotExist):
// noop, file doesn't exist so create it
case err != nil:
return fmt.Errorf("stat: %w", err)
default:
fn(fi.Size(), fi.Size())
return nil
}
var size int64
// completed file doesn't exist, check partial file
fi, err = os.Stat(model.TempFile())
switch {
case errors.Is(err, os.ErrNotExist):
// noop, file doesn't exist so create it
case err != nil:
return fmt.Errorf("stat: %w", err)
default:
size = fi.Size()
}
req.Header.Add("Range", fmt.Sprintf("bytes=%d-", size))
resp, err := client.Do(req)
if err != nil {
return fmt.Errorf("failed to download model: %w", err)
}
defer resp.Body.Close()
if resp.StatusCode == http.StatusRequestedRangeNotSatisfiable {
// already downloaded
fn(alreadyDownloaded, alreadyDownloaded)
return nil
}
if resp.StatusCode != http.StatusOK && resp.StatusCode != http.StatusPartialContent {
if resp.StatusCode >= 400 {
return fmt.Errorf("failed to download model: %s", resp.Status)
}
out, err := os.OpenFile(model.FullName(), os.O_CREATE|os.O_APPEND|os.O_WRONLY, 0o644)
out, err := os.OpenFile(model.TempFile(), os.O_CREATE|os.O_APPEND|os.O_WRONLY, 0o644)
if err != nil {
panic(err)
}
@ -104,27 +121,23 @@ func saveModel(model *Model, fn func(total, completed int64)) error {
totalSize, _ := strconv.ParseInt(resp.Header.Get("Content-Length"), 10, 64)
buf := make([]byte, 1024)
totalBytes := alreadyDownloaded
totalSize += alreadyDownloaded
totalBytes := size
totalSize += size
for {
n, err := resp.Body.Read(buf)
if err != nil && err != io.EOF {
n, err := io.CopyN(out, resp.Body, 8192)
if err != nil && !errors.Is(err, io.EOF) {
return err
}
if n == 0 {
break
}
if _, err := out.Write(buf[:n]); err != nil {
return err
}
totalBytes += int64(n)
totalBytes += n
fn(totalSize, totalBytes)
}
fn(totalSize, totalSize)
return nil
return os.Rename(model.TempFile(), model.FullName())
}