diff --git a/cmd/cmd.go b/cmd/cmd.go index 4d3a48c7..a44ed5c6 100644 --- a/cmd/cmd.go +++ b/cmd/cmd.go @@ -46,28 +46,58 @@ import ( "github.com/ollama/ollama/version" ) +var ( + errModelNotFound = errors.New("no Modelfile or safetensors files found") + errModelfileNotFound = errors.New("specified Modelfile wasn't found") +) + +func getModelfileName(cmd *cobra.Command) (string, error) { + fn, _ := cmd.Flags().GetString("file") + + filename := fn + if filename == "" { + filename = "Modelfile" + } + + absName, err := filepath.Abs(filename) + if err != nil { + return "", err + } + + _, err = os.Stat(absName) + if err != nil { + return fn, err + } + + return absName, nil +} + func CreateHandler(cmd *cobra.Command, args []string) error { - filename, _ := cmd.Flags().GetString("file") - filename, err := filepath.Abs(filename) - if err != nil { - return err - } - - client, err := api.ClientFromEnvironment() - if err != nil { - return err - } - p := progress.NewProgress(os.Stderr) defer p.Stop() - f, err := os.Open(filename) - if err != nil { - return err - } - defer f.Close() + var reader io.Reader - modelfile, err := parser.ParseFile(f) + filename, err := getModelfileName(cmd) + if os.IsNotExist(err) { + if filename == "" { + reader = strings.NewReader("FROM .\n") + } else { + return errModelfileNotFound + } + } else if err != nil { + return err + } else { + f, err := os.Open(filename) + if err != nil { + return err + } + + reader = f + defer f.Close() + } + + modelfile, err := parser.ParseFile(reader) if err != nil { return err } @@ -82,6 +112,11 @@ func CreateHandler(cmd *cobra.Command, args []string) error { p.Add(status, spinner) defer p.Stop() + client, err := api.ClientFromEnvironment() + if err != nil { + return err + } + for i := range modelfile.Commands { switch modelfile.Commands[i].Name { case "model", "adapter": @@ -220,7 +255,7 @@ func tempZipFiles(path string) (string, error) { // covers consolidated.x.pth, consolidated.pth files = append(files, pt...) } else { - return "", errors.New("no safetensors or torch files found") + return "", errModelNotFound } // add configuration files, json files are detected as text/plain @@ -1315,7 +1350,7 @@ func NewCLI() *cobra.Command { RunE: CreateHandler, } - createCmd.Flags().StringP("file", "f", "Modelfile", "Name of the Modelfile") + createCmd.Flags().StringP("file", "f", "", "Name of the Modelfile (default \"Modelfile\"") createCmd.Flags().StringP("quantize", "q", "", "Quantize model to this level (e.g. q4_0)") showCmd := &cobra.Command{ diff --git a/cmd/cmd_test.go b/cmd/cmd_test.go index 9d23f3e9..fd8289cf 100644 --- a/cmd/cmd_test.go +++ b/cmd/cmd_test.go @@ -270,3 +270,102 @@ func TestDeleteHandler(t *testing.T) { t.Fatalf("DeleteHandler failed: expected error about stopping non-existent model, got %v", err) } } + +func TestGetModelfileName(t *testing.T) { + tests := []struct { + name string + modelfileName string + fileExists bool + expectedName string + expectedErr error + }{ + { + name: "no modelfile specified, no modelfile exists", + modelfileName: "", + fileExists: false, + expectedName: "", + expectedErr: os.ErrNotExist, + }, + { + name: "no modelfile specified, modelfile exists", + modelfileName: "", + fileExists: true, + expectedName: "Modelfile", + expectedErr: nil, + }, + { + name: "modelfile specified, no modelfile exists", + modelfileName: "crazyfile", + fileExists: false, + expectedName: "crazyfile", + expectedErr: os.ErrNotExist, + }, + { + name: "modelfile specified, modelfile exists", + modelfileName: "anotherfile", + fileExists: true, + expectedName: "anotherfile", + expectedErr: nil, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + cmd := &cobra.Command{ + Use: "fakecmd", + } + cmd.Flags().String("file", "", "path to modelfile") + + var expectedFilename string + + if tt.fileExists { + tempDir, err := os.MkdirTemp("", "modelfiledir") + defer os.RemoveAll(tempDir) + if err != nil { + t.Fatalf("temp modelfile dir creation failed: %v", err) + } + var fn string + if tt.modelfileName != "" { + fn = tt.modelfileName + } else { + fn = "Modelfile" + } + + tempFile, err := os.CreateTemp(tempDir, fn) + if err != nil { + t.Fatalf("temp modelfile creation failed: %v", err) + } + + expectedFilename = tempFile.Name() + err = cmd.Flags().Set("file", expectedFilename) + if err != nil { + t.Fatalf("couldn't set file flag: %v", err) + } + } else { + if tt.modelfileName != "" { + expectedFilename = tt.modelfileName + err := cmd.Flags().Set("file", tt.modelfileName) + if err != nil { + t.Fatalf("couldn't set file flag: %v", err) + } + } + } + + actualFilename, actualErr := getModelfileName(cmd) + + if actualFilename != expectedFilename { + t.Errorf("expected filename: '%s' actual filename: '%s'", expectedFilename, actualFilename) + } + + if tt.expectedErr != os.ErrNotExist { + if actualErr != tt.expectedErr { + t.Errorf("expected err: %v actual err: %v", tt.expectedErr, actualErr) + } + } else { + if !os.IsNotExist(actualErr) { + t.Errorf("expected err: %v actual err: %v", tt.expectedErr, actualErr) + } + } + }) + } +}