From d81c4e6d1a36f8289f2acb2980f81fbb0d69d555 Mon Sep 17 00:00:00 2001 From: NicoMen Date: Mon, 20 Aug 2018 09:40:03 +0200 Subject: [PATCH 1/9] Avoid duplicated ACME resolution --- acme/acme.go | 39 ++++++++++++++++ acme/acme_test.go | 10 ++++- .../acme/manage_acme_docker_environment.sh | 2 +- provider/acme/provider.go | 37 ++++++++++++++- provider/acme/provider_test.go | 45 +++++++++++++++++-- 5 files changed, 126 insertions(+), 7 deletions(-) diff --git a/acme/acme.go b/acme/acme.go index e31bbbdc3..8a06b823d 100644 --- a/acme/acme.go +++ b/acme/acme.go @@ -11,6 +11,7 @@ import ( "net/http" "reflect" "strings" + "sync" "time" "github.com/BurntSushi/ty/fun" @@ -61,6 +62,8 @@ type ACME struct { jobs *channels.InfiniteChannel TLSConfig *tls.Config `description:"TLS config in case wildcard certs are used"` dynamicCerts *safe.Safe + resolvingDomains map[string]struct{} + resolvingDomainsMutex sync.RWMutex } func (a *ACME) init() error { @@ -81,6 +84,10 @@ func (a *ACME) init() error { a.defaultCertificate = cert a.jobs = channels.NewInfiniteChannel() + + // Init the currently resolved domain map + a.resolvingDomains = make(map[string]struct{}) + return nil } @@ -502,6 +509,10 @@ func (a *ACME) LoadCertificateForDomains(domains []string) { if len(uncheckedDomains) == 0 { return } + + a.addResolvingDomains(uncheckedDomains) + defer a.removeResolvingDomains(uncheckedDomains) + certificate, err := a.getDomainsCertificates(uncheckedDomains) if err != nil { log.Errorf("Error getting ACME certificates %+v : %v", uncheckedDomains, err) @@ -533,6 +544,24 @@ func (a *ACME) LoadCertificateForDomains(domains []string) { } } +func (a *ACME) addResolvingDomains(resolvingDomains []string) { + a.resolvingDomainsMutex.Lock() + defer a.resolvingDomainsMutex.Unlock() + + for _, domain := range resolvingDomains { + a.resolvingDomains[domain] = struct{}{} + } +} + +func (a *ACME) removeResolvingDomains(resolvingDomains []string) { + a.resolvingDomainsMutex.Lock() + defer a.resolvingDomainsMutex.Unlock() + + for _, domain := range resolvingDomains { + delete(a.resolvingDomains, domain) + } +} + // Get provided certificate which check a domains list (Main and SANs) // from static and dynamic provided certificates func (a *ACME) getProvidedCertificate(domains string) *tls.Certificate { @@ -568,6 +597,9 @@ func searchProvidedCertificateForDomains(domain string, certs map[string]*tls.Ce // Get provided certificate which check a domains list (Main and SANs) // from static and dynamic provided certificates func (a *ACME) getUncheckedDomains(domains []string, account *Account) []string { + a.resolvingDomainsMutex.RLock() + defer a.resolvingDomainsMutex.RUnlock() + log.Debugf("Looking for provided certificate to validate %s...", domains) allCerts := make(map[string]*tls.Certificate) @@ -590,6 +622,13 @@ func (a *ACME) getUncheckedDomains(domains []string, account *Account) []string } } + // Get currently resolved domains + for domain := range a.resolvingDomains { + if _, ok := allCerts[domain]; !ok { + allCerts[domain] = &tls.Certificate{} + } + } + // Get Configuration Domains for i := 0; i < len(a.Domains); i++ { allCerts[a.Domains[i].Main] = &tls.Certificate{} diff --git a/acme/acme_test.go b/acme/acme_test.go index 9e3d2ace4..aadfa17b6 100644 --- a/acme/acme_test.go +++ b/acme/acme_test.go @@ -331,9 +331,12 @@ func TestAcme_getUncheckedCertificates(t *testing.T) { mm["*.containo.us"] = &tls.Certificate{} mm["traefik.acme.io"] = &tls.Certificate{} - a := ACME{TLSConfig: &tls.Config{NameToCertificate: mm}} + dm := make(map[string]struct{}) + dm["*.traefik.wtf"] = struct{}{} - domains := []string{"traefik.containo.us", "trae.containo.us"} + a := ACME{TLSConfig: &tls.Config{NameToCertificate: mm}, resolvingDomains: dm} + + domains := []string{"traefik.containo.us", "trae.containo.us", "foo.traefik.wtf"} uncheckedDomains := a.getUncheckedDomains(domains, nil) assert.Empty(t, uncheckedDomains) domains = []string{"traefik.acme.io", "trae.acme.io"} @@ -351,6 +354,9 @@ func TestAcme_getUncheckedCertificates(t *testing.T) { account := Account{DomainsCertificate: domainsCertificates} uncheckedDomains = a.getUncheckedDomains(domains, &account) assert.Empty(t, uncheckedDomains) + domains = []string{"traefik.containo.us", "trae.containo.us", "traefik.wtf"} + uncheckedDomains = a.getUncheckedDomains(domains, nil) + assert.Len(t, uncheckedDomains, 1) } func TestAcme_getProvidedCertificate(t *testing.T) { diff --git a/examples/acme/manage_acme_docker_environment.sh b/examples/acme/manage_acme_docker_environment.sh index a95483c9e..58cf73362 100755 --- a/examples/acme/manage_acme_docker_environment.sh +++ b/examples/acme/manage_acme_docker_environment.sh @@ -50,7 +50,7 @@ start_boulder() { # Script usage show_usage() { echo - echo "USAGE : manage_acme_docker_environment.sh [--start|--stop|--restart]" + echo "USAGE : manage_acme_docker_environment.sh [--dev|--start|--stop|--restart]" echo } diff --git a/provider/acme/provider.go b/provider/acme/provider.go index 958a77707..74adf62d4 100644 --- a/provider/acme/provider.go +++ b/provider/acme/provider.go @@ -63,6 +63,8 @@ type Provider struct { clientMutex sync.Mutex configFromListenerChan chan types.Configuration pool *safe.Pool + resolvingDomains map[string]struct{} + resolvingDomainsMutex sync.RWMutex } // Certificate is a struct which contains all data needed from an ACME certificate @@ -127,6 +129,9 @@ func (p *Provider) init() error { return fmt.Errorf("unable to get ACME certificates : %v", err) } + // Init the currently resolved domain map + p.resolvingDomains = make(map[string]struct{}) + p.watchCertificate() p.watchNewDomains() @@ -226,6 +231,9 @@ func (p *Provider) resolveCertificate(domain types.Domain, domainFromConfigurati return nil, nil } + p.addResolvingDomains(uncheckedDomains) + defer p.removeResolvingDomains(uncheckedDomains) + log.Debugf("Loading ACME certificates %+v...", uncheckedDomains) client, err := p.getClient() if err != nil { @@ -254,6 +262,24 @@ func (p *Provider) resolveCertificate(domain types.Domain, domainFromConfigurati return certificate, nil } +func (p *Provider) removeResolvingDomains(resolvingDomains []string) { + p.resolvingDomainsMutex.Lock() + defer p.resolvingDomainsMutex.Unlock() + + for _, domain := range resolvingDomains { + delete(p.resolvingDomains, domain) + } +} + +func (p *Provider) addResolvingDomains(resolvingDomains []string) { + p.resolvingDomainsMutex.Lock() + defer p.resolvingDomainsMutex.Unlock() + + for _, domain := range resolvingDomains { + p.resolvingDomains[domain] = struct{}{} + } +} + func (p *Provider) getClient() (*acme.Client, error) { p.clientMutex.Lock() defer p.clientMutex.Unlock() @@ -503,6 +529,9 @@ func (p *Provider) AddRoutes(router *mux.Router) { // Get provided certificate which check a domains list (Main and SANs) // from static and dynamic provided certificates func (p *Provider) getUncheckedDomains(domainsToCheck []string, checkConfigurationDomains bool) []string { + p.resolvingDomainsMutex.RLock() + defer p.resolvingDomainsMutex.RUnlock() + log.Debugf("Looking for provided certificate(s) to validate %q...", domainsToCheck) var allCerts []string @@ -523,6 +552,11 @@ func (p *Provider) getUncheckedDomains(domainsToCheck []string, checkConfigurati allCerts = append(allCerts, strings.Join(certificate.Domain.ToStrArray(), ",")) } + // Get currently resolved domains + for domain := range p.resolvingDomains { + allCerts = append(allCerts, domain) + } + // Get Configuration Domains if checkConfigurationDomains { for i := 0; i < len(p.Domains); i++ { @@ -540,8 +574,9 @@ func searchUncheckedDomains(domainsToCheck []string, existentDomains []string) [ uncheckedDomains = append(uncheckedDomains, domainToCheck) } } + if len(uncheckedDomains) == 0 { - log.Debugf("No ACME certificate to generate for domains %q.", domainsToCheck) + log.Debugf("No ACME certificate generation required for domains %q.", domainsToCheck) } else { log.Debugf("Domains %q need ACME certificates generation for domains %q.", domainsToCheck, strings.Join(uncheckedDomains, ",")) } diff --git a/provider/acme/provider_test.go b/provider/acme/provider_test.go index 2f6a3db96..5e79d60b8 100644 --- a/provider/acme/provider_test.go +++ b/provider/acme/provider_test.go @@ -26,6 +26,7 @@ func TestGetUncheckedCertificates(t *testing.T) { desc string dynamicCerts *safe.Safe staticCerts map[string]*tls.Certificate + resolvingDomains map[string]struct{} acmeCertificates []*Certificate domains []string expectedDomains []string @@ -138,17 +139,55 @@ func TestGetUncheckedCertificates(t *testing.T) { }, expectedDomains: []string{"traefik.wtf"}, }, + { + desc: "all domains already managed by ACME", + domains: []string{"traefik.wtf", "foo.traefik.wtf"}, + resolvingDomains: map[string]struct{}{ + "traefik.wtf": {}, + "foo.traefik.wtf": {}, + }, + expectedDomains: []string{}, + }, + { + desc: "one domain already managed by ACME", + domains: []string{"traefik.wtf", "foo.traefik.wtf"}, + resolvingDomains: map[string]struct{}{ + "traefik.wtf": {}, + }, + expectedDomains: []string{"foo.traefik.wtf"}, + }, + { + desc: "wildcard domain already managed by ACME checks the domains", + domains: []string{"bar.traefik.wtf", "foo.traefik.wtf"}, + resolvingDomains: map[string]struct{}{ + "*.traefik.wtf": {}, + }, + expectedDomains: []string{}, + }, + { + desc: "wildcard domain already managed by ACME checks domains and another domain checks one other domain, one domain still unchecked", + domains: []string{"traefik.wtf", "bar.traefik.wtf", "foo.traefik.wtf", "acme.wtf"}, + resolvingDomains: map[string]struct{}{ + "*.traefik.wtf": {}, + "traefik.wtf": {}, + }, + expectedDomains: []string{"acme.wtf"}, + }, } for _, test := range testCases { test := test + if test.resolvingDomains == nil { + test.resolvingDomains = make(map[string]struct{}) + } t.Run(test.desc, func(t *testing.T) { t.Parallel() acmeProvider := Provider{ - dynamicCerts: test.dynamicCerts, - staticCerts: test.staticCerts, - certificates: test.acmeCertificates, + dynamicCerts: test.dynamicCerts, + staticCerts: test.staticCerts, + certificates: test.acmeCertificates, + resolvingDomains: test.resolvingDomains, } domains := acmeProvider.getUncheckedDomains(test.domains, false) From 07be89d6e937baf00df30a7f254d19531aa6d45e Mon Sep 17 00:00:00 2001 From: SALLEYRON Julien Date: Mon, 20 Aug 2018 10:38:03 +0200 Subject: [PATCH 2/9] Update oxy dependency --- Gopkg.lock | 2 +- server/errorhandler.go | 20 ++- server/server.go | 117 ++++++++++++++++-- .../github.com/vulcand/oxy/buffer/buffer.go | 57 +++++---- .../vulcand/oxy/buffer/threshold.go | 1 + .../vulcand/oxy/cbreaker/cbreaker.go | 40 +++--- .../github.com/vulcand/oxy/cbreaker/effect.go | 18 ++- .../vulcand/oxy/cbreaker/fallback.go | 39 ++++-- .../vulcand/oxy/cbreaker/predicates.go | 3 +- .../github.com/vulcand/oxy/cbreaker/ratio.go | 12 +- .../vulcand/oxy/connlimit/connlimit.go | 36 ++++-- vendor/github.com/vulcand/oxy/forward/fwd.go | 84 ++++++++++--- .../github.com/vulcand/oxy/forward/headers.go | 6 +- .../github.com/vulcand/oxy/forward/rewrite.go | 7 +- .../vulcand/oxy/memmetrics/anomaly.go | 6 +- .../vulcand/oxy/memmetrics/counter.go | 12 +- .../vulcand/oxy/memmetrics/histogram.go | 30 +++-- .../vulcand/oxy/memmetrics/ratio.go | 17 +++ .../vulcand/oxy/memmetrics/roundtrip.go | 32 +++-- .../vulcand/oxy/ratelimit/bucket.go | 9 +- .../vulcand/oxy/ratelimit/bucketset.go | 8 +- .../vulcand/oxy/ratelimit/tokenlimiter.go | 49 ++++++-- .../oxy/roundrobin/RequestRewriteListener.go | 1 + .../vulcand/oxy/roundrobin/rebalancer.go | 62 +++++++--- .../github.com/vulcand/oxy/roundrobin/rr.go | 41 ++++-- .../vulcand/oxy/roundrobin/stickysessions.go | 9 +- vendor/github.com/vulcand/oxy/utils/auth.go | 2 + .../github.com/vulcand/oxy/utils/dumpreq.go | 14 ++- .../github.com/vulcand/oxy/utils/handler.go | 30 ++++- .../github.com/vulcand/oxy/utils/netutils.go | 55 +++++--- vendor/github.com/vulcand/oxy/utils/source.go | 12 +- 31 files changed, 636 insertions(+), 195 deletions(-) diff --git a/Gopkg.lock b/Gopkg.lock index 3cfee2ff8..7371071c5 100644 --- a/Gopkg.lock +++ b/Gopkg.lock @@ -1217,7 +1217,7 @@ "roundrobin", "utils" ] - revision = "c2414f4542f085363f490048da2fbec5e4537eb6" + revision = "885e42fe04d8e0efa6c18facad4e0fc5757cde9b" [[projects]] name = "github.com/vulcand/predicate" diff --git a/server/errorhandler.go b/server/errorhandler.go index 732b2cc0e..de58045ed 100644 --- a/server/errorhandler.go +++ b/server/errorhandler.go @@ -1,6 +1,7 @@ package server import ( + "context" "io" "net" "net/http" @@ -9,6 +10,12 @@ import ( "github.com/containous/traefik/middlewares" ) +// StatusClientClosedRequest non-standard HTTP status code for client disconnection +const StatusClientClosedRequest = 499 + +// StatusClientClosedRequestText non-standard HTTP status for client disconnection +const StatusClientClosedRequestText = "Client Closed Request" + // RecordingErrorHandler is an error handler, implementing the vulcand/oxy // error handler interface, which is recording network errors by using the netErrorRecorder. // In addition it sets a proper HTTP status code and body, depending on the type of error occurred. @@ -34,9 +41,18 @@ func (eh *RecordingErrorHandler) ServeHTTP(w http.ResponseWriter, req *http.Requ } else if err == io.EOF { eh.netErrorRecorder.Record(req.Context()) statusCode = http.StatusBadGateway + } else if err == context.Canceled { + statusCode = StatusClientClosedRequest } w.WriteHeader(statusCode) - w.Write([]byte(http.StatusText(statusCode))) - log.Debugf("'%d %s' caused by: %v", statusCode, http.StatusText(statusCode), err) + w.Write([]byte(statusText(statusCode))) + log.Debugf("'%d %s' caused by: %v", statusCode, statusText(statusCode), err) +} + +func statusText(statusCode int) string { + if statusCode == StatusClientClosedRequest { + return StatusClientClosedRequestText + } + return http.StatusText(statusCode) } diff --git a/server/server.go b/server/server.go index c15fb6927..5db244e92 100644 --- a/server/server.go +++ b/server/server.go @@ -80,14 +80,96 @@ type Server struct { bufferPool httputil.BufferPool } +func newHijackConnectionTracker() *hijackConnectionTracker { + return &hijackConnectionTracker{ + conns: make(map[net.Conn]struct{}), + } +} + +type hijackConnectionTracker struct { + conns map[net.Conn]struct{} + lock sync.RWMutex +} + +// AddHijackedConnection add a connection in the tracked connections list +func (h *hijackConnectionTracker) AddHijackedConnection(conn net.Conn) { + h.lock.Lock() + defer h.lock.Unlock() + h.conns[conn] = struct{}{} +} + +// RemoveHijackedConnection remove a connection from the tracked connections list +func (h *hijackConnectionTracker) RemoveHijackedConnection(conn net.Conn) { + h.lock.Lock() + defer h.lock.Unlock() + delete(h.conns, conn) +} + +// Shutdown wait for the connection closing +func (h *hijackConnectionTracker) Shutdown(ctx context.Context) error { + ticker := time.NewTicker(500 * time.Millisecond) + defer ticker.Stop() + for { + h.lock.RLock() + if len(h.conns) == 0 { + return nil + } + h.lock.RUnlock() + select { + case <-ctx.Done(): + return ctx.Err() + case <-ticker.C: + } + } +} + +// Close close all the connections in the tracked connections list +func (h *hijackConnectionTracker) Close() { + for conn := range h.conns { + if err := conn.Close(); err != nil { + log.Errorf("Error while closing Hijacked conn: %v", err) + } + delete(h.conns, conn) + } +} + type serverEntryPoints map[string]*serverEntryPoint type serverEntryPoint struct { - httpServer *http.Server - listener net.Listener - httpRouter *middlewares.HandlerSwitcher - certs safe.Safe - onDemandListener func(string) (*tls.Certificate, error) + httpServer *http.Server + listener net.Listener + httpRouter *middlewares.HandlerSwitcher + certs safe.Safe + onDemandListener func(string) (*tls.Certificate, error) + hijackConnectionTracker *hijackConnectionTracker +} + +func (s serverEntryPoint) Shutdown(ctx context.Context) { + var wg sync.WaitGroup + wg.Add(1) + go func() { + defer wg.Done() + if err := s.httpServer.Shutdown(ctx); err != nil { + if ctx.Err() == context.DeadlineExceeded { + log.Debugf("Wait server shutdown is over due to: %s", err) + err = s.httpServer.Close() + if err != nil { + log.Error(err) + } + } + } + }() + wg.Add(1) + go func() { + defer wg.Done() + if err := s.hijackConnectionTracker.Shutdown(ctx); err != nil { + if ctx.Err() == context.DeadlineExceeded { + log.Debugf("Wait hijack connection is over due to: %s", err) + s.hijackConnectionTracker.Close() + } + } + }() + wg.Wait() } // tcpKeepAliveListener sets TCP keep-alive timeouts on accepted @@ -260,10 +342,7 @@ func (s *Server) Stop() { graceTimeOut := time.Duration(s.globalConfiguration.LifeCycle.GraceTimeOut) ctx, cancel := context.WithTimeout(context.Background(), graceTimeOut) log.Debugf("Waiting %s seconds before killing connections on entrypoint %s...", graceTimeOut, serverEntryPointName) - if err := serverEntryPoint.httpServer.Shutdown(ctx); err != nil { - log.Debugf("Wait is over due to: %s", err) - serverEntryPoint.httpServer.Close() - } + serverEntryPoint.Shutdown(ctx) cancel() log.Debugf("Entrypoint %s closed", serverEntryPointName) }(sepn, sep) @@ -376,9 +455,20 @@ func (s *Server) setupServerEntryPoint(newServerEntryPointName string, newServer log.Fatal("Error preparing server: ", err) } serverEntryPoint := s.serverEntryPoints[newServerEntryPointName] + serverEntryPoint.httpServer = newSrv serverEntryPoint.listener = listener + serverEntryPoint.hijackConnectionTracker = newHijackConnectionTracker() + serverEntryPoint.httpServer.ConnState = func(conn net.Conn, state http.ConnState) { + switch state { + case http.StateHijacked: + serverEntryPoint.hijackConnectionTracker.AddHijackedConnection(conn) + case http.StateClosed: + serverEntryPoint.hijackConnectionTracker.RemoveHijackedConnection(conn) + } + } + return serverEntryPoint } @@ -1025,6 +1115,15 @@ func (s *Server) loadConfig(configurations types.Configurations, globalConfigura forward.Rewriter(rewriter), forward.ResponseModifier(responseModifier), forward.BufferPool(s.bufferPool), + forward.WebsocketConnectionClosedHook(func(req *http.Request, conn net.Conn) { + server := req.Context().Value(http.ServerContextKey).(*http.Server) + if server != nil { + connState := server.ConnState + if connState != nil { + connState(conn, http.StateClosed) + } + } + }), ) if err != nil { diff --git a/vendor/github.com/vulcand/oxy/buffer/buffer.go b/vendor/github.com/vulcand/oxy/buffer/buffer.go index e3ee40b33..d2bbe40ce 100644 --- a/vendor/github.com/vulcand/oxy/buffer/buffer.go +++ b/vendor/github.com/vulcand/oxy/buffer/buffer.go @@ -36,13 +36,12 @@ Examples of a buffering middleware: package buffer import ( + "bufio" "fmt" "io" "io/ioutil" - "net/http" - - "bufio" "net" + "net/http" "reflect" "github.com/mailgun/multibuf" @@ -74,6 +73,8 @@ type Buffer struct { next http.Handler errHandler utils.ErrorHandler + + log *log.Logger } // New returns a new buffer middleware. New() function supports optional functional arguments @@ -86,6 +87,8 @@ func New(next http.Handler, setters ...optSetter) (*Buffer, error) { maxResponseBodyBytes: DefaultMaxBodyBytes, memResponseBodyBytes: DefaultMemBodyBytes, + + log: log.StandardLogger(), } for _, s := range setters { if err := s(strm); err != nil { @@ -99,6 +102,16 @@ func New(next http.Handler, setters ...optSetter) (*Buffer, error) { return strm, nil } +// Logger defines the logger the buffer will use. +// +// It defaults to logrus.StandardLogger(), the global logger used by logrus. +func Logger(l *log.Logger) optSetter { + return func(b *Buffer) error { + b.log = l + return nil + } +} + type optSetter func(b *Buffer) error // CondSetter Conditional setter. @@ -154,7 +167,7 @@ func MaxRequestBodyBytes(m int64) optSetter { } } -// MaxRequestBody bytes sets the maximum request body to be stored in memory +// MemRequestBodyBytes bytes sets the maximum request body to be stored in memory // buffer middleware will serialize the excess to disk. func MemRequestBodyBytes(m int64) optSetter { return func(b *Buffer) error { @@ -196,8 +209,8 @@ func (b *Buffer) Wrap(next http.Handler) error { } func (b *Buffer) ServeHTTP(w http.ResponseWriter, req *http.Request) { - if log.GetLevel() >= log.DebugLevel { - logEntry := log.WithField("Request", utils.DumpHttpRequest(req)) + if b.log.Level >= log.DebugLevel { + logEntry := b.log.WithField("Request", utils.DumpHttpRequest(req)) logEntry.Debug("vulcand/oxy/buffer: begin ServeHttp on request") defer logEntry.Debug("vulcand/oxy/buffer: completed ServeHttp on request") } @@ -210,11 +223,11 @@ func (b *Buffer) ServeHTTP(w http.ResponseWriter, req *http.Request) { // Read the body while keeping limits in mind. This reader controls the maximum bytes // to read into memory and disk. This reader returns an error if the total request size exceeds the - // prefefined MaxSizeBytes. This can occur if we got chunked request, in this case ContentLength would be set to -1 + // predefined MaxSizeBytes. This can occur if we got chunked request, in this case ContentLength would be set to -1 // and the reader would be unbounded bufio in the http.Server body, err := multibuf.New(req.Body, multibuf.MaxBytes(b.maxRequestBodyBytes), multibuf.MemBytes(b.memRequestBodyBytes)) if err != nil || body == nil { - log.Errorf("vulcand/oxy/buffer: error when reading request body, err: %v", err) + b.log.Errorf("vulcand/oxy/buffer: error when reading request body, err: %v", err) b.errHandler.ServeHTTP(w, req, err) return } @@ -235,7 +248,7 @@ func (b *Buffer) ServeHTTP(w http.ResponseWriter, req *http.Request) { // set without content length or using chunked TransferEncoding totalSize, err := body.Size() if err != nil { - log.Errorf("vulcand/oxy/buffer: failed to get request size, err: %v", err) + b.log.Errorf("vulcand/oxy/buffer: failed to get request size, err: %v", err) b.errHandler.ServeHTTP(w, req, err) return } @@ -251,7 +264,7 @@ func (b *Buffer) ServeHTTP(w http.ResponseWriter, req *http.Request) { // We create a special writer that will limit the response size, buffer it to disk if necessary writer, err := multibuf.NewWriterOnce(multibuf.MaxBytes(b.maxResponseBodyBytes), multibuf.MemBytes(b.memResponseBodyBytes)) if err != nil { - log.Errorf("vulcand/oxy/buffer: failed create response writer, err: %v", err) + b.log.Errorf("vulcand/oxy/buffer: failed create response writer, err: %v", err) b.errHandler.ServeHTTP(w, req, err) return } @@ -261,12 +274,13 @@ func (b *Buffer) ServeHTTP(w http.ResponseWriter, req *http.Request) { header: make(http.Header), buffer: writer, responseWriter: w, + log: b.log, } defer bw.Close() b.next.ServeHTTP(bw, outreq) if bw.hijacked { - log.Debugf("vulcand/oxy/buffer: connection was hijacked downstream. Not taking any action in buffer.") + b.log.Debugf("vulcand/oxy/buffer: connection was hijacked downstream. Not taking any action in buffer.") return } @@ -274,7 +288,7 @@ func (b *Buffer) ServeHTTP(w http.ResponseWriter, req *http.Request) { if bw.expectBody(outreq) { rdr, err := writer.Reader() if err != nil { - log.Errorf("vulcand/oxy/buffer: failed to read response, err: %v", err) + b.log.Errorf("vulcand/oxy/buffer: failed to read response, err: %v", err) b.errHandler.ServeHTTP(w, req, err) return } @@ -292,17 +306,17 @@ func (b *Buffer) ServeHTTP(w http.ResponseWriter, req *http.Request) { return } - attempt += 1 + attempt++ if body != nil { if _, err := body.Seek(0, 0); err != nil { - log.Errorf("vulcand/oxy/buffer: failed to rewind response body, err: %v", err) + b.log.Errorf("vulcand/oxy/buffer: failed to rewind response body, err: %v", err) b.errHandler.ServeHTTP(w, req, err) return } } outreq = b.copyRequest(req, body, totalSize) - log.Debugf("vulcand/oxy/buffer: retry Request(%v %v) attempt %v", req.Method, req.URL, attempt) + b.log.Debugf("vulcand/oxy/buffer: retry Request(%v %v) attempt %v", req.Method, req.URL, attempt) } } @@ -339,6 +353,7 @@ type bufferWriter struct { buffer multibuf.WriterOnce responseWriter http.ResponseWriter hijacked bool + log *log.Logger } // RFC2616 #4.4 @@ -376,16 +391,16 @@ func (b *bufferWriter) WriteHeader(code int) { b.code = code } -//CloseNotifier interface - this allows downstream connections to be terminated when the client terminates. +// CloseNotifier interface - this allows downstream connections to be terminated when the client terminates. func (b *bufferWriter) CloseNotify() <-chan bool { if cn, ok := b.responseWriter.(http.CloseNotifier); ok { return cn.CloseNotify() } - log.Warningf("Upstream ResponseWriter of type %v does not implement http.CloseNotifier. Returning dummy channel.", reflect.TypeOf(b.responseWriter)) + b.log.Warningf("Upstream ResponseWriter of type %v does not implement http.CloseNotifier. Returning dummy channel.", reflect.TypeOf(b.responseWriter)) return make(<-chan bool) } -//This allows connections to be hijacked for websockets for instance. +// Hijack This allows connections to be hijacked for websockets for instance. func (b *bufferWriter) Hijack() (net.Conn, *bufio.ReadWriter, error) { if hi, ok := b.responseWriter.(http.Hijacker); ok { conn, rw, err := hi.Hijack() @@ -394,12 +409,12 @@ func (b *bufferWriter) Hijack() (net.Conn, *bufio.ReadWriter, error) { } return conn, rw, err } - log.Warningf("Upstream ResponseWriter of type %v does not implement http.Hijacker. Returning dummy channel.", reflect.TypeOf(b.responseWriter)) + b.log.Warningf("Upstream ResponseWriter of type %v does not implement http.Hijacker. Returning dummy channel.", reflect.TypeOf(b.responseWriter)) return nil, nil, fmt.Errorf("The response writer that was wrapped in this proxy, does not implement http.Hijacker. It is of type: %v", reflect.TypeOf(b.responseWriter)) } -type SizeErrHandler struct { -} +// SizeErrHandler Size error handler +type SizeErrHandler struct{} func (e *SizeErrHandler) ServeHTTP(w http.ResponseWriter, req *http.Request, err error) { if _, ok := err.(*multibuf.MaxSizeReachedError); ok { diff --git a/vendor/github.com/vulcand/oxy/buffer/threshold.go b/vendor/github.com/vulcand/oxy/buffer/threshold.go index 1294bcd60..0fdde7da4 100644 --- a/vendor/github.com/vulcand/oxy/buffer/threshold.go +++ b/vendor/github.com/vulcand/oxy/buffer/threshold.go @@ -7,6 +7,7 @@ import ( "github.com/vulcand/predicate" ) +// IsValidExpression check if it's a valid expression func IsValidExpression(expr string) bool { _, err := parseExpression(expr) return err == nil diff --git a/vendor/github.com/vulcand/oxy/cbreaker/cbreaker.go b/vendor/github.com/vulcand/oxy/cbreaker/cbreaker.go index 5991a8474..4c35f1365 100644 --- a/vendor/github.com/vulcand/oxy/cbreaker/cbreaker.go +++ b/vendor/github.com/vulcand/oxy/cbreaker/cbreaker.go @@ -3,7 +3,7 @@ // Vulcan circuit breaker watches the error condtion to match // after which it activates the fallback scenario, e.g. returns the response code // or redirects the request to another location - +// // Circuit breakers start in the Standby state first, observing responses and watching location metrics. // // Once the Circuit breaker condition is met, it enters the "Tripped" state, where it activates fallback scenario @@ -31,9 +31,8 @@ import ( "sync" "time" - log "github.com/sirupsen/logrus" - "github.com/mailgun/timetools" + log "github.com/sirupsen/logrus" "github.com/vulcand/oxy/memmetrics" "github.com/vulcand/oxy/utils" ) @@ -63,6 +62,8 @@ type CircuitBreaker struct { next http.Handler clock timetools.TimeProvider + + log *log.Logger } // New creates a new CircuitBreaker middleware @@ -76,6 +77,7 @@ func New(next http.Handler, expression string, options ...CircuitBreakerOption) fallbackDuration: defaultFallbackDuration, recoveryDuration: defaultRecoveryDuration, fallback: defaultFallback, + log: log.StandardLogger(), } for _, s := range options { @@ -99,9 +101,19 @@ func New(next http.Handler, expression string, options ...CircuitBreakerOption) return cb, nil } +// Logger defines the logger the circuit breaker will use. +// +// It defaults to logrus.StandardLogger(), the global logger used by logrus. +func Logger(l *log.Logger) CircuitBreakerOption { + return func(c *CircuitBreaker) error { + c.log = l + return nil + } +} + func (c *CircuitBreaker) ServeHTTP(w http.ResponseWriter, req *http.Request) { - if log.GetLevel() >= log.DebugLevel { - logEntry := log.WithField("Request", utils.DumpHttpRequest(req)) + if c.log.Level >= log.DebugLevel { + logEntry := c.log.WithField("Request", utils.DumpHttpRequest(req)) logEntry.Debug("vulcand/oxy/circuitbreaker: begin ServeHttp on request") defer logEntry.Debug("vulcand/oxy/circuitbreaker: completed ServeHttp on request") } @@ -112,6 +124,7 @@ func (c *CircuitBreaker) ServeHTTP(w http.ResponseWriter, req *http.Request) { c.serve(w, req) } +// Wrap sets the next handler to be called by circuit breaker handler. func (c *CircuitBreaker) Wrap(next http.Handler) { c.next = next } @@ -126,7 +139,7 @@ func (c *CircuitBreaker) activateFallback(w http.ResponseWriter, req *http.Reque c.m.Lock() defer c.m.Unlock() - log.Warnf("%v is in error state", c) + c.log.Warnf("%v is in error state", c) switch c.state { case stateStandby: @@ -156,7 +169,7 @@ func (c *CircuitBreaker) activateFallback(w http.ResponseWriter, req *http.Reque func (c *CircuitBreaker) serve(w http.ResponseWriter, req *http.Request) { start := c.clock.UtcNow() - p := utils.NewProxyWriter(w) + p := utils.NewProxyWriterWithLogger(w, c.log) c.next.ServeHTTP(p, req) @@ -191,13 +204,13 @@ func (c *CircuitBreaker) exec(s SideEffect) { } go func() { if err := s.Exec(); err != nil { - log.Errorf("%v side effect failure: %v", c, err) + c.log.Errorf("%v side effect failure: %v", c, err) } }() } func (c *CircuitBreaker) setState(new cbState, until time.Time) { - log.Debugf("%v setting state to %v, until %v", c, new, until) + c.log.Debugf("%v setting state to %v, until %v", c, new, until) c.state = new c.until = until switch new { @@ -230,7 +243,7 @@ func (c *CircuitBreaker) checkAndSet() { c.lastCheck = c.clock.UtcNow().Add(c.checkPeriod) if c.state == stateTripped { - log.Debugf("%v skip set tripped", c) + c.log.Debugf("%v skip set tripped", c) return } @@ -244,7 +257,7 @@ func (c *CircuitBreaker) checkAndSet() { func (c *CircuitBreaker) setRecovering() { c.setState(stateRecovering, c.clock.UtcNow().Add(c.recoveryDuration)) - c.rc = newRatioController(c.clock, c.recoveryDuration) + c.rc = newRatioController(c.clock, c.recoveryDuration, c.log) } // CircuitBreakerOption represents an option you can pass to New. @@ -296,7 +309,7 @@ func OnTripped(s SideEffect) CircuitBreakerOption { } } -// OnTripped sets a SideEffect to run when entering the Standby state. +// OnStandby sets a SideEffect to run when entering the Standby state. // Only one SideEffect can be set for this hook. func OnStandby(s SideEffect) CircuitBreakerOption { return func(c *CircuitBreaker) error { @@ -346,8 +359,7 @@ const ( var defaultFallback = &fallback{} -type fallback struct { -} +type fallback struct{} func (f *fallback) ServeHTTP(w http.ResponseWriter, req *http.Request) { w.WriteHeader(http.StatusServiceUnavailable) diff --git a/vendor/github.com/vulcand/oxy/cbreaker/effect.go b/vendor/github.com/vulcand/oxy/cbreaker/effect.go index 821115491..88aae1426 100644 --- a/vendor/github.com/vulcand/oxy/cbreaker/effect.go +++ b/vendor/github.com/vulcand/oxy/cbreaker/effect.go @@ -13,10 +13,12 @@ import ( "github.com/vulcand/oxy/utils" ) +// SideEffect a side effect type SideEffect interface { Exec() error } +// Webhook Web hook type Webhook struct { URL string Method string @@ -25,11 +27,15 @@ type Webhook struct { Body []byte } +// WebhookSideEffect a web hook side effect type WebhookSideEffect struct { w Webhook + + log *log.Logger } -func NewWebhookSideEffect(w Webhook) (*WebhookSideEffect, error) { +// NewWebhookSideEffectsWithLogger creates a new WebhookSideEffect +func NewWebhookSideEffectsWithLogger(w Webhook, l *log.Logger) (*WebhookSideEffect, error) { if w.Method == "" { return nil, fmt.Errorf("Supply method") } @@ -38,7 +44,12 @@ func NewWebhookSideEffect(w Webhook) (*WebhookSideEffect, error) { return nil, err } - return &WebhookSideEffect{w: w}, nil + return &WebhookSideEffect{w: w, log: l}, nil +} + +// NewWebhookSideEffect creates a new WebhookSideEffect +func NewWebhookSideEffect(w Webhook) (*WebhookSideEffect, error) { + return NewWebhookSideEffectsWithLogger(w, log.StandardLogger()) } func (w *WebhookSideEffect) getBody() io.Reader { @@ -51,6 +62,7 @@ func (w *WebhookSideEffect) getBody() io.Reader { return nil } +// Exec execute the side effect func (w *WebhookSideEffect) Exec() error { r, err := http.NewRequest(w.w.Method, w.w.URL, w.getBody()) if err != nil { @@ -73,6 +85,6 @@ func (w *WebhookSideEffect) Exec() error { if err != nil { return err } - log.Debugf("%v got response: (%s): %s", w, re.Status, string(body)) + w.log.Debugf("%v got response: (%s): %s", w, re.Status, string(body)) return nil } diff --git a/vendor/github.com/vulcand/oxy/cbreaker/fallback.go b/vendor/github.com/vulcand/oxy/cbreaker/fallback.go index a4fed70af..ea0655311 100644 --- a/vendor/github.com/vulcand/oxy/cbreaker/fallback.go +++ b/vendor/github.com/vulcand/oxy/cbreaker/fallback.go @@ -10,26 +10,36 @@ import ( "github.com/vulcand/oxy/utils" ) +// Response response model type Response struct { StatusCode int ContentType string Body []byte } +// ResponseFallback fallback response handler type ResponseFallback struct { r Response + + log *log.Logger } -func NewResponseFallback(r Response) (*ResponseFallback, error) { +// NewResponseFallbackWithLogger creates a new ResponseFallback +func NewResponseFallbackWithLogger(r Response, l *log.Logger) (*ResponseFallback, error) { if r.StatusCode == 0 { return nil, fmt.Errorf("response code should not be 0") } - return &ResponseFallback{r: r}, nil + return &ResponseFallback{r: r, log: l}, nil +} + +// NewResponseFallback creates a new ResponseFallback +func NewResponseFallback(r Response) (*ResponseFallback, error) { + return NewResponseFallbackWithLogger(r, log.StandardLogger()) } func (f *ResponseFallback) ServeHTTP(w http.ResponseWriter, req *http.Request) { - if log.GetLevel() >= log.DebugLevel { - logEntry := log.WithField("Request", utils.DumpHttpRequest(req)) + if f.log.Level >= log.DebugLevel { + logEntry := f.log.WithField("Request", utils.DumpHttpRequest(req)) logEntry.Debug("vulcand/oxy/fallback/response: begin ServeHttp on request") defer logEntry.Debug("vulcand/oxy/fallback/response: completed ServeHttp on request") } @@ -45,27 +55,38 @@ func (f *ResponseFallback) ServeHTTP(w http.ResponseWriter, req *http.Request) { } } +// Redirect redirect model type Redirect struct { URL string PreservePath bool } +// RedirectFallback fallback redirect handler type RedirectFallback struct { - u *url.URL r Redirect + + u *url.URL + + log *log.Logger } -func NewRedirectFallback(r Redirect) (*RedirectFallback, error) { +// NewRedirectFallbackWithLogger creates a new RedirectFallback +func NewRedirectFallbackWithLogger(r Redirect, l *log.Logger) (*RedirectFallback, error) { u, err := url.ParseRequestURI(r.URL) if err != nil { return nil, err } - return &RedirectFallback{u: u, r: r}, nil + return &RedirectFallback{r: r, u: u, log: l}, nil +} + +// NewRedirectFallback creates a new RedirectFallback +func NewRedirectFallback(r Redirect) (*RedirectFallback, error) { + return NewRedirectFallbackWithLogger(r, log.StandardLogger()) } func (f *RedirectFallback) ServeHTTP(w http.ResponseWriter, req *http.Request) { - if log.GetLevel() >= log.DebugLevel { - logEntry := log.WithField("Request", utils.DumpHttpRequest(req)) + if f.log.Level >= log.DebugLevel { + logEntry := f.log.WithField("Request", utils.DumpHttpRequest(req)) logEntry.Debug("vulcand/oxy/fallback/redirect: begin ServeHttp on request") defer logEntry.Debug("vulcand/oxy/fallback/redirect: completed ServeHttp on request") } diff --git a/vendor/github.com/vulcand/oxy/cbreaker/predicates.go b/vendor/github.com/vulcand/oxy/cbreaker/predicates.go index e703156d3..a858daf8c 100644 --- a/vendor/github.com/vulcand/oxy/cbreaker/predicates.go +++ b/vendor/github.com/vulcand/oxy/cbreaker/predicates.go @@ -4,7 +4,6 @@ import ( "fmt" "time" - log "github.com/sirupsen/logrus" "github.com/vulcand/predicate" ) @@ -50,7 +49,7 @@ func latencyAtQuantile(quantile float64) toInt { return func(c *CircuitBreaker) int { h, err := c.metrics.LatencyHistogram() if err != nil { - log.Errorf("Failed to get latency histogram, for %v error: %v", c, err) + c.log.Errorf("Failed to get latency histogram, for %v error: %v", c, err) return 0 } return int(h.LatencyAtQuantile(quantile) / time.Millisecond) diff --git a/vendor/github.com/vulcand/oxy/cbreaker/ratio.go b/vendor/github.com/vulcand/oxy/cbreaker/ratio.go index 4918ab8bf..96f9eeb7b 100644 --- a/vendor/github.com/vulcand/oxy/cbreaker/ratio.go +++ b/vendor/github.com/vulcand/oxy/cbreaker/ratio.go @@ -19,13 +19,17 @@ type ratioController struct { tm timetools.TimeProvider allowed int denied int + + log *log.Logger } -func newRatioController(tm timetools.TimeProvider, rampUp time.Duration) *ratioController { +func newRatioController(tm timetools.TimeProvider, rampUp time.Duration, log *log.Logger) *ratioController { return &ratioController{ duration: rampUp, tm: tm, start: tm.UtcNow(), + + log: log, } } @@ -34,17 +38,17 @@ func (r *ratioController) String() string { } func (r *ratioController) allowRequest() bool { - log.Debugf("%v", r) + r.log.Debugf("%v", r) t := r.targetRatio() // This condition answers the question - would we satisfy the target ratio if we allow this request? e := r.computeRatio(r.allowed+1, r.denied) if e < t { r.allowed++ - log.Debugf("%v allowed", r) + r.log.Debugf("%v allowed", r) return true } r.denied++ - log.Debugf("%v denied", r) + r.log.Debugf("%v denied", r) return false } diff --git a/vendor/github.com/vulcand/oxy/connlimit/connlimit.go b/vendor/github.com/vulcand/oxy/connlimit/connlimit.go index c7b392758..5d2d71468 100644 --- a/vendor/github.com/vulcand/oxy/connlimit/connlimit.go +++ b/vendor/github.com/vulcand/oxy/connlimit/connlimit.go @@ -1,4 +1,4 @@ -// package connlimit provides control over simultaneous connections coming from the same source +// Package connlimit provides control over simultaneous connections coming from the same source package connlimit import ( @@ -10,7 +10,7 @@ import ( "github.com/vulcand/oxy/utils" ) -// Limiter tracks concurrent connection per token +// ConnLimiter tracks concurrent connection per token // and is capable of rejecting connections if they are failed type ConnLimiter struct { mutex *sync.Mutex @@ -21,8 +21,10 @@ type ConnLimiter struct { next http.Handler errHandler utils.ErrorHandler + log *log.Logger } +// New creates a new ConnLimiter func New(next http.Handler, extract utils.SourceExtractor, maxConnections int64, options ...ConnLimitOption) (*ConnLimiter, error) { if extract == nil { return nil, fmt.Errorf("Extract function can not be nil") @@ -33,6 +35,7 @@ func New(next http.Handler, extract utils.SourceExtractor, maxConnections int64, maxConnections: maxConnections, connections: make(map[string]int64), next: next, + log: log.StandardLogger(), } for _, o := range options { @@ -41,11 +44,24 @@ func New(next http.Handler, extract utils.SourceExtractor, maxConnections int64, } } if cl.errHandler == nil { - cl.errHandler = defaultErrHandler + cl.errHandler = &ConnErrHandler{ + log: cl.log, + } } return cl, nil } +// Logger defines the logger the connection limiter will use. +// +// It defaults to logrus.StandardLogger(), the global logger used by logrus. +func Logger(l *log.Logger) ConnLimitOption { + return func(cl *ConnLimiter) error { + cl.log = l + return nil + } +} + +// Wrap sets the next handler to be called by connexion limiter handler. func (cl *ConnLimiter) Wrap(h http.Handler) { cl.next = h } @@ -53,12 +69,12 @@ func (cl *ConnLimiter) Wrap(h http.Handler) { func (cl *ConnLimiter) ServeHTTP(w http.ResponseWriter, r *http.Request) { token, amount, err := cl.extract.Extract(r) if err != nil { - log.Errorf("failed to extract source of the connection: %v", err) + cl.log.Errorf("failed to extract source of the connection: %v", err) cl.errHandler.ServeHTTP(w, r, err) return } if err := cl.acquire(token, amount); err != nil { - log.Debugf("limiting request source %s: %v", token, err) + cl.log.Debugf("limiting request source %s: %v", token, err) cl.errHandler.ServeHTTP(w, r, err) return } @@ -95,6 +111,7 @@ func (cl *ConnLimiter) release(token string, amount int64) { } } +// MaxConnError maximum connections reached error type MaxConnError struct { max int64 } @@ -103,12 +120,14 @@ func (m *MaxConnError) Error() string { return fmt.Sprintf("max connections reached: %d", m.max) } +// ConnErrHandler connection limiter error handler type ConnErrHandler struct { + log *log.Logger } func (e *ConnErrHandler) ServeHTTP(w http.ResponseWriter, req *http.Request, err error) { - if log.GetLevel() >= log.DebugLevel { - logEntry := log.WithField("Request", utils.DumpHttpRequest(req)) + if e.log.Level >= log.DebugLevel { + logEntry := e.log.WithField("Request", utils.DumpHttpRequest(req)) logEntry.Debug("vulcand/oxy/connlimit: begin ServeHttp on request") defer logEntry.Debug("vulcand/oxy/connlimit: completed ServeHttp on request") } @@ -121,6 +140,7 @@ func (e *ConnErrHandler) ServeHTTP(w http.ResponseWriter, req *http.Request, err utils.DefaultHandler.ServeHTTP(w, req, err) } +// ConnLimitOption connection limit option type type ConnLimitOption func(l *ConnLimiter) error // ErrorHandler sets error handler of the server @@ -130,5 +150,3 @@ func ErrorHandler(h utils.ErrorHandler) ConnLimitOption { return nil } } - -var defaultErrHandler = &ConnErrHandler{} diff --git a/vendor/github.com/vulcand/oxy/forward/fwd.go b/vendor/github.com/vulcand/oxy/forward/fwd.go index abeb3c08e..3a715e479 100644 --- a/vendor/github.com/vulcand/oxy/forward/fwd.go +++ b/vendor/github.com/vulcand/oxy/forward/fwd.go @@ -1,12 +1,15 @@ -// package forwarder implements http handler that forwards requests to remote server +// Package forward implements http handler that forwards requests to remote server // and serves back the response // websocket proxying support based on https://github.com/yhat/wsutil package forward import ( + "bytes" "crypto/tls" "errors" "fmt" + "io" + "net" "net/http" "net/http/httptest" "net/http/httputil" @@ -21,7 +24,7 @@ import ( "github.com/vulcand/oxy/utils" ) -// Oxy Logger interface of the internal +// OxyLogger interface of the internal type OxyLogger interface { log.FieldLogger GetLevel() log.Level @@ -42,8 +45,7 @@ type ReqRewriter interface { type optSetter func(f *Forwarder) error -// PassHostHeader specifies if a client's Host header field should -// be delegated +// PassHostHeader specifies if a client's Host header field should be delegated func PassHostHeader(b bool) optSetter { return func(f *Forwarder) error { f.httpForwarder.passHost = b @@ -68,8 +70,7 @@ func Rewriter(r ReqRewriter) optSetter { } } -// PassHostHeader specifies if a client's Host header field should -// be delegated +// WebsocketTLSClientConfig define the websocker client TLS configuration func WebsocketTLSClientConfig(tcc *tls.Config) optSetter { return func(f *Forwarder) error { f.httpForwarder.tlsClientConfig = tcc @@ -120,6 +121,7 @@ func Logger(l log.FieldLogger) optSetter { } } +// StateListener defines a state listener for the HTTP forwarder func StateListener(stateListener UrlForwardingStateListener) optSetter { return func(f *Forwarder) error { f.stateListener = stateListener @@ -127,6 +129,15 @@ func StateListener(stateListener UrlForwardingStateListener) optSetter { } } +// WebsocketConnectionClosedHook defines a hook called when websocket connection is closed +func WebsocketConnectionClosedHook(hook func(req *http.Request, conn net.Conn)) optSetter { + return func(f *Forwarder) error { + f.httpForwarder.websocketConnectionClosedHook = hook + return nil + } +} + +// ResponseModifier defines a response modifier for the HTTP forwarder func ResponseModifier(responseModifier func(*http.Response) error) optSetter { return func(f *Forwarder) error { f.httpForwarder.modifyResponse = responseModifier @@ -134,6 +145,7 @@ func ResponseModifier(responseModifier func(*http.Response) error) optSetter { } } +// StreamingFlushInterval defines a streaming flush interval for the HTTP forwarder func StreamingFlushInterval(flushInterval time.Duration) optSetter { return func(f *Forwarder) error { f.httpForwarder.flushInterval = flushInterval @@ -141,11 +153,13 @@ func StreamingFlushInterval(flushInterval time.Duration) optSetter { } } +// ErrorHandlingRoundTripper a error handling round tripper type ErrorHandlingRoundTripper struct { http.RoundTripper errorHandler utils.ErrorHandler } +// RoundTrip executes the round trip func (rt ErrorHandlingRoundTripper) RoundTrip(req *http.Request) (*http.Response, error) { res, err := rt.RoundTripper.RoundTrip(req) if err != nil { @@ -185,15 +199,19 @@ type httpForwarder struct { log OxyLogger - bufferPool httputil.BufferPool + bufferPool httputil.BufferPool + websocketConnectionClosedHook func(req *http.Request, conn net.Conn) } +const defaultFlushInterval = time.Duration(100) * time.Millisecond + +// Connection states const ( - defaultFlushInterval = time.Duration(100) * time.Millisecond - StateConnected = iota + StateConnected = iota StateDisconnected ) +// UrlForwardingStateListener URL forwarding state listener type UrlForwardingStateListener func(*url.URL, int) // New creates an instance of Forwarder based on the provided list of configuration options @@ -293,11 +311,6 @@ func (f *httpForwarder) modifyRequest(outReq *http.Request, target *url.URL) { outReq.URL.RawQuery = u.RawQuery outReq.RequestURI = "" // Outgoing request should not have RequestURI - // Do not pass client Host header unless optsetter PassHostHeader is set. - if !f.passHost { - outReq.Host = target.Host - } - outReq.Proto = "HTTP/1.1" outReq.ProtoMajor = 1 outReq.ProtoMinor = 1 @@ -305,6 +318,11 @@ func (f *httpForwarder) modifyRequest(outReq *http.Request, target *url.URL) { if f.rewriter != nil { f.rewriter.Rewrite(outReq) } + + // Do not pass client Host header unless optsetter PassHostHeader is set. + if !f.passHost { + outReq.Host = target.Host + } } // serveHTTP forwards websocket traffic @@ -368,14 +386,40 @@ func (f *httpForwarder) serveWebSocket(w http.ResponseWriter, req *http.Request, log.Errorf("vulcand/oxy/forward/websocket: Error while upgrading connection : %v", err) return } - defer underlyingConn.Close() - defer targetConn.Close() + defer func() { + underlyingConn.Close() + targetConn.Close() + if f.websocketConnectionClosedHook != nil { + f.websocketConnectionClosedHook(req, underlyingConn.UnderlyingConn()) + } + }() errClient := make(chan error, 1) errBackend := make(chan error, 1) replicateWebsocketConn := func(dst, src *websocket.Conn, errc chan error) { + + forward := func(messageType int, reader io.Reader) error { + writer, err := dst.NextWriter(messageType) + if err != nil { + return err + } + _, err = io.Copy(writer, reader) + if err != nil { + return err + } + return writer.Close() + } + + src.SetPingHandler(func(data string) error { + return forward(websocket.PingMessage, bytes.NewReader([]byte(data))) + }) + + src.SetPongHandler(func(data string) error { + return forward(websocket.PongMessage, bytes.NewReader([]byte(data))) + }) + for { - msgType, msg, err := src.ReadMessage() + msgType, reader, err := src.NextReader() if err != nil { m := websocket.FormatCloseMessage(websocket.CloseNormalClosure, fmt.Sprintf("%v", err)) @@ -393,11 +437,11 @@ func (f *httpForwarder) serveWebSocket(w http.ResponseWriter, req *http.Request, } errc <- err if m != nil { - dst.WriteMessage(websocket.CloseMessage, m) + forward(websocket.CloseMessage, bytes.NewReader([]byte(m))) } break } - err = dst.WriteMessage(msgType, msg) + err = forward(msgType, reader) if err != nil { errc <- err break @@ -501,7 +545,7 @@ func (f *httpForwarder) serveHTTP(w http.ResponseWriter, inReq *http.Request, ct } } -// isWebsocketRequest determines if the specified HTTP request is a +// IsWebsocketRequest determines if the specified HTTP request is a // websocket handshake request func IsWebsocketRequest(req *http.Request) bool { containsHeader := func(name, value string) bool { diff --git a/vendor/github.com/vulcand/oxy/forward/headers.go b/vendor/github.com/vulcand/oxy/forward/headers.go index f884a5ee0..512e28435 100644 --- a/vendor/github.com/vulcand/oxy/forward/headers.go +++ b/vendor/github.com/vulcand/oxy/forward/headers.go @@ -1,5 +1,6 @@ package forward +// Headers const ( XForwardedProto = "X-Forwarded-Proto" XForwardedFor = "X-Forwarded-For" @@ -22,7 +23,7 @@ const ( SecWebsocketAccept = "Sec-Websocket-Accept" ) -// Hop-by-hop headers. These are removed when sent to the backend. +// HopHeaders Hop-by-hop headers. These are removed when sent to the backend. // http://www.w3.org/Protocols/rfc2616/rfc2616-sec13.html // Copied from reverseproxy.go, too bad var HopHeaders = []string{ @@ -36,6 +37,7 @@ var HopHeaders = []string{ Upgrade, } +// WebsocketDialHeaders Websocket dial headers var WebsocketDialHeaders = []string{ Upgrade, Connection, @@ -45,6 +47,7 @@ var WebsocketDialHeaders = []string{ SecWebsocketAccept, } +// WebsocketUpgradeHeaders Websocket upgrade headers var WebsocketUpgradeHeaders = []string{ Upgrade, Connection, @@ -52,6 +55,7 @@ var WebsocketUpgradeHeaders = []string{ SecWebsocketExtensions, } +// XHeaders X-* headers var XHeaders = []string{ XForwardedProto, XForwardedFor, diff --git a/vendor/github.com/vulcand/oxy/forward/rewrite.go b/vendor/github.com/vulcand/oxy/forward/rewrite.go index 38a7f7fc4..60c1a1947 100644 --- a/vendor/github.com/vulcand/oxy/forward/rewrite.go +++ b/vendor/github.com/vulcand/oxy/forward/rewrite.go @@ -8,7 +8,7 @@ import ( "github.com/vulcand/oxy/utils" ) -// Rewriter is responsible for removing hop-by-hop headers and setting forwarding headers +// HeaderRewriter is responsible for removing hop-by-hop headers and setting forwarding headers type HeaderRewriter struct { TrustForwardHeader bool Hostname string @@ -19,6 +19,7 @@ func ipv6fix(clientIP string) string { return strings.Split(clientIP, "%")[0] } +// Rewrite rewrite request headers func (rw *HeaderRewriter) Rewrite(req *http.Request) { if !rw.TrustForwardHeader { utils.RemoveHeaders(req.Header, XHeaders...) @@ -85,6 +86,10 @@ func forwardedPort(req *http.Request) string { return port } + if req.Header.Get(XForwardedProto) == "https" || req.Header.Get(XForwardedProto) == "wss" { + return "443" + } + if req.TLS != nil { return "443" } diff --git a/vendor/github.com/vulcand/oxy/memmetrics/anomaly.go b/vendor/github.com/vulcand/oxy/memmetrics/anomaly.go index 5aeb13ae3..1f8dfe95d 100644 --- a/vendor/github.com/vulcand/oxy/memmetrics/anomaly.go +++ b/vendor/github.com/vulcand/oxy/memmetrics/anomaly.go @@ -6,7 +6,7 @@ import ( "time" ) -// SplitRatios provides simple anomaly detection for requests latencies. +// SplitLatencies provides simple anomaly detection for requests latencies. // it splits values into good or bad category based on the threshold and the median value. // If all values are not far from the median, it will return all values in 'good' set. // Precision is the smallest value to consider, e.g. if set to millisecond, microseconds will be ignored. @@ -23,10 +23,10 @@ func SplitLatencies(values []time.Duration, precision time.Duration) (good map[t good, bad = make(map[time.Duration]bool), make(map[time.Duration]bool) // Note that multiplier makes this function way less sensitive than ratios detector, this is to avoid noise. vgood, vbad := SplitFloat64(2, 0, ratios) - for r, _ := range vgood { + for r := range vgood { good[v2r[r]] = true } - for r, _ := range vbad { + for r := range vbad { bad[v2r[r]] = true } return good, bad diff --git a/vendor/github.com/vulcand/oxy/memmetrics/counter.go b/vendor/github.com/vulcand/oxy/memmetrics/counter.go index 361d8a878..4faf905dd 100644 --- a/vendor/github.com/vulcand/oxy/memmetrics/counter.go +++ b/vendor/github.com/vulcand/oxy/memmetrics/counter.go @@ -9,6 +9,7 @@ import ( type rcOptSetter func(*RollingCounter) error +// CounterClock defines a counter clock func CounterClock(c timetools.TimeProvider) rcOptSetter { return func(r *RollingCounter) error { r.clock = c @@ -16,7 +17,7 @@ func CounterClock(c timetools.TimeProvider) rcOptSetter { } } -// Calculates in memory failure rate of an endpoint using rolling window of a predefined size +// RollingCounter Calculates in memory failure rate of an endpoint using rolling window of a predefined size type RollingCounter struct { clock timetools.TimeProvider resolution time.Duration @@ -57,11 +58,13 @@ func NewCounter(buckets int, resolution time.Duration, options ...rcOptSetter) ( return rc, nil } +// Append append a counter func (c *RollingCounter) Append(o *RollingCounter) error { c.Inc(int(o.Count())) return nil } +// Clone clone a counter func (c *RollingCounter) Clone() *RollingCounter { c.cleanup() other := &RollingCounter{ @@ -75,6 +78,7 @@ func (c *RollingCounter) Clone() *RollingCounter { return other } +// Reset reset a counter func (c *RollingCounter) Reset() { c.lastBucket = -1 c.countedBuckets = 0 @@ -84,27 +88,33 @@ func (c *RollingCounter) Reset() { } } +// CountedBuckets gets counted buckets func (c *RollingCounter) CountedBuckets() int { return c.countedBuckets } +// Count counts func (c *RollingCounter) Count() int64 { c.cleanup() return c.sum() } +// Resolution gets resolution func (c *RollingCounter) Resolution() time.Duration { return c.resolution } +// Buckets gets buckets func (c *RollingCounter) Buckets() int { return len(c.values) } +// WindowSize gets windows size func (c *RollingCounter) WindowSize() time.Duration { return time.Duration(len(c.values)) * c.resolution } +// Inc increment counter func (c *RollingCounter) Inc(v int) { c.cleanup() c.incBucketValue(v) diff --git a/vendor/github.com/vulcand/oxy/memmetrics/histogram.go b/vendor/github.com/vulcand/oxy/memmetrics/histogram.go index 02c1d561e..2c3aa76af 100644 --- a/vendor/github.com/vulcand/oxy/memmetrics/histogram.go +++ b/vendor/github.com/vulcand/oxy/memmetrics/histogram.go @@ -20,6 +20,7 @@ type HDRHistogram struct { h *hdrhistogram.Histogram } +// NewHDRHistogram creates a new HDRHistogram func NewHDRHistogram(low, high int64, sigfigs int) (h *HDRHistogram, err error) { defer func() { if msg := recover(); msg != nil { @@ -34,37 +35,42 @@ func NewHDRHistogram(low, high int64, sigfigs int) (h *HDRHistogram, err error) }, nil } -func (r *HDRHistogram) Export() *HDRHistogram { - var hist *hdrhistogram.Histogram = nil - if r.h != nil { - snapshot := r.h.Export() +// Export export a HDRHistogram +func (h *HDRHistogram) Export() *HDRHistogram { + var hist *hdrhistogram.Histogram + if h.h != nil { + snapshot := h.h.Export() hist = hdrhistogram.Import(snapshot) } - return &HDRHistogram{low: r.low, high: r.high, sigfigs: r.sigfigs, h: hist} + return &HDRHistogram{low: h.low, high: h.high, sigfigs: h.sigfigs, h: hist} } -// Returns latency at quantile with microsecond precision +// LatencyAtQuantile sets latency at quantile with microsecond precision func (h *HDRHistogram) LatencyAtQuantile(q float64) time.Duration { return time.Duration(h.ValueAtQuantile(q)) * time.Microsecond } -// Records latencies with microsecond precision +// RecordLatencies Records latencies with microsecond precision func (h *HDRHistogram) RecordLatencies(d time.Duration, n int64) error { return h.RecordValues(int64(d/time.Microsecond), n) } +// Reset reset a HDRHistogram func (h *HDRHistogram) Reset() { h.h.Reset() } +// ValueAtQuantile sets value at quantile func (h *HDRHistogram) ValueAtQuantile(q float64) int64 { return h.h.ValueAtQuantile(q) } +// RecordValues sets record values func (h *HDRHistogram) RecordValues(v, n int64) error { return h.h.RecordValues(v, n) } +// Merge merge a HDRHistogram func (h *HDRHistogram) Merge(other *HDRHistogram) error { if other == nil { return fmt.Errorf("other is nil") @@ -75,6 +81,7 @@ func (h *HDRHistogram) Merge(other *HDRHistogram) error { type rhOptSetter func(r *RollingHDRHistogram) error +// RollingClock sets a clock func RollingClock(clock timetools.TimeProvider) rhOptSetter { return func(r *RollingHDRHistogram) error { r.clock = clock @@ -82,7 +89,7 @@ func RollingClock(clock timetools.TimeProvider) rhOptSetter { } } -// RollingHistogram holds multiple histograms and rotates every period. +// RollingHDRHistogram holds multiple histograms and rotates every period. // It provides resulting histogram as a result of a call of 'Merged' function. type RollingHDRHistogram struct { idx int @@ -96,6 +103,7 @@ type RollingHDRHistogram struct { clock timetools.TimeProvider } +// NewRollingHDRHistogram created a new RollingHDRHistogram func NewRollingHDRHistogram(low, high int64, sigfigs int, period time.Duration, bucketCount int, options ...rhOptSetter) (*RollingHDRHistogram, error) { rh := &RollingHDRHistogram{ bucketCount: bucketCount, @@ -127,6 +135,7 @@ func NewRollingHDRHistogram(low, high int64, sigfigs int, period time.Duration, return rh, nil } +// Export export a RollingHDRHistogram func (r *RollingHDRHistogram) Export() *RollingHDRHistogram { export := &RollingHDRHistogram{} export.idx = r.idx @@ -147,6 +156,7 @@ func (r *RollingHDRHistogram) Export() *RollingHDRHistogram { return export } +// Append append a RollingHDRHistogram func (r *RollingHDRHistogram) Append(o *RollingHDRHistogram) error { if r.bucketCount != o.bucketCount || r.period != o.period || r.low != o.low || r.high != o.high || r.sigfigs != o.sigfigs { return fmt.Errorf("can't merge") @@ -160,6 +170,7 @@ func (r *RollingHDRHistogram) Append(o *RollingHDRHistogram) error { return nil } +// Reset reset a RollingHDRHistogram func (r *RollingHDRHistogram) Reset() { r.idx = 0 r.lastRoll = r.clock.UtcNow() @@ -173,6 +184,7 @@ func (r *RollingHDRHistogram) rotate() { r.buckets[r.idx].Reset() } +// Merged gets merged histogram func (r *RollingHDRHistogram) Merged() (*HDRHistogram, error) { m, err := NewHDRHistogram(r.low, r.high, r.sigfigs) if err != nil { @@ -194,10 +206,12 @@ func (r *RollingHDRHistogram) getHist() *HDRHistogram { return r.buckets[r.idx] } +// RecordLatencies sets records latencies func (r *RollingHDRHistogram) RecordLatencies(v time.Duration, n int64) error { return r.getHist().RecordLatencies(v, n) } +// RecordValues set record values func (r *RollingHDRHistogram) RecordValues(v, n int64) error { return r.getHist().RecordValues(v, n) } diff --git a/vendor/github.com/vulcand/oxy/memmetrics/ratio.go b/vendor/github.com/vulcand/oxy/memmetrics/ratio.go index f21f375ea..ecfd50371 100644 --- a/vendor/github.com/vulcand/oxy/memmetrics/ratio.go +++ b/vendor/github.com/vulcand/oxy/memmetrics/ratio.go @@ -8,6 +8,7 @@ import ( type ratioOptSetter func(r *RatioCounter) error +// RatioClock sets a clock func RatioClock(clock timetools.TimeProvider) ratioOptSetter { return func(r *RatioCounter) error { r.clock = clock @@ -22,6 +23,7 @@ type RatioCounter struct { b *RollingCounter } +// NewRatioCounter creates a new RatioCounter func NewRatioCounter(buckets int, resolution time.Duration, options ...ratioOptSetter) (*RatioCounter, error) { rc := &RatioCounter{} @@ -50,39 +52,48 @@ func NewRatioCounter(buckets int, resolution time.Duration, options ...ratioOptS return rc, nil } +// Reset reset the counter func (r *RatioCounter) Reset() { r.a.Reset() r.b.Reset() } +// IsReady returns true if the counter is ready func (r *RatioCounter) IsReady() bool { return r.a.countedBuckets+r.b.countedBuckets >= len(r.a.values) } +// CountA gets count A func (r *RatioCounter) CountA() int64 { return r.a.Count() } +// CountB gets count B func (r *RatioCounter) CountB() int64 { return r.b.Count() } +// Resolution gets resolution func (r *RatioCounter) Resolution() time.Duration { return r.a.Resolution() } +// Buckets gets buckets func (r *RatioCounter) Buckets() int { return r.a.Buckets() } +// WindowSize gets windows size func (r *RatioCounter) WindowSize() time.Duration { return r.a.WindowSize() } +// ProcessedCount gets processed count func (r *RatioCounter) ProcessedCount() int64 { return r.CountA() + r.CountB() } +// Ratio gets ratio func (r *RatioCounter) Ratio() float64 { a := r.a.Count() b := r.b.Count() @@ -93,28 +104,34 @@ func (r *RatioCounter) Ratio() float64 { return float64(a) / float64(a+b) } +// IncA increment counter A func (r *RatioCounter) IncA(v int) { r.a.Inc(v) } +// IncB increment counter B func (r *RatioCounter) IncB(v int) { r.b.Inc(v) } +// TestMeter a test meter type TestMeter struct { Rate float64 NotReady bool WindowSize time.Duration } +// GetWindowSize gets windows size func (tm *TestMeter) GetWindowSize() time.Duration { return tm.WindowSize } +// IsReady returns true if the meter is ready func (tm *TestMeter) IsReady() bool { return !tm.NotReady } +// GetRate gets rate func (tm *TestMeter) GetRate() float64 { return tm.Rate } diff --git a/vendor/github.com/vulcand/oxy/memmetrics/roundtrip.go b/vendor/github.com/vulcand/oxy/memmetrics/roundtrip.go index 4bdb4bba2..34b396915 100644 --- a/vendor/github.com/vulcand/oxy/memmetrics/roundtrip.go +++ b/vendor/github.com/vulcand/oxy/memmetrics/roundtrip.go @@ -29,10 +29,16 @@ type RTMetrics struct { type rrOptSetter func(r *RTMetrics) error +// NewRTMetricsFn builder function type type NewRTMetricsFn func() (*RTMetrics, error) + +// NewCounterFn builder function type type NewCounterFn func() (*RollingCounter, error) + +// NewRollingHistogramFn builder function type type NewRollingHistogramFn func() (*RollingHDRHistogram, error) +// RTCounter set a builder function for Counter func RTCounter(new NewCounterFn) rrOptSetter { return func(r *RTMetrics) error { r.newCounter = new @@ -40,13 +46,15 @@ func RTCounter(new NewCounterFn) rrOptSetter { } } -func RTHistogram(new NewRollingHistogramFn) rrOptSetter { +// RTHistogram set a builder function for RollingHistogram +func RTHistogram(fn NewRollingHistogramFn) rrOptSetter { return func(r *RTMetrics) error { - r.newHist = new + r.newHist = fn return nil } } +// RTClock sets a clock func RTClock(clock timetools.TimeProvider) rrOptSetter { return func(r *RTMetrics) error { r.clock = clock @@ -103,7 +111,7 @@ func NewRTMetrics(settings ...rrOptSetter) (*RTMetrics, error) { return m, nil } -// Returns a new RTMetrics which is a copy of the current one +// Export Returns a new RTMetrics which is a copy of the current one func (m *RTMetrics) Export() *RTMetrics { m.statusCodesLock.RLock() defer m.statusCodesLock.RUnlock() @@ -130,11 +138,12 @@ func (m *RTMetrics) Export() *RTMetrics { return export } +// CounterWindowSize gets total windows size func (m *RTMetrics) CounterWindowSize() time.Duration { return m.total.WindowSize() } -// GetNetworkErrorRatio calculates the amont of network errors such as time outs and dropped connection +// NetworkErrorRatio calculates the amont of network errors such as time outs and dropped connection // that occurred in the given time window compared to the total requests count. func (m *RTMetrics) NetworkErrorRatio() float64 { if m.total.Count() == 0 { @@ -143,7 +152,7 @@ func (m *RTMetrics) NetworkErrorRatio() float64 { return float64(m.netErrors.Count()) / float64(m.total.Count()) } -// GetResponseCodeRatio calculates ratio of count(startA to endA) / count(startB to endB) +// ResponseCodeRatio calculates ratio of count(startA to endA) / count(startB to endB) func (m *RTMetrics) ResponseCodeRatio(startA, endA, startB, endB int) float64 { a := int64(0) b := int64(0) @@ -163,6 +172,7 @@ func (m *RTMetrics) ResponseCodeRatio(startA, endA, startB, endB int) float64 { return 0 } +// Append append a metric func (m *RTMetrics) Append(other *RTMetrics) error { if m == other { return errors.New("RTMetrics cannot append to self") @@ -196,6 +206,7 @@ func (m *RTMetrics) Append(other *RTMetrics) error { return m.histogram.Append(copied.histogram) } +// Record records a metric func (m *RTMetrics) Record(code int, duration time.Duration) { m.total.Inc(1) if code == http.StatusGatewayTimeout || code == http.StatusBadGateway { @@ -205,17 +216,17 @@ func (m *RTMetrics) Record(code int, duration time.Duration) { m.recordLatency(duration) } -// GetTotalCount returns total count of processed requests collected. +// TotalCount returns total count of processed requests collected. func (m *RTMetrics) TotalCount() int64 { return m.total.Count() } -// GetNetworkErrorCount returns total count of processed requests observed +// NetworkErrorCount returns total count of processed requests observed func (m *RTMetrics) NetworkErrorCount() int64 { return m.netErrors.Count() } -// GetStatusCodesCounts returns map with counts of the response codes +// StatusCodesCounts returns map with counts of the response codes func (m *RTMetrics) StatusCodesCounts() map[int]int64 { sc := make(map[int]int64) m.statusCodesLock.RLock() @@ -228,13 +239,14 @@ func (m *RTMetrics) StatusCodesCounts() map[int]int64 { return sc } -// GetLatencyHistogram computes and returns resulting histogram with latencies observed. +// LatencyHistogram computes and returns resulting histogram with latencies observed. func (m *RTMetrics) LatencyHistogram() (*HDRHistogram, error) { m.histogramLock.Lock() defer m.histogramLock.Unlock() return m.histogram.Merged() } +// Reset reset metrics func (m *RTMetrics) Reset() { m.statusCodesLock.Lock() defer m.statusCodesLock.Unlock() @@ -284,7 +296,7 @@ const ( counterResolution = time.Second histMin = 1 histMax = 3600000000 // 1 hour in microseconds - histSignificantFigures = 2 // signigicant figures (1% precision) + histSignificantFigures = 2 // significant figures (1% precision) histBuckets = 6 // number of sub-histograms in a rolling histogram histPeriod = 10 * time.Second // roll time ) diff --git a/vendor/github.com/vulcand/oxy/ratelimit/bucket.go b/vendor/github.com/vulcand/oxy/ratelimit/bucket.go index 78507faf9..9134d1828 100644 --- a/vendor/github.com/vulcand/oxy/ratelimit/bucket.go +++ b/vendor/github.com/vulcand/oxy/ratelimit/bucket.go @@ -7,6 +7,7 @@ import ( "github.com/mailgun/timetools" ) +// UndefinedDelay default delay const UndefinedDelay = -1 // rate defines token bucket parameters. @@ -20,7 +21,7 @@ func (r *rate) String() string { return fmt.Sprintf("rate(%v/%v, burst=%v)", r.average, r.period, r.burst) } -// Implements token bucket algorithm (http://en.wikipedia.org/wiki/Token_bucket) +// tokenBucket Implements token bucket algorithm (http://en.wikipedia.org/wiki/Token_bucket) type tokenBucket struct { // The time period controlled by the bucket in nanoseconds. period time.Duration @@ -63,7 +64,7 @@ func (tb *tokenBucket) consume(tokens int64) (time.Duration, error) { tb.updateAvailableTokens() tb.lastConsumed = 0 if tokens > tb.burst { - return UndefinedDelay, fmt.Errorf("Requested tokens larger than max tokens") + return UndefinedDelay, fmt.Errorf("requested tokens larger than max tokens") } if tb.availableTokens < tokens { return tb.timeTillAvailable(tokens), nil @@ -83,11 +84,11 @@ func (tb *tokenBucket) rollback() { tb.lastConsumed = 0 } -// Update modifies `average` and `burst` fields of the token bucket according +// update modifies `average` and `burst` fields of the token bucket according // to the provided `Rate` func (tb *tokenBucket) update(rate *rate) error { if rate.period != tb.period { - return fmt.Errorf("Period mismatch: %v != %v", tb.period, rate.period) + return fmt.Errorf("period mismatch: %v != %v", tb.period, rate.period) } tb.timePerToken = time.Duration(int64(tb.period) / rate.average) tb.burst = rate.burst diff --git a/vendor/github.com/vulcand/oxy/ratelimit/bucketset.go b/vendor/github.com/vulcand/oxy/ratelimit/bucketset.go index f4a246568..af2c8bb1c 100644 --- a/vendor/github.com/vulcand/oxy/ratelimit/bucketset.go +++ b/vendor/github.com/vulcand/oxy/ratelimit/bucketset.go @@ -2,11 +2,11 @@ package ratelimit import ( "fmt" + "sort" "strings" "time" "github.com/mailgun/timetools" - "sort" ) // TokenBucketSet represents a set of TokenBucket covering different time periods. @@ -16,7 +16,7 @@ type TokenBucketSet struct { clock timetools.TimeProvider } -// newTokenBucketSet creates a `TokenBucketSet` from the specified `rates`. +// NewTokenBucketSet creates a `TokenBucketSet` from the specified `rates`. func NewTokenBucketSet(rates *RateSet, clock timetools.TimeProvider) *TokenBucketSet { tbs := new(TokenBucketSet) tbs.clock = clock @@ -54,9 +54,10 @@ func (tbs *TokenBucketSet) Update(rates *RateSet) { } } +// Consume consume tokens func (tbs *TokenBucketSet) Consume(tokens int64) (time.Duration, error) { var maxDelay time.Duration = UndefinedDelay - var firstErr error = nil + var firstErr error for _, tokenBucket := range tbs.buckets { // We keep calling `Consume` even after a error is returned for one of // buckets because that allows us to simplify the rollback procedure, @@ -80,6 +81,7 @@ func (tbs *TokenBucketSet) Consume(tokens int64) (time.Duration, error) { return maxDelay, firstErr } +// GetMaxPeriod returns the max period func (tbs *TokenBucketSet) GetMaxPeriod() time.Duration { return tbs.maxPeriod } diff --git a/vendor/github.com/vulcand/oxy/ratelimit/tokenlimiter.go b/vendor/github.com/vulcand/oxy/ratelimit/tokenlimiter.go index be16c782b..bfd4c3b2e 100644 --- a/vendor/github.com/vulcand/oxy/ratelimit/tokenlimiter.go +++ b/vendor/github.com/vulcand/oxy/ratelimit/tokenlimiter.go @@ -1,4 +1,4 @@ -// Tokenbucket based request rate limiter +// Package ratelimit Tokenbucket based request rate limiter package ratelimit import ( @@ -13,6 +13,7 @@ import ( "github.com/vulcand/oxy/utils" ) +// DefaultCapacity default capacity const DefaultCapacity = 65536 // RateSet maintains a set of rates. It can contain only one rate per period at a time. @@ -31,15 +32,15 @@ func NewRateSet() *RateSet { // set then the new rate overrides the old one. func (rs *RateSet) Add(period time.Duration, average int64, burst int64) error { if period <= 0 { - return fmt.Errorf("Invalid period: %v", period) + return fmt.Errorf("invalid period: %v", period) } if average <= 0 { - return fmt.Errorf("Invalid average: %v", average) + return fmt.Errorf("invalid average: %v", average) } if burst <= 0 { - return fmt.Errorf("Invalid burst: %v", burst) + return fmt.Errorf("invalid burst: %v", burst) } - rs.m[period] = &rate{period, average, burst} + rs.m[period] = &rate{period: period, average: average, burst: burst} return nil } @@ -47,12 +48,15 @@ func (rs *RateSet) String() string { return fmt.Sprint(rs.m) } +// RateExtractor rate extractor type RateExtractor interface { Extract(r *http.Request) (*RateSet, error) } +// RateExtractorFunc rate extractor function type type RateExtractorFunc func(r *http.Request) (*RateSet, error) +// Extract extract from request func (e RateExtractorFunc) Extract(r *http.Request) (*RateSet, error) { return e(r) } @@ -68,20 +72,24 @@ type TokenLimiter struct { errHandler utils.ErrorHandler capacity int next http.Handler + + log *log.Logger } // New constructs a `TokenLimiter` middleware instance. func New(next http.Handler, extract utils.SourceExtractor, defaultRates *RateSet, opts ...TokenLimiterOption) (*TokenLimiter, error) { if defaultRates == nil || len(defaultRates.m) == 0 { - return nil, fmt.Errorf("Provide default rates") + return nil, fmt.Errorf("provide default rates") } if extract == nil { - return nil, fmt.Errorf("Provide extract function") + return nil, fmt.Errorf("provide extract function") } tl := &TokenLimiter{ next: next, defaultRates: defaultRates, extract: extract, + + log: log.StandardLogger(), } for _, o := range opts { @@ -98,6 +106,17 @@ func New(next http.Handler, extract utils.SourceExtractor, defaultRates *RateSet return tl, nil } +// Logger defines the logger the token limiter will use. +// +// It defaults to logrus.StandardLogger(), the global logger used by logrus. +func Logger(l *log.Logger) TokenLimiterOption { + return func(tl *TokenLimiter) error { + tl.log = l + return nil + } +} + +// Wrap sets the next handler to be called by token limiter handler. func (tl *TokenLimiter) Wrap(next http.Handler) { tl.next = next } @@ -110,7 +129,7 @@ func (tl *TokenLimiter) ServeHTTP(w http.ResponseWriter, req *http.Request) { } if err := tl.consumeRates(req, source, amount); err != nil { - log.Warnf("limiting request %v %v, limit: %v", req.Method, req.URL, err) + tl.log.Warnf("limiting request %v %v, limit: %v", req.Method, req.URL, err) tl.errHandler.ServeHTTP(w, req, err) return } @@ -155,7 +174,7 @@ func (tl *TokenLimiter) resolveRates(req *http.Request) *RateSet { rates, err := tl.extractRates.Extract(req) if err != nil { - log.Errorf("Failed to retrieve rates: %v", err) + tl.log.Errorf("Failed to retrieve rates: %v", err) return tl.defaultRates } @@ -167,6 +186,7 @@ func (tl *TokenLimiter) resolveRates(req *http.Request) *RateSet { return rates } +// MaxRateError max rate error type MaxRateError struct { delay time.Duration } @@ -175,19 +195,21 @@ func (m *MaxRateError) Error() string { return fmt.Sprintf("max rate reached: retry-in %v", m.delay) } -type RateErrHandler struct { -} +// RateErrHandler error handler +type RateErrHandler struct{} func (e *RateErrHandler) ServeHTTP(w http.ResponseWriter, req *http.Request, err error) { if rerr, ok := err.(*MaxRateError); ok { + w.Header().Set("Retry-After", fmt.Sprintf("%.0f", rerr.delay.Seconds())) w.Header().Set("X-Retry-In", rerr.delay.String()) - w.WriteHeader(429) + w.WriteHeader(http.StatusTooManyRequests) w.Write([]byte(err.Error())) return } utils.DefaultHandler.ServeHTTP(w, req, err) } +// TokenLimiterOption token limiter option type type TokenLimiterOption func(l *TokenLimiter) error // ErrorHandler sets error handler of the server @@ -198,6 +220,7 @@ func ErrorHandler(h utils.ErrorHandler) TokenLimiterOption { } } +// ExtractRates sets the rate extractor func ExtractRates(e RateExtractor) TokenLimiterOption { return func(cl *TokenLimiter) error { cl.extractRates = e @@ -205,6 +228,7 @@ func ExtractRates(e RateExtractor) TokenLimiterOption { } } +// Clock sets the clock func Clock(clock timetools.TimeProvider) TokenLimiterOption { return func(cl *TokenLimiter) error { cl.clock = clock @@ -212,6 +236,7 @@ func Clock(clock timetools.TimeProvider) TokenLimiterOption { } } +// Capacity sets the capacity func Capacity(cap int) TokenLimiterOption { return func(cl *TokenLimiter) error { if cap <= 0 { diff --git a/vendor/github.com/vulcand/oxy/roundrobin/RequestRewriteListener.go b/vendor/github.com/vulcand/oxy/roundrobin/RequestRewriteListener.go index 418f4988c..02ae4548e 100644 --- a/vendor/github.com/vulcand/oxy/roundrobin/RequestRewriteListener.go +++ b/vendor/github.com/vulcand/oxy/roundrobin/RequestRewriteListener.go @@ -2,4 +2,5 @@ package roundrobin import "net/http" +// RequestRewriteListener function to rewrite request type RequestRewriteListener func(oldReq *http.Request, newReq *http.Request) diff --git a/vendor/github.com/vulcand/oxy/roundrobin/rebalancer.go b/vendor/github.com/vulcand/oxy/roundrobin/rebalancer.go index fec74d26b..1d182d895 100644 --- a/vendor/github.com/vulcand/oxy/roundrobin/rebalancer.go +++ b/vendor/github.com/vulcand/oxy/roundrobin/rebalancer.go @@ -16,13 +16,14 @@ import ( // RebalancerOption - functional option setter for rebalancer type RebalancerOption func(*Rebalancer) error -// Meter measures server peformance and returns it's relative value via rating +// Meter measures server performance and returns it's relative value via rating type Meter interface { Rating() float64 Record(int, time.Duration) IsReady() bool } +// NewMeterFn type of functions to create new Meter type NewMeterFn func() (Meter, error) // Rebalancer increases weights on servers that perform better than others. It also rolls back to original weights @@ -52,8 +53,11 @@ type Rebalancer struct { stickySession *StickySession requestRewriteListener RequestRewriteListener + + log *log.Logger } +// RebalancerClock sets a clock func RebalancerClock(clock timetools.TimeProvider) RebalancerOption { return func(r *Rebalancer) error { r.clock = clock @@ -61,6 +65,7 @@ func RebalancerClock(clock timetools.TimeProvider) RebalancerOption { } } +// RebalancerBackoff sets a beck off duration func RebalancerBackoff(d time.Duration) RebalancerOption { return func(r *Rebalancer) error { r.backoffDuration = d @@ -68,6 +73,7 @@ func RebalancerBackoff(d time.Duration) RebalancerOption { } } +// RebalancerMeter sets a Meter builder function func RebalancerMeter(newMeter NewMeterFn) RebalancerOption { return func(r *Rebalancer) error { r.newMeter = newMeter @@ -83,6 +89,7 @@ func RebalancerErrorHandler(h utils.ErrorHandler) RebalancerOption { } } +// RebalancerStickySession sets a sticky session func RebalancerStickySession(stickySession *StickySession) RebalancerOption { return func(r *Rebalancer) error { r.stickySession = stickySession @@ -90,7 +97,7 @@ func RebalancerStickySession(stickySession *StickySession) RebalancerOption { } } -// RebalancerErrorHandler is a functional argument that sets error handler of the server +// RebalancerRequestRewriteListener is a functional argument that sets error handler of the server func RebalancerRequestRewriteListener(rrl RequestRewriteListener) RebalancerOption { return func(r *Rebalancer) error { r.requestRewriteListener = rrl @@ -98,11 +105,14 @@ func RebalancerRequestRewriteListener(rrl RequestRewriteListener) RebalancerOpti } } +// NewRebalancer creates a new Rebalancer func NewRebalancer(handler balancerHandler, opts ...RebalancerOption) (*Rebalancer, error) { rb := &Rebalancer{ mtx: &sync.Mutex{}, next: handler, stickySession: nil, + + log: log.StandardLogger(), } for _, o := range opts { if err := o(rb); err != nil { @@ -134,6 +144,17 @@ func NewRebalancer(handler balancerHandler, opts ...RebalancerOption) (*Rebalanc return rb, nil } +// RebalancerLogger defines the logger the rebalancer will use. +// +// It defaults to logrus.StandardLogger(), the global logger used by logrus. +func RebalancerLogger(l *log.Logger) RebalancerOption { + return func(rb *Rebalancer) error { + rb.log = l + return nil + } +} + +// Servers gets all servers func (rb *Rebalancer) Servers() []*url.URL { rb.mtx.Lock() defer rb.mtx.Unlock() @@ -142,8 +163,8 @@ func (rb *Rebalancer) Servers() []*url.URL { } func (rb *Rebalancer) ServeHTTP(w http.ResponseWriter, req *http.Request) { - if log.GetLevel() >= log.DebugLevel { - logEntry := log.WithField("Request", utils.DumpHttpRequest(req)) + if rb.log.Level >= log.DebugLevel { + logEntry := rb.log.WithField("Request", utils.DumpHttpRequest(req)) logEntry.Debug("vulcand/oxy/roundrobin/rebalancer: begin ServeHttp on request") defer logEntry.Debug("vulcand/oxy/roundrobin/rebalancer: completed ServeHttp on request") } @@ -169,25 +190,25 @@ func (rb *Rebalancer) ServeHTTP(w http.ResponseWriter, req *http.Request) { } if !stuck { - url, err := rb.next.NextServer() + fwdURL, err := rb.next.NextServer() if err != nil { rb.errHandler.ServeHTTP(w, req, err) return } if log.GetLevel() >= log.DebugLevel { - //log which backend URL we're sending this request to - log.WithFields(log.Fields{"Request": utils.DumpHttpRequest(req), "ForwardURL": url}).Debugf("vulcand/oxy/roundrobin/rebalancer: Forwarding this request to URL") + // log which backend URL we're sending this request to + log.WithFields(log.Fields{"Request": utils.DumpHttpRequest(req), "ForwardURL": fwdURL}).Debugf("vulcand/oxy/roundrobin/rebalancer: Forwarding this request to URL") } if rb.stickySession != nil { - rb.stickySession.StickBackend(url, &w) + rb.stickySession.StickBackend(fwdURL, &w) } - newReq.URL = url + newReq.URL = fwdURL } - //Emit event to a listener if one exists + // Emit event to a listener if one exists if rb.requestRewriteListener != nil { rb.requestRewriteListener(req, &newReq) } @@ -215,6 +236,7 @@ func (rb *Rebalancer) reset() { rb.ratings = make([]float64, len(rb.servers)) } +// Wrap sets the next handler to be called by rebalancer handler. func (rb *Rebalancer) Wrap(next balancerHandler) error { if rb.next != nil { return fmt.Errorf("already bound to %T", rb.next) @@ -223,6 +245,7 @@ func (rb *Rebalancer) Wrap(next balancerHandler) error { return nil } +// UpsertServer upsert a server func (rb *Rebalancer) UpsertServer(u *url.URL, options ...ServerOption) error { rb.mtx.Lock() defer rb.mtx.Unlock() @@ -239,6 +262,7 @@ func (rb *Rebalancer) UpsertServer(u *url.URL, options ...ServerOption) error { return nil } +// RemoveServer remove a server func (rb *Rebalancer) RemoveServer(u *url.URL) error { rb.mtx.Lock() defer rb.mtx.Unlock() @@ -289,7 +313,7 @@ func (rb *Rebalancer) findServer(u *url.URL) (*rbServer, int) { return nil, -1 } -// Called on every load balancer ServeHTTP call, returns the suggested weights +// adjustWeights Called on every load balancer ServeHTTP call, returns the suggested weights // on every call, can adjust weights if needed. func (rb *Rebalancer) adjustWeights() { rb.mtx.Lock() @@ -319,7 +343,7 @@ func (rb *Rebalancer) adjustWeights() { func (rb *Rebalancer) applyWeights() { for _, srv := range rb.servers { - log.Debugf("upsert server %v, weight %v", srv.url, srv.curWeight) + rb.log.Debugf("upsert server %v, weight %v", srv.url, srv.curWeight) rb.next.UpsertServer(srv.url, Weight(srv.curWeight)) } } @@ -331,7 +355,7 @@ func (rb *Rebalancer) setMarkedWeights() bool { if srv.good { weight := increase(srv.curWeight) if weight <= FSMMaxWeight { - log.Debugf("increasing weight of %v from %v to %v", srv.url, srv.curWeight, weight) + rb.log.Debugf("increasing weight of %v from %v to %v", srv.url, srv.curWeight, weight) srv.curWeight = weight changed = true } @@ -378,7 +402,7 @@ func (rb *Rebalancer) markServers() bool { } } if len(g) != 0 && len(b) != 0 { - log.Debugf("bad: %v good: %v, ratings: %v", b, g, rb.ratings) + rb.log.Debugf("bad: %v good: %v, ratings: %v", b, g, rb.ratings) } return len(g) != 0 && len(b) != 0 } @@ -433,9 +457,8 @@ func decrease(target, current int) int { adjusted := current / FSMGrowFactor if adjusted < target { return target - } else { - return adjusted } + return adjusted } // rebalancer server record that keeps track of the original weight supplied by user @@ -448,9 +471,9 @@ type rbServer struct { } const ( - // This is the maximum weight that handler will set for the server + // FSMMaxWeight is the maximum weight that handler will set for the server FSMMaxWeight = 4096 - // Multiplier for the server weight + // FSMGrowFactor Multiplier for the server weight FSMGrowFactor = 4 ) @@ -460,10 +483,12 @@ type codeMeter struct { codeE int } +// Rating gets ratio func (n *codeMeter) Rating() float64 { return n.r.Ratio() } +// Record records a meter func (n *codeMeter) Record(code int, d time.Duration) { if code >= n.codeS && code < n.codeE { n.r.IncA(1) @@ -472,6 +497,7 @@ func (n *codeMeter) Record(code int, d time.Duration) { } } +// IsReady returns true if the counter is ready func (n *codeMeter) IsReady() bool { return n.r.IsReady() } diff --git a/vendor/github.com/vulcand/oxy/roundrobin/rr.go b/vendor/github.com/vulcand/oxy/roundrobin/rr.go index 053773b7d..631a97af8 100644 --- a/vendor/github.com/vulcand/oxy/roundrobin/rr.go +++ b/vendor/github.com/vulcand/oxy/roundrobin/rr.go @@ -1,4 +1,4 @@ -// package roundrobin implements dynamic weighted round robin load balancer http handler +// Package roundrobin implements dynamic weighted round robin load balancer http handler package roundrobin import ( @@ -30,6 +30,7 @@ func ErrorHandler(h utils.ErrorHandler) LBOption { } } +// EnableStickySession enable sticky session func EnableStickySession(stickySession *StickySession) LBOption { return func(s *RoundRobin) error { s.stickySession = stickySession @@ -37,7 +38,7 @@ func EnableStickySession(stickySession *StickySession) LBOption { } } -// ErrorHandler is a functional argument that sets error handler of the server +// RoundRobinRequestRewriteListener is a functional argument that sets error handler of the server func RoundRobinRequestRewriteListener(rrl RequestRewriteListener) LBOption { return func(s *RoundRobin) error { s.requestRewriteListener = rrl @@ -45,6 +46,7 @@ func RoundRobinRequestRewriteListener(rrl RequestRewriteListener) LBOption { } } +// RoundRobin implements dynamic weighted round robin load balancer http handler type RoundRobin struct { mutex *sync.Mutex next http.Handler @@ -55,8 +57,11 @@ type RoundRobin struct { currentWeight int stickySession *StickySession requestRewriteListener RequestRewriteListener + + log *log.Logger } +// New created a new RoundRobin func New(next http.Handler, opts ...LBOption) (*RoundRobin, error) { rr := &RoundRobin{ next: next, @@ -64,6 +69,8 @@ func New(next http.Handler, opts ...LBOption) (*RoundRobin, error) { mutex: &sync.Mutex{}, servers: []*server{}, stickySession: nil, + + log: log.StandardLogger(), } for _, o := range opts { if err := o(rr); err != nil { @@ -76,13 +83,24 @@ func New(next http.Handler, opts ...LBOption) (*RoundRobin, error) { return rr, nil } +// RoundRobinLogger defines the logger the round robin load balancer will use. +// +// It defaults to logrus.StandardLogger(), the global logger used by logrus. +func RoundRobinLogger(l *log.Logger) LBOption { + return func(r *RoundRobin) error { + r.log = l + return nil + } +} + +// Next returns the next handler func (r *RoundRobin) Next() http.Handler { return r.next } func (r *RoundRobin) ServeHTTP(w http.ResponseWriter, req *http.Request) { - if log.GetLevel() >= log.DebugLevel { - logEntry := log.WithField("Request", utils.DumpHttpRequest(req)) + if r.log.Level >= log.DebugLevel { + logEntry := r.log.WithField("Request", utils.DumpHttpRequest(req)) logEntry.Debug("vulcand/oxy/roundrobin/rr: begin ServeHttp on request") defer logEntry.Debug("vulcand/oxy/roundrobin/rr: completed ServeHttp on request") } @@ -116,12 +134,12 @@ func (r *RoundRobin) ServeHTTP(w http.ResponseWriter, req *http.Request) { newReq.URL = url } - if log.GetLevel() >= log.DebugLevel { - //log which backend URL we're sending this request to - log.WithFields(log.Fields{"Request": utils.DumpHttpRequest(req), "ForwardURL": newReq.URL}).Debugf("vulcand/oxy/roundrobin/rr: Forwarding this request to URL") + if r.log.Level >= log.DebugLevel { + // log which backend URL we're sending this request to + r.log.WithFields(log.Fields{"Request": utils.DumpHttpRequest(req), "ForwardURL": newReq.URL}).Debugf("vulcand/oxy/roundrobin/rr: Forwarding this request to URL") } - //Emit event to a listener if one exists + // Emit event to a listener if one exists if r.requestRewriteListener != nil { r.requestRewriteListener(req, &newReq) } @@ -129,6 +147,7 @@ func (r *RoundRobin) ServeHTTP(w http.ResponseWriter, req *http.Request) { r.next.ServeHTTP(w, &newReq) } +// NextServer gets the next server func (r *RoundRobin) NextServer() (*url.URL, error) { srv, err := r.nextServer() if err != nil { @@ -172,6 +191,7 @@ func (r *RoundRobin) nextServer() (*server, error) { } } +// RemoveServer remove a server func (r *RoundRobin) RemoveServer(u *url.URL) error { r.mutex.Lock() defer r.mutex.Unlock() @@ -185,6 +205,7 @@ func (r *RoundRobin) RemoveServer(u *url.URL) error { return nil } +// Servers gets servers URL func (r *RoundRobin) Servers() []*url.URL { r.mutex.Lock() defer r.mutex.Unlock() @@ -196,6 +217,7 @@ func (r *RoundRobin) Servers() []*url.URL { return out } +// ServerWeight gets the server weight func (r *RoundRobin) ServerWeight(u *url.URL) (int, bool) { r.mutex.Lock() defer r.mutex.Unlock() @@ -206,7 +228,7 @@ func (r *RoundRobin) ServerWeight(u *url.URL) (int, bool) { return -1, false } -// In case if server is already present in the load balancer, returns error +// UpsertServer In case if server is already present in the load balancer, returns error func (r *RoundRobin) UpsertServer(u *url.URL, options ...ServerOption) error { r.mutex.Lock() defer r.mutex.Unlock() @@ -306,6 +328,7 @@ type server struct { var defaultWeight = 1 +// SetDefaultWeight sets the default server weight func SetDefaultWeight(weight int) error { if weight < 0 { return fmt.Errorf("default weight should be >= 0") diff --git a/vendor/github.com/vulcand/oxy/roundrobin/stickysessions.go b/vendor/github.com/vulcand/oxy/roundrobin/stickysessions.go index 3fabeb975..123fbdfad 100644 --- a/vendor/github.com/vulcand/oxy/roundrobin/stickysessions.go +++ b/vendor/github.com/vulcand/oxy/roundrobin/stickysessions.go @@ -1,4 +1,3 @@ -// package stickysession is a mixin for load balancers that implements layer 7 (http cookie) session affinity package roundrobin import ( @@ -6,12 +5,14 @@ import ( "net/url" ) +// StickySession is a mixin for load balancers that implements layer 7 (http cookie) session affinity type StickySession struct { cookieName string } +// NewStickySession creates a new StickySession func NewStickySession(cookieName string) *StickySession { - return &StickySession{cookieName} + return &StickySession{cookieName: cookieName} } // GetBackend returns the backend URL stored in the sticky cookie, iff the backend is still in the valid list of servers. @@ -32,11 +33,11 @@ func (s *StickySession) GetBackend(req *http.Request, servers []*url.URL) (*url. if s.isBackendAlive(serverURL, servers) { return serverURL, true, nil - } else { - return nil, false, nil } + return nil, false, nil } +// StickBackend creates and sets the cookie func (s *StickySession) StickBackend(backend *url.URL, w *http.ResponseWriter) { cookie := &http.Cookie{Name: s.cookieName, Value: backend.String(), Path: "/"} http.SetCookie(*w, cookie) diff --git a/vendor/github.com/vulcand/oxy/utils/auth.go b/vendor/github.com/vulcand/oxy/utils/auth.go index b80b91685..4fd819cfe 100644 --- a/vendor/github.com/vulcand/oxy/utils/auth.go +++ b/vendor/github.com/vulcand/oxy/utils/auth.go @@ -6,6 +6,7 @@ import ( "strings" ) +// BasicAuth basic auth information type BasicAuth struct { Username string Password string @@ -16,6 +17,7 @@ func (ba *BasicAuth) String() string { return fmt.Sprintf("Basic %s", encoded) } +// ParseAuthHeader creates a new BasicAuth from header values func ParseAuthHeader(header string) (*BasicAuth, error) { values := strings.Fields(header) if len(values) != 2 { diff --git a/vendor/github.com/vulcand/oxy/utils/dumpreq.go b/vendor/github.com/vulcand/oxy/utils/dumpreq.go index ef34d38f6..eecb2220c 100644 --- a/vendor/github.com/vulcand/oxy/utils/dumpreq.go +++ b/vendor/github.com/vulcand/oxy/utils/dumpreq.go @@ -9,6 +9,7 @@ import ( "net/url" ) +// SerializableHttpRequest serializable HTTP request type SerializableHttpRequest struct { Method string URL *url.URL @@ -28,6 +29,7 @@ type SerializableHttpRequest struct { TLS *tls.ConnectionState } +// Clone clone a request func Clone(r *http.Request) *SerializableHttpRequest { if r == nil { return nil @@ -47,14 +49,16 @@ func Clone(r *http.Request) *SerializableHttpRequest { return rc } +// ToJson serializes to JSON func (s *SerializableHttpRequest) ToJson() string { - if jsonVal, err := json.Marshal(s); err != nil || jsonVal == nil { - return fmt.Sprintf("Error marshalling SerializableHttpRequest to json: %s", err.Error()) - } else { - return string(jsonVal) + jsonVal, err := json.Marshal(s) + if err != nil || jsonVal == nil { + return fmt.Sprintf("Error marshalling SerializableHttpRequest to json: %s", err) } + return string(jsonVal) } +// DumpHttpRequest dump a HTTP request to JSON func DumpHttpRequest(req *http.Request) string { - return fmt.Sprintf("%v", Clone(req).ToJson()) + return Clone(req).ToJson() } diff --git a/vendor/github.com/vulcand/oxy/utils/handler.go b/vendor/github.com/vulcand/oxy/utils/handler.go index 003fc0319..24b9e3a88 100644 --- a/vendor/github.com/vulcand/oxy/utils/handler.go +++ b/vendor/github.com/vulcand/oxy/utils/handler.go @@ -1,22 +1,34 @@ package utils import ( + "context" "io" "net" "net/http" + + log "github.com/sirupsen/logrus" ) +// StatusClientClosedRequest non-standard HTTP status code for client disconnection +const StatusClientClosedRequest = 499 + +// StatusClientClosedRequestText non-standard HTTP status for client disconnection +const StatusClientClosedRequestText = "Client Closed Request" + +// ErrorHandler error handler type ErrorHandler interface { ServeHTTP(w http.ResponseWriter, req *http.Request, err error) } +// DefaultHandler default error handler var DefaultHandler ErrorHandler = &StdHandler{} -type StdHandler struct { -} +// StdHandler Standard error handler +type StdHandler struct{} func (e *StdHandler) ServeHTTP(w http.ResponseWriter, req *http.Request, err error) { statusCode := http.StatusInternalServerError + if e, ok := err.(net.Error); ok { if e.Timeout() { statusCode = http.StatusGatewayTimeout @@ -25,11 +37,23 @@ func (e *StdHandler) ServeHTTP(w http.ResponseWriter, req *http.Request, err err } } else if err == io.EOF { statusCode = http.StatusBadGateway + } else if err == context.Canceled { + statusCode = StatusClientClosedRequest } + w.WriteHeader(statusCode) - w.Write([]byte(http.StatusText(statusCode))) + w.Write([]byte(statusText(statusCode))) + log.Debugf("'%d %s' caused by: %v", statusCode, statusText(statusCode), err) } +func statusText(statusCode int) string { + if statusCode == StatusClientClosedRequest { + return StatusClientClosedRequestText + } + return http.StatusText(statusCode) +} + +// ErrorHandlerFunc error handler function type type ErrorHandlerFunc func(http.ResponseWriter, *http.Request, error) // ServeHTTP calls f(w, r). diff --git a/vendor/github.com/vulcand/oxy/utils/netutils.go b/vendor/github.com/vulcand/oxy/utils/netutils.go index 95c30e7e5..692d30038 100644 --- a/vendor/github.com/vulcand/oxy/utils/netutils.go +++ b/vendor/github.com/vulcand/oxy/utils/netutils.go @@ -12,18 +12,29 @@ import ( log "github.com/sirupsen/logrus" ) +// ProxyWriter calls recorder, used to debug logs type ProxyWriter struct { - W http.ResponseWriter + w http.ResponseWriter code int length int64 + + log *log.Logger } -func NewProxyWriter(writer http.ResponseWriter) *ProxyWriter { +// NewProxyWriter creates a new ProxyWriter +func NewProxyWriter(w http.ResponseWriter) *ProxyWriter { + return NewProxyWriterWithLogger(w, log.StandardLogger()) +} + +// NewProxyWriterWithLogger creates a new ProxyWriter +func NewProxyWriterWithLogger(w http.ResponseWriter, l *log.Logger) *ProxyWriter { return &ProxyWriter{ - W: writer, + w: w, + log: l, } } +// StatusCode gets status code func (p *ProxyWriter) StatusCode() int { if p.code == 0 { // per contract standard lib will set this to http.StatusOK if not set @@ -33,46 +44,54 @@ func (p *ProxyWriter) StatusCode() int { return p.code } +// GetLength gets content length func (p *ProxyWriter) GetLength() int64 { return p.length } +// Header gets response header func (p *ProxyWriter) Header() http.Header { - return p.W.Header() + return p.w.Header() } func (p *ProxyWriter) Write(buf []byte) (int, error) { p.length = p.length + int64(len(buf)) - return p.W.Write(buf) + return p.w.Write(buf) } +// WriteHeader writes status code func (p *ProxyWriter) WriteHeader(code int) { p.code = code - p.W.WriteHeader(code) + p.w.WriteHeader(code) } +// Flush flush the writer func (p *ProxyWriter) Flush() { - if f, ok := p.W.(http.Flusher); ok { + if f, ok := p.w.(http.Flusher); ok { f.Flush() } } +// CloseNotify returns a channel that receives at most a single value (true) +// when the client connection has gone away. func (p *ProxyWriter) CloseNotify() <-chan bool { - if cn, ok := p.W.(http.CloseNotifier); ok { + if cn, ok := p.w.(http.CloseNotifier); ok { return cn.CloseNotify() } - log.Debugf("Upstream ResponseWriter of type %v does not implement http.CloseNotifier. Returning dummy channel.", reflect.TypeOf(p.W)) + p.log.Debugf("Upstream ResponseWriter of type %v does not implement http.CloseNotifier. Returning dummy channel.", reflect.TypeOf(p.w)) return make(<-chan bool) } +// Hijack lets the caller take over the connection. func (p *ProxyWriter) Hijack() (net.Conn, *bufio.ReadWriter, error) { - if hi, ok := p.W.(http.Hijacker); ok { + if hi, ok := p.w.(http.Hijacker); ok { return hi.Hijack() } - log.Debugf("Upstream ResponseWriter of type %v does not implement http.Hijacker. Returning dummy channel.", reflect.TypeOf(p.W)) - return nil, nil, fmt.Errorf("the response writer that was wrapped in this proxy, does not implement http.Hijacker. It is of type: %v", reflect.TypeOf(p.W)) + p.log.Debugf("Upstream ResponseWriter of type %v does not implement http.Hijacker. Returning dummy channel.", reflect.TypeOf(p.w)) + return nil, nil, fmt.Errorf("the response writer that was wrapped in this proxy, does not implement http.Hijacker. It is of type: %v", reflect.TypeOf(p.w)) } +// NewBufferWriter creates a new BufferWriter func NewBufferWriter(w io.WriteCloser) *BufferWriter { return &BufferWriter{ W: w, @@ -80,16 +99,19 @@ func NewBufferWriter(w io.WriteCloser) *BufferWriter { } } +// BufferWriter buffer writer type BufferWriter struct { H http.Header Code int W io.WriteCloser } +// Close close the writer func (b *BufferWriter) Close() error { return b.W.Close() } +// Header gets response header func (b *BufferWriter) Header() http.Header { return b.H } @@ -98,11 +120,13 @@ func (b *BufferWriter) Write(buf []byte) (int, error) { return b.W.Write(buf) } -// WriteHeader sets rw.Code. +// WriteHeader writes status code func (b *BufferWriter) WriteHeader(code int) { b.Code = code } +// CloseNotify returns a channel that receives at most a single value (true) +// when the client connection has gone away. func (b *BufferWriter) CloseNotify() <-chan bool { if cn, ok := b.W.(http.CloseNotifier); ok { return cn.CloseNotify() @@ -111,6 +135,7 @@ func (b *BufferWriter) CloseNotify() <-chan bool { return make(<-chan bool) } +// Hijack lets the caller take over the connection. func (b *BufferWriter) Hijack() (net.Conn, *bufio.ReadWriter, error) { if hi, ok := b.W.(http.Hijacker); ok { return hi.Hijack() @@ -125,10 +150,10 @@ type nopWriteCloser struct { func (*nopWriteCloser) Close() error { return nil } -// NopCloser returns a WriteCloser with a no-op Close method wrapping +// NopWriteCloser returns a WriteCloser with a no-op Close method wrapping // the provided Writer w. func NopWriteCloser(w io.Writer) io.WriteCloser { - return &nopWriteCloser{w} + return &nopWriteCloser{Writer: w} } // CopyURL provides update safe copy by avoiding shallow copying User field diff --git a/vendor/github.com/vulcand/oxy/utils/source.go b/vendor/github.com/vulcand/oxy/utils/source.go index 4ed89e13b..5306b597c 100644 --- a/vendor/github.com/vulcand/oxy/utils/source.go +++ b/vendor/github.com/vulcand/oxy/utils/source.go @@ -6,21 +6,25 @@ import ( "strings" ) -// ExtractSource extracts the source from the request, e.g. that may be client ip, or particular header that +// SourceExtractor extracts the source from the request, e.g. that may be client ip, or particular header that // identifies the source. amount stands for amount of connections the source consumes, usually 1 for connection limiters // error should be returned when source can not be identified type SourceExtractor interface { Extract(req *http.Request) (token string, amount int64, err error) } +// ExtractorFunc extractor function type type ExtractorFunc func(req *http.Request) (token string, amount int64, err error) +// Extract extract from request func (f ExtractorFunc) Extract(req *http.Request) (string, int64, error) { return f(req) } +// ExtractSource extract source function type type ExtractSource func(req *http.Request) +// NewExtractor creates a new SourceExtractor func NewExtractor(variable string) (SourceExtractor, error) { if variable == "client.ip" { return ExtractorFunc(extractClientIP), nil @@ -31,17 +35,17 @@ func NewExtractor(variable string) (SourceExtractor, error) { if strings.HasPrefix(variable, "request.header.") { header := strings.TrimPrefix(variable, "request.header.") if len(header) == 0 { - return nil, fmt.Errorf("Wrong header: %s", header) + return nil, fmt.Errorf("wrong header: %s", header) } return makeHeaderExtractor(header), nil } - return nil, fmt.Errorf("Unsupported limiting variable: '%s'", variable) + return nil, fmt.Errorf("unsupported limiting variable: '%s'", variable) } func extractClientIP(req *http.Request) (string, int64, error) { vals := strings.SplitN(req.RemoteAddr, ":", 2) if len(vals[0]) == 0 { - return "", 0, fmt.Errorf("Failed to parse client IP: %v", req.RemoteAddr) + return "", 0, fmt.Errorf("failed to parse client IP: %v", req.RemoteAddr) } return vals[0], 1, nil } From a7bb768e98064610baaf37be95b9c873386752e1 Mon Sep 17 00:00:00 2001 From: SALLEYRON Julien Date: Mon, 20 Aug 2018 11:16:02 +0200 Subject: [PATCH 3/9] Remove TLS in API --- integration/consul_test.go | 21 +++-------- integration/etcd3_test.go | 42 +++++----------------- integration/etcd_test.go | 21 +++-------- integration/fixtures/file/dir/simple2.toml | 2 +- integration/try/condition.go | 25 +++++++++++++ integration/try/try.go | 35 +++++++++++++----- types/types.go | 2 +- 7 files changed, 69 insertions(+), 79 deletions(-) diff --git a/integration/consul_test.go b/integration/consul_test.go index 7014e5149..77860fc5a 100644 --- a/integration/consul_test.go +++ b/integration/consul_test.go @@ -585,21 +585,14 @@ func (s *ConsulSuite) TestSNIDynamicTlsConfig(c *check.C) { }) c.Assert(err, checker.IsNil) - // wait for traefik - err = try.GetRequest("http://127.0.0.1:8081/api/providers", 60*time.Second, try.BodyContains("MIIEpQIBAAKCAQEA1RducBK6EiFDv3TYB8ZcrfKWRVaSfHzWicO3J5WdST9oS7hG")) - c.Assert(err, checker.IsNil) - req, err := http.NewRequest(http.MethodGet, "https://127.0.0.1:4443/", nil) c.Assert(err, checker.IsNil) - client := &http.Client{Transport: tr1} req.Host = tr1.TLSClientConfig.ServerName req.Header.Set("Host", tr1.TLSClientConfig.ServerName) req.Header.Set("Accept", "*/*") - var resp *http.Response - resp, err = client.Do(req) + + err = try.RequestWithTransport(req, 30*time.Second, tr1, try.HasCn("snitest.com")) c.Assert(err, checker.IsNil) - cn := resp.TLS.PeerCertificates[0].Subject.CommonName - c.Assert(cn, checker.Equals, "snitest.com") // now we configure the second keypair in consul and the request for host "snitest.org" will use the second keypair for key, value := range tlsconfigure2 { @@ -614,18 +607,12 @@ func (s *ConsulSuite) TestSNIDynamicTlsConfig(c *check.C) { }) c.Assert(err, checker.IsNil) - // waiting for traefik to pull configuration - err = try.GetRequest("http://127.0.0.1:8081/api/providers", 30*time.Second, try.BodyContains("MIIEogIBAAKCAQEAvG9kL+vF57+MICehzbqcQAUlAOSl5r")) - c.Assert(err, checker.IsNil) - req, err = http.NewRequest(http.MethodGet, "https://127.0.0.1:4443/", nil) c.Assert(err, checker.IsNil) - client = &http.Client{Transport: tr2} req.Host = tr2.TLSClientConfig.ServerName req.Header.Set("Host", tr2.TLSClientConfig.ServerName) req.Header.Set("Accept", "*/*") - resp, err = client.Do(req) + + err = try.RequestWithTransport(req, 30*time.Second, tr2, try.HasCn("snitest.org")) c.Assert(err, checker.IsNil) - cn = resp.TLS.PeerCertificates[0].Subject.CommonName - c.Assert(cn, checker.Equals, "snitest.org") } diff --git a/integration/etcd3_test.go b/integration/etcd3_test.go index 0d01d861e..3471dd9dd 100644 --- a/integration/etcd3_test.go +++ b/integration/etcd3_test.go @@ -532,21 +532,14 @@ func (s *Etcd3Suite) TestSNIDynamicTlsConfig(c *check.C) { c.Assert(err, checker.IsNil) defer cmd.Process.Kill() - // wait for Træfik - err = try.GetRequest("http://127.0.0.1:8081/api/providers", 60*time.Second, try.BodyContains(string("MIIEpQIBAAKCAQEA1RducBK6EiFDv3TYB8ZcrfKWRVaSfHzWicO3J5WdST9oS7h"))) - c.Assert(err, checker.IsNil) - req, err := http.NewRequest(http.MethodGet, "https://127.0.0.1:4443/", nil) c.Assert(err, checker.IsNil) - client := &http.Client{Transport: tr1} req.Host = tr1.TLSClientConfig.ServerName req.Header.Set("Host", tr1.TLSClientConfig.ServerName) req.Header.Set("Accept", "*/*") - var resp *http.Response - resp, err = client.Do(req) + + err = try.RequestWithTransport(req, 30*time.Second, tr1, try.HasCn("snitest.com")) c.Assert(err, checker.IsNil) - cn := resp.TLS.PeerCertificates[0].Subject.CommonName - c.Assert(cn, checker.Equals, "snitest.com") // now we configure the second keypair in etcd and the request for host "snitest.org" will use the second keypair @@ -562,20 +555,14 @@ func (s *Etcd3Suite) TestSNIDynamicTlsConfig(c *check.C) { }) c.Assert(err, checker.IsNil) - // waiting for Træfik to pull configuration - err = try.GetRequest("http://127.0.0.1:8081/api/providers", 30*time.Second, try.BodyContains("MIIEogIBAAKCAQEAvG9kL+vF57+MICehzbqcQAUlAOSl5r")) - c.Assert(err, checker.IsNil) - req, err = http.NewRequest(http.MethodGet, "https://127.0.0.1:4443/", nil) c.Assert(err, checker.IsNil) - client = &http.Client{Transport: tr2} req.Host = tr2.TLSClientConfig.ServerName req.Header.Set("Host", tr2.TLSClientConfig.ServerName) req.Header.Set("Accept", "*/*") - resp, err = client.Do(req) + + err = try.RequestWithTransport(req, 30*time.Second, tr2, try.HasCn("snitest.org")) c.Assert(err, checker.IsNil) - cn = resp.TLS.PeerCertificates[0].Subject.CommonName - c.Assert(cn, checker.Equals, "snitest.org") } func (s *Etcd3Suite) TestDeleteSNIDynamicTlsConfig(c *check.C) { @@ -646,21 +633,14 @@ func (s *Etcd3Suite) TestDeleteSNIDynamicTlsConfig(c *check.C) { c.Assert(err, checker.IsNil) defer cmd.Process.Kill() - // wait for Træfik - err = try.GetRequest(traefikWebEtcdURL+"api/providers", 60*time.Second, try.BodyContains(string("MIIEpQIBAAKCAQEA1RducBK6EiFDv3TYB8ZcrfKWRVaSfHzWicO3J5WdST9oS7h"))) - c.Assert(err, checker.IsNil) - req, err := http.NewRequest(http.MethodGet, "https://127.0.0.1:4443/", nil) c.Assert(err, checker.IsNil) - client := &http.Client{Transport: tr1} req.Host = tr1.TLSClientConfig.ServerName req.Header.Set("Host", tr1.TLSClientConfig.ServerName) req.Header.Set("Accept", "*/*") - var resp *http.Response - resp, err = client.Do(req) + + err = try.RequestWithTransport(req, 30*time.Second, tr1, try.HasCn("snitest.com")) c.Assert(err, checker.IsNil) - cn := resp.TLS.PeerCertificates[0].Subject.CommonName - c.Assert(cn, checker.Equals, "snitest.com") // now we delete the tls cert/key pairs,so the endpoint show use default cert/key pair for key := range tlsconfigure1 { @@ -668,18 +648,12 @@ func (s *Etcd3Suite) TestDeleteSNIDynamicTlsConfig(c *check.C) { c.Assert(err, checker.IsNil) } - // waiting for Træfik to pull configuration - err = try.GetRequest(traefikWebEtcdURL+"api/providers", 30*time.Second, try.BodyNotContains("MIIEpQIBAAKCAQEA1RducBK6EiFDv3TYB8ZcrfKWRVaSfHzWicO3J5WdST9oS7h")) - c.Assert(err, checker.IsNil) - req, err = http.NewRequest(http.MethodGet, "https://127.0.0.1:4443/", nil) c.Assert(err, checker.IsNil) - client = &http.Client{Transport: tr1} req.Host = tr1.TLSClientConfig.ServerName req.Header.Set("Host", tr1.TLSClientConfig.ServerName) req.Header.Set("Accept", "*/*") - resp, err = client.Do(req) + + err = try.RequestWithTransport(req, 30*time.Second, tr1, try.HasCn("TRAEFIK DEFAULT CERT")) c.Assert(err, checker.IsNil) - cn = resp.TLS.PeerCertificates[0].Subject.CommonName - c.Assert(cn, checker.Equals, "TRAEFIK DEFAULT CERT") } diff --git a/integration/etcd_test.go b/integration/etcd_test.go index 6e6133f86..9ae0e303f 100644 --- a/integration/etcd_test.go +++ b/integration/etcd_test.go @@ -548,21 +548,14 @@ func (s *EtcdSuite) TestSNIDynamicTlsConfig(c *check.C) { c.Assert(err, checker.IsNil) defer cmd.Process.Kill() - // wait for Træfik - err = try.GetRequest("http://127.0.0.1:8081/api/providers", 60*time.Second, try.BodyContains(string("MIIEpQIBAAKCAQEA1RducBK6EiFDv3TYB8ZcrfKWRVaSfHzWicO3J5WdST9oS7h"))) - c.Assert(err, checker.IsNil) - req, err := http.NewRequest(http.MethodGet, "https://127.0.0.1:4443/", nil) c.Assert(err, checker.IsNil) - client := &http.Client{Transport: tr1} req.Host = tr1.TLSClientConfig.ServerName req.Header.Set("Host", tr1.TLSClientConfig.ServerName) req.Header.Set("Accept", "*/*") - var resp *http.Response - resp, err = client.Do(req) + + err = try.RequestWithTransport(req, 30*time.Second, tr1, try.HasCn("snitest.com")) c.Assert(err, checker.IsNil) - cn := resp.TLS.PeerCertificates[0].Subject.CommonName - c.Assert(cn, checker.Equals, "snitest.com") // now we configure the second keypair in etcd and the request for host "snitest.org" will use the second keypair @@ -578,18 +571,12 @@ func (s *EtcdSuite) TestSNIDynamicTlsConfig(c *check.C) { }) c.Assert(err, checker.IsNil) - // waiting for Træfik to pull configuration - err = try.GetRequest("http://127.0.0.1:8081/api/providers", 30*time.Second, try.BodyContains("MIIEogIBAAKCAQEAvG9kL+vF57+MICehzbqcQAUlAOSl5r")) - c.Assert(err, checker.IsNil) - req, err = http.NewRequest(http.MethodGet, "https://127.0.0.1:4443/", nil) c.Assert(err, checker.IsNil) - client = &http.Client{Transport: tr2} req.Host = tr2.TLSClientConfig.ServerName req.Header.Set("Host", tr2.TLSClientConfig.ServerName) req.Header.Set("Accept", "*/*") - resp, err = client.Do(req) + + err = try.RequestWithTransport(req, 30*time.Second, tr2, try.HasCn("snitest.org")) c.Assert(err, checker.IsNil) - cn = resp.TLS.PeerCertificates[0].Subject.CommonName - c.Assert(cn, checker.Equals, "snitest.org") } diff --git a/integration/fixtures/file/dir/simple2.toml b/integration/fixtures/file/dir/simple2.toml index e02f63550..dcbcffc57 100644 --- a/integration/fixtures/file/dir/simple2.toml +++ b/integration/fixtures/file/dir/simple2.toml @@ -2,7 +2,7 @@ [backends] [backends.backend2] [backends.backend2.servers.server1] - url = "http://172.17.0.2:80" + url = "http://172.17.0.123:80" weight = 1 [frontends] diff --git a/integration/try/condition.go b/integration/try/condition.go index e7ee9c9d2..a3d5d7656 100644 --- a/integration/try/condition.go +++ b/integration/try/condition.go @@ -88,6 +88,31 @@ func HasBody() ResponseCondition { } } +// HasCn returns a retry condition function. +// The condition returns an error if the cn is not correct. +func HasCn(cn string) ResponseCondition { + return func(res *http.Response) error { + if res.TLS == nil { + return errors.New("response doesn't have TLS") + } + + if len(res.TLS.PeerCertificates) == 0 { + return errors.New("response TLS doesn't have peer certificates") + } + + if res.TLS.PeerCertificates[0] == nil { + return errors.New("first peer certificate is nil") + } + + commonName := res.TLS.PeerCertificates[0].Subject.CommonName + if cn != commonName { + return fmt.Errorf("common name don't match: %s != %s", cn, commonName) + } + + return nil + } +} + // StatusCodeIs returns a retry condition function. // The condition returns an error if the given response's status code is not the // given HTTP status code. diff --git a/integration/try/try.go b/integration/try/try.go index f201cd0d8..0f75cbdaf 100644 --- a/integration/try/try.go +++ b/integration/try/try.go @@ -31,7 +31,7 @@ func Sleep(d time.Duration) { // response body needs to be closed or not. Callers are expected to close on // their own if the function returns a nil error. func Response(req *http.Request, timeout time.Duration) (*http.Response, error) { - return doTryRequest(req, timeout) + return doTryRequest(req, timeout, nil) } // ResponseUntilStatusCode is like Request, but returns the response for further @@ -40,7 +40,7 @@ func Response(req *http.Request, timeout time.Duration) (*http.Response, error) // response body needs to be closed or not. Callers are expected to close on // their own if the function returns a nil error. func ResponseUntilStatusCode(req *http.Request, timeout time.Duration, statusCode int) (*http.Response, error) { - return doTryRequest(req, timeout, StatusCodeIs(statusCode)) + return doTryRequest(req, timeout, nil, StatusCodeIs(statusCode)) } // GetRequest is like Do, but runs a request against the given URL and applies @@ -48,7 +48,7 @@ func ResponseUntilStatusCode(req *http.Request, timeout time.Duration, statusCod // ResponseCondition may be nil, in which case only the request against the URL must // succeed. func GetRequest(url string, timeout time.Duration, conditions ...ResponseCondition) error { - resp, err := doTryGet(url, timeout, conditions...) + resp, err := doTryGet(url, timeout, nil, conditions...) if resp != nil && resp.Body != nil { defer resp.Body.Close() @@ -62,7 +62,21 @@ func GetRequest(url string, timeout time.Duration, conditions ...ResponseConditi // ResponseCondition may be nil, in which case only the request against the URL must // succeed. func Request(req *http.Request, timeout time.Duration, conditions ...ResponseCondition) error { - resp, err := doTryRequest(req, timeout, conditions...) + resp, err := doTryRequest(req, timeout, nil, conditions...) + + if resp != nil && resp.Body != nil { + defer resp.Body.Close() + } + + return err +} + +// RequestWithTransport is like Do, but runs a request against the given URL and applies +// the condition on the response. +// ResponseCondition may be nil, in which case only the request against the URL must +// succeed. +func RequestWithTransport(req *http.Request, timeout time.Duration, transport *http.Transport, conditions ...ResponseCondition) error { + resp, err := doTryRequest(req, timeout, transport, conditions...) if resp != nil && resp.Body != nil { defer resp.Body.Close() @@ -112,24 +126,27 @@ func Do(timeout time.Duration, operation DoCondition) error { } } -func doTryGet(url string, timeout time.Duration, conditions ...ResponseCondition) (*http.Response, error) { +func doTryGet(url string, timeout time.Duration, transport *http.Transport, conditions ...ResponseCondition) (*http.Response, error) { req, err := http.NewRequest(http.MethodGet, url, nil) if err != nil { return nil, err } - return doTryRequest(req, timeout, conditions...) + return doTryRequest(req, timeout, transport, conditions...) } -func doTryRequest(request *http.Request, timeout time.Duration, conditions ...ResponseCondition) (*http.Response, error) { - return doRequest(Do, timeout, request, conditions...) +func doTryRequest(request *http.Request, timeout time.Duration, transport *http.Transport, conditions ...ResponseCondition) (*http.Response, error) { + return doRequest(Do, timeout, request, transport, conditions...) } -func doRequest(action timedAction, timeout time.Duration, request *http.Request, conditions ...ResponseCondition) (*http.Response, error) { +func doRequest(action timedAction, timeout time.Duration, request *http.Request, transport *http.Transport, conditions ...ResponseCondition) (*http.Response, error) { var resp *http.Response return resp, action(timeout, func() error { var err error client := http.DefaultClient + if transport != nil { + client.Transport = transport + } resp, err = client.Do(request) if err != nil { diff --git a/types/types.go b/types/types.go index eed737bee..ca893b7fc 100644 --- a/types/types.go +++ b/types/types.go @@ -235,7 +235,7 @@ type Configurations map[string]*Configuration type Configuration struct { Backends map[string]*Backend `json:"backends,omitempty"` Frontends map[string]*Frontend `json:"frontends,omitempty"` - TLS []*traefiktls.Configuration `json:"tls,omitempty"` + TLS []*traefiktls.Configuration `json:"-"` } // ConfigMessage hold configuration information exchanged between parts of traefik. From f062ee80c8daa0ede62f094e9a72a1f5d38dd0d0 Mon Sep 17 00:00:00 2001 From: Damien Duportal Date: Mon, 20 Aug 2018 12:02:03 +0200 Subject: [PATCH 4/9] Docs: Adding warnings and solution about the configuration exposure --- docs/configuration/api.md | 21 ++++++++++++++++++++- docs/index.md | 4 ++++ 2 files changed, 24 insertions(+), 1 deletion(-) diff --git a/docs/configuration/api.md b/docs/configuration/api.md index 9da05c7e8..05aa23277 100644 --- a/docs/configuration/api.md +++ b/docs/configuration/api.md @@ -4,6 +4,9 @@ ```toml # API definition +# Warning: Enabling API will expose Træfik's configuration and secret. +# It is not recommended in production, +# unless secured by authentication and authorizations [api] # Name of the related entry point # @@ -12,7 +15,7 @@ # entryPoint = "traefik" - # Enabled Dashboard + # Enable Dashboard # # Optional # Default: true @@ -38,6 +41,22 @@ For more customization, see [entry points](/configuration/entrypoints/) document ![Web UI Health](/img/traefik-health.png) +## Security + +Enabling the API will expose all configuration elements, +including secret. + +It is not recommended in production, +unless secured by authentication and authorizations. + +A good sane default (but not exhaustive) set of recommendations +would be to apply the following protection mechanism: + +* _At application level:_ enabling HTTP [Basic Authentication](#authentication) +* _At transport level:_ NOT exposing publicly the API's port, +keeping it restricted over internal networks +(restricted networks as in https://en.wikipedia.org/wiki/Principle_of_least_privilege). + ## API | Path | Method | Description | diff --git a/docs/index.md b/docs/index.md index a0fb81280..a7b6cc339 100644 --- a/docs/index.md +++ b/docs/index.md @@ -86,6 +86,10 @@ services: - /var/run/docker.sock:/var/run/docker.sock # So that Traefik can listen to the Docker events ``` +!!! warning + Enabling the Web UI with the `--api` flag might exposes configuration elements. You can read more about this on the [API/Dashboard's Security section](/configuration/api#security). + + **That's it. Now you can launch Træfik!** Start your `reverse-proxy` with the following command: From 2beb5236d0c9fe0e8f964d69688b1a05d510db30 Mon Sep 17 00:00:00 2001 From: Damien Duportal Date: Mon, 20 Aug 2018 13:34:03 +0200 Subject: [PATCH 5/9] A tiny rewording on the documentation API's page --- docs/configuration/api.md | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/docs/configuration/api.md b/docs/configuration/api.md index 05aa23277..e6e8aa932 100644 --- a/docs/configuration/api.md +++ b/docs/configuration/api.md @@ -4,7 +4,7 @@ ```toml # API definition -# Warning: Enabling API will expose Træfik's configuration and secret. +# Warning: Enabling API will expose Træfik's configuration. # It is not recommended in production, # unless secured by authentication and authorizations [api] @@ -44,7 +44,7 @@ For more customization, see [entry points](/configuration/entrypoints/) document ## Security Enabling the API will expose all configuration elements, -including secret. +including sensitive data. It is not recommended in production, unless secured by authentication and authorizations. From feeb7f81a611eb58685ffd28478add6d179b091f Mon Sep 17 00:00:00 2001 From: Ludovic Fernandez Date: Mon, 20 Aug 2018 14:46:02 +0200 Subject: [PATCH 6/9] Prepare Release v1.6.6 --- CHANGELOG.md | 18 ++++++++++++++++++ 1 file changed, 18 insertions(+) diff --git a/CHANGELOG.md b/CHANGELOG.md index a194c611f..d763ef91c 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -1,5 +1,23 @@ # Change Log +## [v1.6.6](https://github.com/containous/traefik/tree/v1.6.6) (2018-08-20) +[All Commits](https://github.com/containous/traefik/compare/v1.6.5...v1.6.6) + +**Bug fixes:** +- **[acme]** Avoid duplicated ACME resolution ([#3751](https://github.com/containous/traefik/pull/3751) by [nmengin](https://github.com/nmengin)) +- **[api]** Remove TLS in API ([#3788](https://github.com/containous/traefik/pull/3788) by [Juliens](https://github.com/Juliens)) +- **[cluster]** Remove unusable `--cluster` flag ([#3616](https://github.com/containous/traefik/pull/3616) by [dtomcej](https://github.com/dtomcej)) +- **[ecs]** Fix bad condition in ECS provider ([#3609](https://github.com/containous/traefik/pull/3609) by [mmatur](https://github.com/mmatur)) +- Set keepalive on TCP socket so idleTimeout works ([#3740](https://github.com/containous/traefik/pull/3740) by [ajardan](https://github.com/ajardan)) + +**Documentation:** +- A tiny rewording on the documentation API's page ([#3794](https://github.com/containous/traefik/pull/3794) by [dduportal](https://github.com/dduportal)) +- Adding warnings and solution about the configuration exposure ([#3790](https://github.com/containous/traefik/pull/3790) by [dduportal](https://github.com/dduportal)) +- Fix path to the debug pprof API ([#3608](https://github.com/containous/traefik/pull/3608) by [multani](https://github.com/multani)) + +**Misc:** +- **[oxy,websocket]** Update oxy dependency ([#3777](https://github.com/containous/traefik/pull/3777) by [Juliens](https://github.com/Juliens)) + ## [v1.6.5](https://github.com/containous/traefik/tree/v1.6.5) (2018-07-09) [All Commits](https://github.com/containous/traefik/compare/v1.6.4...v1.6.5) From df41cd925e7764883c119346c95a26bfcc4e134f Mon Sep 17 00:00:00 2001 From: Emile Vauge Date: Mon, 20 Aug 2018 17:08:03 +0200 Subject: [PATCH 7/9] Add vulnerability form --- docs/index.md | 5 +++++ mkdocs.yml | 3 --- 2 files changed, 5 insertions(+), 3 deletions(-) diff --git a/docs/index.md b/docs/index.md index a7b6cc339..e5c038b59 100644 --- a/docs/index.md +++ b/docs/index.md @@ -203,3 +203,8 @@ Using the tiny Docker image: ```shell docker run -d -p 8080:8080 -p 80:80 -v $PWD/traefik.toml:/etc/traefik/traefik.toml traefik ``` + +## Security + +We want to keep Træfik safe for everyone. +If you've discovered a security vulnerability in Træfik, we appreciate your help in disclosing it to us in a responsible manner, using [this form](https://security.traefik.io). \ No newline at end of file diff --git a/mkdocs.yml b/mkdocs.yml index cf530c9e8..673698651 100644 --- a/mkdocs.yml +++ b/mkdocs.yml @@ -21,9 +21,6 @@ theme: accent: 'light blue' feature: tabs: false - palette: - primary: 'cyan' - accent: 'cyan' i18n: prev: 'Previous' next: 'Next' From cf2d7497e4d199c99dc2fb0c7af87e1016ddc5be Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?David=20Rodr=C3=ADguez?= Date: Mon, 20 Aug 2018 12:34:05 -0300 Subject: [PATCH 8/9] Mention docker-compose as a requirement in the let's encrypt guide --- docs/user-guide/docker-and-lets-encrypt.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/docs/user-guide/docker-and-lets-encrypt.md b/docs/user-guide/docker-and-lets-encrypt.md index 8fd443037..22eef0a10 100644 --- a/docs/user-guide/docker-and-lets-encrypt.md +++ b/docs/user-guide/docker-and-lets-encrypt.md @@ -8,7 +8,7 @@ In addition, we want to use Let's Encrypt to automatically generate and renew SS ## Setting Up -In order for this to work, you'll need a server with a public IP address, with Docker installed on it. +In order for this to work, you'll need a server with a public IP address, with Docker and docker-compose installed on it. In this example, we're using the fictitious domain _my-awesome-app.org_. From 27e4a8a227d32343fb40850d36c326a25c10a99c Mon Sep 17 00:00:00 2001 From: Emile Vauge Date: Mon, 20 Aug 2018 17:50:04 +0200 Subject: [PATCH 9/9] Fixes bad palette in doc --- mkdocs.yml | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/mkdocs.yml b/mkdocs.yml index 673698651..f85f2c2fa 100644 --- a/mkdocs.yml +++ b/mkdocs.yml @@ -17,8 +17,8 @@ theme: favicon: img/traefik.icon.png logo: img/traefik.logo.png palette: - primary: 'blue' - accent: 'light blue' + primary: 'cyan' + accent: 'cyan' feature: tabs: false i18n: