diff --git a/cmd/cmd.go b/cmd/cmd.go index a9b18b06..57af8707 100644 --- a/cmd/cmd.go +++ b/cmd/cmd.go @@ -43,27 +43,23 @@ func RunRun(cmd *cobra.Command, args []string) error { } func pull(model string) error { - // TODO: check if the local model is up to date with remote - _, err := os.Stat(cacheDir() + "/models/" + model + ".bin") - switch { - case errors.Is(err, os.ErrNotExist): - client := api.NewClient() - var bar *progressbar.ProgressBar - return client.Pull( - context.Background(), - &api.PullRequest{Model: model}, - func(progress api.PullProgress) error { - if bar == nil { - bar = progressbar.DefaultBytes(progress.Total) - } + client := api.NewClient() + var bar *progressbar.ProgressBar + return client.Pull( + context.Background(), + &api.PullRequest{Model: model}, + func(progress api.PullProgress) error { + if bar == nil && progress.Percent == 100 { + // already downloaded + return nil + } + if bar == nil { + bar = progressbar.DefaultBytes(progress.Total) + } - return bar.Set64(progress.Completed) - }, - ) - case err != nil: - return err - } - return nil + return bar.Set64(progress.Completed) + }, + ) } func RunGenerate(_ *cobra.Command, args []string) error {