Add forward authentication option
This commit is contained in:
parent
f16219f90a
commit
52b69fbcb8
11 changed files with 252 additions and 105 deletions
60
auth/forward.go
Normal file
60
auth/forward.go
Normal file
|
@ -0,0 +1,60 @@
|
|||
package auth
|
||||
|
||||
import (
|
||||
"io/ioutil"
|
||||
"net/http"
|
||||
|
||||
"github.com/containous/traefik/log"
|
||||
"github.com/containous/traefik/types"
|
||||
)
|
||||
|
||||
// Forward the authentication to a external server
|
||||
func Forward(forward *types.Forward, w http.ResponseWriter, r *http.Request, next http.HandlerFunc) {
|
||||
httpClient := http.Client{}
|
||||
|
||||
if forward.TLS != nil {
|
||||
tlsConfig, err := forward.TLS.CreateTLSConfig()
|
||||
if err != nil {
|
||||
log.Debugf("Impossible to configure TLS to call %s. Cause %s", forward.Address, err)
|
||||
w.WriteHeader(http.StatusInternalServerError)
|
||||
return
|
||||
}
|
||||
httpClient.Transport = &http.Transport{
|
||||
TLSClientConfig: tlsConfig,
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
forwardReq, err := http.NewRequest(http.MethodGet, forward.Address, nil)
|
||||
if err != nil {
|
||||
log.Debugf("Error calling %s. Cause %s", forward.Address, err)
|
||||
w.WriteHeader(http.StatusInternalServerError)
|
||||
return
|
||||
}
|
||||
forwardReq.Header = r.Header
|
||||
|
||||
forwardResponse, forwardErr := httpClient.Do(forwardReq)
|
||||
if forwardErr != nil {
|
||||
log.Debugf("Error calling %s. Cause: %s", forward.Address, forwardErr)
|
||||
w.WriteHeader(http.StatusInternalServerError)
|
||||
return
|
||||
}
|
||||
|
||||
body, readError := ioutil.ReadAll(forwardResponse.Body)
|
||||
if readError != nil {
|
||||
log.Debugf("Error reading body %s. Cause: %s", forward.Address, readError)
|
||||
w.WriteHeader(http.StatusInternalServerError)
|
||||
return
|
||||
}
|
||||
defer forwardResponse.Body.Close()
|
||||
|
||||
if forwardResponse.StatusCode < http.StatusOK || forwardResponse.StatusCode >= http.StatusMultipleChoices {
|
||||
log.Debugf("Remote error %s. StatusCode: %d", forward.Address, forwardResponse.StatusCode)
|
||||
w.WriteHeader(forwardResponse.StatusCode)
|
||||
w.Write(body)
|
||||
return
|
||||
}
|
||||
|
||||
r.RequestURI = r.URL.RequestURI()
|
||||
next(w, r)
|
||||
}
|
|
@ -11,7 +11,7 @@ import (
|
|||
"github.com/containous/staert"
|
||||
"github.com/containous/traefik/cluster"
|
||||
"github.com/containous/traefik/integration/try"
|
||||
"github.com/containous/traefik/provider"
|
||||
"github.com/containous/traefik/types"
|
||||
"github.com/docker/libkv"
|
||||
"github.com/docker/libkv/store"
|
||||
"github.com/docker/libkv/store/consul"
|
||||
|
@ -52,7 +52,7 @@ func (s *ConsulSuite) setupConsulTLS(c *check.C) {
|
|||
s.composeProject.Start(c)
|
||||
|
||||
consul.Register()
|
||||
clientTLS := &provider.ClientTLS{
|
||||
clientTLS := &types.ClientTLS{
|
||||
CA: "resources/tls/ca.cert",
|
||||
Cert: "resources/tls/consul.cert",
|
||||
Key: "resources/tls/consul.key",
|
||||
|
|
|
@ -6,7 +6,8 @@ import (
|
|||
"net/http"
|
||||
"strings"
|
||||
|
||||
"github.com/abbot/go-http-auth"
|
||||
goauth "github.com/abbot/go-http-auth"
|
||||
"github.com/containous/traefik/auth"
|
||||
"github.com/containous/traefik/log"
|
||||
"github.com/containous/traefik/types"
|
||||
"github.com/urfave/negroni"
|
||||
|
@ -30,7 +31,7 @@ func NewAuthenticator(authConfig *types.Auth) (*Authenticator, error) {
|
|||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
basicAuth := auth.NewBasicAuthenticator("traefik", authenticator.secretBasic)
|
||||
basicAuth := goauth.NewBasicAuthenticator("traefik", authenticator.secretBasic)
|
||||
authenticator.handler = negroni.HandlerFunc(func(w http.ResponseWriter, r *http.Request, next http.HandlerFunc) {
|
||||
if username := basicAuth.CheckAuth(r); username == "" {
|
||||
log.Debug("Basic auth failed...")
|
||||
|
@ -48,7 +49,7 @@ func NewAuthenticator(authConfig *types.Auth) (*Authenticator, error) {
|
|||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
digestAuth := auth.NewDigestAuthenticator("traefik", authenticator.secretDigest)
|
||||
digestAuth := goauth.NewDigestAuthenticator("traefik", authenticator.secretDigest)
|
||||
authenticator.handler = negroni.HandlerFunc(func(w http.ResponseWriter, r *http.Request, next http.HandlerFunc) {
|
||||
if username, _ := digestAuth.CheckAuth(r); username == "" {
|
||||
log.Debug("Digest auth failed...")
|
||||
|
@ -61,6 +62,10 @@ func NewAuthenticator(authConfig *types.Auth) (*Authenticator, error) {
|
|||
next.ServeHTTP(w, r)
|
||||
}
|
||||
})
|
||||
} else if authConfig.Forward != nil {
|
||||
authenticator.handler = negroni.HandlerFunc(func(w http.ResponseWriter, r *http.Request, next http.HandlerFunc) {
|
||||
auth.Forward(authConfig.Forward, w, r, next)
|
||||
})
|
||||
}
|
||||
return &authenticator, nil
|
||||
}
|
||||
|
|
|
@ -186,3 +186,67 @@ func TestBasicAuthUserHeader(t *testing.T) {
|
|||
assert.NoError(t, err, "there should be no error")
|
||||
assert.Equal(t, "traefik\n", string(body), "they should be equal")
|
||||
}
|
||||
|
||||
func TestForwardAuthFail(t *testing.T) {
|
||||
authTs := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
http.Error(w, "Forbidden", http.StatusForbidden)
|
||||
}))
|
||||
defer authTs.Close()
|
||||
|
||||
authMiddleware, err := NewAuthenticator(&types.Auth{
|
||||
Forward: &types.Forward{
|
||||
Address: authTs.URL,
|
||||
},
|
||||
})
|
||||
assert.NoError(t, err, "there should be no error")
|
||||
|
||||
handler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
fmt.Fprintln(w, "traefik")
|
||||
})
|
||||
n := negroni.New(authMiddleware)
|
||||
n.UseHandler(handler)
|
||||
ts := httptest.NewServer(n)
|
||||
defer ts.Close()
|
||||
|
||||
client := &http.Client{}
|
||||
req := testhelpers.MustNewRequest(http.MethodGet, ts.URL, nil)
|
||||
res, err := client.Do(req)
|
||||
assert.NoError(t, err, "there should be no error")
|
||||
assert.Equal(t, http.StatusForbidden, res.StatusCode, "they should be equal")
|
||||
|
||||
body, err := ioutil.ReadAll(res.Body)
|
||||
assert.NoError(t, err, "there should be no error")
|
||||
assert.Equal(t, "Forbidden\n", string(body), "they should be equal")
|
||||
}
|
||||
|
||||
func TestForwardAuthSuccess(t *testing.T) {
|
||||
authTs := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
fmt.Fprintln(w, "Success")
|
||||
}))
|
||||
defer authTs.Close()
|
||||
|
||||
authMiddleware, err := NewAuthenticator(&types.Auth{
|
||||
Forward: &types.Forward{
|
||||
Address: authTs.URL,
|
||||
},
|
||||
})
|
||||
assert.NoError(t, err, "there should be no error")
|
||||
|
||||
handler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
fmt.Fprintln(w, "traefik")
|
||||
})
|
||||
n := negroni.New(authMiddleware)
|
||||
n.UseHandler(handler)
|
||||
ts := httptest.NewServer(n)
|
||||
defer ts.Close()
|
||||
|
||||
client := &http.Client{}
|
||||
req := testhelpers.MustNewRequest(http.MethodGet, ts.URL, nil)
|
||||
res, err := client.Do(req)
|
||||
assert.NoError(t, err, "there should be no error")
|
||||
assert.Equal(t, http.StatusOK, res.StatusCode, "they should be equal")
|
||||
|
||||
body, err := ioutil.ReadAll(res.Body)
|
||||
assert.NoError(t, err, "there should be no error")
|
||||
assert.Equal(t, "traefik\n", string(body), "they should be equal")
|
||||
}
|
||||
|
|
|
@ -49,7 +49,7 @@ type Provider struct {
|
|||
provider.BaseProvider `mapstructure:",squash"`
|
||||
Endpoint string `description:"Docker server endpoint. Can be a tcp or a unix socket endpoint"`
|
||||
Domain string `description:"Default domain used"`
|
||||
TLS *provider.ClientTLS `description:"Enable Docker TLS support"`
|
||||
TLS *types.ClientTLS `description:"Enable Docker TLS support"`
|
||||
ExposedByDefault bool `description:"Expose containers by default"`
|
||||
UseBindPortIP bool `description:"Use the ip address from the bound port, rather than from the inner network"`
|
||||
SwarmMode bool `description:"Use Docker on Swarm Mode"`
|
||||
|
|
|
@ -23,7 +23,7 @@ type Provider struct {
|
|||
provider.BaseProvider `mapstructure:",squash"`
|
||||
Endpoint string `description:"Comma separated server endpoints"`
|
||||
Prefix string `description:"Prefix used for KV store"`
|
||||
TLS *provider.ClientTLS `description:"Enable TLS support"`
|
||||
TLS *types.ClientTLS `description:"Enable TLS support"`
|
||||
Username string `description:"KV Username"`
|
||||
Password string `description:"KV Password"`
|
||||
storeType store.Backend
|
||||
|
|
|
@ -59,7 +59,7 @@ type Provider struct {
|
|||
GroupsAsSubDomains bool `description:"Convert Marathon groups to subdomains"`
|
||||
DCOSToken string `description:"DCOSToken for DCOS environment, This will override the Authorization header"`
|
||||
MarathonLBCompatibility bool `description:"Add compatibility with marathon-lb labels"`
|
||||
TLS *provider.ClientTLS `description:"Enable Docker TLS support"`
|
||||
TLS *types.ClientTLS `description:"Enable Docker TLS support"`
|
||||
DialerTimeout flaeg.Duration `description:"Set a non-default connection timeout for Marathon"`
|
||||
KeepAlive flaeg.Duration `description:"Set a non-default TCP Keep Alive time in seconds"`
|
||||
ForceTaskHostname bool `description:"Force to use the task's hostname."`
|
||||
|
|
|
@ -2,11 +2,7 @@ package provider
|
|||
|
||||
import (
|
||||
"bytes"
|
||||
"crypto/tls"
|
||||
"crypto/x509"
|
||||
"fmt"
|
||||
"io/ioutil"
|
||||
"os"
|
||||
"strings"
|
||||
"text/template"
|
||||
"unicode"
|
||||
|
@ -123,71 +119,3 @@ func ReverseStringSlice(slice *[]string) {
|
|||
(*slice)[i], (*slice)[j] = (*slice)[j], (*slice)[i]
|
||||
}
|
||||
}
|
||||
|
||||
// ClientTLS holds TLS specific configurations as client
|
||||
// CA, Cert and Key can be either path or file contents
|
||||
type ClientTLS struct {
|
||||
CA string `description:"TLS CA"`
|
||||
Cert string `description:"TLS cert"`
|
||||
Key string `description:"TLS key"`
|
||||
InsecureSkipVerify bool `description:"TLS insecure skip verify"`
|
||||
}
|
||||
|
||||
// CreateTLSConfig creates a TLS config from ClientTLS structures
|
||||
func (clientTLS *ClientTLS) CreateTLSConfig() (*tls.Config, error) {
|
||||
var err error
|
||||
if clientTLS == nil {
|
||||
log.Warnf("clientTLS is nil")
|
||||
return nil, nil
|
||||
}
|
||||
caPool := x509.NewCertPool()
|
||||
if clientTLS.CA != "" {
|
||||
var ca []byte
|
||||
if _, errCA := os.Stat(clientTLS.CA); errCA == nil {
|
||||
ca, err = ioutil.ReadFile(clientTLS.CA)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("Failed to read CA. %s", err)
|
||||
}
|
||||
} else {
|
||||
ca = []byte(clientTLS.CA)
|
||||
}
|
||||
caPool.AppendCertsFromPEM(ca)
|
||||
}
|
||||
|
||||
cert := tls.Certificate{}
|
||||
_, errKeyIsFile := os.Stat(clientTLS.Key)
|
||||
|
||||
if !clientTLS.InsecureSkipVerify && (len(clientTLS.Cert) == 0 || len(clientTLS.Key) == 0) {
|
||||
return nil, fmt.Errorf("TLS Certificate or Key file must be set when TLS configuration is created")
|
||||
}
|
||||
|
||||
if len(clientTLS.Cert) > 0 && len(clientTLS.Key) > 0 {
|
||||
if _, errCertIsFile := os.Stat(clientTLS.Cert); errCertIsFile == nil {
|
||||
if errKeyIsFile == nil {
|
||||
cert, err = tls.LoadX509KeyPair(clientTLS.Cert, clientTLS.Key)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("Failed to load TLS keypair: %v", err)
|
||||
}
|
||||
} else {
|
||||
return nil, fmt.Errorf("tls cert is a file, but tls key is not")
|
||||
}
|
||||
} else {
|
||||
if errKeyIsFile != nil {
|
||||
cert, err = tls.X509KeyPair([]byte(clientTLS.Cert), []byte(clientTLS.Key))
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("Failed to load TLS keypair: %v", err)
|
||||
|
||||
}
|
||||
} else {
|
||||
return nil, fmt.Errorf("tls key is a file, but tls cert is not")
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
TLSConfig := &tls.Config{
|
||||
Certificates: []tls.Certificate{cert},
|
||||
RootCAs: caPool,
|
||||
InsecureSkipVerify: clientTLS.InsecureSkipVerify,
|
||||
}
|
||||
return TLSConfig, nil
|
||||
}
|
||||
|
|
|
@ -12,7 +12,7 @@ import (
|
|||
|
||||
type myProvider struct {
|
||||
BaseProvider
|
||||
TLS *ClientTLS
|
||||
TLS *types.ClientTLS
|
||||
}
|
||||
|
||||
func (p *myProvider) Foo() string {
|
||||
|
@ -202,7 +202,7 @@ func TestInsecureSkipVerifyClientTLS(t *testing.T) {
|
|||
BaseProvider{
|
||||
Filename: "",
|
||||
},
|
||||
&ClientTLS{
|
||||
&types.ClientTLS{
|
||||
InsecureSkipVerify: true,
|
||||
},
|
||||
}
|
||||
|
@ -220,7 +220,7 @@ func TestInsecureSkipVerifyFalseClientTLS(t *testing.T) {
|
|||
BaseProvider{
|
||||
Filename: "",
|
||||
},
|
||||
&ClientTLS{
|
||||
&types.ClientTLS{
|
||||
InsecureSkipVerify: false,
|
||||
},
|
||||
}
|
||||
|
|
|
@ -304,6 +304,15 @@
|
|||
# users = ["test:traefik:a2688e031edb4be6a3797f3882655c05 ", "test2:traefik:518845800f9e2bfb1f1f740ec24f074e"]
|
||||
# usersFile = "/path/to/.htdigest"
|
||||
#
|
||||
# To enable forward auth on an entrypoint
|
||||
# This configuration will first forward the request to http://authserver.com/auth. If the response code is 2XX,
|
||||
# access is granted and the original request is performed. Otherwise, the response from the auth server is returned.
|
||||
# [entryPoints]
|
||||
# [entryPoints.http]
|
||||
# address = ":80"
|
||||
# [entryPoints.http.auth.forward]
|
||||
# address = "http://authserver.com/auth"
|
||||
#
|
||||
# To specify an https entrypoint with a minimum TLS version, and specifying an array of cipher suites (from crypto/tls):
|
||||
# [entryPoints]
|
||||
# [entryPoints.https]
|
||||
|
|
|
@ -7,6 +7,12 @@ import (
|
|||
"strconv"
|
||||
"strings"
|
||||
|
||||
"crypto/tls"
|
||||
"crypto/x509"
|
||||
"io/ioutil"
|
||||
"os"
|
||||
|
||||
"github.com/containous/traefik/log"
|
||||
"github.com/docker/libkv/store"
|
||||
"github.com/ryanuber/go-glob"
|
||||
)
|
||||
|
@ -299,6 +305,7 @@ type Cluster struct {
|
|||
type Auth struct {
|
||||
Basic *Basic
|
||||
Digest *Digest
|
||||
Forward *Forward
|
||||
HeaderField string
|
||||
}
|
||||
|
||||
|
@ -317,6 +324,12 @@ type Digest struct {
|
|||
UsersFile string
|
||||
}
|
||||
|
||||
// Forward authentication
|
||||
type Forward struct {
|
||||
Address string `description:"Authentication server address"`
|
||||
TLS *ClientTLS `description:"Enable TLS support"`
|
||||
}
|
||||
|
||||
// CanonicalDomain returns a lower case domain with trim space
|
||||
func CanonicalDomain(domain string) string {
|
||||
return strings.ToLower(strings.TrimSpace(domain))
|
||||
|
@ -388,3 +401,71 @@ type AccessLog struct {
|
|||
FilePath string `json:"file,omitempty" description:"Access log file path. Stdout is used when omitted or empty"`
|
||||
Format string `json:"format,omitempty" description:"Access log format: json | common"`
|
||||
}
|
||||
|
||||
// ClientTLS holds TLS specific configurations as client
|
||||
// CA, Cert and Key can be either path or file contents
|
||||
type ClientTLS struct {
|
||||
CA string `description:"TLS CA"`
|
||||
Cert string `description:"TLS cert"`
|
||||
Key string `description:"TLS key"`
|
||||
InsecureSkipVerify bool `description:"TLS insecure skip verify"`
|
||||
}
|
||||
|
||||
// CreateTLSConfig creates a TLS config from ClientTLS structures
|
||||
func (clientTLS *ClientTLS) CreateTLSConfig() (*tls.Config, error) {
|
||||
var err error
|
||||
if clientTLS == nil {
|
||||
log.Warnf("clientTLS is nil")
|
||||
return nil, nil
|
||||
}
|
||||
caPool := x509.NewCertPool()
|
||||
if clientTLS.CA != "" {
|
||||
var ca []byte
|
||||
if _, errCA := os.Stat(clientTLS.CA); errCA == nil {
|
||||
ca, err = ioutil.ReadFile(clientTLS.CA)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("Failed to read CA. %s", err)
|
||||
}
|
||||
} else {
|
||||
ca = []byte(clientTLS.CA)
|
||||
}
|
||||
caPool.AppendCertsFromPEM(ca)
|
||||
}
|
||||
|
||||
cert := tls.Certificate{}
|
||||
_, errKeyIsFile := os.Stat(clientTLS.Key)
|
||||
|
||||
if !clientTLS.InsecureSkipVerify && (len(clientTLS.Cert) == 0 || len(clientTLS.Key) == 0) {
|
||||
return nil, fmt.Errorf("TLS Certificate or Key file must be set when TLS configuration is created")
|
||||
}
|
||||
|
||||
if len(clientTLS.Cert) > 0 && len(clientTLS.Key) > 0 {
|
||||
if _, errCertIsFile := os.Stat(clientTLS.Cert); errCertIsFile == nil {
|
||||
if errKeyIsFile == nil {
|
||||
cert, err = tls.LoadX509KeyPair(clientTLS.Cert, clientTLS.Key)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("Failed to load TLS keypair: %v", err)
|
||||
}
|
||||
} else {
|
||||
return nil, fmt.Errorf("tls cert is a file, but tls key is not")
|
||||
}
|
||||
} else {
|
||||
if errKeyIsFile != nil {
|
||||
cert, err = tls.X509KeyPair([]byte(clientTLS.Cert), []byte(clientTLS.Key))
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("Failed to load TLS keypair: %v", err)
|
||||
|
||||
}
|
||||
} else {
|
||||
return nil, fmt.Errorf("tls key is a file, but tls cert is not")
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
TLSConfig := &tls.Config{
|
||||
Certificates: []tls.Certificate{cert},
|
||||
RootCAs: caPool,
|
||||
InsecureSkipVerify: clientTLS.InsecureSkipVerify,
|
||||
}
|
||||
return TLSConfig, nil
|
||||
}
|
||||
|
|
Loading…
Reference in a new issue