From 6137b12799c6d40b75131ab36187474858cd8e8c Mon Sep 17 00:00:00 2001 From: Michael Yang Date: Thu, 21 Sep 2023 09:50:52 -0700 Subject: [PATCH] validate existence and pull model using api --- cmd/cmd.go | 32 ++++++++++++-------------------- 1 file changed, 12 insertions(+), 20 deletions(-) diff --git a/cmd/cmd.go b/cmd/cmd.go index d5f8a157..e33f4ef3 100644 --- a/cmd/cmd.go +++ b/cmd/cmd.go @@ -11,7 +11,6 @@ import ( "io" "log" "net" - "net/http" "os" "os/exec" "path/filepath" @@ -108,35 +107,28 @@ func CreateHandler(cmd *cobra.Command, args []string) error { } func RunHandler(cmd *cobra.Command, args []string) error { - insecure, err := cmd.Flags().GetBool("insecure") + client, err := api.FromEnv() if err != nil { return err } - mp := server.ParseModelPath(args[0]) - if mp.ProtocolScheme == "http" && !insecure { - return fmt.Errorf("insecure protocol http") - } - - fp, err := mp.GetManifestPath(false) + models, err := client.List(context.Background()) if err != nil { return err } - _, err = os.Stat(fp) - switch { - case errors.Is(err, os.ErrNotExist): - if err := pull(args[0], insecure); err != nil { - var apiStatusError api.StatusError - if !errors.As(err, &apiStatusError) { - return err - } + modelName, modelTag, ok := strings.Cut(args[0], ":") + if !ok { + modelTag = "latest" + } - if apiStatusError.StatusCode != http.StatusBadGateway { - return err - } + for _, model := range models.Models { + if model.Name == strings.Join([]string{modelName, modelTag}, ":") { + return RunGenerate(cmd, args) } - case err != nil: + } + + if err := PullHandler(cmd, args); err != nil { return err }