From 5ade3db040b96b4c618e080aaf273aeec9f8edd1 Mon Sep 17 00:00:00 2001 From: Michael Yang Date: Fri, 14 Jul 2023 14:15:53 -0700 Subject: [PATCH] fix race block on write which only returns when the channel is closed. this is contrary to the previous arrangement where the handler may return but the stream hasn't finished writing. it can lead to the client receiving unexpected responses (since the request has been handled) or worst case a nil-pointer dereference as the stream tries to flush a nil writer --- server/routes.go | 55 ++++++++++++++++++++++-------------------------- 1 file changed, 25 insertions(+), 30 deletions(-) diff --git a/server/routes.go b/server/routes.go index ab651f42..29c9dacf 100644 --- a/server/routes.go +++ b/server/routes.go @@ -58,9 +58,6 @@ func generate(c *gin.Context) { req.Model = path.Join(cacheDir(), "models", req.Model+".bin") } - ch := make(chan any) - go stream(c, ch) - templateNames := make([]string, 0, len(templates.Templates())) for _, template := range templates.Templates() { templateNames = append(templateNames, template.Name()) @@ -84,21 +81,21 @@ func generate(c *gin.Context) { } defer llm.Close() - fn := func(r api.GenerateResponse) { - r.Model = req.Model - r.CreatedAt = time.Now().UTC() - if r.Done { - r.TotalDuration = time.Since(start) - } + ch := make(chan any) + go func() { + defer close(ch) + llm.Predict(req.Context, req.Prompt, func(r api.GenerateResponse) { + r.Model = req.Model + r.CreatedAt = time.Now().UTC() + if r.Done { + r.TotalDuration = time.Since(start) + } - ch <- r - } - - if err := llm.Predict(req.Context, req.Prompt, fn); err != nil { - c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()}) - return - } + ch <- r + }) + }() + streamResponse(c, ch) } func pull(c *gin.Context) { @@ -133,20 +130,18 @@ func pull(c *gin.Context) { } ch := make(chan any) - go stream(c, ch) + go func() { + defer close(ch) + saveModel(remote, func(total, completed int64) { + ch <- api.PullProgress{ + Total: total, + Completed: completed, + Percent: float64(completed) / float64(total) * 100, + } + }) + }() - fn := func(total, completed int64) { - ch <- api.PullProgress{ - Total: total, - Completed: completed, - Percent: float64(completed) / float64(total) * 100, - } - } - - if err := saveModel(remote, fn); err != nil { - c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()}) - return - } + streamResponse(c, ch) } func Serve(ln net.Listener) error { @@ -179,7 +174,7 @@ func matchRankOne(source string, targets []string) (bestMatch string, bestRank i return } -func stream(c *gin.Context, ch chan any) { +func streamResponse(c *gin.Context, ch chan any) { c.Stream(func(w io.Writer) bool { val, ok := <-ch if !ok {