Merge pull request #81 from jmorganca/fix-race-2

fix race
This commit is contained in:
Michael Yang 2023-07-14 15:12:01 -07:00 committed by GitHub
commit 567e74e7d7
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23

View file

@ -58,9 +58,6 @@ func generate(c *gin.Context) {
req.Model = path.Join(cacheDir(), "models", req.Model+".bin") req.Model = path.Join(cacheDir(), "models", req.Model+".bin")
} }
ch := make(chan any)
go stream(c, ch)
templateNames := make([]string, 0, len(templates.Templates())) templateNames := make([]string, 0, len(templates.Templates()))
for _, template := range templates.Templates() { for _, template := range templates.Templates() {
templateNames = append(templateNames, template.Name()) templateNames = append(templateNames, template.Name())
@ -84,21 +81,21 @@ func generate(c *gin.Context) {
} }
defer llm.Close() defer llm.Close()
fn := func(r api.GenerateResponse) { ch := make(chan any)
r.Model = req.Model go func() {
r.CreatedAt = time.Now().UTC() defer close(ch)
if r.Done { llm.Predict(req.Context, req.Prompt, func(r api.GenerateResponse) {
r.TotalDuration = time.Since(start) r.Model = req.Model
} r.CreatedAt = time.Now().UTC()
if r.Done {
r.TotalDuration = time.Since(start)
}
ch <- r ch <- r
} })
}()
if err := llm.Predict(req.Context, req.Prompt, fn); err != nil {
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
return
}
streamResponse(c, ch)
} }
func pull(c *gin.Context) { func pull(c *gin.Context) {
@ -133,20 +130,18 @@ func pull(c *gin.Context) {
} }
ch := make(chan any) ch := make(chan any)
go stream(c, ch) go func() {
defer close(ch)
saveModel(remote, func(total, completed int64) {
ch <- api.PullProgress{
Total: total,
Completed: completed,
Percent: float64(completed) / float64(total) * 100,
}
})
}()
fn := func(total, completed int64) { streamResponse(c, ch)
ch <- api.PullProgress{
Total: total,
Completed: completed,
Percent: float64(completed) / float64(total) * 100,
}
}
if err := saveModel(remote, fn); err != nil {
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
return
}
} }
func Serve(ln net.Listener) error { func Serve(ln net.Listener) error {
@ -179,7 +174,7 @@ func matchRankOne(source string, targets []string) (bestMatch string, bestRank i
return return
} }
func stream(c *gin.Context, ch chan any) { func streamResponse(c *gin.Context, ch chan any) {
c.Stream(func(w io.Writer) bool { c.Stream(func(w io.Writer) bool {
val, ok := <-ch val, ok := <-ch
if !ok { if !ok {