no errgroup

This commit is contained in:
Michael Yang 2023-07-11 14:57:17 -07:00
parent 948323fa78
commit a806b03f62
3 changed files with 24 additions and 41 deletions

1
go.mod
View file

@ -39,7 +39,6 @@ require (
golang.org/x/arch v0.3.0 // indirect
golang.org/x/crypto v0.10.0 // indirect
golang.org/x/net v0.10.0 // indirect
golang.org/x/sync v0.3.0
golang.org/x/sys v0.10.0 // indirect
golang.org/x/term v0.10.0
golang.org/x/text v0.10.0 // indirect

2
go.sum
View file

@ -99,8 +99,6 @@ golang.org/x/net v0.10.0/go.mod h1:0qNGK6F8kojg2nk9dLZ2mShWaEBan6FAoqfSigmmuDg=
golang.org/x/sync v0.0.0-20190423024810-112230192c58/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM=
golang.org/x/sync v0.0.0-20220722155255-886fb9371eb4/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM=
golang.org/x/sync v0.1.0/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM=
golang.org/x/sync v0.3.0 h1:ftCYgMx6zT/asHUrPw8BLLscYtGznsLAnjq5RH9P66E=
golang.org/x/sync v0.3.0/go.mod h1:FU7BRWz2tNW+3quACPkgCx/L+uEAv1htQ0V83Z9Rj+Y=
golang.org/x/sys v0.0.0-20190215142949-d0b11bdaac8a/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY=
golang.org/x/sys v0.0.0-20201119102817-f84b799fce68/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
golang.org/x/sys v0.0.0-20210615035016-665e8c7367d1/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=

View file

@ -16,7 +16,6 @@ import (
"github.com/gin-gonic/gin"
"github.com/lithammer/fuzzysearch/fuzzy"
"golang.org/x/sync/errgroup"
"github.com/jmorganca/ollama/api"
"github.com/jmorganca/ollama/llama"
@ -56,12 +55,8 @@ func generate(c *gin.Context) {
req.Model = path.Join(cacheDir(), "models", req.Model+".bin")
}
llm, err := llama.New(req.Model, req.Options)
if err != nil {
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
return
}
defer llm.Close()
ch := make(chan any)
go stream(c, ch)
templateNames := make([]string, 0, len(templates.Templates()))
for _, template := range templates.Templates() {
@ -79,24 +74,22 @@ func generate(c *gin.Context) {
req.Prompt = sb.String()
}
ch := make(chan any)
g, _ := errgroup.WithContext(c.Request.Context())
g.Go(func() error {
defer close(ch)
return llm.Predict(req.Prompt, func(s string) {
ch <- api.GenerateResponse{Response: s}
})
})
g.Go(func() error {
stream(c, ch)
return nil
})
if err := g.Wait(); err != nil && !errors.Is(err, io.EOF) {
llm, err := llama.New(req.Model, req.Options)
if err != nil {
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
return
}
defer llm.Close()
fn := func(s string) {
ch <- api.GenerateResponse{Response: s}
}
if err := llm.Predict(req.Prompt, fn); err != nil {
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
return
}
}
func pull(c *gin.Context) {
@ -113,24 +106,17 @@ func pull(c *gin.Context) {
}
ch := make(chan any)
g, _ := errgroup.WithContext(c.Request.Context())
g.Go(func() error {
defer close(ch)
return saveModel(remote, func(total, completed int64) {
ch <- api.PullProgress{
Total: total,
Completed: completed,
Percent: float64(total) / float64(completed) * 100,
}
})
})
go stream(c, ch)
g.Go(func() error {
stream(c, ch)
return nil
})
fn := func(total, completed int64) {
ch <- api.PullProgress{
Total: total,
Completed: completed,
Percent: float64(total) / float64(completed) * 100,
}
}
if err := g.Wait(); err != nil && !errors.Is(err, io.EOF) {
if err := saveModel(remote, fn); err != nil {
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
return
}