diff --git a/server/routes.go b/server/routes.go index ab651f42..29c9dacf 100644 --- a/server/routes.go +++ b/server/routes.go @@ -58,9 +58,6 @@ func generate(c *gin.Context) { req.Model = path.Join(cacheDir(), "models", req.Model+".bin") } - ch := make(chan any) - go stream(c, ch) - templateNames := make([]string, 0, len(templates.Templates())) for _, template := range templates.Templates() { templateNames = append(templateNames, template.Name()) @@ -84,21 +81,21 @@ func generate(c *gin.Context) { } defer llm.Close() - fn := func(r api.GenerateResponse) { - r.Model = req.Model - r.CreatedAt = time.Now().UTC() - if r.Done { - r.TotalDuration = time.Since(start) - } + ch := make(chan any) + go func() { + defer close(ch) + llm.Predict(req.Context, req.Prompt, func(r api.GenerateResponse) { + r.Model = req.Model + r.CreatedAt = time.Now().UTC() + if r.Done { + r.TotalDuration = time.Since(start) + } - ch <- r - } - - if err := llm.Predict(req.Context, req.Prompt, fn); err != nil { - c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()}) - return - } + ch <- r + }) + }() + streamResponse(c, ch) } func pull(c *gin.Context) { @@ -133,20 +130,18 @@ func pull(c *gin.Context) { } ch := make(chan any) - go stream(c, ch) + go func() { + defer close(ch) + saveModel(remote, func(total, completed int64) { + ch <- api.PullProgress{ + Total: total, + Completed: completed, + Percent: float64(completed) / float64(total) * 100, + } + }) + }() - fn := func(total, completed int64) { - ch <- api.PullProgress{ - Total: total, - Completed: completed, - Percent: float64(completed) / float64(total) * 100, - } - } - - if err := saveModel(remote, fn); err != nil { - c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()}) - return - } + streamResponse(c, ch) } func Serve(ln net.Listener) error { @@ -179,7 +174,7 @@ func matchRankOne(source string, targets []string) (bestMatch string, bestRank i return } -func stream(c *gin.Context, ch chan any) { +func streamResponse(c *gin.Context, ch chan any) { c.Stream(func(w io.Writer) bool { val, ok := <-ch if !ok {