From 2a66a1164a009f597f8931f155e18b05777c6602 Mon Sep 17 00:00:00 2001 From: Michael Yang Date: Tue, 11 Jul 2023 11:54:22 -0700 Subject: [PATCH] common stream producer --- server/models.go | 32 ++----------- server/routes.go | 114 +++++++++++++++++++++++------------------------ 2 files changed, 61 insertions(+), 85 deletions(-) diff --git a/server/models.go b/server/models.go index 813cccc9..496b2c45 100644 --- a/server/models.go +++ b/server/models.go @@ -8,8 +8,6 @@ import ( "os" "path" "strconv" - - "github.com/jmorganca/ollama/api" ) const directoryURL = "https://ollama.ai/api/models" @@ -36,14 +34,6 @@ func (m *Model) FullName() string { return path.Join(home, ".ollama", "models", m.Name+".bin") } -func pull(model string, progressCh chan<- api.PullProgress) error { - remote, err := getRemote(model) - if err != nil { - return fmt.Errorf("failed to pull model: %w", err) - } - return saveModel(remote, progressCh) -} - func getRemote(model string) (*Model, error) { // resolve the model download from our directory resp, err := http.Get(directoryURL) @@ -68,7 +58,7 @@ func getRemote(model string) (*Model, error) { return nil, fmt.Errorf("model not found in directory: %s", model) } -func saveModel(model *Model, progressCh chan<- api.PullProgress) error { +func saveModel(model *Model, fn func(total, completed int64)) error { // this models cache directory is created by the server on startup client := &http.Client{} @@ -98,11 +88,7 @@ func saveModel(model *Model, progressCh chan<- api.PullProgress) error { if resp.StatusCode == http.StatusRequestedRangeNotSatisfiable { // already downloaded - progressCh <- api.PullProgress{ - Total: alreadyDownloaded, - Completed: alreadyDownloaded, - Percent: 100, - } + fn(alreadyDownloaded, alreadyDownloaded) return nil } @@ -136,19 +122,9 @@ func saveModel(model *Model, progressCh chan<- api.PullProgress) error { totalBytes += int64(n) - // send progress updates - progressCh <- api.PullProgress{ - Total: totalSize, - Completed: totalBytes, - Percent: float64(totalBytes) / float64(totalSize) * 100, - } - } - - progressCh <- api.PullProgress{ - Total: totalSize, - Completed: totalSize, - Percent: 100, + fn(totalSize, totalBytes) } + fn(totalSize, totalSize) return nil } diff --git a/server/routes.go b/server/routes.go index 47551f15..94894fdb 100644 --- a/server/routes.go +++ b/server/routes.go @@ -79,35 +79,54 @@ func generate(c *gin.Context) { req.Prompt = sb.String() } - ch := make(chan string) + ch := make(chan any) g, _ := errgroup.WithContext(c.Request.Context()) g.Go(func() error { defer close(ch) return llm.Predict(req.Prompt, func(s string) { - ch <- s + ch <- api.GenerateResponse{Response: s} }) }) g.Go(func() error { - c.Stream(func(w io.Writer) bool { - s, ok := <-ch - if !ok { - return false - } + stream(c, ch) + return nil + }) - bts, err := json.Marshal(api.GenerateResponse{Response: s}) - if err != nil { - return false - } + if err := g.Wait(); err != nil && !errors.Is(err, io.EOF) { + c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()}) + return + } +} - bts = append(bts, '\n') - if _, err := w.Write(bts); err != nil { - return false - } +func pull(c *gin.Context) { + var req api.PullRequest + if err := c.ShouldBindJSON(&req); err != nil { + c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()}) + return + } - return true + remote, err := getRemote(req.Model) + if err != nil { + c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()}) + return + } + + ch := make(chan any) + g, _ := errgroup.WithContext(c.Request.Context()) + g.Go(func() error { + defer close(ch) + return saveModel(remote, func(total, completed int64) { + ch <- api.PullProgress{ + Total: total, + Completed: completed, + Percent: float64(total) / float64(completed) * 100, + } }) + }) + g.Go(func() error { + stream(c, ch) return nil }) @@ -124,47 +143,7 @@ func Serve(ln net.Listener) error { c.String(http.StatusOK, "Ollama is running") }) - r.POST("api/pull", func(c *gin.Context) { - var req api.PullRequest - if err := c.ShouldBindJSON(&req); err != nil { - c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()}) - return - } - - progressCh := make(chan api.PullProgress) - go func() { - defer close(progressCh) - if err := pull(req.Model, progressCh); err != nil { - var opError *net.OpError - if errors.As(err, &opError) { - c.JSON(http.StatusBadGateway, gin.H{"error": err.Error()}) - return - } - c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()}) - return - } - }() - - c.Stream(func(w io.Writer) bool { - progress, ok := <-progressCh - if !ok { - return false - } - - bts, err := json.Marshal(progress) - if err != nil { - return false - } - - bts = append(bts, '\n') - if _, err := w.Write(bts); err != nil { - return false - } - - return true - }) - }) - + r.POST("api/pull", pull) r.POST("/api/generate", generate) log.Printf("Listening on %s", ln.Addr()) @@ -186,3 +165,24 @@ func matchRankOne(source string, targets []string) (bestMatch string, bestRank i return } + +func stream(c *gin.Context, ch chan any) { + c.Stream(func(w io.Writer) bool { + val, ok := <-ch + if !ok { + return false + } + + bts, err := json.Marshal(val) + if err != nil { + return false + } + + bts = append(bts, '\n') + if _, err := w.Write(bts); err != nil { + return false + } + + return true + }) +}