fix stream errors

This commit is contained in:
Michael Yang 2023-07-20 12:12:08 -07:00
parent 00aaa05901
commit 1f27d7f1b8
3 changed files with 13 additions and 15 deletions

View file

@ -131,6 +131,10 @@ func (c *Client) stream(ctx context.Context, method, path string, data any, fn f
return fmt.Errorf("unmarshal: %w", err)
}
if errorResponse.Error != "" {
return fmt.Errorf("stream: %s", errorResponse.Error)
}
if response.StatusCode >= 400 {
return StatusError{
StatusCode: response.StatusCode,

View file

@ -192,7 +192,6 @@ func CreateModel(name string, path string, fn func(status string)) error {
fn("parsing modelfile")
commands, err := parser.Parse(mf)
if err != nil {
fn(fmt.Sprintf("error: %v", err))
return err
}
@ -227,14 +226,12 @@ func CreateModel(name string, path string, fn func(status string)) error {
fn("creating model layer")
file, err := os.Open(fp)
if err != nil {
fn(fmt.Sprintf("couldn't find model '%s'", c.Args))
return fmt.Errorf("failed to open file: %v", err)
}
defer file.Close()
l, err := CreateLayer(file)
if err != nil {
fn(fmt.Sprintf("couldn't create model layer: %v", err))
return fmt.Errorf("failed to create layer: %v", err)
}
l.MediaType = "application/vnd.ollama.image.model"
@ -244,7 +241,6 @@ func CreateModel(name string, path string, fn func(status string)) error {
for _, l := range mf.Layers {
newLayer, err := GetLayerWithBufferFromLayer(l)
if err != nil {
fn(fmt.Sprintf("couldn't read layer: %v", err))
return err
}
layers = append(layers, newLayer)
@ -304,7 +300,6 @@ func CreateModel(name string, path string, fn func(status string)) error {
err = SaveLayers(layers, fn, false)
if err != nil {
fn(fmt.Sprintf("error saving layers: %v", err))
return err
}
@ -312,7 +307,6 @@ func CreateModel(name string, path string, fn func(status string)) error {
fn("writing manifest")
err = CreateManifest(name, cfg, manifestLayers)
if err != nil {
fn(fmt.Sprintf("error creating manifest: %v", err))
return err
}
@ -610,7 +604,6 @@ func PullModel(name, username, password string, fn func(api.ProgressResponse)) e
for _, layer := range layers {
if err := downloadBlob(mp, layer.Digest, username, password, fn); err != nil {
fn(api.ProgressResponse{Status: fmt.Sprintf("error downloading: %v", err), Digest: layer.Digest})
return err
}
}

View file

@ -60,7 +60,7 @@ func generate(c *gin.Context) {
ch := make(chan any)
go func() {
defer close(ch)
llm.Predict(req.Context, prompt, func(r api.GenerateResponse) {
fn := func(r api.GenerateResponse) {
r.Model = req.Model
r.CreatedAt = time.Now().UTC()
if r.Done {
@ -68,7 +68,11 @@ func generate(c *gin.Context) {
}
ch <- r
})
}
if err := llm.Predict(req.Context, prompt, fn); err != nil {
ch <- gin.H{"error": err.Error()}
}
}()
streamResponse(c, ch)
@ -89,8 +93,7 @@ func pull(c *gin.Context) {
}
if err := PullModel(req.Name, req.Username, req.Password, fn); err != nil {
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
return
ch <- gin.H{"error": err.Error()}
}
}()
@ -112,8 +115,7 @@ func push(c *gin.Context) {
}
if err := PushModel(req.Name, req.Username, req.Password, fn); err != nil {
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
return
ch <- gin.H{"error": err.Error()}
}
}()
@ -137,8 +139,7 @@ func create(c *gin.Context) {
}
if err := CreateModel(req.Name, req.Path, fn); err != nil {
c.JSON(http.StatusBadRequest, gin.H{"message": err.Error()})
return
ch <- gin.H{"error": err.Error()}
}
}()