common stream producer

This commit is contained in:
Michael Yang 2023-07-11 11:54:22 -07:00
parent 62620914e9
commit 2a66a1164a
2 changed files with 61 additions and 85 deletions

View file

@ -8,8 +8,6 @@ import (
"os" "os"
"path" "path"
"strconv" "strconv"
"github.com/jmorganca/ollama/api"
) )
const directoryURL = "https://ollama.ai/api/models" const directoryURL = "https://ollama.ai/api/models"
@ -36,14 +34,6 @@ func (m *Model) FullName() string {
return path.Join(home, ".ollama", "models", m.Name+".bin") return path.Join(home, ".ollama", "models", m.Name+".bin")
} }
func pull(model string, progressCh chan<- api.PullProgress) error {
remote, err := getRemote(model)
if err != nil {
return fmt.Errorf("failed to pull model: %w", err)
}
return saveModel(remote, progressCh)
}
func getRemote(model string) (*Model, error) { func getRemote(model string) (*Model, error) {
// resolve the model download from our directory // resolve the model download from our directory
resp, err := http.Get(directoryURL) resp, err := http.Get(directoryURL)
@ -68,7 +58,7 @@ func getRemote(model string) (*Model, error) {
return nil, fmt.Errorf("model not found in directory: %s", model) return nil, fmt.Errorf("model not found in directory: %s", model)
} }
func saveModel(model *Model, progressCh chan<- api.PullProgress) error { func saveModel(model *Model, fn func(total, completed int64)) error {
// this models cache directory is created by the server on startup // this models cache directory is created by the server on startup
client := &http.Client{} client := &http.Client{}
@ -98,11 +88,7 @@ func saveModel(model *Model, progressCh chan<- api.PullProgress) error {
if resp.StatusCode == http.StatusRequestedRangeNotSatisfiable { if resp.StatusCode == http.StatusRequestedRangeNotSatisfiable {
// already downloaded // already downloaded
progressCh <- api.PullProgress{ fn(alreadyDownloaded, alreadyDownloaded)
Total: alreadyDownloaded,
Completed: alreadyDownloaded,
Percent: 100,
}
return nil return nil
} }
@ -136,19 +122,9 @@ func saveModel(model *Model, progressCh chan<- api.PullProgress) error {
totalBytes += int64(n) totalBytes += int64(n)
// send progress updates fn(totalSize, totalBytes)
progressCh <- api.PullProgress{
Total: totalSize,
Completed: totalBytes,
Percent: float64(totalBytes) / float64(totalSize) * 100,
}
}
progressCh <- api.PullProgress{
Total: totalSize,
Completed: totalSize,
Percent: 100,
} }
fn(totalSize, totalSize)
return nil return nil
} }

View file

@ -79,35 +79,54 @@ func generate(c *gin.Context) {
req.Prompt = sb.String() req.Prompt = sb.String()
} }
ch := make(chan string) ch := make(chan any)
g, _ := errgroup.WithContext(c.Request.Context()) g, _ := errgroup.WithContext(c.Request.Context())
g.Go(func() error { g.Go(func() error {
defer close(ch) defer close(ch)
return llm.Predict(req.Prompt, func(s string) { return llm.Predict(req.Prompt, func(s string) {
ch <- s ch <- api.GenerateResponse{Response: s}
}) })
}) })
g.Go(func() error { g.Go(func() error {
c.Stream(func(w io.Writer) bool { stream(c, ch)
s, ok := <-ch return nil
if !ok { })
return false
}
bts, err := json.Marshal(api.GenerateResponse{Response: s}) if err := g.Wait(); err != nil && !errors.Is(err, io.EOF) {
if err != nil { c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
return false return
} }
}
bts = append(bts, '\n') func pull(c *gin.Context) {
if _, err := w.Write(bts); err != nil { var req api.PullRequest
return false if err := c.ShouldBindJSON(&req); err != nil {
} c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()})
return
}
return true remote, err := getRemote(req.Model)
if err != nil {
c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()})
return
}
ch := make(chan any)
g, _ := errgroup.WithContext(c.Request.Context())
g.Go(func() error {
defer close(ch)
return saveModel(remote, func(total, completed int64) {
ch <- api.PullProgress{
Total: total,
Completed: completed,
Percent: float64(total) / float64(completed) * 100,
}
}) })
})
g.Go(func() error {
stream(c, ch)
return nil return nil
}) })
@ -124,47 +143,7 @@ func Serve(ln net.Listener) error {
c.String(http.StatusOK, "Ollama is running") c.String(http.StatusOK, "Ollama is running")
}) })
r.POST("api/pull", func(c *gin.Context) { r.POST("api/pull", pull)
var req api.PullRequest
if err := c.ShouldBindJSON(&req); err != nil {
c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()})
return
}
progressCh := make(chan api.PullProgress)
go func() {
defer close(progressCh)
if err := pull(req.Model, progressCh); err != nil {
var opError *net.OpError
if errors.As(err, &opError) {
c.JSON(http.StatusBadGateway, gin.H{"error": err.Error()})
return
}
c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()})
return
}
}()
c.Stream(func(w io.Writer) bool {
progress, ok := <-progressCh
if !ok {
return false
}
bts, err := json.Marshal(progress)
if err != nil {
return false
}
bts = append(bts, '\n')
if _, err := w.Write(bts); err != nil {
return false
}
return true
})
})
r.POST("/api/generate", generate) r.POST("/api/generate", generate)
log.Printf("Listening on %s", ln.Addr()) log.Printf("Listening on %s", ln.Addr())
@ -186,3 +165,24 @@ func matchRankOne(source string, targets []string) (bestMatch string, bestRank i
return return
} }
func stream(c *gin.Context, ch chan any) {
c.Stream(func(w io.Writer) bool {
val, ok := <-ch
if !ok {
return false
}
bts, err := json.Marshal(val)
if err != nil {
return false
}
bts = append(bts, '\n')
if _, err := w.Write(bts); err != nil {
return false
}
return true
})
}