prompt to display and add local ollama keys to account (#3717)
- return descriptive error messages when unauthorized to create blob or push a model - display the local public key associated with the request that was denied
This commit is contained in:
parent
5950c176ca
commit
0a7fdbe533
4 changed files with 155 additions and 7 deletions
36
auth/auth.go
36
auth/auth.go
|
@ -10,12 +10,44 @@ import (
|
||||||
"log/slog"
|
"log/slog"
|
||||||
"os"
|
"os"
|
||||||
"path/filepath"
|
"path/filepath"
|
||||||
|
"strings"
|
||||||
|
|
||||||
"golang.org/x/crypto/ssh"
|
"golang.org/x/crypto/ssh"
|
||||||
)
|
)
|
||||||
|
|
||||||
const defaultPrivateKey = "id_ed25519"
|
const defaultPrivateKey = "id_ed25519"
|
||||||
|
|
||||||
|
func keyPath() (string, error) {
|
||||||
|
home, err := os.UserHomeDir()
|
||||||
|
if err != nil {
|
||||||
|
return "", err
|
||||||
|
}
|
||||||
|
|
||||||
|
return filepath.Join(home, ".ollama", defaultPrivateKey), nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func GetPublicKey() (string, error) {
|
||||||
|
keyPath, err := keyPath()
|
||||||
|
if err != nil {
|
||||||
|
return "", err
|
||||||
|
}
|
||||||
|
|
||||||
|
privateKeyFile, err := os.ReadFile(keyPath)
|
||||||
|
if err != nil {
|
||||||
|
slog.Info(fmt.Sprintf("Failed to load private key: %v", err))
|
||||||
|
return "", err
|
||||||
|
}
|
||||||
|
|
||||||
|
privateKey, err := ssh.ParsePrivateKey(privateKeyFile)
|
||||||
|
if err != nil {
|
||||||
|
return "", err
|
||||||
|
}
|
||||||
|
|
||||||
|
publicKey := ssh.MarshalAuthorizedKey(privateKey.PublicKey())
|
||||||
|
|
||||||
|
return strings.TrimSpace(string(publicKey)), nil
|
||||||
|
}
|
||||||
|
|
||||||
func NewNonce(r io.Reader, length int) (string, error) {
|
func NewNonce(r io.Reader, length int) (string, error) {
|
||||||
nonce := make([]byte, length)
|
nonce := make([]byte, length)
|
||||||
if _, err := io.ReadFull(r, nonce); err != nil {
|
if _, err := io.ReadFull(r, nonce); err != nil {
|
||||||
|
@ -26,13 +58,11 @@ func NewNonce(r io.Reader, length int) (string, error) {
|
||||||
}
|
}
|
||||||
|
|
||||||
func Sign(ctx context.Context, bts []byte) (string, error) {
|
func Sign(ctx context.Context, bts []byte) (string, error) {
|
||||||
home, err := os.UserHomeDir()
|
keyPath, err := keyPath()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return "", err
|
return "", err
|
||||||
}
|
}
|
||||||
|
|
||||||
keyPath := filepath.Join(home, ".ollama", defaultPrivateKey)
|
|
||||||
|
|
||||||
privateKeyFile, err := os.ReadFile(keyPath)
|
privateKeyFile, err := os.ReadFile(keyPath)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
slog.Info(fmt.Sprintf("Failed to load private key: %v", err))
|
slog.Info(fmt.Sprintf("Failed to load private key: %v", err))
|
||||||
|
|
58
cmd/cmd.go
58
cmd/cmd.go
|
@ -32,10 +32,13 @@ import (
|
||||||
"golang.org/x/term"
|
"golang.org/x/term"
|
||||||
|
|
||||||
"github.com/ollama/ollama/api"
|
"github.com/ollama/ollama/api"
|
||||||
|
"github.com/ollama/ollama/auth"
|
||||||
"github.com/ollama/ollama/format"
|
"github.com/ollama/ollama/format"
|
||||||
"github.com/ollama/ollama/parser"
|
"github.com/ollama/ollama/parser"
|
||||||
"github.com/ollama/ollama/progress"
|
"github.com/ollama/ollama/progress"
|
||||||
"github.com/ollama/ollama/server"
|
"github.com/ollama/ollama/server"
|
||||||
|
"github.com/ollama/ollama/types/errtypes"
|
||||||
|
"github.com/ollama/ollama/types/model"
|
||||||
"github.com/ollama/ollama/version"
|
"github.com/ollama/ollama/version"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@ -357,6 +360,47 @@ func RunHandler(cmd *cobra.Command, args []string) error {
|
||||||
return generateInteractive(cmd, opts)
|
return generateInteractive(cmd, opts)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func errFromUnknownKey(unknownKeyErr error) error {
|
||||||
|
// find SSH public key in the error message
|
||||||
|
sshKeyPattern := `ssh-\w+ [^\s"]+`
|
||||||
|
re := regexp.MustCompile(sshKeyPattern)
|
||||||
|
matches := re.FindStringSubmatch(unknownKeyErr.Error())
|
||||||
|
|
||||||
|
if len(matches) > 0 {
|
||||||
|
serverPubKey := matches[0]
|
||||||
|
|
||||||
|
localPubKey, err := auth.GetPublicKey()
|
||||||
|
if err != nil {
|
||||||
|
return unknownKeyErr
|
||||||
|
}
|
||||||
|
|
||||||
|
if runtime.GOOS == "linux" && serverPubKey != localPubKey {
|
||||||
|
// try the ollama service public key
|
||||||
|
svcPubKey, err := os.ReadFile("/usr/share/ollama/.ollama/id_ed25519.pub")
|
||||||
|
if err != nil {
|
||||||
|
return unknownKeyErr
|
||||||
|
}
|
||||||
|
localPubKey = strings.TrimSpace(string(svcPubKey))
|
||||||
|
}
|
||||||
|
|
||||||
|
// check if the returned public key matches the local public key, this prevents adding a remote key to the user's account
|
||||||
|
if serverPubKey != localPubKey {
|
||||||
|
return unknownKeyErr
|
||||||
|
}
|
||||||
|
|
||||||
|
var msg strings.Builder
|
||||||
|
msg.WriteString(unknownKeyErr.Error())
|
||||||
|
msg.WriteString("\n\nYour ollama key is:\n")
|
||||||
|
msg.WriteString(localPubKey)
|
||||||
|
msg.WriteString("\nAdd your key at:\n")
|
||||||
|
msg.WriteString("https://ollama.com/settings/keys")
|
||||||
|
|
||||||
|
return errors.New(msg.String())
|
||||||
|
}
|
||||||
|
|
||||||
|
return unknownKeyErr
|
||||||
|
}
|
||||||
|
|
||||||
func PushHandler(cmd *cobra.Command, args []string) error {
|
func PushHandler(cmd *cobra.Command, args []string) error {
|
||||||
client, err := api.ClientFromEnvironment()
|
client, err := api.ClientFromEnvironment()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
|
@ -404,6 +448,20 @@ func PushHandler(cmd *cobra.Command, args []string) error {
|
||||||
|
|
||||||
request := api.PushRequest{Name: args[0], Insecure: insecure}
|
request := api.PushRequest{Name: args[0], Insecure: insecure}
|
||||||
if err := client.Push(cmd.Context(), &request, fn); err != nil {
|
if err := client.Push(cmd.Context(), &request, fn); err != nil {
|
||||||
|
if spinner != nil {
|
||||||
|
spinner.Stop()
|
||||||
|
}
|
||||||
|
if strings.Contains(err.Error(), "access denied") {
|
||||||
|
return errors.New("you are not authorized to push to this namespace, create the model under a namespace you own")
|
||||||
|
}
|
||||||
|
host := model.ParseName(args[0]).Host
|
||||||
|
isOllamaHost := strings.HasSuffix(host, ".ollama.ai") || strings.HasSuffix(host, ".ollama.com")
|
||||||
|
if strings.Contains(err.Error(), errtypes.UnknownOllamaKeyErrMsg) && isOllamaHost {
|
||||||
|
// the user has not added their ollama key to ollama.com
|
||||||
|
// re-throw an error with a more user-friendly message
|
||||||
|
return errFromUnknownKey(err)
|
||||||
|
}
|
||||||
|
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -5,6 +5,7 @@ import (
|
||||||
"bytes"
|
"bytes"
|
||||||
"context"
|
"context"
|
||||||
"crypto/sha256"
|
"crypto/sha256"
|
||||||
|
"encoding/base64"
|
||||||
"encoding/hex"
|
"encoding/hex"
|
||||||
"encoding/json"
|
"encoding/json"
|
||||||
"errors"
|
"errors"
|
||||||
|
@ -25,10 +26,12 @@ import (
|
||||||
"golang.org/x/exp/slices"
|
"golang.org/x/exp/slices"
|
||||||
|
|
||||||
"github.com/ollama/ollama/api"
|
"github.com/ollama/ollama/api"
|
||||||
|
"github.com/ollama/ollama/auth"
|
||||||
"github.com/ollama/ollama/convert"
|
"github.com/ollama/ollama/convert"
|
||||||
"github.com/ollama/ollama/format"
|
"github.com/ollama/ollama/format"
|
||||||
"github.com/ollama/ollama/llm"
|
"github.com/ollama/ollama/llm"
|
||||||
"github.com/ollama/ollama/parser"
|
"github.com/ollama/ollama/parser"
|
||||||
|
"github.com/ollama/ollama/types/errtypes"
|
||||||
"github.com/ollama/ollama/types/model"
|
"github.com/ollama/ollama/types/model"
|
||||||
"github.com/ollama/ollama/version"
|
"github.com/ollama/ollama/version"
|
||||||
)
|
)
|
||||||
|
@ -980,9 +983,6 @@ func PushModel(ctx context.Context, name string, regOpts *registryOptions, fn fu
|
||||||
for _, layer := range layers {
|
for _, layer := range layers {
|
||||||
if err := uploadBlob(ctx, mp, layer, regOpts, fn); err != nil {
|
if err := uploadBlob(ctx, mp, layer, regOpts, fn); err != nil {
|
||||||
slog.Info(fmt.Sprintf("error uploading blob: %v", err))
|
slog.Info(fmt.Sprintf("error uploading blob: %v", err))
|
||||||
if errors.Is(err, errUnauthorized) {
|
|
||||||
return fmt.Errorf("unable to push %s, make sure this namespace exists and you are authorized to push to it", ParseModelPath(name).GetNamespaceRepository())
|
|
||||||
}
|
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -1145,9 +1145,40 @@ func GetSHA256Digest(r io.Reader) (string, int64) {
|
||||||
return fmt.Sprintf("sha256:%x", h.Sum(nil)), n
|
return fmt.Sprintf("sha256:%x", h.Sum(nil)), n
|
||||||
}
|
}
|
||||||
|
|
||||||
var errUnauthorized = errors.New("unauthorized")
|
var errUnauthorized = fmt.Errorf("unauthorized: access denied")
|
||||||
|
|
||||||
|
// getTokenSubject returns the subject of a JWT token, it does not validate the token
|
||||||
|
func getTokenSubject(token string) string {
|
||||||
|
parts := strings.Split(token, ".")
|
||||||
|
if len(parts) != 3 {
|
||||||
|
slog.Error("jwt token does not contain 3 parts")
|
||||||
|
return ""
|
||||||
|
}
|
||||||
|
|
||||||
|
payload := parts[1]
|
||||||
|
payloadBytes, err := base64.RawURLEncoding.DecodeString(payload)
|
||||||
|
if err != nil {
|
||||||
|
slog.Error(fmt.Sprintf("failed to decode jwt payload: %v", err))
|
||||||
|
return ""
|
||||||
|
}
|
||||||
|
|
||||||
|
var payloadMap map[string]interface{}
|
||||||
|
if err := json.Unmarshal(payloadBytes, &payloadMap); err != nil {
|
||||||
|
slog.Error(fmt.Sprintf("failed to unmarshal payload JSON: %v", err))
|
||||||
|
return ""
|
||||||
|
}
|
||||||
|
|
||||||
|
sub, ok := payloadMap["sub"]
|
||||||
|
if !ok {
|
||||||
|
slog.Error("jwt does not contain 'sub' field")
|
||||||
|
return ""
|
||||||
|
}
|
||||||
|
|
||||||
|
return fmt.Sprintf("%s", sub)
|
||||||
|
}
|
||||||
|
|
||||||
func makeRequestWithRetry(ctx context.Context, method string, requestURL *url.URL, headers http.Header, body io.ReadSeeker, regOpts *registryOptions) (*http.Response, error) {
|
func makeRequestWithRetry(ctx context.Context, method string, requestURL *url.URL, headers http.Header, body io.ReadSeeker, regOpts *registryOptions) (*http.Response, error) {
|
||||||
|
anonymous := true // access will default to anonymous if no user is found associated with the public key
|
||||||
for i := 0; i < 2; i++ {
|
for i := 0; i < 2; i++ {
|
||||||
resp, err := makeRequest(ctx, method, requestURL, headers, body, regOpts)
|
resp, err := makeRequest(ctx, method, requestURL, headers, body, regOpts)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
|
@ -1166,6 +1197,7 @@ func makeRequestWithRetry(ctx context.Context, method string, requestURL *url.UR
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
anonymous = getTokenSubject(token) == "anonymous"
|
||||||
regOpts.Token = token
|
regOpts.Token = token
|
||||||
if body != nil {
|
if body != nil {
|
||||||
_, err = body.Seek(0, io.SeekStart)
|
_, err = body.Seek(0, io.SeekStart)
|
||||||
|
@ -1186,6 +1218,16 @@ func makeRequestWithRetry(ctx context.Context, method string, requestURL *url.UR
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if anonymous {
|
||||||
|
// no user is associated with the public key, and the request requires non-anonymous access
|
||||||
|
pubKey, nestedErr := auth.GetPublicKey()
|
||||||
|
if nestedErr != nil {
|
||||||
|
slog.Error(fmt.Sprintf("couldn't get public key: %v", nestedErr))
|
||||||
|
return nil, errUnauthorized
|
||||||
|
}
|
||||||
|
return nil, &errtypes.UnknownOllamaKey{Key: pubKey}
|
||||||
|
}
|
||||||
|
// user is associated with the public key, but is not authorized to make the request
|
||||||
return nil, errUnauthorized
|
return nil, errUnauthorized
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
18
types/errtypes/errtypes.go
Normal file
18
types/errtypes/errtypes.go
Normal file
|
@ -0,0 +1,18 @@
|
||||||
|
// Package errtypes contains custom error types
|
||||||
|
package errtypes
|
||||||
|
|
||||||
|
import (
|
||||||
|
"fmt"
|
||||||
|
"strings"
|
||||||
|
)
|
||||||
|
|
||||||
|
const UnknownOllamaKeyErrMsg = "unknown ollama key"
|
||||||
|
|
||||||
|
// TODO: This should have a structured response from the API
|
||||||
|
type UnknownOllamaKey struct {
|
||||||
|
Key string
|
||||||
|
}
|
||||||
|
|
||||||
|
func (e *UnknownOllamaKey) Error() string {
|
||||||
|
return fmt.Sprintf("unauthorized: %s %q", UnknownOllamaKeyErrMsg, strings.TrimSpace(e.Key))
|
||||||
|
}
|
Loading…
Reference in a new issue