From b4e3bca6fa347f743470ab3bf28d6c7ee88c20f3 Mon Sep 17 00:00:00 2001 From: SALLEYRON Julien Date: Tue, 24 Apr 2018 22:40:04 +0200 Subject: [PATCH] Remove acme provider dependency in server --- cmd/traefik/traefik.go | 46 ++++++++++++++++-------- configuration/configuration.go | 22 +++++++----- configuration/provider_aggregator.go | 24 ++++++------- provider/acme/provider.go | 53 +++++++--------------------- provider/acme/provider_test.go | 17 +++++---- server/server.go | 30 +++++++--------- tls/certificate_store.go | 33 +++++++++++++++++ 7 files changed, 124 insertions(+), 101 deletions(-) create mode 100644 tls/certificate_store.go diff --git a/cmd/traefik/traefik.go b/cmd/traefik/traefik.go index eb7637dab..b9042975c 100644 --- a/cmd/traefik/traefik.go +++ b/cmd/traefik/traefik.go @@ -24,7 +24,6 @@ import ( "github.com/containous/traefik/configuration/router" "github.com/containous/traefik/job" "github.com/containous/traefik/log" - "github.com/containous/traefik/provider/acme" "github.com/containous/traefik/provider/ecs" "github.com/containous/traefik/provider/kubernetes" "github.com/containous/traefik/safe" @@ -174,28 +173,47 @@ func runCmd(globalConfiguration *configuration.GlobalConfiguration, configFile s stats(globalConfiguration) log.Debugf("Global configuration loaded %s", string(jsonConf)) - if acme.IsEnabled() { - store := acme.NewLocalStore(acme.Get().Storage) - acme.Get().Store = &store + + providerAggregator := configuration.NewProviderAggregator(globalConfiguration) + + acmeprovider := globalConfiguration.InitACMEProvider() + if acmeprovider != nil { + providerAggregator.AddProvider(acmeprovider) } entryPoints := map[string]server.EntryPoint{} for entryPointName, config := range globalConfiguration.EntryPoints { - internalRouter := router.NewInternalRouterAggregator(*globalConfiguration, entryPointName) - if acme.IsEnabled() && acme.Get().HTTPChallenge != nil && acme.Get().HTTPChallenge.EntryPoint == entryPointName { - internalRouter.AddRouter(acme.Get()) + + entryPoint := server.EntryPoint{ + Configuration: config, } - entryPoints[entryPointName] = server.EntryPoint{ - InternalRouter: internalRouter, - Configuration: config, + internalRouter := router.NewInternalRouterAggregator(*globalConfiguration, entryPointName) + if acmeprovider != nil { + if acmeprovider.HTTPChallenge != nil && acmeprovider.HTTPChallenge.EntryPoint == entryPointName { + internalRouter.AddRouter(acmeprovider) + } + + if acmeprovider.EntryPoint == entryPointName && acmeprovider.OnDemand { + entryPoint.OnDemandListener = acmeprovider.ListenRequest + } + + entryPoint.CertificateStore = &traefiktls.CertificateStore{ + DynamicCerts: &safe.Safe{}, + StaticCerts: &safe.Safe{}, + } + acmeprovider.SetCertificateStore(*entryPoint.CertificateStore) + } + + entryPoint.InternalRouter = internalRouter + entryPoints[entryPointName] = entryPoint } - svr := server.NewServer(*globalConfiguration, configuration.NewProviderAggregator(globalConfiguration), entryPoints) - if acme.IsEnabled() && acme.Get().OnHostRule { - acme.Get().SetConfigListenerChan(make(chan types.Configuration)) - svr.AddListener(acme.Get().ListenConfiguration) + svr := server.NewServer(*globalConfiguration, providerAggregator, entryPoints) + if acmeprovider != nil && acmeprovider.OnHostRule { + acmeprovider.SetConfigListenerChan(make(chan types.Configuration)) + svr.AddListener(acmeprovider.ListenConfiguration) } ctx := cmd.ContextWithSignal(context.Background()) diff --git a/configuration/configuration.go b/configuration/configuration.go index d29bb4465..ed80485ac 100644 --- a/configuration/configuration.go +++ b/configuration/configuration.go @@ -370,11 +370,17 @@ func (gc *GlobalConfiguration) initACMEProvider() { if gc.ACME.OnDemand { log.Warn("ACME.OnDemand is deprecated") } + } +} +// InitACMEProvider create an acme provider from the ACME part of globalConfiguration +func (gc *GlobalConfiguration) InitACMEProvider() *acmeprovider.Provider { + if gc.ACME != nil { // TODO: Remove when Provider ACME will replace totally ACME // If provider file, use Provider ACME instead of ACME if gc.Cluster == nil { - acmeprovider.Get().Configuration = &acmeprovider.Configuration{ + provider := &acmeprovider.Provider{} + provider.Configuration = &acmeprovider.Configuration{ OnHostRule: gc.ACME.OnHostRule, OnDemand: gc.ACME.OnDemand, Email: gc.ACME.Email, @@ -386,9 +392,15 @@ func (gc *GlobalConfiguration) initACMEProvider() { CAServer: gc.ACME.CAServer, EntryPoint: gc.ACME.EntryPoint, } + + store := acmeprovider.NewLocalStore(provider.Storage) + provider.Store = &store + acme.ConvertToNewFormat(provider.Storage) gc.ACME = nil + return provider } } + return nil } // ValidateConfiguration validate that configuration is coherent @@ -401,14 +413,6 @@ func (gc *GlobalConfiguration) ValidateConfiguration() { log.Fatalf("Entrypoint %q has no TLS configuration for ACME configuration", gc.ACME.EntryPoint) } } - } else if acmeprovider.IsEnabled() { - if _, ok := gc.EntryPoints[acmeprovider.Get().EntryPoint]; !ok { - log.Fatalf("Unknown entrypoint %q for provider ACME configuration", acmeprovider.Get().EntryPoint) - } else { - if gc.EntryPoints[acmeprovider.Get().EntryPoint].TLS == nil { - log.Fatalf("Entrypoint %q has no TLS configuration for provider ACME configuration", acmeprovider.Get().EntryPoint) - } - } } } diff --git a/configuration/provider_aggregator.go b/configuration/provider_aggregator.go index 48377fb08..20661176e 100644 --- a/configuration/provider_aggregator.go +++ b/configuration/provider_aggregator.go @@ -4,21 +4,20 @@ import ( "encoding/json" "reflect" - "github.com/containous/traefik/acme" "github.com/containous/traefik/log" "github.com/containous/traefik/provider" - acmeprovider "github.com/containous/traefik/provider/acme" "github.com/containous/traefik/safe" "github.com/containous/traefik/types" ) -type providerAggregator struct { +// ProviderAggregator aggregate providers +type ProviderAggregator struct { providers []provider.Provider } // NewProviderAggregator return an aggregate of all the providers configured in GlobalConfiguration -func NewProviderAggregator(gc *GlobalConfiguration) provider.Provider { - provider := providerAggregator{} +func NewProviderAggregator(gc *GlobalConfiguration) ProviderAggregator { + provider := ProviderAggregator{} if gc.Docker != nil { provider.providers = append(provider.providers, gc.Docker) } @@ -67,17 +66,16 @@ func NewProviderAggregator(gc *GlobalConfiguration) provider.Provider { if gc.ServiceFabric != nil { provider.providers = append(provider.providers, gc.ServiceFabric) } - if acmeprovider.IsEnabled() { - provider.providers = append(provider.providers, acmeprovider.Get()) - acme.ConvertToNewFormat(acmeprovider.Get().Storage) - } - if len(provider.providers) == 1 { - return provider.providers[0] - } return provider } -func (p providerAggregator) Provide(configurationChan chan<- types.ConfigMessage, pool *safe.Pool, constraints types.Constraints) error { +// AddProvider add a provider in the providers map +func (p *ProviderAggregator) AddProvider(provider provider.Provider) { + p.providers = append(p.providers, provider) +} + +// Provide call the provide method of every providers +func (p ProviderAggregator) Provide(configurationChan chan<- types.ConfigMessage, pool *safe.Pool, constraints types.Constraints) error { for _, p := range p.providers { providerType := reflect.TypeOf(p) jsonConf, err := json.Marshal(p) diff --git a/provider/acme/provider.go b/provider/acme/provider.go index 19c9086e8..fe57ce173 100644 --- a/provider/acme/provider.go +++ b/provider/acme/provider.go @@ -20,7 +20,7 @@ import ( "github.com/containous/traefik/log" "github.com/containous/traefik/rules" "github.com/containous/traefik/safe" - traefikTLS "github.com/containous/traefik/tls" + traefiktls "github.com/containous/traefik/tls" "github.com/containous/traefik/types" "github.com/pkg/errors" acme "github.com/xenolf/lego/acmev2" @@ -30,7 +30,6 @@ import ( var ( // OSCPMustStaple enables OSCP stapling as from https://github.com/xenolf/lego/issues/270 OSCPMustStaple = false - provider = &Provider{} ) // Configuration holds ACME configuration provided by users @@ -56,8 +55,7 @@ type Provider struct { client *acme.Client certsChan chan *Certificate configurationChan chan<- types.ConfigMessage - dynamicCerts *safe.Safe - staticCerts map[string]*tls.Certificate + certificateStore traefiktls.CertificateStore clientMutex sync.Mutex configFromListenerChan chan types.Configuration pool *safe.Pool @@ -81,16 +79,6 @@ type HTTPChallenge struct { EntryPoint string `description:"HTTP challenge EntryPoint"` } -// Get returns the provider instance -func Get() *Provider { - return provider -} - -// IsEnabled returns true if the provider instance and its configuration are not nil, otherwise false -func IsEnabled() bool { - return provider != nil && provider.Configuration != nil -} - // SetConfigListenerChan initializes the configFromListenerChan func (p *Provider) SetConfigListenerChan(configFromListenerChan chan types.Configuration) { p.configFromListenerChan = configFromListenerChan @@ -196,14 +184,9 @@ func (p *Provider) watchNewDomains() { }) } -// SetDynamicCertificates allow to initialize dynamicCerts map -func (p *Provider) SetDynamicCertificates(safe *safe.Safe) { - p.dynamicCerts = safe -} - -// SetStaticCertificates allow to initialize staticCerts map -func (p *Provider) SetStaticCertificates(staticCerts map[string]*tls.Certificate) { - p.staticCerts = staticCerts +// SetCertificateStore allow to initialize certificate store +func (p *Provider) SetCertificateStore(certificateStore traefiktls.CertificateStore) { + p.certificateStore = certificateStore } func (p *Provider) resolveCertificate(domain types.Domain, domainFromConfigurationFile bool) (*acme.CertificateResource, error) { @@ -424,13 +407,13 @@ func (p *Provider) refreshCertificates() { Configuration: &types.Configuration{ Backends: map[string]*types.Backend{}, Frontends: map[string]*types.Frontend{}, - TLS: []*traefikTLS.Configuration{}, + TLS: []*traefiktls.Configuration{}, }, } for _, cert := range p.certificates { - certificate := &traefikTLS.Certificate{CertFile: traefikTLS.FileOrContent(cert.Certificate), KeyFile: traefikTLS.FileOrContent(cert.Key)} - config.Configuration.TLS = append(config.Configuration.TLS, &traefikTLS.Configuration{Certificate: certificate, EntryPoints: []string{p.EntryPoint}}) + certificate := &traefiktls.Certificate{CertFile: traefiktls.FileOrContent(cert.Certificate), KeyFile: traefiktls.FileOrContent(cert.Key)} + config.Configuration.TLS = append(config.Configuration.TLS, &traefiktls.Configuration{Certificate: certificate, EntryPoints: []string{p.EntryPoint}}) } p.configurationChan <- config } @@ -507,33 +490,23 @@ func (p *Provider) AddRoutes(router *mux.Router) { // from static and dynamic provided certificates func (p *Provider) getUncheckedDomains(domainsToCheck []string, checkConfigurationDomains bool) []string { log.Debugf("Looking for provided certificate(s) to validate %q...", domainsToCheck) - var allCerts []string + var allDomains []string - // Get static certificates - for domains := range p.staticCerts { - allCerts = append(allCerts, domains) - } - - // Get dynamic certificates - if p.dynamicCerts != nil && p.dynamicCerts.Get() != nil { - for domains := range p.dynamicCerts.Get().(map[string]*tls.Certificate) { - allCerts = append(allCerts, domains) - } - } + allDomains = p.certificateStore.GetAllDomains() // Get ACME certificates for _, certificate := range p.certificates { - allCerts = append(allCerts, strings.Join(certificate.Domain.ToStrArray(), ",")) + allDomains = append(allDomains, strings.Join(certificate.Domain.ToStrArray(), ",")) } // Get Configuration Domains if checkConfigurationDomains { for i := 0; i < len(p.Domains); i++ { - allCerts = append(allCerts, strings.Join(p.Domains[i].ToStrArray(), ",")) + allDomains = append(allDomains, strings.Join(p.Domains[i].ToStrArray(), ",")) } } - return searchUncheckedDomains(domainsToCheck, allCerts) + return searchUncheckedDomains(domainsToCheck, allDomains) } func searchUncheckedDomains(domainsToCheck []string, existentDomains []string) []string { diff --git a/provider/acme/provider_test.go b/provider/acme/provider_test.go index 2f6a3db96..384ed5b0c 100644 --- a/provider/acme/provider_test.go +++ b/provider/acme/provider_test.go @@ -5,6 +5,7 @@ import ( "testing" "github.com/containous/traefik/safe" + traefiktls "github.com/containous/traefik/tls" "github.com/containous/traefik/types" "github.com/stretchr/testify/assert" ) @@ -25,7 +26,7 @@ func TestGetUncheckedCertificates(t *testing.T) { testCases := []struct { desc string dynamicCerts *safe.Safe - staticCerts map[string]*tls.Certificate + staticCerts *safe.Safe acmeCertificates []*Certificate domains []string expectedDomains []string @@ -44,7 +45,7 @@ func TestGetUncheckedCertificates(t *testing.T) { { desc: "wildcard already exists in static certificates", domains: []string{"*.traefik.wtf"}, - staticCerts: wildcardMap, + staticCerts: wildcardSafe, expectedDomains: nil, }, { @@ -71,7 +72,7 @@ func TestGetUncheckedCertificates(t *testing.T) { { desc: "domain CN already exists in static certificates and SANs to generate", domains: []string{"traefik.wtf", "foo.traefik.wtf"}, - staticCerts: domainMap, + staticCerts: domainSafe, expectedDomains: []string{"foo.traefik.wtf"}, }, { @@ -93,7 +94,7 @@ func TestGetUncheckedCertificates(t *testing.T) { { desc: "domain already exists in static certificates", domains: []string{"traefik.wtf"}, - staticCerts: domainMap, + staticCerts: domainSafe, expectedDomains: nil, }, { @@ -115,7 +116,7 @@ func TestGetUncheckedCertificates(t *testing.T) { { desc: "domain matched by wildcard in static certificates", domains: []string{"who.traefik.wtf", "foo.traefik.wtf"}, - staticCerts: wildcardMap, + staticCerts: wildcardSafe, expectedDomains: nil, }, { @@ -146,8 +147,10 @@ func TestGetUncheckedCertificates(t *testing.T) { t.Parallel() acmeProvider := Provider{ - dynamicCerts: test.dynamicCerts, - staticCerts: test.staticCerts, + certificateStore: traefiktls.CertificateStore{ + DynamicCerts: test.dynamicCerts, + StaticCerts: test.staticCerts, + }, certificates: test.acmeCertificates, } diff --git a/server/server.go b/server/server.go index ae16ccd28..569d2445d 100644 --- a/server/server.go +++ b/server/server.go @@ -35,7 +35,6 @@ import ( "github.com/containous/traefik/middlewares/redirect" "github.com/containous/traefik/middlewares/tracing" "github.com/containous/traefik/provider" - "github.com/containous/traefik/provider/acme" "github.com/containous/traefik/rules" "github.com/containous/traefik/safe" "github.com/containous/traefik/server/cookie" @@ -81,8 +80,10 @@ type Server struct { // EntryPoint entryPoint information (configuration + internalRouter) type EntryPoint struct { - InternalRouter types.InternalRouter - Configuration *configuration.EntryPoint + InternalRouter types.InternalRouter + Configuration *configuration.EntryPoint + OnDemandListener func(string) (*tls.Certificate, error) + CertificateStore *traefiktls.CertificateStore } type serverEntryPoints map[string]*serverEntryPoint @@ -502,11 +503,6 @@ func (s *Server) AddListener(listener func(types.Configuration)) { s.configurationListeners = append(s.configurationListeners, listener) } -// SetOnDemandListener adds a new listener function used when a request is caught -func (s *serverEntryPoint) SetOnDemandListener(listener func(string) (*tls.Certificate, error)) { - s.onDemandListener = listener -} - // loadHTTPSConfiguration add/delete HTTPS certificate managed dynamically func (s *Server) loadHTTPSConfiguration(configurations types.Configurations, defaultEntryPoints configuration.DefaultEntryPoints) (map[string]map[string]*tls.Certificate, error) { newEPCertificates := make(map[string]map[string]*tls.Certificate) @@ -693,14 +689,8 @@ func (s *Server) createTLSConfig(entryPointName string, tlsOption *traefiktls.TL // in each certificate and populates the config.NameToCertificate map. config.BuildNameToCertificate() - if acme.IsEnabled() { - if entryPointName == acme.Get().EntryPoint { - acme.Get().SetStaticCertificates(config.NameToCertificate) - acme.Get().SetDynamicCertificates(&s.serverEntryPoints[entryPointName].certs) - if acme.Get().OnDemand { - s.serverEntryPoints[entryPointName].SetOnDemandListener(acme.Get().ListenRequest) - } - } + if s.entryPoints[entryPointName].CertificateStore != nil { + s.entryPoints[entryPointName].CertificateStore.StaticCerts.Set(config.NameToCertificate) } // Set the minimum TLS version if set in the config TOML @@ -839,9 +829,13 @@ func buildServerTimeouts(globalConfig configuration.GlobalConfiguration) (readTi func (s *Server) buildEntryPoints() map[string]*serverEntryPoint { serverEntryPoints := make(map[string]*serverEntryPoint) - for entryPointName := range s.entryPoints { + for entryPointName, entryPoint := range s.entryPoints { serverEntryPoints[entryPointName] = &serverEntryPoint{ - httpRouter: middlewares.NewHandlerSwitcher(s.buildDefaultHTTPRouter()), + httpRouter: middlewares.NewHandlerSwitcher(s.buildDefaultHTTPRouter()), + onDemandListener: entryPoint.OnDemandListener, + } + if entryPoint.CertificateStore != nil { + serverEntryPoints[entryPointName].certs = *entryPoint.CertificateStore.DynamicCerts } } return serverEntryPoints diff --git a/tls/certificate_store.go b/tls/certificate_store.go new file mode 100644 index 000000000..c70e1710f --- /dev/null +++ b/tls/certificate_store.go @@ -0,0 +1,33 @@ +package tls + +import ( + "crypto/tls" + + "github.com/containous/traefik/safe" +) + +// CertificateStore store for dynamic and static certificates +type CertificateStore struct { + DynamicCerts *safe.Safe + StaticCerts *safe.Safe +} + +// GetAllDomains return a slice with all the certificate domain +func (c CertificateStore) GetAllDomains() []string { + var allCerts []string + + // Get static certificates + if c.StaticCerts != nil && c.StaticCerts.Get() != nil { + for domains := range c.StaticCerts.Get().(map[string]*tls.Certificate) { + allCerts = append(allCerts, domains) + } + } + + // Get dynamic certificates + if c.DynamicCerts != nil && c.DynamicCerts.Get() != nil { + for domains := range c.DynamicCerts.Get().(map[string]*tls.Certificate) { + allCerts = append(allCerts, domains) + } + } + return allCerts +}