diff --git a/api/types.go b/api/types.go index 84dc3731..79bc2c24 100644 --- a/api/types.go +++ b/api/types.go @@ -26,6 +26,7 @@ type PullProgress struct { Total int64 `json:"total"` Completed int64 `json:"completed"` Percent float64 `json:"percent"` + Error Error `json:"error"` } type GenerateRequest struct { diff --git a/cmd/cmd.go b/cmd/cmd.go index e6c12367..1b71d3ad 100644 --- a/cmd/cmd.go +++ b/cmd/cmd.go @@ -7,6 +7,7 @@ import ( "fmt" "log" "net" + "net/http" "os" "path" "strings" @@ -50,6 +51,10 @@ func pull(model string) error { context.Background(), &api.PullRequest{Model: model}, func(progress api.PullProgress) error { + if progress.Error.Code == http.StatusBadGateway { + // couldn't pull the model from the directory, proceed in offline mode + return nil + } if bar == nil && progress.Percent == 100 { // already downloaded return nil diff --git a/server/routes.go b/server/routes.go index 21a23785..d1280286 100644 --- a/server/routes.go +++ b/server/routes.go @@ -3,12 +3,14 @@ package server import ( "embed" "encoding/json" + "errors" "fmt" "io" "log" "math" "net" "net/http" + "os" "path" "runtime" "strings" @@ -25,6 +27,15 @@ import ( var templatesFS embed.FS var templates = template.Must(template.ParseFS(templatesFS, "templates/*.prompt")) +func cacheDir() string { + home, err := os.UserHomeDir() + if err != nil { + panic(err) + } + + return path.Join(home, ".ollama") +} + func generate(c *gin.Context) { var req api.GenerateRequest req.ModelOptions = api.DefaultModelOptions @@ -34,12 +45,25 @@ func generate(c *gin.Context) { return } - if remoteModel, _ := getRemote(req.Model); remoteModel != nil { + remoteModel, err := getRemote(req.Model) + if err != nil { + // couldn't check the directory, proceed in offline mode + _, err := os.Stat(req.Model) + if err != nil { + if !os.IsNotExist(err) { + c.JSON(http.StatusInternalServerError, gin.H{"message": err.Error()}) + return + } + // couldn't find the model file, try setting the model to the cache directory + req.Model = path.Join(cacheDir(), "models", req.Model+".bin") + } + } + if remoteModel != nil { req.Model = remoteModel.FullName() } modelOpts := getModelOpts(req) - modelOpts.NGPULayers = 1 // hard-code this for now + modelOpts.NGPULayers = 1 // hard-code this for now model, err := llama.New(req.Model, modelOpts) if err != nil { @@ -118,6 +142,17 @@ func Serve(ln net.Listener) error { go func() { defer close(progressCh) if err := pull(req.Model, progressCh); err != nil { + var opError *net.OpError + if errors.As(err, &opError) { + result := api.PullProgress{ + Error: api.Error{ + Code: http.StatusBadGateway, + Message: "failed to get models from directory", + }, + } + c.JSON(http.StatusBadGateway, result) + return + } c.JSON(http.StatusBadRequest, gin.H{"message": err.Error()}) return }