diff --git a/auth/auth.go b/auth/auth.go index ca64670d..026b2a2c 100644 --- a/auth/auth.go +++ b/auth/auth.go @@ -10,12 +10,44 @@ import ( "log/slog" "os" "path/filepath" + "strings" "golang.org/x/crypto/ssh" ) const defaultPrivateKey = "id_ed25519" +func keyPath() (string, error) { + home, err := os.UserHomeDir() + if err != nil { + return "", err + } + + return filepath.Join(home, ".ollama", defaultPrivateKey), nil +} + +func GetPublicKey() (string, error) { + keyPath, err := keyPath() + if err != nil { + return "", err + } + + privateKeyFile, err := os.ReadFile(keyPath) + if err != nil { + slog.Info(fmt.Sprintf("Failed to load private key: %v", err)) + return "", err + } + + privateKey, err := ssh.ParsePrivateKey(privateKeyFile) + if err != nil { + return "", err + } + + publicKey := ssh.MarshalAuthorizedKey(privateKey.PublicKey()) + + return strings.TrimSpace(string(publicKey)), nil +} + func NewNonce(r io.Reader, length int) (string, error) { nonce := make([]byte, length) if _, err := io.ReadFull(r, nonce); err != nil { @@ -26,13 +58,11 @@ func NewNonce(r io.Reader, length int) (string, error) { } func Sign(ctx context.Context, bts []byte) (string, error) { - home, err := os.UserHomeDir() + keyPath, err := keyPath() if err != nil { return "", err } - keyPath := filepath.Join(home, ".ollama", defaultPrivateKey) - privateKeyFile, err := os.ReadFile(keyPath) if err != nil { slog.Info(fmt.Sprintf("Failed to load private key: %v", err)) diff --git a/cmd/cmd.go b/cmd/cmd.go index a1eb8eba..2315ad1a 100644 --- a/cmd/cmd.go +++ b/cmd/cmd.go @@ -32,10 +32,13 @@ import ( "golang.org/x/term" "github.com/ollama/ollama/api" + "github.com/ollama/ollama/auth" "github.com/ollama/ollama/format" "github.com/ollama/ollama/parser" "github.com/ollama/ollama/progress" "github.com/ollama/ollama/server" + "github.com/ollama/ollama/types/errtypes" + "github.com/ollama/ollama/types/model" "github.com/ollama/ollama/version" ) @@ -357,6 +360,47 @@ func RunHandler(cmd *cobra.Command, args []string) error { return generateInteractive(cmd, opts) } +func errFromUnknownKey(unknownKeyErr error) error { + // find SSH public key in the error message + sshKeyPattern := `ssh-\w+ [^\s"]+` + re := regexp.MustCompile(sshKeyPattern) + matches := re.FindStringSubmatch(unknownKeyErr.Error()) + + if len(matches) > 0 { + serverPubKey := matches[0] + + localPubKey, err := auth.GetPublicKey() + if err != nil { + return unknownKeyErr + } + + if runtime.GOOS == "linux" && serverPubKey != localPubKey { + // try the ollama service public key + svcPubKey, err := os.ReadFile("/usr/share/ollama/.ollama/id_ed25519.pub") + if err != nil { + return unknownKeyErr + } + localPubKey = strings.TrimSpace(string(svcPubKey)) + } + + // check if the returned public key matches the local public key, this prevents adding a remote key to the user's account + if serverPubKey != localPubKey { + return unknownKeyErr + } + + var msg strings.Builder + msg.WriteString(unknownKeyErr.Error()) + msg.WriteString("\n\nYour ollama key is:\n") + msg.WriteString(localPubKey) + msg.WriteString("\nAdd your key at:\n") + msg.WriteString("https://ollama.com/settings/keys") + + return errors.New(msg.String()) + } + + return unknownKeyErr +} + func PushHandler(cmd *cobra.Command, args []string) error { client, err := api.ClientFromEnvironment() if err != nil { @@ -404,6 +448,20 @@ func PushHandler(cmd *cobra.Command, args []string) error { request := api.PushRequest{Name: args[0], Insecure: insecure} if err := client.Push(cmd.Context(), &request, fn); err != nil { + if spinner != nil { + spinner.Stop() + } + if strings.Contains(err.Error(), "access denied") { + return errors.New("you are not authorized to push to this namespace, create the model under a namespace you own") + } + host := model.ParseName(args[0]).Host + isOllamaHost := strings.HasSuffix(host, ".ollama.ai") || strings.HasSuffix(host, ".ollama.com") + if strings.Contains(err.Error(), errtypes.UnknownOllamaKeyErrMsg) && isOllamaHost { + // the user has not added their ollama key to ollama.com + // re-throw an error with a more user-friendly message + return errFromUnknownKey(err) + } + return err } diff --git a/server/images.go b/server/images.go index 7b2199c7..4e4107f7 100644 --- a/server/images.go +++ b/server/images.go @@ -5,6 +5,7 @@ import ( "bytes" "context" "crypto/sha256" + "encoding/base64" "encoding/hex" "encoding/json" "errors" @@ -25,10 +26,12 @@ import ( "golang.org/x/exp/slices" "github.com/ollama/ollama/api" + "github.com/ollama/ollama/auth" "github.com/ollama/ollama/convert" "github.com/ollama/ollama/format" "github.com/ollama/ollama/llm" "github.com/ollama/ollama/parser" + "github.com/ollama/ollama/types/errtypes" "github.com/ollama/ollama/types/model" "github.com/ollama/ollama/version" ) @@ -980,9 +983,6 @@ func PushModel(ctx context.Context, name string, regOpts *registryOptions, fn fu for _, layer := range layers { if err := uploadBlob(ctx, mp, layer, regOpts, fn); err != nil { slog.Info(fmt.Sprintf("error uploading blob: %v", err)) - if errors.Is(err, errUnauthorized) { - return fmt.Errorf("unable to push %s, make sure this namespace exists and you are authorized to push to it", ParseModelPath(name).GetNamespaceRepository()) - } return err } } @@ -1145,9 +1145,40 @@ func GetSHA256Digest(r io.Reader) (string, int64) { return fmt.Sprintf("sha256:%x", h.Sum(nil)), n } -var errUnauthorized = errors.New("unauthorized") +var errUnauthorized = fmt.Errorf("unauthorized: access denied") + +// getTokenSubject returns the subject of a JWT token, it does not validate the token +func getTokenSubject(token string) string { + parts := strings.Split(token, ".") + if len(parts) != 3 { + slog.Error("jwt token does not contain 3 parts") + return "" + } + + payload := parts[1] + payloadBytes, err := base64.RawURLEncoding.DecodeString(payload) + if err != nil { + slog.Error(fmt.Sprintf("failed to decode jwt payload: %v", err)) + return "" + } + + var payloadMap map[string]interface{} + if err := json.Unmarshal(payloadBytes, &payloadMap); err != nil { + slog.Error(fmt.Sprintf("failed to unmarshal payload JSON: %v", err)) + return "" + } + + sub, ok := payloadMap["sub"] + if !ok { + slog.Error("jwt does not contain 'sub' field") + return "" + } + + return fmt.Sprintf("%s", sub) +} func makeRequestWithRetry(ctx context.Context, method string, requestURL *url.URL, headers http.Header, body io.ReadSeeker, regOpts *registryOptions) (*http.Response, error) { + anonymous := true // access will default to anonymous if no user is found associated with the public key for i := 0; i < 2; i++ { resp, err := makeRequest(ctx, method, requestURL, headers, body, regOpts) if err != nil { @@ -1166,6 +1197,7 @@ func makeRequestWithRetry(ctx context.Context, method string, requestURL *url.UR if err != nil { return nil, err } + anonymous = getTokenSubject(token) == "anonymous" regOpts.Token = token if body != nil { _, err = body.Seek(0, io.SeekStart) @@ -1186,6 +1218,16 @@ func makeRequestWithRetry(ctx context.Context, method string, requestURL *url.UR } } + if anonymous { + // no user is associated with the public key, and the request requires non-anonymous access + pubKey, nestedErr := auth.GetPublicKey() + if nestedErr != nil { + slog.Error(fmt.Sprintf("couldn't get public key: %v", nestedErr)) + return nil, errUnauthorized + } + return nil, &errtypes.UnknownOllamaKey{Key: pubKey} + } + // user is associated with the public key, but is not authorized to make the request return nil, errUnauthorized } diff --git a/types/errtypes/errtypes.go b/types/errtypes/errtypes.go new file mode 100644 index 00000000..e3a18d0b --- /dev/null +++ b/types/errtypes/errtypes.go @@ -0,0 +1,18 @@ +// Package errtypes contains custom error types +package errtypes + +import ( + "fmt" + "strings" +) + +const UnknownOllamaKeyErrMsg = "unknown ollama key" + +// TODO: This should have a structured response from the API +type UnknownOllamaKey struct { + Key string +} + +func (e *UnknownOllamaKey) Error() string { + return fmt.Sprintf("unauthorized: %s %q", UnknownOllamaKeyErrMsg, strings.TrimSpace(e.Key)) +}