diff --git a/cmd/cmd.go b/cmd/cmd.go index fad06ffd..01eb66f9 100644 --- a/cmd/cmd.go +++ b/cmd/cmd.go @@ -39,6 +39,7 @@ import ( "github.com/ollama/ollama/parser" "github.com/ollama/ollama/progress" "github.com/ollama/ollama/server" + "github.com/ollama/ollama/types/model" "github.com/ollama/ollama/version" ) @@ -558,6 +559,8 @@ func PushHandler(cmd *cobra.Command, args []string) error { } request := api.PushRequest{Name: args[0], Insecure: insecure} + + n := model.ParseName(args[0]) if err := client.Push(cmd.Context(), &request, fn); err != nil { if spinner != nil { spinner.Stop() @@ -568,7 +571,16 @@ func PushHandler(cmd *cobra.Command, args []string) error { return err } + p.Stop() spinner.Stop() + + destination := n.String() + if strings.HasSuffix(n.Host, ".ollama.ai") || strings.HasSuffix(n.Host, ".ollama.com") { + destination = "https://ollama.com/" + strings.TrimSuffix(n.DisplayShortest(), ":latest") + } + fmt.Printf("\nYou can find your model at:\n\n") + fmt.Printf("\t%s\n", destination) + return nil } diff --git a/cmd/cmd_test.go b/cmd/cmd_test.go index fd8289cf..2e6428cf 100644 --- a/cmd/cmd_test.go +++ b/cmd/cmd_test.go @@ -4,6 +4,7 @@ import ( "bytes" "context" "encoding/json" + "io" "net/http" "net/http/httptest" "os" @@ -369,3 +370,127 @@ func TestGetModelfileName(t *testing.T) { }) } } + +func TestPushHandler(t *testing.T) { + tests := []struct { + name string + modelName string + serverResponse map[string]func(w http.ResponseWriter, r *http.Request) + expectedError string + expectedOutput string + }{ + { + name: "successful push", + modelName: "test-model", + serverResponse: map[string]func(w http.ResponseWriter, r *http.Request){ + "/api/push": func(w http.ResponseWriter, r *http.Request) { + if r.Method != http.MethodPost { + t.Errorf("expected POST request, got %s", r.Method) + } + + var req api.PushRequest + if err := json.NewDecoder(r.Body).Decode(&req); err != nil { + http.Error(w, err.Error(), http.StatusBadRequest) + return + } + + if req.Name != "test-model" { + t.Errorf("expected model name 'test-model', got %s", req.Name) + } + + // Simulate progress updates + responses := []api.ProgressResponse{ + {Status: "preparing manifest"}, + {Digest: "sha256:abc123456789", Total: 100, Completed: 50}, + {Digest: "sha256:abc123456789", Total: 100, Completed: 100}, + } + + for _, resp := range responses { + if err := json.NewEncoder(w).Encode(resp); err != nil { + http.Error(w, err.Error(), http.StatusInternalServerError) + return + } + w.(http.Flusher).Flush() + } + }, + }, + expectedOutput: "\nYou can find your model at:\n\n\thttps://ollama.com/test-model\n", + }, + { + name: "unauthorized push", + modelName: "unauthorized-model", + serverResponse: map[string]func(w http.ResponseWriter, r *http.Request){ + "/api/push": func(w http.ResponseWriter, r *http.Request) { + w.Header().Set("Content-Type", "application/json") + w.WriteHeader(http.StatusUnauthorized) + err := json.NewEncoder(w).Encode(map[string]string{ + "error": "access denied", + }) + if err != nil { + t.Fatal(err) + } + }, + }, + expectedError: "you are not authorized to push to this namespace, create the model under a namespace you own", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + mockServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + if handler, ok := tt.serverResponse[r.URL.Path]; ok { + handler(w, r) + return + } + http.Error(w, "not found", http.StatusNotFound) + })) + defer mockServer.Close() + + t.Setenv("OLLAMA_HOST", mockServer.URL) + + cmd := &cobra.Command{} + cmd.Flags().Bool("insecure", false, "") + cmd.SetContext(context.TODO()) + + // Redirect stderr to capture progress output + oldStderr := os.Stderr + r, w, _ := os.Pipe() + os.Stderr = w + + // Capture stdout for the "Model pushed" message + oldStdout := os.Stdout + outR, outW, _ := os.Pipe() + os.Stdout = outW + + err := PushHandler(cmd, []string{tt.modelName}) + + // Restore stderr + w.Close() + os.Stderr = oldStderr + // drain the pipe + if _, err := io.ReadAll(r); err != nil { + t.Fatal(err) + } + + // Restore stdout and get output + outW.Close() + os.Stdout = oldStdout + stdout, _ := io.ReadAll(outR) + + if tt.expectedError == "" { + if err != nil { + t.Errorf("expected no error, got %v", err) + } + if tt.expectedOutput != "" { + if got := string(stdout); got != tt.expectedOutput { + t.Errorf("expected output %q, got %q", tt.expectedOutput, got) + } + } + } else { + if err == nil || !strings.Contains(err.Error(), tt.expectedError) { + t.Errorf("expected error containing %q, got %v", tt.expectedError, err) + } + } + }) + } +}