common stream producer
This commit is contained in:
parent
62620914e9
commit
2a66a1164a
2 changed files with 61 additions and 85 deletions
|
@ -8,8 +8,6 @@ import (
|
||||||
"os"
|
"os"
|
||||||
"path"
|
"path"
|
||||||
"strconv"
|
"strconv"
|
||||||
|
|
||||||
"github.com/jmorganca/ollama/api"
|
|
||||||
)
|
)
|
||||||
|
|
||||||
const directoryURL = "https://ollama.ai/api/models"
|
const directoryURL = "https://ollama.ai/api/models"
|
||||||
|
@ -36,14 +34,6 @@ func (m *Model) FullName() string {
|
||||||
return path.Join(home, ".ollama", "models", m.Name+".bin")
|
return path.Join(home, ".ollama", "models", m.Name+".bin")
|
||||||
}
|
}
|
||||||
|
|
||||||
func pull(model string, progressCh chan<- api.PullProgress) error {
|
|
||||||
remote, err := getRemote(model)
|
|
||||||
if err != nil {
|
|
||||||
return fmt.Errorf("failed to pull model: %w", err)
|
|
||||||
}
|
|
||||||
return saveModel(remote, progressCh)
|
|
||||||
}
|
|
||||||
|
|
||||||
func getRemote(model string) (*Model, error) {
|
func getRemote(model string) (*Model, error) {
|
||||||
// resolve the model download from our directory
|
// resolve the model download from our directory
|
||||||
resp, err := http.Get(directoryURL)
|
resp, err := http.Get(directoryURL)
|
||||||
|
@ -68,7 +58,7 @@ func getRemote(model string) (*Model, error) {
|
||||||
return nil, fmt.Errorf("model not found in directory: %s", model)
|
return nil, fmt.Errorf("model not found in directory: %s", model)
|
||||||
}
|
}
|
||||||
|
|
||||||
func saveModel(model *Model, progressCh chan<- api.PullProgress) error {
|
func saveModel(model *Model, fn func(total, completed int64)) error {
|
||||||
// this models cache directory is created by the server on startup
|
// this models cache directory is created by the server on startup
|
||||||
|
|
||||||
client := &http.Client{}
|
client := &http.Client{}
|
||||||
|
@ -98,11 +88,7 @@ func saveModel(model *Model, progressCh chan<- api.PullProgress) error {
|
||||||
|
|
||||||
if resp.StatusCode == http.StatusRequestedRangeNotSatisfiable {
|
if resp.StatusCode == http.StatusRequestedRangeNotSatisfiable {
|
||||||
// already downloaded
|
// already downloaded
|
||||||
progressCh <- api.PullProgress{
|
fn(alreadyDownloaded, alreadyDownloaded)
|
||||||
Total: alreadyDownloaded,
|
|
||||||
Completed: alreadyDownloaded,
|
|
||||||
Percent: 100,
|
|
||||||
}
|
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -136,19 +122,9 @@ func saveModel(model *Model, progressCh chan<- api.PullProgress) error {
|
||||||
|
|
||||||
totalBytes += int64(n)
|
totalBytes += int64(n)
|
||||||
|
|
||||||
// send progress updates
|
fn(totalSize, totalBytes)
|
||||||
progressCh <- api.PullProgress{
|
|
||||||
Total: totalSize,
|
|
||||||
Completed: totalBytes,
|
|
||||||
Percent: float64(totalBytes) / float64(totalSize) * 100,
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
progressCh <- api.PullProgress{
|
|
||||||
Total: totalSize,
|
|
||||||
Completed: totalSize,
|
|
||||||
Percent: 100,
|
|
||||||
}
|
}
|
||||||
|
|
||||||
|
fn(totalSize, totalSize)
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
114
server/routes.go
114
server/routes.go
|
@ -79,35 +79,54 @@ func generate(c *gin.Context) {
|
||||||
req.Prompt = sb.String()
|
req.Prompt = sb.String()
|
||||||
}
|
}
|
||||||
|
|
||||||
ch := make(chan string)
|
ch := make(chan any)
|
||||||
g, _ := errgroup.WithContext(c.Request.Context())
|
g, _ := errgroup.WithContext(c.Request.Context())
|
||||||
g.Go(func() error {
|
g.Go(func() error {
|
||||||
defer close(ch)
|
defer close(ch)
|
||||||
return llm.Predict(req.Prompt, func(s string) {
|
return llm.Predict(req.Prompt, func(s string) {
|
||||||
ch <- s
|
ch <- api.GenerateResponse{Response: s}
|
||||||
})
|
})
|
||||||
})
|
})
|
||||||
|
|
||||||
g.Go(func() error {
|
g.Go(func() error {
|
||||||
c.Stream(func(w io.Writer) bool {
|
stream(c, ch)
|
||||||
s, ok := <-ch
|
return nil
|
||||||
if !ok {
|
})
|
||||||
return false
|
|
||||||
}
|
|
||||||
|
|
||||||
bts, err := json.Marshal(api.GenerateResponse{Response: s})
|
if err := g.Wait(); err != nil && !errors.Is(err, io.EOF) {
|
||||||
if err != nil {
|
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
|
||||||
return false
|
return
|
||||||
}
|
}
|
||||||
|
}
|
||||||
|
|
||||||
bts = append(bts, '\n')
|
func pull(c *gin.Context) {
|
||||||
if _, err := w.Write(bts); err != nil {
|
var req api.PullRequest
|
||||||
return false
|
if err := c.ShouldBindJSON(&req); err != nil {
|
||||||
}
|
c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()})
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
return true
|
remote, err := getRemote(req.Model)
|
||||||
|
if err != nil {
|
||||||
|
c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()})
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
ch := make(chan any)
|
||||||
|
g, _ := errgroup.WithContext(c.Request.Context())
|
||||||
|
g.Go(func() error {
|
||||||
|
defer close(ch)
|
||||||
|
return saveModel(remote, func(total, completed int64) {
|
||||||
|
ch <- api.PullProgress{
|
||||||
|
Total: total,
|
||||||
|
Completed: completed,
|
||||||
|
Percent: float64(total) / float64(completed) * 100,
|
||||||
|
}
|
||||||
})
|
})
|
||||||
|
})
|
||||||
|
|
||||||
|
g.Go(func() error {
|
||||||
|
stream(c, ch)
|
||||||
return nil
|
return nil
|
||||||
})
|
})
|
||||||
|
|
||||||
|
@ -124,47 +143,7 @@ func Serve(ln net.Listener) error {
|
||||||
c.String(http.StatusOK, "Ollama is running")
|
c.String(http.StatusOK, "Ollama is running")
|
||||||
})
|
})
|
||||||
|
|
||||||
r.POST("api/pull", func(c *gin.Context) {
|
r.POST("api/pull", pull)
|
||||||
var req api.PullRequest
|
|
||||||
if err := c.ShouldBindJSON(&req); err != nil {
|
|
||||||
c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()})
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
progressCh := make(chan api.PullProgress)
|
|
||||||
go func() {
|
|
||||||
defer close(progressCh)
|
|
||||||
if err := pull(req.Model, progressCh); err != nil {
|
|
||||||
var opError *net.OpError
|
|
||||||
if errors.As(err, &opError) {
|
|
||||||
c.JSON(http.StatusBadGateway, gin.H{"error": err.Error()})
|
|
||||||
return
|
|
||||||
}
|
|
||||||
c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()})
|
|
||||||
return
|
|
||||||
}
|
|
||||||
}()
|
|
||||||
|
|
||||||
c.Stream(func(w io.Writer) bool {
|
|
||||||
progress, ok := <-progressCh
|
|
||||||
if !ok {
|
|
||||||
return false
|
|
||||||
}
|
|
||||||
|
|
||||||
bts, err := json.Marshal(progress)
|
|
||||||
if err != nil {
|
|
||||||
return false
|
|
||||||
}
|
|
||||||
|
|
||||||
bts = append(bts, '\n')
|
|
||||||
if _, err := w.Write(bts); err != nil {
|
|
||||||
return false
|
|
||||||
}
|
|
||||||
|
|
||||||
return true
|
|
||||||
})
|
|
||||||
})
|
|
||||||
|
|
||||||
r.POST("/api/generate", generate)
|
r.POST("/api/generate", generate)
|
||||||
|
|
||||||
log.Printf("Listening on %s", ln.Addr())
|
log.Printf("Listening on %s", ln.Addr())
|
||||||
|
@ -186,3 +165,24 @@ func matchRankOne(source string, targets []string) (bestMatch string, bestRank i
|
||||||
|
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func stream(c *gin.Context, ch chan any) {
|
||||||
|
c.Stream(func(w io.Writer) bool {
|
||||||
|
val, ok := <-ch
|
||||||
|
if !ok {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
|
||||||
|
bts, err := json.Marshal(val)
|
||||||
|
if err != nil {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
|
||||||
|
bts = append(bts, '\n')
|
||||||
|
if _, err := w.Write(bts); err != nil {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
|
||||||
|
return true
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
Loading…
Reference in a new issue