From f533f85d44e84124eddbcf4e4a1833bed0b04f96 Mon Sep 17 00:00:00 2001 From: Bruce MacDonald Date: Fri, 7 Jul 2023 17:12:02 -0400 Subject: [PATCH] pr feedback - move error check to api client pull - simplify error check in generate - return nil on any pull error --- api/client.go | 5 +++++ cmd/cmd.go | 5 ----- server/routes.go | 22 ++++++++-------------- 3 files changed, 13 insertions(+), 19 deletions(-) diff --git a/api/client.go b/api/client.go index f153f32e..65d36ecb 100644 --- a/api/client.go +++ b/api/client.go @@ -106,6 +106,11 @@ func (c *Client) Pull(ctx context.Context, req *PullRequest, fn PullProgressFunc return err } + if resp.Error.Message != "" { + // couldn't pull the model from the directory, proceed anyway + return nil + } + return fn(resp) }), ) diff --git a/cmd/cmd.go b/cmd/cmd.go index 1b71d3ad..e6c12367 100644 --- a/cmd/cmd.go +++ b/cmd/cmd.go @@ -7,7 +7,6 @@ import ( "fmt" "log" "net" - "net/http" "os" "path" "strings" @@ -51,10 +50,6 @@ func pull(model string) error { context.Background(), &api.PullRequest{Model: model}, func(progress api.PullProgress) error { - if progress.Error.Code == http.StatusBadGateway { - // couldn't pull the model from the directory, proceed in offline mode - return nil - } if bar == nil && progress.Percent == 100 { // already downloaded return nil diff --git a/server/routes.go b/server/routes.go index d1280286..684bfcf7 100644 --- a/server/routes.go +++ b/server/routes.go @@ -45,22 +45,16 @@ func generate(c *gin.Context) { return } - remoteModel, err := getRemote(req.Model) - if err != nil { - // couldn't check the directory, proceed in offline mode - _, err := os.Stat(req.Model) - if err != nil { - if !os.IsNotExist(err) { - c.JSON(http.StatusInternalServerError, gin.H{"message": err.Error()}) - return - } - // couldn't find the model file, try setting the model to the cache directory - req.Model = path.Join(cacheDir(), "models", req.Model+".bin") - } - } - if remoteModel != nil { + if remoteModel, _ := getRemote(req.Model); remoteModel != nil { req.Model = remoteModel.FullName() } + if _, err := os.Stat(req.Model); err != nil { + if !errors.Is(err, os.ErrNotExist) { + c.JSON(http.StatusBadRequest, gin.H{"message": err.Error()}) + return + } + req.Model = path.Join(cacheDir(), "models", req.Model+".bin") + } modelOpts := getModelOpts(req) modelOpts.NGPULayers = 1 // hard-code this for now