From 61dd87bd907162804263472c92fbcd2a7335421b Mon Sep 17 00:00:00 2001 From: Bruce MacDonald Date: Fri, 7 Jul 2023 15:27:43 -0400 Subject: [PATCH 1/2] if directory cannot be resolved, do not fail --- api/types.go | 1 + cmd/cmd.go | 5 +++++ server/routes.go | 39 +++++++++++++++++++++++++++++++++++++-- 3 files changed, 43 insertions(+), 2 deletions(-) diff --git a/api/types.go b/api/types.go index 84dc3731..79bc2c24 100644 --- a/api/types.go +++ b/api/types.go @@ -26,6 +26,7 @@ type PullProgress struct { Total int64 `json:"total"` Completed int64 `json:"completed"` Percent float64 `json:"percent"` + Error Error `json:"error"` } type GenerateRequest struct { diff --git a/cmd/cmd.go b/cmd/cmd.go index e6c12367..1b71d3ad 100644 --- a/cmd/cmd.go +++ b/cmd/cmd.go @@ -7,6 +7,7 @@ import ( "fmt" "log" "net" + "net/http" "os" "path" "strings" @@ -50,6 +51,10 @@ func pull(model string) error { context.Background(), &api.PullRequest{Model: model}, func(progress api.PullProgress) error { + if progress.Error.Code == http.StatusBadGateway { + // couldn't pull the model from the directory, proceed in offline mode + return nil + } if bar == nil && progress.Percent == 100 { // already downloaded return nil diff --git a/server/routes.go b/server/routes.go index 21a23785..d1280286 100644 --- a/server/routes.go +++ b/server/routes.go @@ -3,12 +3,14 @@ package server import ( "embed" "encoding/json" + "errors" "fmt" "io" "log" "math" "net" "net/http" + "os" "path" "runtime" "strings" @@ -25,6 +27,15 @@ import ( var templatesFS embed.FS var templates = template.Must(template.ParseFS(templatesFS, "templates/*.prompt")) +func cacheDir() string { + home, err := os.UserHomeDir() + if err != nil { + panic(err) + } + + return path.Join(home, ".ollama") +} + func generate(c *gin.Context) { var req api.GenerateRequest req.ModelOptions = api.DefaultModelOptions @@ -34,12 +45,25 @@ func generate(c *gin.Context) { return } - if remoteModel, _ := getRemote(req.Model); remoteModel != nil { + remoteModel, err := getRemote(req.Model) + if err != nil { + // couldn't check the directory, proceed in offline mode + _, err := os.Stat(req.Model) + if err != nil { + if !os.IsNotExist(err) { + c.JSON(http.StatusInternalServerError, gin.H{"message": err.Error()}) + return + } + // couldn't find the model file, try setting the model to the cache directory + req.Model = path.Join(cacheDir(), "models", req.Model+".bin") + } + } + if remoteModel != nil { req.Model = remoteModel.FullName() } modelOpts := getModelOpts(req) - modelOpts.NGPULayers = 1 // hard-code this for now + modelOpts.NGPULayers = 1 // hard-code this for now model, err := llama.New(req.Model, modelOpts) if err != nil { @@ -118,6 +142,17 @@ func Serve(ln net.Listener) error { go func() { defer close(progressCh) if err := pull(req.Model, progressCh); err != nil { + var opError *net.OpError + if errors.As(err, &opError) { + result := api.PullProgress{ + Error: api.Error{ + Code: http.StatusBadGateway, + Message: "failed to get models from directory", + }, + } + c.JSON(http.StatusBadGateway, result) + return + } c.JSON(http.StatusBadRequest, gin.H{"message": err.Error()}) return } From f533f85d44e84124eddbcf4e4a1833bed0b04f96 Mon Sep 17 00:00:00 2001 From: Bruce MacDonald Date: Fri, 7 Jul 2023 17:12:02 -0400 Subject: [PATCH 2/2] pr feedback - move error check to api client pull - simplify error check in generate - return nil on any pull error --- api/client.go | 5 +++++ cmd/cmd.go | 5 ----- server/routes.go | 22 ++++++++-------------- 3 files changed, 13 insertions(+), 19 deletions(-) diff --git a/api/client.go b/api/client.go index f153f32e..65d36ecb 100644 --- a/api/client.go +++ b/api/client.go @@ -106,6 +106,11 @@ func (c *Client) Pull(ctx context.Context, req *PullRequest, fn PullProgressFunc return err } + if resp.Error.Message != "" { + // couldn't pull the model from the directory, proceed anyway + return nil + } + return fn(resp) }), ) diff --git a/cmd/cmd.go b/cmd/cmd.go index 1b71d3ad..e6c12367 100644 --- a/cmd/cmd.go +++ b/cmd/cmd.go @@ -7,7 +7,6 @@ import ( "fmt" "log" "net" - "net/http" "os" "path" "strings" @@ -51,10 +50,6 @@ func pull(model string) error { context.Background(), &api.PullRequest{Model: model}, func(progress api.PullProgress) error { - if progress.Error.Code == http.StatusBadGateway { - // couldn't pull the model from the directory, proceed in offline mode - return nil - } if bar == nil && progress.Percent == 100 { // already downloaded return nil diff --git a/server/routes.go b/server/routes.go index d1280286..684bfcf7 100644 --- a/server/routes.go +++ b/server/routes.go @@ -45,22 +45,16 @@ func generate(c *gin.Context) { return } - remoteModel, err := getRemote(req.Model) - if err != nil { - // couldn't check the directory, proceed in offline mode - _, err := os.Stat(req.Model) - if err != nil { - if !os.IsNotExist(err) { - c.JSON(http.StatusInternalServerError, gin.H{"message": err.Error()}) - return - } - // couldn't find the model file, try setting the model to the cache directory - req.Model = path.Join(cacheDir(), "models", req.Model+".bin") - } - } - if remoteModel != nil { + if remoteModel, _ := getRemote(req.Model); remoteModel != nil { req.Model = remoteModel.FullName() } + if _, err := os.Stat(req.Model); err != nil { + if !errors.Is(err, os.ErrNotExist) { + c.JSON(http.StatusBadRequest, gin.H{"message": err.Error()}) + return + } + req.Model = path.Join(cacheDir(), "models", req.Model+".bin") + } modelOpts := getModelOpts(req) modelOpts.NGPULayers = 1 // hard-code this for now