Merge pull request #70 from jmorganca/offline-fixes

offline fixes
This commit is contained in:
Michael Yang 2023-07-11 15:50:19 -07:00 committed by GitHub
commit 7226980fb6
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
6 changed files with 143 additions and 144 deletions

View file

@ -10,6 +10,20 @@ import (
"net/url" "net/url"
) )
type StatusError struct {
StatusCode int
Status string
Message string
}
func (e StatusError) Error() string {
if e.Message != "" {
return fmt.Sprintf("%s: %s", e.Status, e.Message)
}
return e.Status
}
type Client struct { type Client struct {
base url.URL base url.URL
} }
@ -25,7 +39,7 @@ func NewClient(hosts ...string) *Client {
} }
} }
func (c *Client) stream(ctx context.Context, method, path string, data any, callback func([]byte) error) error { func (c *Client) stream(ctx context.Context, method, path string, data any, fn func([]byte) error) error {
var buf *bytes.Buffer var buf *bytes.Buffer
if data != nil { if data != nil {
bts, err := json.Marshal(data) bts, err := json.Marshal(data)
@ -53,7 +67,7 @@ func (c *Client) stream(ctx context.Context, method, path string, data any, call
scanner := bufio.NewScanner(response.Body) scanner := bufio.NewScanner(response.Body)
for scanner.Scan() { for scanner.Scan() {
var errorResponse struct { var errorResponse struct {
Error string `json:"error"` Error string `json:"error,omitempty"`
} }
bts := scanner.Bytes() bts := scanner.Bytes()
@ -61,11 +75,15 @@ func (c *Client) stream(ctx context.Context, method, path string, data any, call
return fmt.Errorf("unmarshal: %w", err) return fmt.Errorf("unmarshal: %w", err)
} }
if len(errorResponse.Error) > 0 { if response.StatusCode >= 400 {
return fmt.Errorf("stream: %s", errorResponse.Error) return StatusError{
StatusCode: response.StatusCode,
Status: response.Status,
Message: errorResponse.Error,
}
} }
if err := callback(bts); err != nil { if err := fn(bts); err != nil {
return err return err
} }
} }

View file

@ -7,6 +7,7 @@ import (
"fmt" "fmt"
"log" "log"
"net" "net"
"net/http"
"os" "os"
"path" "path"
"strings" "strings"
@ -34,7 +35,14 @@ func RunRun(cmd *cobra.Command, args []string) error {
switch { switch {
case errors.Is(err, os.ErrNotExist): case errors.Is(err, os.ErrNotExist):
if err := pull(args[0]); err != nil { if err := pull(args[0]); err != nil {
return err var apiStatusError api.StatusError
if !errors.As(err, &apiStatusError) {
return err
}
if apiStatusError.StatusCode != http.StatusBadGateway {
return err
}
} }
case err != nil: case err != nil:
return err return err
@ -50,11 +58,12 @@ func pull(model string) error {
context.Background(), context.Background(),
&api.PullRequest{Model: model}, &api.PullRequest{Model: model},
func(progress api.PullProgress) error { func(progress api.PullProgress) error {
if bar == nil && progress.Percent == 100 {
// already downloaded
return nil
}
if bar == nil { if bar == nil {
if progress.Percent == 100 {
// already downloaded
return nil
}
bar = progressbar.DefaultBytes(progress.Total) bar = progressbar.DefaultBytes(progress.Total)
} }

1
go.mod
View file

@ -39,7 +39,6 @@ require (
golang.org/x/arch v0.3.0 // indirect golang.org/x/arch v0.3.0 // indirect
golang.org/x/crypto v0.10.0 // indirect golang.org/x/crypto v0.10.0 // indirect
golang.org/x/net v0.10.0 // indirect golang.org/x/net v0.10.0 // indirect
golang.org/x/sync v0.3.0
golang.org/x/sys v0.10.0 // indirect golang.org/x/sys v0.10.0 // indirect
golang.org/x/term v0.10.0 golang.org/x/term v0.10.0
golang.org/x/text v0.10.0 // indirect golang.org/x/text v0.10.0 // indirect

2
go.sum
View file

@ -99,8 +99,6 @@ golang.org/x/net v0.10.0/go.mod h1:0qNGK6F8kojg2nk9dLZ2mShWaEBan6FAoqfSigmmuDg=
golang.org/x/sync v0.0.0-20190423024810-112230192c58/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= golang.org/x/sync v0.0.0-20190423024810-112230192c58/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM=
golang.org/x/sync v0.0.0-20220722155255-886fb9371eb4/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= golang.org/x/sync v0.0.0-20220722155255-886fb9371eb4/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM=
golang.org/x/sync v0.1.0/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= golang.org/x/sync v0.1.0/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM=
golang.org/x/sync v0.3.0 h1:ftCYgMx6zT/asHUrPw8BLLscYtGznsLAnjq5RH9P66E=
golang.org/x/sync v0.3.0/go.mod h1:FU7BRWz2tNW+3quACPkgCx/L+uEAv1htQ0V83Z9Rj+Y=
golang.org/x/sys v0.0.0-20190215142949-d0b11bdaac8a/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY= golang.org/x/sys v0.0.0-20190215142949-d0b11bdaac8a/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY=
golang.org/x/sys v0.0.0-20201119102817-f84b799fce68/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= golang.org/x/sys v0.0.0-20201119102817-f84b799fce68/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
golang.org/x/sys v0.0.0-20210615035016-665e8c7367d1/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.0.0-20210615035016-665e8c7367d1/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=

View file

@ -2,14 +2,13 @@ package server
import ( import (
"encoding/json" "encoding/json"
"errors"
"fmt" "fmt"
"io" "io"
"net/http" "net/http"
"os" "os"
"path" "path"
"strconv" "strconv"
"github.com/jmorganca/ollama/api"
) )
const directoryURL = "https://ollama.ai/api/models" const directoryURL = "https://ollama.ai/api/models"
@ -36,12 +35,12 @@ func (m *Model) FullName() string {
return path.Join(home, ".ollama", "models", m.Name+".bin") return path.Join(home, ".ollama", "models", m.Name+".bin")
} }
func pull(model string, progressCh chan<- api.PullProgress) error { func (m *Model) TempFile() string {
remote, err := getRemote(model) fullName := m.FullName()
if err != nil { return path.Join(
return fmt.Errorf("failed to pull model: %w", err) path.Dir(fullName),
} fmt.Sprintf(".%s.part", path.Base(fullName)),
return saveModel(remote, progressCh) )
} }
func getRemote(model string) (*Model, error) { func getRemote(model string) (*Model, error) {
@ -68,7 +67,7 @@ func getRemote(model string) (*Model, error) {
return nil, fmt.Errorf("model not found in directory: %s", model) return nil, fmt.Errorf("model not found in directory: %s", model)
} }
func saveModel(model *Model, progressCh chan<- api.PullProgress) error { func saveModel(model *Model, fn func(total, completed int64)) error {
// this models cache directory is created by the server on startup // this models cache directory is created by the server on startup
client := &http.Client{} client := &http.Client{}
@ -76,41 +75,45 @@ func saveModel(model *Model, progressCh chan<- api.PullProgress) error {
if err != nil { if err != nil {
return fmt.Errorf("failed to download model: %w", err) return fmt.Errorf("failed to download model: %w", err)
} }
// check for resume
alreadyDownloaded := int64(0) // check if completed file exists
fileInfo, err := os.Stat(model.FullName()) fi, err := os.Stat(model.FullName())
if err != nil { switch {
if !os.IsNotExist(err) { case errors.Is(err, os.ErrNotExist):
return fmt.Errorf("failed to check resume model file: %w", err) // noop, file doesn't exist so create it
} case err != nil:
// file doesn't exist, create it now return fmt.Errorf("stat: %w", err)
} else { default:
alreadyDownloaded = fileInfo.Size() fn(fi.Size(), fi.Size())
req.Header.Add("Range", fmt.Sprintf("bytes=%d-", alreadyDownloaded)) return nil
} }
var size int64
// completed file doesn't exist, check partial file
fi, err = os.Stat(model.TempFile())
switch {
case errors.Is(err, os.ErrNotExist):
// noop, file doesn't exist so create it
case err != nil:
return fmt.Errorf("stat: %w", err)
default:
size = fi.Size()
}
req.Header.Add("Range", fmt.Sprintf("bytes=%d-", size))
resp, err := client.Do(req) resp, err := client.Do(req)
if err != nil { if err != nil {
return fmt.Errorf("failed to download model: %w", err) return fmt.Errorf("failed to download model: %w", err)
} }
defer resp.Body.Close() defer resp.Body.Close()
if resp.StatusCode == http.StatusRequestedRangeNotSatisfiable { if resp.StatusCode >= 400 {
// already downloaded
progressCh <- api.PullProgress{
Total: alreadyDownloaded,
Completed: alreadyDownloaded,
Percent: 100,
}
return nil
}
if resp.StatusCode != http.StatusOK && resp.StatusCode != http.StatusPartialContent {
return fmt.Errorf("failed to download model: %s", resp.Status) return fmt.Errorf("failed to download model: %s", resp.Status)
} }
out, err := os.OpenFile(model.FullName(), os.O_CREATE|os.O_APPEND|os.O_WRONLY, 0o644) out, err := os.OpenFile(model.TempFile(), os.O_CREATE|os.O_APPEND|os.O_WRONLY, 0o644)
if err != nil { if err != nil {
panic(err) panic(err)
} }
@ -118,37 +121,23 @@ func saveModel(model *Model, progressCh chan<- api.PullProgress) error {
totalSize, _ := strconv.ParseInt(resp.Header.Get("Content-Length"), 10, 64) totalSize, _ := strconv.ParseInt(resp.Header.Get("Content-Length"), 10, 64)
buf := make([]byte, 1024) totalBytes := size
totalBytes := alreadyDownloaded totalSize += size
totalSize += alreadyDownloaded
for { for {
n, err := resp.Body.Read(buf) n, err := io.CopyN(out, resp.Body, 8192)
if err != nil && err != io.EOF { if err != nil && !errors.Is(err, io.EOF) {
return err return err
} }
if n == 0 { if n == 0 {
break break
} }
if _, err := out.Write(buf[:n]); err != nil {
return err
}
totalBytes += int64(n) totalBytes += n
fn(totalSize, totalBytes)
// send progress updates
progressCh <- api.PullProgress{
Total: totalSize,
Completed: totalBytes,
Percent: float64(totalBytes) / float64(totalSize) * 100,
}
} }
progressCh <- api.PullProgress{ fn(totalSize, totalSize)
Total: totalSize, return os.Rename(model.TempFile(), model.FullName())
Completed: totalSize,
Percent: 100,
}
return nil
} }

View file

@ -16,7 +16,6 @@ import (
"github.com/gin-gonic/gin" "github.com/gin-gonic/gin"
"github.com/lithammer/fuzzysearch/fuzzy" "github.com/lithammer/fuzzysearch/fuzzy"
"golang.org/x/sync/errgroup"
"github.com/jmorganca/ollama/api" "github.com/jmorganca/ollama/api"
"github.com/jmorganca/ollama/llama" "github.com/jmorganca/ollama/llama"
@ -56,12 +55,8 @@ func generate(c *gin.Context) {
req.Model = path.Join(cacheDir(), "models", req.Model+".bin") req.Model = path.Join(cacheDir(), "models", req.Model+".bin")
} }
llm, err := llama.New(req.Model, req.Options) ch := make(chan any)
if err != nil { go stream(c, ch)
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
return
}
defer llm.Close()
templateNames := make([]string, 0, len(templates.Templates())) templateNames := make([]string, 0, len(templates.Templates()))
for _, template := range templates.Templates() { for _, template := range templates.Templates() {
@ -79,39 +74,49 @@ func generate(c *gin.Context) {
req.Prompt = sb.String() req.Prompt = sb.String()
} }
ch := make(chan string) llm, err := llama.New(req.Model, req.Options)
g, _ := errgroup.WithContext(c.Request.Context()) if err != nil {
g.Go(func() error { c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
defer close(ch) return
return llm.Predict(req.Prompt, func(s string) { }
ch <- s defer llm.Close()
})
})
g.Go(func() error { fn := func(s string) {
c.Stream(func(w io.Writer) bool { ch <- api.GenerateResponse{Response: s}
s, ok := <-ch }
if !ok {
return false
}
bts, err := json.Marshal(api.GenerateResponse{Response: s}) if err := llm.Predict(req.Prompt, fn); err != nil {
if err != nil { c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
return false return
} }
bts = append(bts, '\n') }
if _, err := w.Write(bts); err != nil {
return false
}
return true func pull(c *gin.Context) {
}) var req api.PullRequest
if err := c.ShouldBindJSON(&req); err != nil {
c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()})
return
}
return nil remote, err := getRemote(req.Model)
}) if err != nil {
c.JSON(http.StatusBadGateway, gin.H{"error": err.Error()})
return
}
if err := g.Wait(); err != nil && !errors.Is(err, io.EOF) { ch := make(chan any)
go stream(c, ch)
fn := func(total, completed int64) {
ch <- api.PullProgress{
Total: total,
Completed: completed,
Percent: float64(total) / float64(completed) * 100,
}
}
if err := saveModel(remote, fn); err != nil {
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()}) c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
return return
} }
@ -124,47 +129,7 @@ func Serve(ln net.Listener) error {
c.String(http.StatusOK, "Ollama is running") c.String(http.StatusOK, "Ollama is running")
}) })
r.POST("api/pull", func(c *gin.Context) { r.POST("api/pull", pull)
var req api.PullRequest
if err := c.ShouldBindJSON(&req); err != nil {
c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()})
return
}
progressCh := make(chan api.PullProgress)
go func() {
defer close(progressCh)
if err := pull(req.Model, progressCh); err != nil {
var opError *net.OpError
if errors.As(err, &opError) {
c.JSON(http.StatusBadGateway, gin.H{"error": err.Error()})
return
}
c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()})
return
}
}()
c.Stream(func(w io.Writer) bool {
progress, ok := <-progressCh
if !ok {
return false
}
bts, err := json.Marshal(progress)
if err != nil {
return false
}
bts = append(bts, '\n')
if _, err := w.Write(bts); err != nil {
return false
}
return true
})
})
r.POST("/api/generate", generate) r.POST("/api/generate", generate)
log.Printf("Listening on %s", ln.Addr()) log.Printf("Listening on %s", ln.Addr())
@ -186,3 +151,24 @@ func matchRankOne(source string, targets []string) (bestMatch string, bestRank i
return return
} }
func stream(c *gin.Context, ch chan any) {
c.Stream(func(w io.Writer) bool {
val, ok := <-ch
if !ok {
return false
}
bts, err := json.Marshal(val)
if err != nil {
return false
}
bts = append(bts, '\n')
if _, err := w.Write(bts); err != nil {
return false
}
return true
})
}