diff --git a/middlewares/metrics.go b/middlewares/metrics.go index 08c4a6350..f09f74512 100644 --- a/middlewares/metrics.go +++ b/middlewares/metrics.go @@ -1,22 +1,29 @@ package middlewares import ( - "github.com/go-kit/kit/metrics" "net/http" "strconv" "time" + + "github.com/go-kit/kit/metrics" ) // Metrics is an Interface that must be satisfied by any system that -// wants to expose and monitor metrics +// wants to expose and monitor Metrics. type Metrics interface { getReqsCounter() metrics.Counter - getLatencyHistogram() metrics.Histogram - handler() http.Handler + getReqDurationHistogram() metrics.Histogram + RetryMetrics +} + +// RetryMetrics must be satisfied by any system that wants to collect and +// expose retry specific Metrics. +type RetryMetrics interface { + getRetryCounter() metrics.Counter } // MetricsWrapper is a Negroni compatible Handler which relies on a -// given Metrics implementation to expose and monitor Traefik metrics +// given Metrics implementation to expose and monitor Traefik Metrics. type MetricsWrapper struct { Impl Metrics } @@ -35,17 +42,25 @@ func (m *MetricsWrapper) ServeHTTP(rw http.ResponseWriter, r *http.Request, next start := time.Now() prw := &responseRecorder{rw, http.StatusOK} next(prw, r) - labels := []string{"code", strconv.Itoa(prw.StatusCode()), "method", r.Method} + labels := []string{"code", strconv.Itoa(prw.statusCode), "method", r.Method} m.Impl.getReqsCounter().With(labels...).Add(1) - m.Impl.getLatencyHistogram().Observe(float64(time.Since(start).Seconds())) + m.Impl.getReqDurationHistogram().Observe(float64(time.Since(start).Seconds())) } -func (rw *responseRecorder) StatusCode() int { - return rw.statusCode +// MetricsRetryListener is an implementation of the RetryListener interface to +// record Metrics about retry attempts. +type MetricsRetryListener struct { + retryMetrics RetryMetrics } -// Handler is the chance for the Metrics implementation -// to expose its metrics on a server endpoint -func (m *MetricsWrapper) Handler() http.Handler { - return m.Impl.handler() +// Retried tracks the retry in the Metrics implementation. +func (m *MetricsRetryListener) Retried(attempt int) { + if m.retryMetrics != nil { + m.retryMetrics.getRetryCounter().Add(1) + } +} + +// NewMetricsRetryListener instantiates a MetricsRetryListener with the given RetryMetrics. +func NewMetricsRetryListener(retryMetrics RetryMetrics) RetryListener { + return &MetricsRetryListener{retryMetrics: retryMetrics} } diff --git a/middlewares/metrics_test.go b/middlewares/metrics_test.go new file mode 100644 index 000000000..5e77efaba --- /dev/null +++ b/middlewares/metrics_test.go @@ -0,0 +1,48 @@ +package middlewares + +import ( + "testing" + + "github.com/go-kit/kit/metrics" +) + +func TestMetricsRetryListener(t *testing.T) { + // nil implementation, nothing should fail + retryListener := NewMetricsRetryListener(nil) + retryListener.Retried(1) + + retryMetrics := newCollectingMetrics() + retryListener = NewMetricsRetryListener(retryMetrics) + retryListener.Retried(1) + retryListener.Retried(2) + + wantCounterValue := float64(2) + if retryMetrics.retryCounter.counterValue != wantCounterValue { + t.Errorf("got counter value of %d, want %d", retryMetrics.retryCounter.counterValue, wantCounterValue) + } +} + +// collectingRetryMetrics is an implementation of the RetryMetrics interface that can be used inside tests to collect the times Add() was called. +type collectingRetryMetrics struct { + retryCounter *collectingCounter +} + +func newCollectingMetrics() collectingRetryMetrics { + return collectingRetryMetrics{retryCounter: &collectingCounter{}} +} + +func (metrics collectingRetryMetrics) getRetryCounter() metrics.Counter { + return metrics.retryCounter +} + +type collectingCounter struct { + counterValue float64 +} + +func (c *collectingCounter) With(labelValues ...string) metrics.Counter { + panic("collectingCounter.With not implemented!") +} + +func (c *collectingCounter) Add(delta float64) { + c.counterValue += delta +} diff --git a/middlewares/prometheus.go b/middlewares/prometheus.go index f41a51ffb..f2192fbf8 100644 --- a/middlewares/prometheus.go +++ b/middlewares/prometheus.go @@ -1,57 +1,64 @@ package middlewares import ( + "fmt" + "github.com/containous/traefik/types" "github.com/go-kit/kit/metrics" "github.com/go-kit/kit/metrics/prometheus" stdprometheus "github.com/prometheus/client_golang/prometheus" - "github.com/prometheus/client_golang/prometheus/promhttp" - "net/http" ) const ( - reqsName = "traefik_requests_total" - latencyName = "traefik_request_duration_seconds" + reqsTotalName = "traefik_requests_total" + reqDurationName = "traefik_request_duration_seconds" + retriesTotalName = "traefik_backend_retries_total" ) -// Prometheus is an Implementation for Metrics that exposes prometheus metrics for the latency -// and the number of requests partitioned by status code and method. +// Prometheus is an Implementation for Metrics that exposes the following Prometheus metrics: +// - number of requests partitioned by status code and method +// - request durations +// - amount of retries happened type Prometheus struct { - reqsCounter metrics.Counter - latencyHistogram metrics.Histogram + reqsCounter metrics.Counter + reqDurationHistogram metrics.Histogram + retryCounter metrics.Counter } func (p *Prometheus) getReqsCounter() metrics.Counter { return p.reqsCounter } -func (p *Prometheus) getLatencyHistogram() metrics.Histogram { - return p.latencyHistogram +func (p *Prometheus) getReqDurationHistogram() metrics.Histogram { + return p.reqDurationHistogram } -// NewPrometheus returns a new prometheus Metrics implementation. -func NewPrometheus(name string, config *types.Prometheus) *Prometheus { - var m Prometheus +func (p *Prometheus) getRetryCounter() metrics.Counter { + return p.retryCounter +} + +// NewPrometheus returns a new Prometheus Metrics implementation. +// With the returned collectors you have the possibility to clean up the internal Prometheus state by unsubscribing the collectors. +// This is for example useful while testing the Prometheus implementation. +// If any of the Prometheus Metrics can not be registered an error will be returned and the returned Metrics implementation will be nil. +func NewPrometheus(name string, config *types.Prometheus) (*Prometheus, []stdprometheus.Collector, error) { + var prom Prometheus + var collectors []stdprometheus.Collector cv := stdprometheus.NewCounterVec( stdprometheus.CounterOpts{ - Name: reqsName, + Name: reqsTotalName, Help: "How many HTTP requests processed, partitioned by status code and method.", ConstLabels: stdprometheus.Labels{"service": name}, }, []string{"code", "method"}, ) - - err := stdprometheus.Register(cv) + cv, err := registerCounterVec(cv) if err != nil { - e, ok := err.(stdprometheus.AlreadyRegisteredError) - if !ok { - panic(err) - } - m.reqsCounter = prometheus.NewCounter(e.ExistingCollector.(*stdprometheus.CounterVec)) - } else { - m.reqsCounter = prometheus.NewCounter(cv) + return nil, collectors, err } + prom.reqsCounter = prometheus.NewCounter(cv) + collectors = append(collectors, cv) var buckets []float64 if config.Buckets != nil { @@ -59,31 +66,64 @@ func NewPrometheus(name string, config *types.Prometheus) *Prometheus { } else { buckets = []float64{0.1, 0.3, 1.2, 5} } - hv := stdprometheus.NewHistogramVec( stdprometheus.HistogramOpts{ - Name: latencyName, + Name: reqDurationName, Help: "How long it took to process the request.", ConstLabels: stdprometheus.Labels{"service": name}, Buckets: buckets, }, []string{}, ) + hv, err = registerHistogramVec(hv) + if err != nil { + return nil, collectors, err + } + prom.reqDurationHistogram = prometheus.NewHistogram(hv) + collectors = append(collectors, hv) + + cv = stdprometheus.NewCounterVec( + stdprometheus.CounterOpts{ + Name: retriesTotalName, + Help: "How many request retries happened in total.", + ConstLabels: stdprometheus.Labels{"service": name}, + }, + []string{}, + ) + cv, err = registerCounterVec(cv) + if err != nil { + return nil, collectors, err + } + prom.retryCounter = prometheus.NewCounter(cv) + collectors = append(collectors, cv) + + return &prom, collectors, nil +} + +func registerCounterVec(cv *stdprometheus.CounterVec) (*stdprometheus.CounterVec, error) { + err := stdprometheus.Register(cv) - err = stdprometheus.Register(hv) if err != nil { e, ok := err.(stdprometheus.AlreadyRegisteredError) if !ok { - panic(err) + return nil, fmt.Errorf("error registering CounterVec: %s", e) } - m.latencyHistogram = prometheus.NewHistogram(e.ExistingCollector.(*stdprometheus.HistogramVec)) - } else { - m.latencyHistogram = prometheus.NewHistogram(hv) + cv = e.ExistingCollector.(*stdprometheus.CounterVec) } - return &m + return cv, nil } -func (p *Prometheus) handler() http.Handler { - return promhttp.Handler() +func registerHistogramVec(hv *stdprometheus.HistogramVec) (*stdprometheus.HistogramVec, error) { + err := stdprometheus.Register(hv) + + if err != nil { + e, ok := err.(stdprometheus.AlreadyRegisteredError) + if !ok { + return nil, fmt.Errorf("error registering HistogramVec: %s", e) + } + hv = e.ExistingCollector.(*stdprometheus.HistogramVec) + } + + return hv, nil } diff --git a/middlewares/prometheus_test.go b/middlewares/prometheus_test.go index 08931ac8b..5b011c64b 100644 --- a/middlewares/prometheus_test.go +++ b/middlewares/prometheus_test.go @@ -2,6 +2,8 @@ package middlewares import ( "fmt" + "io" + "io/ioutil" "net/http" "net/http/httptest" "strings" @@ -16,55 +18,37 @@ import ( ) func TestPrometheus(t *testing.T) { - metricsFamily, err := prometheus.DefaultGatherer.Gather() + defer resetPrometheusValues() + + metricsFamilies, err := prometheus.DefaultGatherer.Gather() if err != nil { t.Fatalf("could not gather metrics family: %s", err) } - initialMetricsFamilyCount := len(metricsFamily) + initialMetricsFamilyCount := len(metricsFamilies) recorder := httptest.NewRecorder() - n := negroni.New() - metricsMiddlewareBackend := NewMetricsWrapper(NewPrometheus("test", &types.Prometheus{})) - n.Use(metricsMiddlewareBackend) - r := http.NewServeMux() - r.Handle("/metrics", promhttp.Handler()) - r.HandleFunc(`/ok`, func(w http.ResponseWriter, r *http.Request) { - w.WriteHeader(http.StatusOK) - fmt.Fprintln(w, "ok") - }) - n.UseHandler(r) + req1 := mustNewRequest("GET", "http://localhost:3000/ok", ioutil.NopCloser(nil)) + req2 := mustNewRequest("GET", "http://localhost:3000/metrics", ioutil.NopCloser(nil)) - req1, err := http.NewRequest("GET", "http://localhost:3000/ok", nil) - if err != nil { - t.Error(err) - } - req2, err := http.NewRequest("GET", "http://localhost:3000/metrics", nil) - if err != nil { - t.Error(err) - } + httpHandler := setupTestHTTPHandler() + httpHandler.ServeHTTP(recorder, req1) + httpHandler.ServeHTTP(recorder, req2) - n.ServeHTTP(recorder, req1) - n.ServeHTTP(recorder, req2) body := recorder.Body.String() - if !strings.Contains(body, reqsName) { - t.Errorf("body does not contain request total entry '%s'", reqsName) + if !strings.Contains(body, reqsTotalName) { + t.Errorf("body does not contain request total entry '%s'", reqsTotalName) } - if !strings.Contains(body, latencyName) { - t.Errorf("body does not contain request duration entry '%s'", latencyName) + if !strings.Contains(body, reqDurationName) { + t.Errorf("body does not contain request duration entry '%s'", reqDurationName) + } + if !strings.Contains(body, retriesTotalName) { + t.Errorf("body does not contain total retries entry '%s'", retriesTotalName) } - // Register the same metrics again - metricsMiddlewareBackend = NewMetricsWrapper(NewPrometheus("test", &types.Prometheus{})) - n = negroni.New() - n.Use(metricsMiddlewareBackend) - n.UseHandler(r) - - n.ServeHTTP(recorder, req2) - - metricsFamily, err = prometheus.DefaultGatherer.Gather() + metricsFamilies, err = prometheus.DefaultGatherer.Gather() if err != nil { - t.Fatalf("could not gather metrics family: %s", err) + t.Fatalf("could not gather metrics families: %s", err) } tests := []struct { @@ -73,7 +57,7 @@ func TestPrometheus(t *testing.T) { assert func(*dto.MetricFamily) }{ { - name: reqsName, + name: reqsTotalName, labels: map[string]string{ "code": "200", "method": "GET", @@ -81,29 +65,44 @@ func TestPrometheus(t *testing.T) { }, assert: func(family *dto.MetricFamily) { cv := uint(family.Metric[0].Counter.GetValue()) - if cv != 3 { - t.Errorf("gathered metrics do not contain correct value for total requests, got %d", cv) + expectedCv := uint(2) + if cv != expectedCv { + t.Errorf("gathered metrics do not contain correct value for total requests, got %d expected %d", cv, expectedCv) } }, }, { - name: latencyName, + name: reqDurationName, labels: map[string]string{ "service": "test", }, assert: func(family *dto.MetricFamily) { sc := family.Metric[0].Histogram.GetSampleCount() - if sc != 3 { - t.Errorf("gathered metrics do not contain correct sample count for request duration, got %d", sc) + expectedSc := uint64(2) + if sc != expectedSc { + t.Errorf("gathered metrics do not contain correct sample count for request duration, got %d expected %d", sc, expectedSc) + } + }, + }, + { + name: retriesTotalName, + labels: map[string]string{ + "service": "test", + }, + assert: func(family *dto.MetricFamily) { + cv := uint(family.Metric[0].Counter.GetValue()) + expectedCv := uint(1) + if cv != expectedCv { + t.Errorf("gathered metrics do not contain correct value for total retries, got '%d' expected '%d'", cv, expectedCv) } }, }, } - assert.Equal(t, len(tests), len(metricsFamily)-initialMetricsFamilyCount, "gathered traefic metrics count does not match tests count") + assert.Equal(t, len(tests), len(metricsFamilies)-initialMetricsFamilyCount, "gathered traefic metrics count does not match tests count") for _, test := range tests { - family := findMetricFamily(test.name, metricsFamily) + family := findMetricFamily(test.name, metricsFamilies) if family == nil { t.Errorf("gathered metrics do not contain '%s'", test.name) continue @@ -120,6 +119,74 @@ func TestPrometheus(t *testing.T) { } } +func TestPrometheusRegisterMetricsMultipleTimes(t *testing.T) { + defer resetPrometheusValues() + + recorder := httptest.NewRecorder() + req1 := mustNewRequest("GET", "http://localhost:3000/ok", ioutil.NopCloser(nil)) + + httpHandler := setupTestHTTPHandler() + httpHandler.ServeHTTP(recorder, req1) + + httpHandler = setupTestHTTPHandler() + httpHandler.ServeHTTP(recorder, req1) + + metricsFamilies, err := prometheus.DefaultGatherer.Gather() + if err != nil { + t.Fatalf("could not gather metrics families: %s", err) + } + + reqsTotalFamily := findMetricFamily(reqsTotalName, metricsFamilies) + if reqsTotalFamily == nil { + t.Fatalf("gathered metrics do not contain '%s'", reqsTotalName) + } + + cv := uint(reqsTotalFamily.Metric[0].Counter.GetValue()) + expectedCv := uint(2) + if cv != expectedCv { + t.Errorf("wrong counter value when registering metrics multiple times, got '%d' expected '%d'", cv, expectedCv) + } +} + +func setupTestHTTPHandler() http.Handler { + serveMux := http.NewServeMux() + serveMux.Handle("/metrics", promhttp.Handler()) + serveMux.Handle("/ok", &networkFailingHTTPHandler{failAtCalls: []int{1}}) + + metrics, _ := newPrometheusMetrics() + + n := negroni.New() + n.Use(NewMetricsWrapper(metrics)) + n.UseHandler(NewRetry(2, serveMux, NewMetricsRetryListener(metrics))) + + return n +} + +// mustNewRequest is like http.NewRequest but panics if an error occurs. +func mustNewRequest(method, urlStr string, body io.Reader) *http.Request { + req, err := http.NewRequest(method, urlStr, body) + if err != nil { + panic(fmt.Sprintf("NewRequest(%s, %s, %+v): %s", method, urlStr, body, err)) + } + return req +} + +func resetPrometheusValues() { + _, collectors := newPrometheusMetrics() + + for _, collector := range collectors { + prometheus.Unregister(collector) + } +} + +func newPrometheusMetrics() (*Prometheus, []prometheus.Collector) { + prom, collectors, err := NewPrometheus("test", &types.Prometheus{}) + if err != nil { + panic(fmt.Sprintf("Error creating Prometheus Metrics: %s", err)) + } + return prom, collectors +} + func findMetricFamily(name string, families []*dto.MetricFamily) *dto.MetricFamily { for _, family := range families { if family.GetName() == name { diff --git a/middlewares/retry.go b/middlewares/retry.go index 031490045..41c8a225f 100644 --- a/middlewares/retry.go +++ b/middlewares/retry.go @@ -11,21 +11,24 @@ import ( "github.com/vulcand/oxy/utils" ) +// Compile time validation responseRecorder implements http interfaces correctly. var ( - _ Stateful = &ResponseRecorder{} + _ Stateful = &retryResponseRecorder{} ) // Retry is a middleware that retries requests type Retry struct { attempts int next http.Handler + listener RetryListener } // NewRetry returns a new Retry instance -func NewRetry(attempts int, next http.Handler) *Retry { +func NewRetry(attempts int, next http.Handler, listener RetryListener) *Retry { return &Retry{ attempts: attempts, next: next, + listener: listener, } } @@ -39,7 +42,7 @@ func (retry *Retry) ServeHTTP(rw http.ResponseWriter, r *http.Request) { } attempts := 1 for { - recorder := NewRecorder() + recorder := newRetryResponseRecorder() recorder.responseWriter = rw retry.next.ServeHTTP(recorder, r) if !isNetworkError(recorder.Code) || attempts >= retry.attempts { @@ -50,6 +53,7 @@ func (retry *Retry) ServeHTTP(rw http.ResponseWriter, r *http.Request) { } attempts++ log.Debugf("New attempt %d for request: %v", attempts, r.URL) + retry.listener.Retried(attempts) } } @@ -57,9 +61,16 @@ func isNetworkError(status int) bool { return status == http.StatusBadGateway || status == http.StatusGatewayTimeout } -// ResponseRecorder is an implementation of http.ResponseWriter that -// records its mutations for later inspection in tests. -type ResponseRecorder struct { +// RetryListener is used to inform about retry attempts. +type RetryListener interface { + // Retried will be called when a retry happens, with the request attempt passed to it. + // For the first retry this will be attempt 2. + Retried(attempt int) +} + +// retryResponseRecorder is an implementation of http.ResponseWriter that +// records its mutations for later inspection. +type retryResponseRecorder struct { Code int // the HTTP response code from WriteHeader HeaderMap http.Header // the HTTP response headers Body *bytes.Buffer // if non-nil, the bytes.Buffer to append written data to @@ -68,9 +79,9 @@ type ResponseRecorder struct { err error } -// NewRecorder returns an initialized ResponseRecorder. -func NewRecorder() *ResponseRecorder { - return &ResponseRecorder{ +// newRetryResponseRecorder returns an initialized retryResponseRecorder. +func newRetryResponseRecorder() *retryResponseRecorder { + return &retryResponseRecorder{ HeaderMap: make(http.Header), Body: new(bytes.Buffer), Code: 200, @@ -78,7 +89,7 @@ func NewRecorder() *ResponseRecorder { } // Header returns the response headers. -func (rw *ResponseRecorder) Header() http.Header { +func (rw *retryResponseRecorder) Header() http.Header { m := rw.HeaderMap if m == nil { m = make(http.Header) @@ -88,7 +99,7 @@ func (rw *ResponseRecorder) Header() http.Header { } // Write always succeeds and writes to rw.Body, if not nil. -func (rw *ResponseRecorder) Write(buf []byte) (int, error) { +func (rw *retryResponseRecorder) Write(buf []byte) (int, error) { if rw.err != nil { return 0, rw.err } @@ -96,27 +107,27 @@ func (rw *ResponseRecorder) Write(buf []byte) (int, error) { } // WriteHeader sets rw.Code. -func (rw *ResponseRecorder) WriteHeader(code int) { +func (rw *retryResponseRecorder) WriteHeader(code int) { rw.Code = code } // Hijack hijacks the connection -func (rw *ResponseRecorder) Hijack() (net.Conn, *bufio.ReadWriter, error) { +func (rw *retryResponseRecorder) Hijack() (net.Conn, *bufio.ReadWriter, error) { return rw.responseWriter.(http.Hijacker).Hijack() } // CloseNotify returns a channel that receives at most a // single value (true) when the client connection has gone // away. -func (rw *ResponseRecorder) CloseNotify() <-chan bool { +func (rw *retryResponseRecorder) CloseNotify() <-chan bool { return rw.responseWriter.(http.CloseNotifier).CloseNotify() } // Flush sends any buffered data to the client. -func (rw *ResponseRecorder) Flush() { +func (rw *retryResponseRecorder) Flush() { _, err := rw.responseWriter.Write(rw.Body.Bytes()) if err != nil { - log.Errorf("Error writing response in ResponseRecorder: %s", err) + log.Errorf("Error writing response in retryResponseRecorder: %s", err) rw.err = err } rw.Body.Reset() diff --git a/middlewares/retry_test.go b/middlewares/retry_test.go new file mode 100644 index 000000000..cd161dc44 --- /dev/null +++ b/middlewares/retry_test.go @@ -0,0 +1,91 @@ +package middlewares + +import ( + "fmt" + "io/ioutil" + "net/http" + "net/http/httptest" + "testing" +) + +func TestRetry(t *testing.T) { + testCases := []struct { + failAtCalls []int + attempts int + responseStatus int + listener *countingRetryListener + retriedCount int + }{ + { + failAtCalls: []int{1, 2}, + attempts: 3, + responseStatus: http.StatusOK, + listener: &countingRetryListener{}, + retriedCount: 2, + }, + { + failAtCalls: []int{1, 2}, + attempts: 2, + responseStatus: http.StatusBadGateway, + listener: &countingRetryListener{}, + retriedCount: 1, + }, + } + + for _, tc := range testCases { + // bind tc locally + tc := tc + tcName := fmt.Sprintf("FailAtCalls(%v) RetryAttempts(%v)", tc.failAtCalls, tc.attempts) + + t.Run(tcName, func(t *testing.T) { + t.Parallel() + + var httpHandler http.Handler + httpHandler = &networkFailingHTTPHandler{failAtCalls: tc.failAtCalls} + httpHandler = NewRetry(tc.attempts, httpHandler, tc.listener) + + recorder := httptest.NewRecorder() + req, err := http.NewRequest("GET", "http://localhost:3000/ok", ioutil.NopCloser(nil)) + if err != nil { + t.Fatalf("could not create request: %+v", err) + } + + httpHandler.ServeHTTP(recorder, req) + + if tc.responseStatus != recorder.Code { + t.Errorf("wrong status code %d, want %d", recorder.Code, tc.responseStatus) + } + if tc.retriedCount != tc.listener.timesCalled { + t.Errorf("RetryListener called %d times, want %d times", tc.listener.timesCalled, tc.retriedCount) + } + }) + } +} + +// networkFailingHTTPHandler is an http.Handler implementation you can use to test retries. +type networkFailingHTTPHandler struct { + failAtCalls []int + callNumber int +} + +func (handler *networkFailingHTTPHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) { + handler.callNumber++ + + for _, failAtCall := range handler.failAtCalls { + if handler.callNumber == failAtCall { + w.WriteHeader(http.StatusBadGateway) + return + } + } + + w.WriteHeader(http.StatusOK) +} + +// countingRetryListener is a RetryListener implementation to count the times the Retried fn is called. +type countingRetryListener struct { + timesCalled int +} + +func (l *countingRetryListener) Retried(attempt int) { + l.timesCalled++ +} diff --git a/server/server.go b/server/server.go index 1a95ebab7..982aaa313 100644 --- a/server/server.go +++ b/server/server.go @@ -192,11 +192,9 @@ func (server *Server) startHTTPServers() { if server.accessLoggerMiddleware != nil { serverMiddlewares = append(serverMiddlewares, server.accessLoggerMiddleware) } - if server.globalConfiguration.Web != nil && server.globalConfiguration.Web.Metrics != nil { - if server.globalConfiguration.Web.Metrics.Prometheus != nil { - metricsMiddleware := middlewares.NewMetricsWrapper(middlewares.NewPrometheus(newServerEntryPointName, server.globalConfiguration.Web.Metrics.Prometheus)) - serverMiddlewares = append(serverMiddlewares, metricsMiddleware) - } + metrics := newMetrics(server.globalConfiguration, newServerEntryPointName) + if metrics != nil { + serverMiddlewares = append(serverMiddlewares, middlewares.NewMetricsWrapper(metrics)) } if server.globalConfiguration.Web != nil && server.globalConfiguration.Web.Statistics != nil { statsRecorder = middlewares.NewStatsRecorder(server.globalConfiguration.Web.Statistics.RecentErrors) @@ -726,21 +724,15 @@ func (server *Server) loadConfig(configurations configs, globalConfiguration Glo continue frontend } } - // retry ? - if globalConfiguration.Retry != nil { - retries := len(configuration.Backends[frontend.Backend].Servers) - if globalConfiguration.Retry.Attempts > 0 { - retries = globalConfiguration.Retry.Attempts - } - lb = middlewares.NewRetry(retries, lb) - log.Debugf("Creating retries max attempts %d", retries) - } - if server.globalConfiguration.Web != nil && server.globalConfiguration.Web.Metrics != nil { - if server.globalConfiguration.Web.Metrics.Prometheus != nil { - metricsMiddlewareBackend := middlewares.NewMetricsWrapper(middlewares.NewPrometheus(frontend.Backend, server.globalConfiguration.Web.Metrics.Prometheus)) - negroni.Use(metricsMiddlewareBackend) - } + metrics := newMetrics(server.globalConfiguration, frontend.Backend) + + if globalConfiguration.Retry != nil { + retryListener := middlewares.NewMetricsRetryListener(metrics) + lb = registerRetryMiddleware(lb, globalConfiguration, configuration, frontend.Backend, retryListener) + } + if metrics != nil { + negroni.Use(middlewares.NewMetricsWrapper(metrics)) } ipWhitelistMiddleware, err := configureIPWhitelistMiddleware(frontend.WhitelistSourceRange) @@ -961,3 +953,37 @@ func (*Server) configureBackends(backends map[string]*types.Backend) { } } } + +// newMetrics instantiates the proper Metrics implementation, depending on the global configuration. +// Note that given there is no metrics instrumentation configured, it will return nil. +func newMetrics(globalConfig GlobalConfiguration, name string) middlewares.Metrics { + metricsEnabled := globalConfig.Web != nil && globalConfig.Web.Metrics != nil + if metricsEnabled && globalConfig.Web.Metrics.Prometheus != nil { + metrics, _, err := middlewares.NewPrometheus(name, globalConfig.Web.Metrics.Prometheus) + if err != nil { + log.Errorf("Error creating Prometheus Metrics implementation: %s", err) + return nil + } + return metrics + } + + return nil +} + +func registerRetryMiddleware( + httpHandler http.Handler, + globalConfig GlobalConfiguration, + config *types.Configuration, + backend string, + listener middlewares.RetryListener, +) http.Handler { + retries := len(config.Backends[backend].Servers) + if globalConfig.Retry.Attempts > 0 { + retries = globalConfig.Retry.Attempts + } + + httpHandler = middlewares.NewRetry(retries, httpHandler, listener) + log.Debugf("Creating retries max attempts %d", retries) + + return httpHandler +} diff --git a/server/server_test.go b/server/server_test.go index 63560514a..93c37d01b 100644 --- a/server/server_test.go +++ b/server/server_test.go @@ -11,6 +11,7 @@ import ( "github.com/containous/flaeg" "github.com/containous/mux" "github.com/containous/traefik/healthcheck" + "github.com/containous/traefik/middlewares" "github.com/containous/traefik/testhelpers" "github.com/containous/traefik/types" "github.com/davecgh/go-spew/spew" @@ -409,3 +410,109 @@ func TestConfigureBackends(t *testing.T) { }) } } + +func TestNewMetrics(t *testing.T) { + testCases := []struct { + desc string + globalConfig GlobalConfiguration + }{ + { + desc: "metrics disabled", + globalConfig: GlobalConfiguration{}, + }, + { + desc: "prometheus metrics", + globalConfig: GlobalConfiguration{ + Web: &WebProvider{ + Metrics: &types.Metrics{ + Prometheus: &types.Prometheus{ + Buckets: types.Buckets{0.1, 0.3, 1.2, 5.0}, + }, + }, + }, + }, + }, + } + + for _, tc := range testCases { + tc := tc + t.Run(tc.desc, func(t *testing.T) { + t.Parallel() + + metricsImpl := newMetrics(tc.globalConfig, "test1") + if metricsImpl != nil { + if _, ok := metricsImpl.(*middlewares.Prometheus); !ok { + t.Errorf("invalid metricsImpl type, got %T want %T", metricsImpl, &middlewares.Prometheus{}) + } + } + }) + } +} + +func TestRegisterRetryMiddleware(t *testing.T) { + testCases := []struct { + name string + globalConfig GlobalConfiguration + countServers int + expectedRetries int + }{ + { + name: "configured retry attempts", + globalConfig: GlobalConfiguration{ + Retry: &Retry{ + Attempts: 3, + }, + }, + expectedRetries: 3, + }, + { + name: "retry attempts defaults to server amount", + globalConfig: GlobalConfiguration{ + Retry: &Retry{}, + }, + expectedRetries: 2, + }, + } + + for _, tc := range testCases { + tc := tc + t.Run(tc.name, func(t *testing.T) { + t.Parallel() + + var retryListener middlewares.RetryListener + httpHandler := okHTTPHandler{} + dynamicConfig := &types.Configuration{ + Backends: map[string]*types.Backend{ + "backend": { + Servers: map[string]types.Server{ + "server": { + URL: "http://localhost", + }, + "server2": { + URL: "http://localhost", + }, + }, + }, + }, + } + + httpHandlerWithRetry := registerRetryMiddleware(httpHandler, tc.globalConfig, dynamicConfig, "backend", retryListener) + + retry, ok := httpHandlerWithRetry.(*middlewares.Retry) + if !ok { + t.Fatalf("httpHandler was not decorated with retry httpHandler, got %#v", httpHandlerWithRetry) + } + + expectedRetry := middlewares.NewRetry(tc.expectedRetries, httpHandler, retryListener) + if !reflect.DeepEqual(retry, expectedRetry) { + t.Errorf("retry httpHandler was not instantiated correctly, got %#v want %#v", retry, expectedRetry) + } + }) + } +} + +type okHTTPHandler struct{} + +func (okHTTPHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(http.StatusOK) +}