fix race
block on write which only returns when the channel is closed. this is contrary to the previous arrangement where the handler may return but the stream hasn't finished writing. it can lead to the client receiving unexpected responses (since the request has been handled) or worst case a nil-pointer dereference as the stream tries to flush a nil writer
This commit is contained in:
parent
965f9ad033
commit
5ade3db040
1 changed files with 25 additions and 30 deletions
|
@ -58,9 +58,6 @@ func generate(c *gin.Context) {
|
|||
req.Model = path.Join(cacheDir(), "models", req.Model+".bin")
|
||||
}
|
||||
|
||||
ch := make(chan any)
|
||||
go stream(c, ch)
|
||||
|
||||
templateNames := make([]string, 0, len(templates.Templates()))
|
||||
for _, template := range templates.Templates() {
|
||||
templateNames = append(templateNames, template.Name())
|
||||
|
@ -84,21 +81,21 @@ func generate(c *gin.Context) {
|
|||
}
|
||||
defer llm.Close()
|
||||
|
||||
fn := func(r api.GenerateResponse) {
|
||||
r.Model = req.Model
|
||||
r.CreatedAt = time.Now().UTC()
|
||||
if r.Done {
|
||||
r.TotalDuration = time.Since(start)
|
||||
}
|
||||
ch := make(chan any)
|
||||
go func() {
|
||||
defer close(ch)
|
||||
llm.Predict(req.Context, req.Prompt, func(r api.GenerateResponse) {
|
||||
r.Model = req.Model
|
||||
r.CreatedAt = time.Now().UTC()
|
||||
if r.Done {
|
||||
r.TotalDuration = time.Since(start)
|
||||
}
|
||||
|
||||
ch <- r
|
||||
}
|
||||
|
||||
if err := llm.Predict(req.Context, req.Prompt, fn); err != nil {
|
||||
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
|
||||
return
|
||||
}
|
||||
ch <- r
|
||||
})
|
||||
}()
|
||||
|
||||
streamResponse(c, ch)
|
||||
}
|
||||
|
||||
func pull(c *gin.Context) {
|
||||
|
@ -133,20 +130,18 @@ func pull(c *gin.Context) {
|
|||
}
|
||||
|
||||
ch := make(chan any)
|
||||
go stream(c, ch)
|
||||
go func() {
|
||||
defer close(ch)
|
||||
saveModel(remote, func(total, completed int64) {
|
||||
ch <- api.PullProgress{
|
||||
Total: total,
|
||||
Completed: completed,
|
||||
Percent: float64(completed) / float64(total) * 100,
|
||||
}
|
||||
})
|
||||
}()
|
||||
|
||||
fn := func(total, completed int64) {
|
||||
ch <- api.PullProgress{
|
||||
Total: total,
|
||||
Completed: completed,
|
||||
Percent: float64(completed) / float64(total) * 100,
|
||||
}
|
||||
}
|
||||
|
||||
if err := saveModel(remote, fn); err != nil {
|
||||
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
|
||||
return
|
||||
}
|
||||
streamResponse(c, ch)
|
||||
}
|
||||
|
||||
func Serve(ln net.Listener) error {
|
||||
|
@ -179,7 +174,7 @@ func matchRankOne(source string, targets []string) (bestMatch string, bestRank i
|
|||
return
|
||||
}
|
||||
|
||||
func stream(c *gin.Context, ch chan any) {
|
||||
func streamResponse(c *gin.Context, ch chan any) {
|
||||
c.Stream(func(w io.Writer) bool {
|
||||
val, ok := <-ch
|
||||
if !ok {
|
||||
|
|
Loading…
Reference in a new issue