diff --git a/api/client.go b/api/client.go index ccbcbf6b..29ab2698 100644 --- a/api/client.go +++ b/api/client.go @@ -10,6 +10,20 @@ import ( "net/url" ) +type StatusError struct { + StatusCode int + Status string + Message string +} + +func (e StatusError) Error() string { + if e.Message != "" { + return fmt.Sprintf("%s: %s", e.Status, e.Message) + } + + return e.Status +} + type Client struct { base url.URL } @@ -25,7 +39,7 @@ func NewClient(hosts ...string) *Client { } } -func (c *Client) stream(ctx context.Context, method, path string, data any, callback func([]byte) error) error { +func (c *Client) stream(ctx context.Context, method, path string, data any, fn func([]byte) error) error { var buf *bytes.Buffer if data != nil { bts, err := json.Marshal(data) @@ -53,7 +67,7 @@ func (c *Client) stream(ctx context.Context, method, path string, data any, call scanner := bufio.NewScanner(response.Body) for scanner.Scan() { var errorResponse struct { - Error string `json:"error"` + Error string `json:"error,omitempty"` } bts := scanner.Bytes() @@ -61,11 +75,15 @@ func (c *Client) stream(ctx context.Context, method, path string, data any, call return fmt.Errorf("unmarshal: %w", err) } - if len(errorResponse.Error) > 0 { - return fmt.Errorf("stream: %s", errorResponse.Error) + if response.StatusCode >= 400 { + return StatusError{ + StatusCode: response.StatusCode, + Status: response.Status, + Message: errorResponse.Error, + } } - if err := callback(bts); err != nil { + if err := fn(bts); err != nil { return err } } diff --git a/cmd/cmd.go b/cmd/cmd.go index 8421b8f5..ca924ae9 100644 --- a/cmd/cmd.go +++ b/cmd/cmd.go @@ -7,6 +7,7 @@ import ( "fmt" "log" "net" + "net/http" "os" "path" "strings" @@ -34,7 +35,14 @@ func RunRun(cmd *cobra.Command, args []string) error { switch { case errors.Is(err, os.ErrNotExist): if err := pull(args[0]); err != nil { - return err + var apiStatusError api.StatusError + if !errors.As(err, &apiStatusError) { + return err + } + + if apiStatusError.StatusCode != http.StatusBadGateway { + return err + } } case err != nil: return err @@ -50,11 +58,12 @@ func pull(model string) error { context.Background(), &api.PullRequest{Model: model}, func(progress api.PullProgress) error { - if bar == nil && progress.Percent == 100 { - // already downloaded - return nil - } if bar == nil { + if progress.Percent == 100 { + // already downloaded + return nil + } + bar = progressbar.DefaultBytes(progress.Total) } diff --git a/go.mod b/go.mod index 8beb32bd..c2e15346 100644 --- a/go.mod +++ b/go.mod @@ -39,7 +39,6 @@ require ( golang.org/x/arch v0.3.0 // indirect golang.org/x/crypto v0.10.0 // indirect golang.org/x/net v0.10.0 // indirect - golang.org/x/sync v0.3.0 golang.org/x/sys v0.10.0 // indirect golang.org/x/term v0.10.0 golang.org/x/text v0.10.0 // indirect diff --git a/go.sum b/go.sum index 9189b115..2adee49d 100644 --- a/go.sum +++ b/go.sum @@ -99,8 +99,6 @@ golang.org/x/net v0.10.0/go.mod h1:0qNGK6F8kojg2nk9dLZ2mShWaEBan6FAoqfSigmmuDg= golang.org/x/sync v0.0.0-20190423024810-112230192c58/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= golang.org/x/sync v0.0.0-20220722155255-886fb9371eb4/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= golang.org/x/sync v0.1.0/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= -golang.org/x/sync v0.3.0 h1:ftCYgMx6zT/asHUrPw8BLLscYtGznsLAnjq5RH9P66E= -golang.org/x/sync v0.3.0/go.mod h1:FU7BRWz2tNW+3quACPkgCx/L+uEAv1htQ0V83Z9Rj+Y= golang.org/x/sys v0.0.0-20190215142949-d0b11bdaac8a/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY= golang.org/x/sys v0.0.0-20201119102817-f84b799fce68/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= golang.org/x/sys v0.0.0-20210615035016-665e8c7367d1/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= diff --git a/server/models.go b/server/models.go index 813cccc9..fd689ed6 100644 --- a/server/models.go +++ b/server/models.go @@ -2,14 +2,13 @@ package server import ( "encoding/json" + "errors" "fmt" "io" "net/http" "os" "path" "strconv" - - "github.com/jmorganca/ollama/api" ) const directoryURL = "https://ollama.ai/api/models" @@ -36,12 +35,12 @@ func (m *Model) FullName() string { return path.Join(home, ".ollama", "models", m.Name+".bin") } -func pull(model string, progressCh chan<- api.PullProgress) error { - remote, err := getRemote(model) - if err != nil { - return fmt.Errorf("failed to pull model: %w", err) - } - return saveModel(remote, progressCh) +func (m *Model) TempFile() string { + fullName := m.FullName() + return path.Join( + path.Dir(fullName), + fmt.Sprintf(".%s.part", path.Base(fullName)), + ) } func getRemote(model string) (*Model, error) { @@ -68,7 +67,7 @@ func getRemote(model string) (*Model, error) { return nil, fmt.Errorf("model not found in directory: %s", model) } -func saveModel(model *Model, progressCh chan<- api.PullProgress) error { +func saveModel(model *Model, fn func(total, completed int64)) error { // this models cache directory is created by the server on startup client := &http.Client{} @@ -76,41 +75,45 @@ func saveModel(model *Model, progressCh chan<- api.PullProgress) error { if err != nil { return fmt.Errorf("failed to download model: %w", err) } - // check for resume - alreadyDownloaded := int64(0) - fileInfo, err := os.Stat(model.FullName()) - if err != nil { - if !os.IsNotExist(err) { - return fmt.Errorf("failed to check resume model file: %w", err) - } - // file doesn't exist, create it now - } else { - alreadyDownloaded = fileInfo.Size() - req.Header.Add("Range", fmt.Sprintf("bytes=%d-", alreadyDownloaded)) + + // check if completed file exists + fi, err := os.Stat(model.FullName()) + switch { + case errors.Is(err, os.ErrNotExist): + // noop, file doesn't exist so create it + case err != nil: + return fmt.Errorf("stat: %w", err) + default: + fn(fi.Size(), fi.Size()) + return nil } + var size int64 + + // completed file doesn't exist, check partial file + fi, err = os.Stat(model.TempFile()) + switch { + case errors.Is(err, os.ErrNotExist): + // noop, file doesn't exist so create it + case err != nil: + return fmt.Errorf("stat: %w", err) + default: + size = fi.Size() + } + + req.Header.Add("Range", fmt.Sprintf("bytes=%d-", size)) + resp, err := client.Do(req) if err != nil { return fmt.Errorf("failed to download model: %w", err) } - defer resp.Body.Close() - if resp.StatusCode == http.StatusRequestedRangeNotSatisfiable { - // already downloaded - progressCh <- api.PullProgress{ - Total: alreadyDownloaded, - Completed: alreadyDownloaded, - Percent: 100, - } - return nil - } - - if resp.StatusCode != http.StatusOK && resp.StatusCode != http.StatusPartialContent { + if resp.StatusCode >= 400 { return fmt.Errorf("failed to download model: %s", resp.Status) } - out, err := os.OpenFile(model.FullName(), os.O_CREATE|os.O_APPEND|os.O_WRONLY, 0o644) + out, err := os.OpenFile(model.TempFile(), os.O_CREATE|os.O_APPEND|os.O_WRONLY, 0o644) if err != nil { panic(err) } @@ -118,37 +121,23 @@ func saveModel(model *Model, progressCh chan<- api.PullProgress) error { totalSize, _ := strconv.ParseInt(resp.Header.Get("Content-Length"), 10, 64) - buf := make([]byte, 1024) - totalBytes := alreadyDownloaded - totalSize += alreadyDownloaded + totalBytes := size + totalSize += size for { - n, err := resp.Body.Read(buf) - if err != nil && err != io.EOF { + n, err := io.CopyN(out, resp.Body, 8192) + if err != nil && !errors.Is(err, io.EOF) { return err } + if n == 0 { break } - if _, err := out.Write(buf[:n]); err != nil { - return err - } - totalBytes += int64(n) - - // send progress updates - progressCh <- api.PullProgress{ - Total: totalSize, - Completed: totalBytes, - Percent: float64(totalBytes) / float64(totalSize) * 100, - } + totalBytes += n + fn(totalSize, totalBytes) } - progressCh <- api.PullProgress{ - Total: totalSize, - Completed: totalSize, - Percent: 100, - } - - return nil + fn(totalSize, totalSize) + return os.Rename(model.TempFile(), model.FullName()) } diff --git a/server/routes.go b/server/routes.go index 47551f15..ef19f3c2 100644 --- a/server/routes.go +++ b/server/routes.go @@ -16,7 +16,6 @@ import ( "github.com/gin-gonic/gin" "github.com/lithammer/fuzzysearch/fuzzy" - "golang.org/x/sync/errgroup" "github.com/jmorganca/ollama/api" "github.com/jmorganca/ollama/llama" @@ -56,12 +55,8 @@ func generate(c *gin.Context) { req.Model = path.Join(cacheDir(), "models", req.Model+".bin") } - llm, err := llama.New(req.Model, req.Options) - if err != nil { - c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()}) - return - } - defer llm.Close() + ch := make(chan any) + go stream(c, ch) templateNames := make([]string, 0, len(templates.Templates())) for _, template := range templates.Templates() { @@ -79,39 +74,49 @@ func generate(c *gin.Context) { req.Prompt = sb.String() } - ch := make(chan string) - g, _ := errgroup.WithContext(c.Request.Context()) - g.Go(func() error { - defer close(ch) - return llm.Predict(req.Prompt, func(s string) { - ch <- s - }) - }) + llm, err := llama.New(req.Model, req.Options) + if err != nil { + c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()}) + return + } + defer llm.Close() - g.Go(func() error { - c.Stream(func(w io.Writer) bool { - s, ok := <-ch - if !ok { - return false - } + fn := func(s string) { + ch <- api.GenerateResponse{Response: s} + } - bts, err := json.Marshal(api.GenerateResponse{Response: s}) - if err != nil { - return false - } + if err := llm.Predict(req.Prompt, fn); err != nil { + c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()}) + return + } - bts = append(bts, '\n') - if _, err := w.Write(bts); err != nil { - return false - } +} - return true - }) +func pull(c *gin.Context) { + var req api.PullRequest + if err := c.ShouldBindJSON(&req); err != nil { + c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()}) + return + } - return nil - }) + remote, err := getRemote(req.Model) + if err != nil { + c.JSON(http.StatusBadGateway, gin.H{"error": err.Error()}) + return + } - if err := g.Wait(); err != nil && !errors.Is(err, io.EOF) { + ch := make(chan any) + go stream(c, ch) + + fn := func(total, completed int64) { + ch <- api.PullProgress{ + Total: total, + Completed: completed, + Percent: float64(total) / float64(completed) * 100, + } + } + + if err := saveModel(remote, fn); err != nil { c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()}) return } @@ -124,47 +129,7 @@ func Serve(ln net.Listener) error { c.String(http.StatusOK, "Ollama is running") }) - r.POST("api/pull", func(c *gin.Context) { - var req api.PullRequest - if err := c.ShouldBindJSON(&req); err != nil { - c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()}) - return - } - - progressCh := make(chan api.PullProgress) - go func() { - defer close(progressCh) - if err := pull(req.Model, progressCh); err != nil { - var opError *net.OpError - if errors.As(err, &opError) { - c.JSON(http.StatusBadGateway, gin.H{"error": err.Error()}) - return - } - c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()}) - return - } - }() - - c.Stream(func(w io.Writer) bool { - progress, ok := <-progressCh - if !ok { - return false - } - - bts, err := json.Marshal(progress) - if err != nil { - return false - } - - bts = append(bts, '\n') - if _, err := w.Write(bts); err != nil { - return false - } - - return true - }) - }) - + r.POST("api/pull", pull) r.POST("/api/generate", generate) log.Printf("Listening on %s", ln.Addr()) @@ -186,3 +151,24 @@ func matchRankOne(source string, targets []string) (bestMatch string, bestRank i return } + +func stream(c *gin.Context, ch chan any) { + c.Stream(func(w io.Writer) bool { + val, ok := <-ch + if !ok { + return false + } + + bts, err := json.Marshal(val) + if err != nil { + return false + } + + bts = append(bts, '\n') + if _, err := w.Write(bts); err != nil { + return false + } + + return true + }) +}