pr feedback

- move error check to api client pull
- simplify error check in generate
- return nil on any pull error
This commit is contained in:
Bruce MacDonald 2023-07-07 17:12:02 -04:00
parent 61dd87bd90
commit f533f85d44
3 changed files with 13 additions and 19 deletions

View file

@ -106,6 +106,11 @@ func (c *Client) Pull(ctx context.Context, req *PullRequest, fn PullProgressFunc
return err return err
} }
if resp.Error.Message != "" {
// couldn't pull the model from the directory, proceed anyway
return nil
}
return fn(resp) return fn(resp)
}), }),
) )

View file

@ -7,7 +7,6 @@ import (
"fmt" "fmt"
"log" "log"
"net" "net"
"net/http"
"os" "os"
"path" "path"
"strings" "strings"
@ -51,10 +50,6 @@ func pull(model string) error {
context.Background(), context.Background(),
&api.PullRequest{Model: model}, &api.PullRequest{Model: model},
func(progress api.PullProgress) error { func(progress api.PullProgress) error {
if progress.Error.Code == http.StatusBadGateway {
// couldn't pull the model from the directory, proceed in offline mode
return nil
}
if bar == nil && progress.Percent == 100 { if bar == nil && progress.Percent == 100 {
// already downloaded // already downloaded
return nil return nil

View file

@ -45,22 +45,16 @@ func generate(c *gin.Context) {
return return
} }
remoteModel, err := getRemote(req.Model) if remoteModel, _ := getRemote(req.Model); remoteModel != nil {
if err != nil {
// couldn't check the directory, proceed in offline mode
_, err := os.Stat(req.Model)
if err != nil {
if !os.IsNotExist(err) {
c.JSON(http.StatusInternalServerError, gin.H{"message": err.Error()})
return
}
// couldn't find the model file, try setting the model to the cache directory
req.Model = path.Join(cacheDir(), "models", req.Model+".bin")
}
}
if remoteModel != nil {
req.Model = remoteModel.FullName() req.Model = remoteModel.FullName()
} }
if _, err := os.Stat(req.Model); err != nil {
if !errors.Is(err, os.ErrNotExist) {
c.JSON(http.StatusBadRequest, gin.H{"message": err.Error()})
return
}
req.Model = path.Join(cacheDir(), "models", req.Model+".bin")
}
modelOpts := getModelOpts(req) modelOpts := getModelOpts(req)
modelOpts.NGPULayers = 1 // hard-code this for now modelOpts.NGPULayers = 1 // hard-code this for now