From bee8ebb00b4bdaea1795087d0dce33759c8453ec Mon Sep 17 00:00:00 2001 From: SALLEYRON Julien Date: Wed, 22 Nov 2017 18:20:03 +0100 Subject: [PATCH] Resync oxy with original repository --- glide.lock | 4 +- glide.yaml | 2 +- integration/grpc_test.go | 4 +- integration/websocket_test.go | 2 +- server/adapters.go | 21 - server/server.go | 17 +- .../vulcand/oxy/cbreaker/cbreaker.go | 27 +- .../github.com/vulcand/oxy/cbreaker/effect.go | 4 +- .../vulcand/oxy/cbreaker/fallback.go | 38 +- .../vulcand/oxy/cbreaker/predicates.go | 3 +- .../github.com/vulcand/oxy/cbreaker/ratio.go | 4 + .../vulcand/oxy/connlimit/connlimit.go | 27 +- vendor/github.com/vulcand/oxy/forward/fwd.go | 394 ++++++++++-------- .../github.com/vulcand/oxy/forward/headers.go | 1 - .../vulcand/oxy/forward/responseflusher.go | 53 --- .../github.com/vulcand/oxy/forward/rewrite.go | 35 +- .../vulcand/oxy/memmetrics/anomaly.go | 2 +- .../vulcand/oxy/memmetrics/counter.go | 4 +- .../vulcand/oxy/memmetrics/histogram.go | 33 +- .../vulcand/oxy/memmetrics/roundtrip.go | 61 ++- .../vulcand/oxy/ratelimit/tokenlimiter.go | 17 +- .../oxy/roundrobin/RequestRewriteListener.go | 5 + .../vulcand/oxy/roundrobin/rebalancer.go | 79 ++-- .../github.com/vulcand/oxy/roundrobin/rr.go | 64 ++- .../vulcand/oxy/roundrobin/stickysessions.go | 23 +- .../github.com/vulcand/oxy/stream/stream.go | 323 ++------------ .../vulcand/oxy/stream/threshold.go | 2 - vendor/github.com/vulcand/oxy/utils/auth.go | 12 +- .../github.com/vulcand/oxy/utils/dumpreq.go | 60 +++ .../github.com/vulcand/oxy/utils/logging.go | 86 ---- .../github.com/vulcand/oxy/utils/netutils.go | 51 ++- 31 files changed, 650 insertions(+), 808 deletions(-) delete mode 100644 vendor/github.com/vulcand/oxy/forward/responseflusher.go create mode 100644 vendor/github.com/vulcand/oxy/roundrobin/RequestRewriteListener.go create mode 100644 vendor/github.com/vulcand/oxy/utils/dumpreq.go delete mode 100644 vendor/github.com/vulcand/oxy/utils/logging.go diff --git a/glide.lock b/glide.lock index b1c8c395c..72640efa1 100644 --- a/glide.lock +++ b/glide.lock @@ -1,4 +1,4 @@ -hash: 2a604d8b74e8659df7db72d063432fa0822fdee6c81bc6657efa3c1bf0d9cd8a +hash: fec4fec4363272870c49e10cea64cc51095ecd0987b9c020c9714d950cf38784 updated: 2017-11-17T14:21:55.148450413+01:00 imports: - name: cloud.google.com/go @@ -516,7 +516,7 @@ imports: - name: github.com/VividCortex/gohistogram version: 51564d9861991fb0ad0f531c99ef602d0f9866e6 - name: github.com/vulcand/oxy - version: 7e9763c4dc71b9758379da3581e6495c145caaab + version: 7b6e758ab449705195df638765c4ca472248908a repo: https://github.com/containous/oxy.git vcs: git subpackages: diff --git a/glide.yaml b/glide.yaml index f0ecc335f..7e6d3364f 100644 --- a/glide.yaml +++ b/glide.yaml @@ -12,7 +12,7 @@ import: - package: github.com/cenk/backoff - package: github.com/containous/flaeg - package: github.com/vulcand/oxy - version: 7e9763c4dc71b9758379da3581e6495c145caaab + version: 7b6e758ab449705195df638765c4ca472248908a repo: https://github.com/containous/oxy.git vcs: git subpackages: diff --git a/integration/grpc_test.go b/integration/grpc_test.go index 6b9aba63e..1deb611f1 100644 --- a/integration/grpc_test.go +++ b/integration/grpc_test.go @@ -224,10 +224,12 @@ func (s *GRPCSuite) TestGRPCBuffer(c *check.C) { var client helloworld.Greeter_StreamExampleClient client, closer, err := callStreamExampleClientGRPC() defer closer() + c.Assert(err, check.IsNil) received := make(chan bool) go func() { - tr, _ := client.Recv() + tr, err := client.Recv() + c.Assert(err, check.IsNil) c.Assert(len(tr.Data), check.Equals, 512) received <- true }() diff --git a/integration/websocket_test.go b/integration/websocket_test.go index ad58c1ead..6cc9c8f47 100644 --- a/integration/websocket_test.go +++ b/integration/websocket_test.go @@ -395,7 +395,7 @@ func (s *WebsocketSuite) TestURLWithURLEncodedChar(c *check.C) { var upgrader = gorillawebsocket.Upgrader{} // use default options srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - c.Assert(r.URL.Path, check.Equals, "/ws/http%3A%2F%2Ftest") + c.Assert(r.URL.EscapedPath(), check.Equals, "/ws/http%3A%2F%2Ftest") conn, err := upgrader.Upgrade(w, r, nil) if err != nil { return diff --git a/server/adapters.go b/server/adapters.go index de66c38ba..453b8187c 100644 --- a/server/adapters.go +++ b/server/adapters.go @@ -2,29 +2,8 @@ package server import ( "net/http" - - "github.com/containous/traefik/log" ) -// OxyLogger implements oxy Logger interface with logrus. -type OxyLogger struct { -} - -// Infof logs specified string as Debug level in logrus. -func (oxylogger *OxyLogger) Infof(format string, args ...interface{}) { - log.Debugf(format, args...) -} - -// Warningf logs specified string as Warning level in logrus. -func (oxylogger *OxyLogger) Warningf(format string, args ...interface{}) { - log.Warningf(format, args...) -} - -// Errorf logs specified string as Warningf level in logrus. -func (oxylogger *OxyLogger) Errorf(format string, args ...interface{}) { - log.Warningf(format, args...) -} - func notFoundHandler(w http.ResponseWriter, r *http.Request) { http.NotFound(w, r) } diff --git a/server/server.go b/server/server.go index 0839e14b7..5678ac5a5 100644 --- a/server/server.go +++ b/server/server.go @@ -39,7 +39,6 @@ import ( "github.com/eapache/channels" thoas_stats "github.com/thoas/stats" "github.com/urfave/negroni" - "github.com/vulcand/oxy/cbreaker" "github.com/vulcand/oxy/connlimit" "github.com/vulcand/oxy/forward" "github.com/vulcand/oxy/ratelimit" @@ -48,10 +47,6 @@ import ( "golang.org/x/net/http2" ) -var ( - oxyLogger = &OxyLogger{} -) - // Server is the reverse-proxy/load-balancer engine type Server struct { serverEntryPoints serverEntryPoints @@ -959,7 +954,7 @@ func (server *Server) loadConfig(configurations types.Configurations, globalConf } fwd, err := forward.New( - forward.Logger(oxyLogger), + forward.Stream(true), forward.PassHostHeader(frontend.PassHostHeader), forward.RoundTripper(roundTripper), forward.ErrorHandler(errorHandler), @@ -1006,10 +1001,10 @@ func (server *Server) loadConfig(configurations types.Configurations, globalConf switch lbMethod { case types.Drr: log.Debugf("Creating load-balancer drr") - rebalancer, _ := roundrobin.NewRebalancer(rr, roundrobin.RebalancerLogger(oxyLogger)) + rebalancer, _ := roundrobin.NewRebalancer(rr) if sticky != nil { log.Debugf("Sticky session with cookie %v", cookieName) - rebalancer, _ = roundrobin.NewRebalancer(rr, roundrobin.RebalancerLogger(oxyLogger), roundrobin.RebalancerStickySession(sticky)) + rebalancer, _ = roundrobin.NewRebalancer(rr, roundrobin.RebalancerStickySession(sticky)) } lb = rebalancer if err := configureLBServers(rebalancer, config, frontend); err != nil { @@ -1080,7 +1075,7 @@ func (server *Server) loadConfig(configurations types.Configurations, globalConf continue frontend } log.Debugf("Creating load-balancer connlimit") - lb, err = connlimit.New(lb, extractFunc, maxConns.Amount, connlimit.Logger(oxyLogger)) + lb, err = connlimit.New(lb, extractFunc, maxConns.Amount) if err != nil { log.Errorf("Error creating connlimit: %v", err) log.Errorf("Skipping frontend %s...", frontendName) @@ -1151,7 +1146,7 @@ func (server *Server) loadConfig(configurations types.Configurations, globalConf if config.Backends[frontend.Backend].CircuitBreaker != nil { log.Debugf("Creating circuit breaker %s", config.Backends[frontend.Backend].CircuitBreaker.Expression) - circuitBreaker, err := middlewares.NewCircuitBreaker(lb, config.Backends[frontend.Backend].CircuitBreaker.Expression, cbreaker.Logger(oxyLogger)) + circuitBreaker, err := middlewares.NewCircuitBreaker(lb, config.Backends[frontend.Backend].CircuitBreaker.Expression) if err != nil { log.Errorf("Error creating circuit breaker: %v", err) log.Errorf("Skipping frontend %s...", frontendName) @@ -1445,7 +1440,7 @@ func (server *Server) buildRateLimiter(handler http.Handler, rlConfig *types.Rat return nil, err } } - return ratelimit.New(handler, extractFunc, rateSet, ratelimit.Logger(oxyLogger)) + return ratelimit.New(handler, extractFunc, rateSet) } func (server *Server) buildRetryMiddleware(handler http.Handler, globalConfig configuration.GlobalConfiguration, countServers int, backendName string) http.Handler { diff --git a/vendor/github.com/vulcand/oxy/cbreaker/cbreaker.go b/vendor/github.com/vulcand/oxy/cbreaker/cbreaker.go index 36ac7abbe..de588acea 100644 --- a/vendor/github.com/vulcand/oxy/cbreaker/cbreaker.go +++ b/vendor/github.com/vulcand/oxy/cbreaker/cbreaker.go @@ -1,4 +1,4 @@ -// package cbreaker implements circuit breaker similar to https://github.com/Netflix/Hystrix/wiki/How-it-Works +// Package cbreaker implements circuit breaker similar to https://github.com/Netflix/Hystrix/wiki/How-it-Works // // Vulcan circuit breaker watches the error condtion to match // after which it activates the fallback scenario, e.g. returns the response code @@ -31,6 +31,8 @@ import ( "sync" "time" + log "github.com/Sirupsen/logrus" + "github.com/mailgun/timetools" "github.com/vulcand/oxy/memmetrics" "github.com/vulcand/oxy/utils" @@ -60,7 +62,6 @@ type CircuitBreaker struct { fallback http.Handler next http.Handler - log utils.Logger clock timetools.TimeProvider } @@ -75,7 +76,6 @@ func New(next http.Handler, expression string, options ...CircuitBreakerOption) fallbackDuration: defaultFallbackDuration, recoveryDuration: defaultRecoveryDuration, fallback: defaultFallback, - log: utils.NullLogger, } for _, s := range options { @@ -100,6 +100,11 @@ func New(next http.Handler, expression string, options ...CircuitBreakerOption) } func (c *CircuitBreaker) ServeHTTP(w http.ResponseWriter, req *http.Request) { + if log.GetLevel() >= log.DebugLevel { + logEntry := log.WithField("Request", utils.DumpHttpRequest(req)) + logEntry.Debug("vulcand/oxy/circuitbreaker: begin ServeHttp on request") + defer logEntry.Debug("vulcand/oxy/circuitbreaker: competed ServeHttp on request") + } if c.activateFallback(w, req) { c.fallback.ServeHTTP(w, req) return @@ -121,7 +126,7 @@ func (c *CircuitBreaker) activateFallback(w http.ResponseWriter, req *http.Reque c.m.Lock() defer c.m.Unlock() - c.log.Infof("%v is in error state", c) + log.Infof("%v is in error state", c) switch c.state { case stateStandby: @@ -186,13 +191,13 @@ func (c *CircuitBreaker) exec(s SideEffect) { } go func() { if err := s.Exec(); err != nil { - c.log.Errorf("%v side effect failure: %v", c, err) + log.Errorf("%v side effect failure: %v", c, err) } }() } func (c *CircuitBreaker) setState(new cbState, until time.Time) { - c.log.Infof("%v setting state to %v, until %v", c, new, until) + log.Infof("%v setting state to %v, until %v", c, new, until) c.state = new c.until = until switch new { @@ -225,7 +230,7 @@ func (c *CircuitBreaker) checkAndSet() { c.lastCheck = c.clock.UtcNow().Add(c.checkPeriod) if c.state == stateTripped { - c.log.Infof("%v skip set tripped", c) + log.Infof("%v skip set tripped", c) return } @@ -309,14 +314,6 @@ func Fallback(h http.Handler) CircuitBreakerOption { } } -// Logger adds logging for the CircuitBreaker. -func Logger(l utils.Logger) CircuitBreakerOption { - return func(c *CircuitBreaker) error { - c.log = l - return nil - } -} - // cbState is the state of the circuit breaker type cbState int diff --git a/vendor/github.com/vulcand/oxy/cbreaker/effect.go b/vendor/github.com/vulcand/oxy/cbreaker/effect.go index 35d54c518..1d1dedd9d 100644 --- a/vendor/github.com/vulcand/oxy/cbreaker/effect.go +++ b/vendor/github.com/vulcand/oxy/cbreaker/effect.go @@ -9,6 +9,7 @@ import ( "net/url" "strings" + log "github.com/Sirupsen/logrus" "github.com/vulcand/oxy/utils" ) @@ -68,9 +69,10 @@ func (w *WebhookSideEffect) Exec() error { if re.Body != nil { defer re.Body.Close() } - _, err = ioutil.ReadAll(re.Body) + body, err := ioutil.ReadAll(re.Body) if err != nil { return err } + log.Infof("%v got response: (%s): %s", w, re.Status, string(body)) return nil } diff --git a/vendor/github.com/vulcand/oxy/cbreaker/fallback.go b/vendor/github.com/vulcand/oxy/cbreaker/fallback.go index 9bd6678c1..d7bd8fb0c 100644 --- a/vendor/github.com/vulcand/oxy/cbreaker/fallback.go +++ b/vendor/github.com/vulcand/oxy/cbreaker/fallback.go @@ -5,6 +5,9 @@ import ( "net/http" "net/url" "strconv" + + log "github.com/Sirupsen/logrus" + "github.com/vulcand/oxy/utils" ) type Response struct { @@ -25,20 +28,31 @@ func NewResponseFallback(r Response) (*ResponseFallback, error) { } func (f *ResponseFallback) ServeHTTP(w http.ResponseWriter, req *http.Request) { + if log.GetLevel() >= log.DebugLevel { + logEntry := log.WithField("Request", utils.DumpHttpRequest(req)) + logEntry.Debug("vulcand/oxy/fallback/response: begin ServeHttp on request") + defer logEntry.Debug("vulcand/oxy/fallback/response: competed ServeHttp on request") + } + if f.r.ContentType != "" { w.Header().Set("Content-Type", f.r.ContentType) } w.Header().Set("Content-Length", strconv.Itoa(len(f.r.Body))) w.WriteHeader(f.r.StatusCode) - w.Write(f.r.Body) + _, err := w.Write(f.r.Body) + if err != nil { + log.Errorf("vulcand/oxy/fallback/response: failed to write response, err: %v", err) + } } type Redirect struct { - URL string + URL string + PreservePath bool } type RedirectFallback struct { u *url.URL + r Redirect } func NewRedirectFallback(r Redirect) (*RedirectFallback, error) { @@ -46,11 +60,25 @@ func NewRedirectFallback(r Redirect) (*RedirectFallback, error) { if err != nil { return nil, err } - return &RedirectFallback{u: u}, nil + return &RedirectFallback{u: u, r: r}, nil } func (f *RedirectFallback) ServeHTTP(w http.ResponseWriter, req *http.Request) { - w.Header().Set("Location", f.u.String()) + if log.GetLevel() >= log.DebugLevel { + logEntry := log.WithField("Request", utils.DumpHttpRequest(req)) + logEntry.Debug("vulcand/oxy/fallback/redirect: begin ServeHttp on request") + defer logEntry.Debug("vulcand/oxy/fallback/redirect: competed ServeHttp on request") + } + + location := f.u.String() + if f.r.PreservePath { + location += req.URL.Path + } + + w.Header().Set("Location", location) w.WriteHeader(http.StatusFound) - w.Write([]byte(http.StatusText(http.StatusFound))) + _, err := w.Write([]byte(http.StatusText(http.StatusFound))) + if err != nil { + log.Errorf("vulcand/oxy/fallback/redirect: failed to write response, err: %v", err) + } } diff --git a/vendor/github.com/vulcand/oxy/cbreaker/predicates.go b/vendor/github.com/vulcand/oxy/cbreaker/predicates.go index a858daf8c..f63875f01 100644 --- a/vendor/github.com/vulcand/oxy/cbreaker/predicates.go +++ b/vendor/github.com/vulcand/oxy/cbreaker/predicates.go @@ -4,6 +4,7 @@ import ( "fmt" "time" + log "github.com/Sirupsen/logrus" "github.com/vulcand/predicate" ) @@ -49,7 +50,7 @@ func latencyAtQuantile(quantile float64) toInt { return func(c *CircuitBreaker) int { h, err := c.metrics.LatencyHistogram() if err != nil { - c.log.Errorf("Failed to get latency histogram, for %v error: %v", c, err) + log.Errorf("Failed to get latency histogram, for %v error: %v", c, err) return 0 } return int(h.LatencyAtQuantile(quantile) / time.Millisecond) diff --git a/vendor/github.com/vulcand/oxy/cbreaker/ratio.go b/vendor/github.com/vulcand/oxy/cbreaker/ratio.go index 9758f7442..62db2ea82 100644 --- a/vendor/github.com/vulcand/oxy/cbreaker/ratio.go +++ b/vendor/github.com/vulcand/oxy/cbreaker/ratio.go @@ -4,6 +4,7 @@ import ( "fmt" "time" + log "github.com/Sirupsen/logrus" "github.com/mailgun/timetools" ) @@ -33,14 +34,17 @@ func (r *ratioController) String() string { } func (r *ratioController) allowRequest() bool { + log.Infof("%v", r) t := r.targetRatio() // This condition answers the question - would we satisfy the target ratio if we allow this request? e := r.computeRatio(r.allowed+1, r.denied) if e < t { r.allowed++ + log.Infof("%v allowed", r) return true } r.denied++ + log.Infof("%v denied", r) return false } diff --git a/vendor/github.com/vulcand/oxy/connlimit/connlimit.go b/vendor/github.com/vulcand/oxy/connlimit/connlimit.go index c819585aa..9169321d4 100644 --- a/vendor/github.com/vulcand/oxy/connlimit/connlimit.go +++ b/vendor/github.com/vulcand/oxy/connlimit/connlimit.go @@ -6,6 +6,7 @@ import ( "net/http" "sync" + log "github.com/Sirupsen/logrus" "github.com/vulcand/oxy/utils" ) @@ -20,7 +21,6 @@ type ConnLimiter struct { next http.Handler errHandler utils.ErrorHandler - log utils.Logger } func New(next http.Handler, extract utils.SourceExtractor, maxConnections int64, options ...ConnLimitOption) (*ConnLimiter, error) { @@ -40,9 +40,6 @@ func New(next http.Handler, extract utils.SourceExtractor, maxConnections int64, return nil, err } } - if cl.log == nil { - cl.log = utils.NullLogger - } if cl.errHandler == nil { cl.errHandler = defaultErrHandler } @@ -56,12 +53,12 @@ func (cl *ConnLimiter) Wrap(h http.Handler) { func (cl *ConnLimiter) ServeHTTP(w http.ResponseWriter, r *http.Request) { token, amount, err := cl.extract.Extract(r) if err != nil { - cl.log.Errorf("failed to extract source of the connection: %v", err) + log.Errorf("failed to extract source of the connection: %v", err) cl.errHandler.ServeHTTP(w, r, err) return } if err := cl.acquire(token, amount); err != nil { - cl.log.Infof("limiting request source %s: %v", token, err) + log.Infof("limiting request source %s: %v", token, err) cl.errHandler.ServeHTTP(w, r, err) return } @@ -81,7 +78,7 @@ func (cl *ConnLimiter) acquire(token string, amount int64) error { } cl.connections[token] += amount - cl.totalConnections += int64(amount) + cl.totalConnections += amount return nil } @@ -90,7 +87,7 @@ func (cl *ConnLimiter) release(token string, amount int64) { defer cl.mutex.Unlock() cl.connections[token] -= amount - cl.totalConnections -= int64(amount) + cl.totalConnections -= amount // Otherwise it would grow forever if cl.connections[token] == 0 { @@ -110,6 +107,12 @@ type ConnErrHandler struct { } func (e *ConnErrHandler) ServeHTTP(w http.ResponseWriter, req *http.Request, err error) { + if log.GetLevel() >= log.DebugLevel { + logEntry := log.WithField("Request", utils.DumpHttpRequest(req)) + logEntry.Debug("vulcand/oxy/connlimit: begin ServeHttp on request") + defer logEntry.Debug("vulcand/oxy/connlimit: competed ServeHttp on request") + } + if _, ok := err.(*MaxConnError); ok { w.WriteHeader(429) w.Write([]byte(err.Error())) @@ -120,14 +123,6 @@ func (e *ConnErrHandler) ServeHTTP(w http.ResponseWriter, req *http.Request, err type ConnLimitOption func(l *ConnLimiter) error -// Logger sets the logger that will be used by this middleware. -func Logger(l utils.Logger) ConnLimitOption { - return func(cl *ConnLimiter) error { - cl.log = l - return nil - } -} - // ErrorHandler sets error handler of the server func ErrorHandler(h utils.ErrorHandler) ConnLimitOption { return func(cl *ConnLimiter) error { diff --git a/vendor/github.com/vulcand/oxy/forward/fwd.go b/vendor/github.com/vulcand/oxy/forward/fwd.go index e6aba940a..ffd226860 100644 --- a/vendor/github.com/vulcand/oxy/forward/fwd.go +++ b/vendor/github.com/vulcand/oxy/forward/fwd.go @@ -7,13 +7,15 @@ import ( "crypto/tls" "io" "net/http" + "net/http/httptest" + "net/http/httputil" "net/url" "os" "reflect" - "strconv" "strings" "time" + log "github.com/Sirupsen/logrus" "github.com/gorilla/websocket" "github.com/vulcand/oxy/utils" ) @@ -29,15 +31,7 @@ type optSetter func(f *Forwarder) error // be delegated func PassHostHeader(b bool) optSetter { return func(f *Forwarder) error { - f.passHost = b - return nil - } -} - -// StreamResponse forces streaming body (flushes response directly to client) -func StreamResponse(b bool) optSetter { - return func(f *Forwarder) error { - f.httpForwarder.streamResponse = b + f.httpForwarder.passHost = b return nil } } @@ -46,7 +40,7 @@ func StreamResponse(b bool) optSetter { // Forwarder will use http.DefaultTransport as a default round tripper func RoundTripper(r http.RoundTripper) optSetter { return func(f *Forwarder) error { - f.roundTripper = r + f.httpForwarder.roundTripper = r return nil } } @@ -59,10 +53,11 @@ func Rewriter(r ReqRewriter) optSetter { } } -// WebsocketRewriter defines a request rewriter for the websocket forwarder -func WebsocketRewriter(r ReqRewriter) optSetter { +// PassHostHeader specifies if a client's Host header field should +// be delegated +func WebsocketTLSClientConfig(tcc *tls.Config) optSetter { return func(f *Forwarder) error { - f.websocketForwarder.rewriter = r + f.httpForwarder.tlsClientConfig = tcc return nil } } @@ -75,27 +70,74 @@ func ErrorHandler(h utils.ErrorHandler) optSetter { } } -// Logger specifies the logger to use. -// Forwarder will default to oxyutils.NullLogger if no logger has been specified -func Logger(l utils.Logger) optSetter { +// Stream specifies if HTTP responses should be streamed. +func Stream(stream bool) optSetter { + return func(f *Forwarder) error { + f.stream = stream + return nil + } +} + +// Logger defines the logger the forwarder will use. +// +// It defaults to logrus.StandardLogger(), the global logger used by logrus. +func Logger(l *log.Logger) optSetter { return func(f *Forwarder) error { f.log = l return nil } } +func StateListener(stateListener UrlForwardingStateListener) optSetter { + return func(f *Forwarder) error { + f.stateListener = stateListener + return nil + } +} + +func ResponseModifier(responseModifier func(*http.Response) error) optSetter { + return func(f *Forwarder) error { + f.httpForwarder.modifyResponse = responseModifier + return nil + } +} + +func StreamingFlushInterval(flushInterval time.Duration) optSetter { + return func(f *Forwarder) error { + f.httpForwarder.flushInterval = flushInterval + return nil + } +} + +type ErrorHandlingRoundTripper struct { + http.RoundTripper + errorHandler utils.ErrorHandler +} + +func (rt ErrorHandlingRoundTripper) RoundTrip(req *http.Request) (*http.Response, error) { + res, err := rt.RoundTripper.RoundTrip(req) + if err != nil { + // We use the recorder from httptest because there isn't another `public` implementation of a recorder. + recorder := httptest.NewRecorder() + rt.errorHandler.ServeHTTP(recorder, req, err) + res = recorder.Result() + err = nil + } + return res, err +} + // Forwarder wraps two traffic forwarding implementations: HTTP and websockets. // It decides based on the specified request which implementation to use type Forwarder struct { *httpForwarder - *websocketForwarder *handlerContext + stateListener UrlForwardingStateListener + stream bool } // handlerContext defines a handler context for error reporting and logging type handlerContext struct { errHandler utils.ErrorHandler - log utils.Logger } // httpForwarder is a handler that can reverse proxy @@ -104,32 +146,40 @@ type httpForwarder struct { roundTripper http.RoundTripper rewriter ReqRewriter passHost bool - streamResponse bool + flushInterval time.Duration + modifyResponse func(*http.Response) error + + tlsClientConfig *tls.Config + + log *log.Logger } -// websocketForwarder is a handler that can reverse proxy -// websocket traffic -type websocketForwarder struct { - rewriter ReqRewriter - TLSClientConfig *tls.Config -} +const ( + defaultFlushInterval = time.Duration(100) * time.Millisecond + StateConnected = iota + StateDisconnected +) + +type UrlForwardingStateListener func(*url.URL, int) // New creates an instance of Forwarder based on the provided list of configuration options func New(setters ...optSetter) (*Forwarder, error) { f := &Forwarder{ - httpForwarder: &httpForwarder{}, - websocketForwarder: &websocketForwarder{}, - handlerContext: &handlerContext{}, + httpForwarder: &httpForwarder{log: log.StandardLogger()}, + handlerContext: &handlerContext{}, } for _, s := range setters { if err := s(f); err != nil { return nil, err } } - if f.httpForwarder.roundTripper == nil { - f.httpForwarder.roundTripper = http.DefaultTransport + + if !f.stream { + f.flushInterval = 0 + } else if f.flushInterval == 0 { + f.flushInterval = defaultFlushInterval } - f.websocketForwarder.TLSClientConfig = f.httpForwarder.roundTripper.(*http.Transport).TLSClientConfig + if f.httpForwarder.rewriter == nil { h, err := os.Hostname() if err != nil { @@ -137,136 +187,104 @@ func New(setters ...optSetter) (*Forwarder, error) { } f.httpForwarder.rewriter = &HeaderRewriter{TrustForwardHeader: true, Hostname: h} } - if f.log == nil { - f.log = utils.NullLogger + + if f.httpForwarder.roundTripper == nil { + f.httpForwarder.roundTripper = http.DefaultTransport } + if f.errHandler == nil { f.errHandler = utils.DefaultHandler } + + if f.tlsClientConfig == nil { + f.tlsClientConfig = f.httpForwarder.roundTripper.(*http.Transport).TLSClientConfig + } + + f.httpForwarder.roundTripper = ErrorHandlingRoundTripper{ + RoundTripper: f.httpForwarder.roundTripper, + errorHandler: f.errHandler, + } + return f, nil } // ServeHTTP decides which forwarder to use based on the specified // request and delegates to the proper implementation func (f *Forwarder) ServeHTTP(w http.ResponseWriter, req *http.Request) { - if isWebsocketRequest(req) { - f.websocketForwarder.serveHTTP(w, req, f.handlerContext) + if f.log.Level >= log.DebugLevel { + logEntry := f.log.WithField("Request", utils.DumpHttpRequest(req)) + logEntry.Debug("vulcand/oxy/forward: begin ServeHttp on request") + defer logEntry.Debug("vulcand/oxy/forward: competed ServeHttp on request") + } + + if f.stateListener != nil { + f.stateListener(req.URL, StateConnected) + defer f.stateListener(req.URL, StateDisconnected) + } + if IsWebsocketRequest(req) { + f.httpForwarder.serveWebSocket(w, req, f.handlerContext) } else { f.httpForwarder.serveHTTP(w, req, f.handlerContext) } } -// serveHTTP forwards HTTP traffic using the configured transport -func (f *httpForwarder) serveHTTP(w http.ResponseWriter, req *http.Request, ctx *handlerContext) { - start := time.Now().UTC() - - response, err := f.roundTripper.RoundTrip(f.copyRequest(req, req.URL)) - - if err != nil { - ctx.log.Errorf("Error forwarding to %v, err: %v", req.URL, err) - ctx.errHandler.ServeHTTP(w, req, err) - return - } - - utils.CopyHeaders(w.Header(), response.Header) - // Remove hop-by-hop headers. - utils.RemoveHeaders(w.Header(), HopHeaders...) - - announcedTrailerKeyCount := len(response.Trailer) - if announcedTrailerKeyCount > 0 { - trailerKeys := make([]string, 0, announcedTrailerKeyCount) - for k := range response.Trailer { - trailerKeys = append(trailerKeys, k) - } - w.Header().Add("Trailer", strings.Join(trailerKeys, ", ")) - } - - w.WriteHeader(response.StatusCode) - - stream := f.streamResponse - if !stream { - contentType, err := utils.GetHeaderMediaType(response.Header, ContentType) +func (f *httpForwarder) getUrlFromRequest(req *http.Request) *url.URL { + // If the Request was created by Go via a real HTTP request, RequestURI will + // contain the original query string. If the Request was created in code, RequestURI + // will be empty, and we will use the URL object instead + u := req.URL + if req.RequestURI != "" { + parsedURL, err := url.ParseRequestURI(req.RequestURI) if err == nil { - stream = contentType == "text/event-stream" + u = parsedURL + } else { + f.log.Warnf("vulcand/oxy/forward: error when parsing RequestURI: %s", err) } } - - flush := stream || req.ProtoMajor == 2 - written, err := io.Copy(newResponseFlusher(w, flush), response.Body) - if err != nil { - ctx.log.Errorf("Error copying upstream response body: %v", err) - ctx.errHandler.ServeHTTP(w, req, err) - return - } - - defer response.Body.Close() - - forceSetTrailers := len(response.Trailer) != announcedTrailerKeyCount - shallowCopyTrailers(w.Header(), response.Trailer, forceSetTrailers) - - if written != 0 { - w.Header().Set(ContentLength, strconv.FormatInt(written, 10)) - } - - if req.TLS != nil { - ctx.log.Infof("Round trip: %v, code: %v, duration: %v tls:version: %x, tls:resume:%t, tls:csuite:%x, tls:server:%v", - req.URL, response.StatusCode, time.Now().UTC().Sub(start), - req.TLS.Version, - req.TLS.DidResume, - req.TLS.CipherSuite, - req.TLS.ServerName) - } else { - ctx.log.Infof("Round trip: %v, code: %v, duration: %v", - req.URL, response.StatusCode, time.Now().UTC().Sub(start)) - } - + return u } -// copyRequest makes a copy of the specified request to be sent using the configured -// transport -func (f *httpForwarder) copyRequest(req *http.Request, u *url.URL) *http.Request { - outReq := new(http.Request) - *outReq = *req // includes shallow copies of maps, but we handle this below +// Modify the request to handle the target URL +func (f *httpForwarder) modifyRequest(outReq *http.Request, target *url.URL) { + outReq.URL = utils.CopyURL(outReq.URL) + outReq.URL.Scheme = target.Scheme + outReq.URL.Host = target.Host + + u := f.getUrlFromRequest(outReq) + + outReq.URL.Path = u.Path + outReq.URL.RawPath = u.RawPath + outReq.URL.RawQuery = u.RawQuery + outReq.RequestURI = "" // Outgoing request should not have RequestURI - outReq.URL = utils.CopyURL(req.URL) - outReq.URL.Scheme = u.Scheme - outReq.URL.Host = u.Host - outReq.URL.Opaque = req.RequestURI - // raw query is already included in RequestURI, so ignore it to avoid dupes - outReq.URL.RawQuery = "" // Do not pass client Host header unless optsetter PassHostHeader is set. if !f.passHost { - outReq.Host = u.Host + outReq.Host = target.Host } + outReq.Proto = "HTTP/1.1" outReq.ProtoMajor = 1 outReq.ProtoMinor = 1 - // Overwrite close flag so we can keep persistent connection for the backend servers - outReq.Close = false - - outReq.Header = make(http.Header) - utils.CopyHeaders(outReq.Header, req.Header) - if f.rewriter != nil { f.rewriter.Rewrite(outReq) } - - if req.ContentLength == 0 { - // https://github.com/golang/go/issues/16036: nil Body for http.Transport retries - outReq.Body = nil - } - - return outReq } // serveHTTP forwards websocket traffic -func (f *websocketForwarder) serveHTTP(w http.ResponseWriter, req *http.Request, ctx *handlerContext) { - outReq := f.copyRequest(req, req.URL) +func (f *httpForwarder) serveWebSocket(w http.ResponseWriter, req *http.Request, ctx *handlerContext) { + if f.log.Level >= log.DebugLevel { + logEntry := f.log.WithField("Request", utils.DumpHttpRequest(req)) + logEntry.Debug("vulcand/oxy/forward/websocket: begin ServeHttp on request") + defer logEntry.Debug("vulcand/oxy/forward/websocket: competed ServeHttp on request") + } + + outReq := f.copyWebSocketRequest(req) dialer := websocket.DefaultDialer - if outReq.URL.Scheme == "wss" && f.TLSClientConfig != nil { - dialer.TLSClientConfig = f.TLSClientConfig.Clone() + if outReq.URL.Scheme == "wss" && f.tlsClientConfig != nil { + dialer.TLSClientConfig = f.tlsClientConfig.Clone() + // WebSocket is only in http/1.1 dialer.TLSClientConfig.NextProtos = []string{"http/1.1"} } targetConn, resp, err := dialer.Dial(outReq.URL.String(), outReq.Header) @@ -274,33 +292,33 @@ func (f *websocketForwarder) serveHTTP(w http.ResponseWriter, req *http.Request, if resp == nil { ctx.errHandler.ServeHTTP(w, req, err) } else { - ctx.log.Errorf("Error dialing %q: %v with resp: %d %s", outReq.Host, err, resp.StatusCode, resp.Status) + log.Errorf("vulcand/oxy/forward/websocket: Error dialing %q: %v with resp: %d %s", outReq.Host, err, resp.StatusCode, resp.Status) hijacker, ok := w.(http.Hijacker) if !ok { - ctx.log.Errorf("%s can not be hijack", reflect.TypeOf(w)) + log.Errorf("vulcand/oxy/forward/websocket: %s can not be hijack", reflect.TypeOf(w)) ctx.errHandler.ServeHTTP(w, req, err) return } - conn, _, err := hijacker.Hijack() - if err != nil { - ctx.log.Errorf("Failed to hijack responseWriter") - ctx.errHandler.ServeHTTP(w, req, err) + conn, _, errHijack := hijacker.Hijack() + if errHijack != nil { + log.Errorf("vulcand/oxy/forward/websocket: Failed to hijack responseWriter") + ctx.errHandler.ServeHTTP(w, req, errHijack) return } defer conn.Close() - err = resp.Write(conn) - if err != nil { - ctx.log.Errorf("Failed to forward response") - ctx.errHandler.ServeHTTP(w, req, err) + errWrite := resp.Write(conn) + if errWrite != nil { + log.Errorf("vulcand/oxy/forward/websocket: Failed to forward response") + ctx.errHandler.ServeHTTP(w, req, errWrite) return } } return } - //Only the targetConn choose to CheckOrigin or not + // Only the targetConn choose to CheckOrigin or not upgrader := websocket.Upgrader{CheckOrigin: func(r *http.Request) bool { return true }} @@ -308,62 +326,64 @@ func (f *websocketForwarder) serveHTTP(w http.ResponseWriter, req *http.Request, utils.RemoveHeaders(resp.Header, WebsocketUpgradeHeaders...) underlyingConn, err := upgrader.Upgrade(w, req, resp.Header) if err != nil { - ctx.log.Errorf("Error while upgrading connection : %v", err) + log.Errorf("vulcand/oxy/forward/websocket: Error while upgrading connection : %v", err) return } defer underlyingConn.Close() defer targetConn.Close() errc := make(chan error, 2) - replicate := func(dst io.Writer, src io.Reader) { - _, err := io.Copy(dst, src) - errc <- err + replicate := func(dst io.Writer, src io.Reader, dstName string, srcName string) { + _, errCopy := io.Copy(dst, src) + if errCopy != nil { + f.log.Errorf("vulcand/oxy/forward/websocket: Error when copying from %s to %s using io.Copy: %v", srcName, dstName, errCopy) + } else { + f.log.Infof("vulcand/oxy/forward/websocket: Copying from %s to %s using io.Copy completed without error.", srcName, dstName) + } + errc <- errCopy } - go replicate(targetConn.UnderlyingConn(), underlyingConn.UnderlyingConn()) + go replicate(targetConn.UnderlyingConn(), underlyingConn.UnderlyingConn(), "backend", "client") // Try to read the first message - t, msg, err := targetConn.ReadMessage() + msgType, msg, err := targetConn.ReadMessage() if err != nil { - ctx.log.Errorf("Couldn't read first message : %v", err) + log.Errorf("vulcand/oxy/forward/websocket: Couldn't read first message : %v", err) } else { - underlyingConn.WriteMessage(t, msg) + underlyingConn.WriteMessage(msgType, msg) } - go replicate(underlyingConn.UnderlyingConn(), targetConn.UnderlyingConn()) + go replicate(underlyingConn.UnderlyingConn(), targetConn.UnderlyingConn(), "client", "backend") <-errc - } -// copyRequest makes a copy of the specified request. -func (f *websocketForwarder) copyRequest(req *http.Request, u *url.URL) (outReq *http.Request) { +// copyWebsocketRequest makes a copy of the specified request. +func (f *httpForwarder) copyWebSocketRequest(req *http.Request) (outReq *http.Request) { outReq = new(http.Request) *outReq = *req // includes shallow copies of maps, but we handle this below outReq.URL = utils.CopyURL(req.URL) - outReq.URL.Scheme = u.Scheme + outReq.URL.Scheme = req.URL.Scheme - //sometimes backends might be registered as HTTP/HTTPS servers so translate URLs to websocket URLs. - switch u.Scheme { + // sometimes backends might be registered as HTTP/HTTPS servers so translate URLs to websocket URLs. + switch req.URL.Scheme { case "https": outReq.URL.Scheme = "wss" case "http": outReq.URL.Scheme = "ws" } - if requestURI, err := url.ParseRequestURI(outReq.RequestURI); err == nil { - if requestURI.RawPath != "" { - outReq.URL.Path = requestURI.RawPath - } else { - outReq.URL.Path = requestURI.Path - } - outReq.URL.RawQuery = requestURI.RawQuery - } + u := f.getUrlFromRequest(outReq) - outReq.URL.Host = u.Host + outReq.URL.Path = u.Path + outReq.URL.RawPath = u.RawPath + outReq.URL.RawQuery = u.RawQuery + outReq.RequestURI = "" // Outgoing request should not have RequestURI + + outReq.URL.Host = req.URL.Host outReq.Header = make(http.Header) - //gorilla websocket use this header to set the request.Host tested in checkSameOrigin + // gorilla websocket use this header to set the request.Host tested in checkSameOrigin outReq.Header.Set("Host", outReq.Host) utils.CopyHeaders(outReq.Header, req.Header) utils.RemoveHeaders(outReq.Header, WebsocketDialHeaders...) @@ -374,9 +394,48 @@ func (f *websocketForwarder) copyRequest(req *http.Request, u *url.URL) (outReq return outReq } +// serveHTTP forwards HTTP traffic using the configured transport +func (f *httpForwarder) serveHTTP(w http.ResponseWriter, inReq *http.Request, ctx *handlerContext) { + if f.log.Level >= log.DebugLevel { + logEntry := f.log.WithField("Request", utils.DumpHttpRequest(inReq)) + logEntry.Debug("vulcand/oxy/forward/http: begin ServeHttp on request") + defer logEntry.Debug("vulcand/oxy/forward/http: completed ServeHttp on request") + } + + pw := &utils.ProxyWriter{ + W: w, + } + start := time.Now().UTC() + + outReq := new(http.Request) + *outReq = *inReq // includes shallow copies of maps, but we handle this in Director + + revproxy := httputil.ReverseProxy{ + Director: func(req *http.Request) { + f.modifyRequest(req, inReq.URL) + }, + Transport: f.roundTripper, + FlushInterval: f.flushInterval, + ModifyResponse: f.modifyResponse, + } + revproxy.ServeHTTP(pw, outReq) + + if inReq.TLS != nil { + f.log.Infof("vulcand/oxy/forward/http: Round trip: %v, code: %v, Length: %v, duration: %v tls:version: %x, tls:resume:%t, tls:csuite:%x, tls:server:%v", + inReq.URL, pw.Code, pw.Length, time.Now().UTC().Sub(start), + inReq.TLS.Version, + inReq.TLS.DidResume, + inReq.TLS.CipherSuite, + inReq.TLS.ServerName) + } else { + f.log.Infof("vulcand/oxy/forward/http: Round trip: %v, code: %v, Length: %v, duration: %v", + inReq.URL, pw.Code, pw.Length, time.Now().UTC().Sub(start)) + } +} + // isWebsocketRequest determines if the specified HTTP request is a // websocket handshake request -func isWebsocketRequest(req *http.Request) bool { +func IsWebsocketRequest(req *http.Request) bool { containsHeader := func(name, value string) bool { items := strings.Split(req.Header.Get(name), ",") for _, item := range items { @@ -388,12 +447,3 @@ func isWebsocketRequest(req *http.Request) bool { } return containsHeader(Connection, "upgrade") && containsHeader(Upgrade, "websocket") } - -func shallowCopyTrailers(dstHeader, srcTrailer http.Header, forceSetTrailers bool) { - for k, vv := range srcTrailer { - if forceSetTrailers { - k = http.TrailerPrefix + k - } - dstHeader[k] = vv - } -} diff --git a/vendor/github.com/vulcand/oxy/forward/headers.go b/vendor/github.com/vulcand/oxy/forward/headers.go index 629421551..9c9bd734c 100644 --- a/vendor/github.com/vulcand/oxy/forward/headers.go +++ b/vendor/github.com/vulcand/oxy/forward/headers.go @@ -16,7 +16,6 @@ const ( TransferEncoding = "Transfer-Encoding" Upgrade = "Upgrade" ContentLength = "Content-Length" - ContentType = "Content-Type" SecWebsocketKey = "Sec-Websocket-Key" SecWebsocketVersion = "Sec-Websocket-Version" SecWebsocketExtensions = "Sec-Websocket-Extensions" diff --git a/vendor/github.com/vulcand/oxy/forward/responseflusher.go b/vendor/github.com/vulcand/oxy/forward/responseflusher.go deleted file mode 100644 index ea8cc89a7..000000000 --- a/vendor/github.com/vulcand/oxy/forward/responseflusher.go +++ /dev/null @@ -1,53 +0,0 @@ -package forward - -import ( - "bufio" - "fmt" - "net" - "net/http" -) - -var ( - _ http.Hijacker = &responseFlusher{} - _ http.Flusher = &responseFlusher{} - _ http.CloseNotifier = &responseFlusher{} -) - -type responseFlusher struct { - http.ResponseWriter - flush bool -} - -func newResponseFlusher(rw http.ResponseWriter, flush bool) *responseFlusher { - return &responseFlusher{ - ResponseWriter: rw, - flush: flush, - } -} - -func (wf *responseFlusher) Write(p []byte) (int, error) { - written, err := wf.ResponseWriter.Write(p) - if wf.flush { - wf.Flush() - } - return written, err -} - -func (wf *responseFlusher) Hijack() (net.Conn, *bufio.ReadWriter, error) { - hijacker, ok := wf.ResponseWriter.(http.Hijacker) - if !ok { - return nil, nil, fmt.Errorf("the ResponseWriter doesn't support the Hijacker interface") - } - return hijacker.Hijack() -} - -func (wf *responseFlusher) CloseNotify() <-chan bool { - return wf.ResponseWriter.(http.CloseNotifier).CloseNotify() -} - -func (wf *responseFlusher) Flush() { - flusher, ok := wf.ResponseWriter.(http.Flusher) - if ok { - flusher.Flush() - } -} diff --git a/vendor/github.com/vulcand/oxy/forward/rewrite.go b/vendor/github.com/vulcand/oxy/forward/rewrite.go index 6a39241f2..38a7f7fc4 100644 --- a/vendor/github.com/vulcand/oxy/forward/rewrite.go +++ b/vendor/github.com/vulcand/oxy/forward/rewrite.go @@ -14,16 +14,25 @@ type HeaderRewriter struct { Hostname string } +// clean up IP in case if it is ipv6 address and it has {zone} information in it, like "[fe80::d806:a55d:eb1b:49cc%vEthernet (vmxnet3 Ethernet Adapter - Virtual Switch)]:64692" +func ipv6fix(clientIP string) string { + return strings.Split(clientIP, "%")[0] +} + func (rw *HeaderRewriter) Rewrite(req *http.Request) { if !rw.TrustForwardHeader { utils.RemoveHeaders(req.Header, XHeaders...) } if clientIP, _, err := net.SplitHostPort(req.RemoteAddr); err == nil { - if prior, ok := req.Header[XForwardedFor]; ok { - req.Header.Set(XForwardedFor, strings.Join(prior, ", ")+", "+clientIP) - } else { - req.Header.Set(XForwardedFor, clientIP) + clientIP = ipv6fix(clientIP) + // If not websocket, done in http.ReverseProxy + if IsWebsocketRequest(req) { + if prior, ok := req.Header[XForwardedFor]; ok { + req.Header.Set(XForwardedFor, strings.Join(prior, ", ")+", "+clientIP) + } else { + req.Header.Set(XForwardedFor, clientIP) + } } if req.Header.Get(XRealIp) == "" { @@ -40,7 +49,15 @@ func (rw *HeaderRewriter) Rewrite(req *http.Request) { } } - if xfp := req.Header.Get(XForwardedPort); xfp == "" { + if IsWebsocketRequest(req) { + if req.Header.Get(XForwardedProto) == "https" { + req.Header.Set(XForwardedProto, "wss") + } else { + req.Header.Set(XForwardedProto, "ws") + } + } + + if xfPort := req.Header.Get(XForwardedPort); xfPort == "" { req.Header.Set(XForwardedPort, forwardedPort(req)) } @@ -52,9 +69,11 @@ func (rw *HeaderRewriter) Rewrite(req *http.Request) { req.Header.Set(XForwardedServer, rw.Hostname) } - // Remove hop-by-hop headers to the backend. Especially important is "Connection" because we want a persistent - // connection, regardless of what the client sent to us. - utils.RemoveHeaders(req.Header, HopHeaders...) + if !IsWebsocketRequest(req) { + // Remove hop-by-hop headers to the backend. Especially important is "Connection" because we want a persistent + // connection, regardless of what the client sent to us. + utils.RemoveHeaders(req.Header, HopHeaders...) + } } func forwardedPort(req *http.Request) string { diff --git a/vendor/github.com/vulcand/oxy/memmetrics/anomaly.go b/vendor/github.com/vulcand/oxy/memmetrics/anomaly.go index 5fa068911..5aeb13ae3 100644 --- a/vendor/github.com/vulcand/oxy/memmetrics/anomaly.go +++ b/vendor/github.com/vulcand/oxy/memmetrics/anomaly.go @@ -40,7 +40,7 @@ func SplitRatios(values []float64) (good map[float64]bool, bad map[float64]bool) } // SplitFloat64 provides simple anomaly detection for skewed data sets with no particular distribution. -// In essense it applies the formula if(v > median(values) + threshold * medianAbsoluteDeviation) -> anomaly +// In essence it applies the formula if(v > median(values) + threshold * medianAbsoluteDeviation) -> anomaly // There's a corner case where there are just 2 values, so by definition there's no value that exceeds the threshold. // This case is solved by introducing additional value that we know is good, e.g. 0. That helps to improve the detection results // on such data sets. diff --git a/vendor/github.com/vulcand/oxy/memmetrics/counter.go b/vendor/github.com/vulcand/oxy/memmetrics/counter.go index 71e0a7951..361d8a878 100644 --- a/vendor/github.com/vulcand/oxy/memmetrics/counter.go +++ b/vendor/github.com/vulcand/oxy/memmetrics/counter.go @@ -71,9 +71,7 @@ func (c *RollingCounter) Clone() *RollingCounter { lastBucket: c.lastBucket, lastUpdated: c.lastUpdated, } - for i, v := range c.values { - other.values[i] = v - } + copy(other.values, c.values) return other } diff --git a/vendor/github.com/vulcand/oxy/memmetrics/histogram.go b/vendor/github.com/vulcand/oxy/memmetrics/histogram.go index 21db94ce1..02c1d561e 100644 --- a/vendor/github.com/vulcand/oxy/memmetrics/histogram.go +++ b/vendor/github.com/vulcand/oxy/memmetrics/histogram.go @@ -34,6 +34,15 @@ func NewHDRHistogram(low, high int64, sigfigs int) (h *HDRHistogram, err error) }, nil } +func (r *HDRHistogram) Export() *HDRHistogram { + var hist *hdrhistogram.Histogram = nil + if r.h != nil { + snapshot := r.h.Export() + hist = hdrhistogram.Import(snapshot) + } + return &HDRHistogram{low: r.low, high: r.high, sigfigs: r.sigfigs, h: hist} +} + // Returns latency at quantile with microsecond precision func (h *HDRHistogram) LatencyAtQuantile(q float64) time.Duration { return time.Duration(h.ValueAtQuantile(q)) * time.Microsecond @@ -118,6 +127,26 @@ func NewRollingHDRHistogram(low, high int64, sigfigs int, period time.Duration, return rh, nil } +func (r *RollingHDRHistogram) Export() *RollingHDRHistogram { + export := &RollingHDRHistogram{} + export.idx = r.idx + export.lastRoll = r.lastRoll + export.period = r.period + export.bucketCount = r.bucketCount + export.low = r.low + export.high = r.high + export.sigfigs = r.sigfigs + export.clock = r.clock + + exportBuckets := make([]*HDRHistogram, len(r.buckets)) + for i, hist := range r.buckets { + exportBuckets[i] = hist.Export() + } + export.buckets = exportBuckets + + return export +} + func (r *RollingHDRHistogram) Append(o *RollingHDRHistogram) error { if r.bucketCount != o.bucketCount || r.period != o.period || r.low != o.low || r.high != o.high || r.sigfigs != o.sigfigs { return fmt.Errorf("can't merge") @@ -150,8 +179,8 @@ func (r *RollingHDRHistogram) Merged() (*HDRHistogram, error) { return m, err } for _, h := range r.buckets { - if m.Merge(h); err != nil { - return nil, err + if errMerge := m.Merge(h); errMerge != nil { + return nil, errMerge } } return m, nil diff --git a/vendor/github.com/vulcand/oxy/memmetrics/roundtrip.go b/vendor/github.com/vulcand/oxy/memmetrics/roundtrip.go index 0b1acf49f..4bdb4bba2 100644 --- a/vendor/github.com/vulcand/oxy/memmetrics/roundtrip.go +++ b/vendor/github.com/vulcand/oxy/memmetrics/roundtrip.go @@ -20,6 +20,7 @@ type RTMetrics struct { statusCodes map[int]*RollingCounter statusCodesLock sync.RWMutex histogram *RollingHDRHistogram + histogramLock sync.RWMutex newCounter NewCounterFn newHist NewRollingHistogramFn @@ -102,12 +103,39 @@ func NewRTMetrics(settings ...rrOptSetter) (*RTMetrics, error) { return m, nil } +// Returns a new RTMetrics which is a copy of the current one +func (m *RTMetrics) Export() *RTMetrics { + m.statusCodesLock.RLock() + defer m.statusCodesLock.RUnlock() + m.histogramLock.RLock() + defer m.histogramLock.RUnlock() + + export := &RTMetrics{} + export.statusCodesLock = sync.RWMutex{} + export.histogramLock = sync.RWMutex{} + export.total = m.total.Clone() + export.netErrors = m.netErrors.Clone() + exportStatusCodes := map[int]*RollingCounter{} + for code, rollingCounter := range m.statusCodes { + exportStatusCodes[code] = rollingCounter.Clone() + } + export.statusCodes = exportStatusCodes + if m.histogram != nil { + export.histogram = m.histogram.Export() + } + export.newCounter = m.newCounter + export.newHist = m.newHist + export.clock = m.clock + + return export +} + func (m *RTMetrics) CounterWindowSize() time.Duration { return m.total.WindowSize() } // GetNetworkErrorRatio calculates the amont of network errors such as time outs and dropped connection -// that occured in the given time window compared to the total requests count. +// that occurred in the given time window compared to the total requests count. func (m *RTMetrics) NetworkErrorRatio() float64 { if m.total.Count() == 0 { return 0 @@ -148,11 +176,13 @@ func (m *RTMetrics) Append(other *RTMetrics) error { return err } + copied := other.Export() + m.statusCodesLock.Lock() defer m.statusCodesLock.Unlock() - other.statusCodesLock.RLock() - defer other.statusCodesLock.RUnlock() - for code, c := range other.statusCodes { + m.histogramLock.Lock() + defer m.histogramLock.Unlock() + for code, c := range copied.statusCodes { o, ok := m.statusCodes[code] if ok { if err := o.Append(c); err != nil { @@ -163,7 +193,7 @@ func (m *RTMetrics) Append(other *RTMetrics) error { } } - return m.histogram.Append(other.histogram) + return m.histogram.Append(copied.histogram) } func (m *RTMetrics) Record(code int, duration time.Duration) { @@ -200,35 +230,36 @@ func (m *RTMetrics) StatusCodesCounts() map[int]int64 { // GetLatencyHistogram computes and returns resulting histogram with latencies observed. func (m *RTMetrics) LatencyHistogram() (*HDRHistogram, error) { + m.histogramLock.Lock() + defer m.histogramLock.Unlock() return m.histogram.Merged() } func (m *RTMetrics) Reset() { + m.statusCodesLock.Lock() + defer m.statusCodesLock.Unlock() + m.histogramLock.Lock() + defer m.histogramLock.Unlock() m.histogram.Reset() m.total.Reset() m.netErrors.Reset() - m.statusCodesLock.Lock() - defer m.statusCodesLock.Unlock() m.statusCodes = make(map[int]*RollingCounter) } -func (m *RTMetrics) recordNetError() error { - m.netErrors.Inc(1) - return nil -} - func (m *RTMetrics) recordLatency(d time.Duration) error { + m.histogramLock.Lock() + defer m.histogramLock.Unlock() return m.histogram.RecordLatencies(d, 1) } func (m *RTMetrics) recordStatusCode(statusCode int) error { - m.statusCodesLock.RLock() + m.statusCodesLock.Lock() if c, ok := m.statusCodes[statusCode]; ok { c.Inc(1) - m.statusCodesLock.RUnlock() + m.statusCodesLock.Unlock() return nil } - m.statusCodesLock.RUnlock() + m.statusCodesLock.Unlock() m.statusCodesLock.Lock() defer m.statusCodesLock.Unlock() diff --git a/vendor/github.com/vulcand/oxy/ratelimit/tokenlimiter.go b/vendor/github.com/vulcand/oxy/ratelimit/tokenlimiter.go index c621a6070..5edf1349d 100644 --- a/vendor/github.com/vulcand/oxy/ratelimit/tokenlimiter.go +++ b/vendor/github.com/vulcand/oxy/ratelimit/tokenlimiter.go @@ -7,6 +7,7 @@ import ( "sync" "time" + log "github.com/Sirupsen/logrus" "github.com/mailgun/timetools" "github.com/mailgun/ttlmap" "github.com/vulcand/oxy/utils" @@ -65,7 +66,6 @@ type TokenLimiter struct { mutex sync.Mutex bucketSets *ttlmap.TtlMap errHandler utils.ErrorHandler - log utils.Logger capacity int next http.Handler } @@ -110,7 +110,7 @@ func (tl *TokenLimiter) ServeHTTP(w http.ResponseWriter, req *http.Request) { } if err := tl.consumeRates(req, source, amount); err != nil { - tl.log.Infof("limiting request %v %v, limit: %v", req.Method, req.URL, err) + log.Infof("limiting request %v %v, limit: %v", req.Method, req.URL, err) tl.errHandler.ServeHTTP(w, req, err) return } @@ -155,7 +155,7 @@ func (tl *TokenLimiter) resolveRates(req *http.Request) *RateSet { rates, err := tl.extractRates.Extract(req) if err != nil { - tl.log.Errorf("Failed to retrieve rates: %v", err) + log.Errorf("Failed to retrieve rates: %v", err) return tl.defaultRates } @@ -190,14 +190,6 @@ func (e *RateErrHandler) ServeHTTP(w http.ResponseWriter, req *http.Request, err type TokenLimiterOption func(l *TokenLimiter) error -// Logger sets the logger that will be used by this middleware. -func Logger(l utils.Logger) TokenLimiterOption { - return func(cl *TokenLimiter) error { - cl.log = l - return nil - } -} - // ErrorHandler sets error handler of the server func ErrorHandler(h utils.ErrorHandler) TokenLimiterOption { return func(cl *TokenLimiter) error { @@ -233,9 +225,6 @@ func Capacity(cap int) TokenLimiterOption { var defaultErrHandler = &RateErrHandler{} func setDefaults(tl *TokenLimiter) { - if tl.log == nil { - tl.log = utils.NullLogger - } if tl.capacity <= 0 { tl.capacity = DefaultCapacity } diff --git a/vendor/github.com/vulcand/oxy/roundrobin/RequestRewriteListener.go b/vendor/github.com/vulcand/oxy/roundrobin/RequestRewriteListener.go new file mode 100644 index 000000000..418f4988c --- /dev/null +++ b/vendor/github.com/vulcand/oxy/roundrobin/RequestRewriteListener.go @@ -0,0 +1,5 @@ +package roundrobin + +import "net/http" + +type RequestRewriteListener func(oldReq *http.Request, newReq *http.Request) diff --git a/vendor/github.com/vulcand/oxy/roundrobin/rebalancer.go b/vendor/github.com/vulcand/oxy/roundrobin/rebalancer.go index d9a8939ab..0a5c3a12e 100644 --- a/vendor/github.com/vulcand/oxy/roundrobin/rebalancer.go +++ b/vendor/github.com/vulcand/oxy/roundrobin/rebalancer.go @@ -7,6 +7,7 @@ import ( "sync" "time" + log "github.com/Sirupsen/logrus" "github.com/mailgun/timetools" "github.com/vulcand/oxy/memmetrics" "github.com/vulcand/oxy/utils" @@ -42,22 +43,15 @@ type Rebalancer struct { // errHandler is HTTP handler called in case of errors errHandler utils.ErrorHandler - log utils.Logger - ratings []float64 // creates new meters newMeter NewMeterFn // sticky session object - ss *StickySession -} + stickySession *StickySession -func RebalancerLogger(log utils.Logger) RebalancerOption { - return func(r *Rebalancer) error { - r.log = log - return nil - } + requestRewriteListener RequestRewriteListener } func RebalancerClock(clock timetools.TimeProvider) RebalancerOption { @@ -89,18 +83,26 @@ func RebalancerErrorHandler(h utils.ErrorHandler) RebalancerOption { } } -func RebalancerStickySession(ss *StickySession) RebalancerOption { +func RebalancerStickySession(stickySession *StickySession) RebalancerOption { return func(r *Rebalancer) error { - r.ss = ss + r.stickySession = stickySession + return nil + } +} + +// RebalancerErrorHandler is a functional argument that sets error handler of the server +func RebalancerRequestRewriteListener(rrl RequestRewriteListener) RebalancerOption { + return func(r *Rebalancer) error { + r.requestRewriteListener = rrl return nil } } func NewRebalancer(handler balancerHandler, opts ...RebalancerOption) (*Rebalancer, error) { rb := &Rebalancer{ - mtx: &sync.Mutex{}, - next: handler, - ss: nil, + mtx: &sync.Mutex{}, + next: handler, + stickySession: nil, } for _, o := range opts { if err := o(rb); err != nil { @@ -113,9 +115,6 @@ func NewRebalancer(handler balancerHandler, opts ...RebalancerOption) (*Rebalanc if rb.backoffDuration == 0 { rb.backoffDuration = 10 * time.Second } - if rb.log == nil { - rb.log = &utils.NOPLogger{} - } if rb.newMeter == nil { rb.newMeter = func() (Meter, error) { rc, err := memmetrics.NewRatioCounter(10, time.Second, memmetrics.RatioClock(rb.clock)) @@ -143,6 +142,12 @@ func (rb *Rebalancer) Servers() []*url.URL { } func (rb *Rebalancer) ServeHTTP(w http.ResponseWriter, req *http.Request) { + if log.GetLevel() >= log.DebugLevel { + logEntry := log.WithField("Request", utils.DumpHttpRequest(req)) + logEntry.Debug("vulcand/oxy/roundrobin/rebalancer: begin ServeHttp on request") + defer logEntry.Debug("vulcand/oxy/roundrobin/rebalancer: competed ServeHttp on request") + } + pw := &utils.ProxyWriter{W: w} start := rb.clock.UtcNow() @@ -150,16 +155,15 @@ func (rb *Rebalancer) ServeHTTP(w http.ResponseWriter, req *http.Request) { newReq := *req stuck := false - if rb.ss != nil { - cookie_url, present, err := rb.ss.GetBackend(&newReq, rb.Servers()) + if rb.stickySession != nil { + cookieUrl, present, err := rb.stickySession.GetBackend(&newReq, rb.Servers()) if err != nil { - rb.errHandler.ServeHTTP(w, req, err) - return + log.Infof("vulcand/oxy/roundrobin/rebalancer: error using server from cookie: %v", err) } if present { - newReq.URL = cookie_url + newReq.URL = cookieUrl stuck = true } } @@ -171,12 +175,23 @@ func (rb *Rebalancer) ServeHTTP(w http.ResponseWriter, req *http.Request) { return } - if rb.ss != nil { - rb.ss.StickBackend(url, &w) + if log.GetLevel() >= log.DebugLevel { + //log which backend URL we're sending this request to + log.WithFields(log.Fields{"Request": utils.DumpHttpRequest(req), "ForwardURL": url}).Debugf("vulcand/oxy/roundrobin/rebalancer: Forwarding this request to URL") + } + + if rb.stickySession != nil { + rb.stickySession.StickBackend(url, &w) } newReq.URL = url } + + //Emit event to a listener if one exists + if rb.requestRewriteListener != nil { + rb.requestRewriteListener(req, &newReq) + } + rb.next.Next().ServeHTTP(pw, &newReq) rb.recordMetrics(newReq.URL, pw.Code, rb.clock.UtcNow().Sub(start)) @@ -262,11 +277,11 @@ func (rb *Rebalancer) upsertServer(u *url.URL, weight int) error { return nil } -func (r *Rebalancer) findServer(u *url.URL) (*rbServer, int) { - if len(r.servers) == 0 { +func (rb *Rebalancer) findServer(u *url.URL) (*rbServer, int) { + if len(rb.servers) == 0 { return nil, -1 } - for i, s := range r.servers { + for i, s := range rb.servers { if sameURL(u, s.url) { return s, i } @@ -304,7 +319,7 @@ func (rb *Rebalancer) adjustWeights() { func (rb *Rebalancer) applyWeights() { for _, srv := range rb.servers { - rb.log.Infof("upsert server %v, weight %v", srv.url, srv.curWeight) + log.Infof("upsert server %v, weight %v", srv.url, srv.curWeight) rb.next.UpsertServer(srv.url, Weight(srv.curWeight)) } } @@ -316,7 +331,7 @@ func (rb *Rebalancer) setMarkedWeights() bool { if srv.good { weight := increase(srv.curWeight) if weight <= FSMMaxWeight { - rb.log.Infof("increasing weight of %v from %v to %v", srv.url, srv.curWeight, weight) + log.Infof("increasing weight of %v from %v to %v", srv.url, srv.curWeight, weight) srv.curWeight = weight changed = true } @@ -363,13 +378,13 @@ func (rb *Rebalancer) markServers() bool { } } if len(g) != 0 && len(b) != 0 { - rb.log.Infof("bad: %v good: %v, ratings: %v", b, g, rb.ratings) + log.Infof("bad: %v good: %v, ratings: %v", b, g, rb.ratings) } return len(g) != 0 && len(b) != 0 } func (rb *Rebalancer) convergeWeights() bool { - // If we have previoulsy changed servers try to restore weights to the original state + // If we have previously changed servers try to restore weights to the original state changed := false for _, s := range rb.servers { if s.origWeight == s.curWeight { @@ -377,7 +392,7 @@ func (rb *Rebalancer) convergeWeights() bool { } changed = true newWeight := decrease(s.origWeight, s.curWeight) - rb.log.Infof("decreasing weight of %v from %v to %v", s.url, s.curWeight, newWeight) + log.Infof("decreasing weight of %v from %v to %v", s.url, s.curWeight, newWeight) s.curWeight = newWeight } if !changed { diff --git a/vendor/github.com/vulcand/oxy/roundrobin/rr.go b/vendor/github.com/vulcand/oxy/roundrobin/rr.go index 4f1b0a30a..69e06f6ce 100644 --- a/vendor/github.com/vulcand/oxy/roundrobin/rr.go +++ b/vendor/github.com/vulcand/oxy/roundrobin/rr.go @@ -7,6 +7,7 @@ import ( "net/url" "sync" + log "github.com/Sirupsen/logrus" "github.com/vulcand/oxy/utils" ) @@ -29,9 +30,17 @@ func ErrorHandler(h utils.ErrorHandler) LBOption { } } -func EnableStickySession(ss *StickySession) LBOption { +func EnableStickySession(stickySession *StickySession) LBOption { return func(s *RoundRobin) error { - s.ss = ss + s.stickySession = stickySession + return nil + } +} + +// ErrorHandler is a functional argument that sets error handler of the server +func RoundRobinRequestRewriteListener(rrl RequestRewriteListener) LBOption { + return func(s *RoundRobin) error { + s.requestRewriteListener = rrl return nil } } @@ -41,19 +50,20 @@ type RoundRobin struct { next http.Handler errHandler utils.ErrorHandler // Current index (starts from -1) - index int - servers []*server - currentWeight int - ss *StickySession + index int + servers []*server + currentWeight int + stickySession *StickySession + requestRewriteListener RequestRewriteListener } func New(next http.Handler, opts ...LBOption) (*RoundRobin, error) { rr := &RoundRobin{ - next: next, - index: -1, - mutex: &sync.Mutex{}, - servers: []*server{}, - ss: nil, + next: next, + index: -1, + mutex: &sync.Mutex{}, + servers: []*server{}, + stickySession: nil, } for _, o := range opts { if err := o(rr); err != nil { @@ -71,19 +81,24 @@ func (r *RoundRobin) Next() http.Handler { } func (r *RoundRobin) ServeHTTP(w http.ResponseWriter, req *http.Request) { + if log.GetLevel() >= log.DebugLevel { + logEntry := log.WithField("Request", utils.DumpHttpRequest(req)) + logEntry.Debug("vulcand/oxy/roundrobin/rr: begin ServeHttp on request") + defer logEntry.Debug("vulcand/oxy/roundrobin/rr: competed ServeHttp on request") + } + // make shallow copy of request before chaning anything to avoid side effects newReq := *req stuck := false - if r.ss != nil { - cookie_url, present, err := r.ss.GetBackend(&newReq, r.Servers()) + if r.stickySession != nil { + cookieURL, present, err := r.stickySession.GetBackend(&newReq, r.Servers()) if err != nil { - r.errHandler.ServeHTTP(w, req, err) - return + log.Infof("vulcand/oxy/roundrobin/rr: error using server from cookie: %v", err) } if present { - newReq.URL = cookie_url + newReq.URL = cookieURL stuck = true } } @@ -95,11 +110,22 @@ func (r *RoundRobin) ServeHTTP(w http.ResponseWriter, req *http.Request) { return } - if r.ss != nil { - r.ss.StickBackend(url, &w) + if r.stickySession != nil { + r.stickySession.StickBackend(url, &w) } newReq.URL = url } + + if log.GetLevel() >= log.DebugLevel { + //log which backend URL we're sending this request to + log.WithFields(log.Fields{"Request": utils.DumpHttpRequest(req), "ForwardURL": newReq.URL}).Debugf("vulcand/oxy/roundrobin/rr: Forwarding this request to URL") + } + + //Emit event to a listener if one exists + if r.requestRewriteListener != nil { + r.requestRewriteListener(req, &newReq) + } + r.next.ServeHTTP(w, &newReq) } @@ -144,8 +170,6 @@ func (r *RoundRobin) nextServer() (*server, error) { return srv, nil } } - // We did full circle and found no available servers - return nil, fmt.Errorf("no available servers") } func (r *RoundRobin) RemoveServer(u *url.URL) error { diff --git a/vendor/github.com/vulcand/oxy/roundrobin/stickysessions.go b/vendor/github.com/vulcand/oxy/roundrobin/stickysessions.go index d24dfd940..3fabeb975 100644 --- a/vendor/github.com/vulcand/oxy/roundrobin/stickysessions.go +++ b/vendor/github.com/vulcand/oxy/roundrobin/stickysessions.go @@ -7,16 +7,16 @@ import ( ) type StickySession struct { - cookiename string + cookieName string } -func NewStickySession(c string) *StickySession { - return &StickySession{c} +func NewStickySession(cookieName string) *StickySession { + return &StickySession{cookieName} } // GetBackend returns the backend URL stored in the sticky cookie, iff the backend is still in the valid list of servers. func (s *StickySession) GetBackend(req *http.Request, servers []*url.URL) (*url.URL, bool, error) { - cookie, err := req.Cookie(s.cookiename) + cookie, err := req.Cookie(s.cookieName) switch err { case nil: case http.ErrNoCookie: @@ -25,22 +25,21 @@ func (s *StickySession) GetBackend(req *http.Request, servers []*url.URL) (*url. return nil, false, err } - s_url, err := url.Parse(cookie.Value) + serverURL, err := url.Parse(cookie.Value) if err != nil { return nil, false, err } - if s.isBackendAlive(s_url, servers) { - return s_url, true, nil + if s.isBackendAlive(serverURL, servers) { + return serverURL, true, nil } else { return nil, false, nil } } func (s *StickySession) StickBackend(backend *url.URL, w *http.ResponseWriter) { - c := &http.Cookie{Name: s.cookiename, Value: backend.String(), Path: "/"} - http.SetCookie(*w, c) - return + cookie := &http.Cookie{Name: s.cookieName, Value: backend.String(), Path: "/"} + http.SetCookie(*w, cookie) } func (s *StickySession) isBackendAlive(needle *url.URL, haystack []*url.URL) bool { @@ -48,8 +47,8 @@ func (s *StickySession) isBackendAlive(needle *url.URL, haystack []*url.URL) boo return false } - for _, s := range haystack { - if sameURL(needle, s) { + for _, serverURL := range haystack { + if sameURL(needle, serverURL) { return true } } diff --git a/vendor/github.com/vulcand/oxy/stream/stream.go b/vendor/github.com/vulcand/oxy/stream/stream.go index 38bc7f49e..52263753b 100644 --- a/vendor/github.com/vulcand/oxy/stream/stream.go +++ b/vendor/github.com/vulcand/oxy/stream/stream.go @@ -1,9 +1,21 @@ /* -package stream provides http.Handler middleware that solves several problems when dealing with http requests: +package stream provides http.Handler middleware that passes-through the entire request -Reads the entire request and response into buffer, optionally buffering it to disk for large requests. -Checks the limits for the requests and responses, rejecting in case if the limit was exceeded. -Changes request content-transfer-encoding from chunked and provides total size to the handlers. +Stream works around several limitations caused by buffering implementations, but +also introduces certain risks. + +Workarounds for buffering limitations: +1. Streaming really large chunks of data (large file transfers, or streaming videos, +etc.) + +2. Streaming (chunking) sparse data. For example, an implementation might +send a health check or a heart beat over a long-lived connection. This +does not play well with buffering. + +Risks: +1. Connections could survive for very long periods of time. + +2. There is no easy way to enforce limits on size/time of a connection. Examples of a streaming middleware: @@ -12,340 +24,69 @@ Examples of a streaming middleware: w.Write([]byte("hello")) }) - // Stream will read the body in buffer before passing the request to the handler - // calculate total size of the request and transform it from chunked encoding - // before passing to the server + // Stream will literally pass through to the next handler without ANY buffering + // or validation of the data. stream.New(handler) - // This version will buffer up to 2MB in memory and will serialize any extra - // to a temporary file, if the request size exceeds 10MB it will reject the request - stream.New(handler, - stream.MemRequestBodyBytes(2 * 1024 * 1024), - stream.MaxRequestBodyBytes(10 * 1024 * 1024)) - - // Will do the same as above, but with responses - stream.New(handler, - stream.MemResponseBodyBytes(2 * 1024 * 1024), - stream.MaxResponseBodyBytes(10 * 1024 * 1024)) - - // Stream will replay the request if the handler returns error at least 3 times - // before returning the response - stream.New(handler, stream.Retry(`IsNetworkError() && Attempts() <= 2`)) - */ package stream import ( - "fmt" - "io" - "io/ioutil" "net/http" - "github.com/mailgun/multibuf" + log "github.com/Sirupsen/logrus" "github.com/vulcand/oxy/utils" ) const ( - // Store up to 1MB in RAM - DefaultMemBodyBytes = 1048576 // No limit by default DefaultMaxBodyBytes = -1 - // Maximum retry attempts - DefaultMaxRetryAttempts = 10 ) -var errHandler utils.ErrorHandler = &SizeErrHandler{} - -// Streamer is responsible for streaming requests and responses -// It buffers large reqeuests and responses to disk, -type Streamer struct { +// Stream is responsible for buffering requests and responses +// It buffers large requests and responses to disk, +type Stream struct { maxRequestBodyBytes int64 - memRequestBodyBytes int64 maxResponseBodyBytes int64 - memResponseBodyBytes int64 retryPredicate hpredicate next http.Handler errHandler utils.ErrorHandler - log utils.Logger } // New returns a new streamer middleware. New() function supports optional functional arguments -func New(next http.Handler, setters ...optSetter) (*Streamer, error) { - strm := &Streamer{ +func New(next http.Handler, setters ...optSetter) (*Stream, error) { + strm := &Stream{ next: next, maxRequestBodyBytes: DefaultMaxBodyBytes, - memRequestBodyBytes: DefaultMemBodyBytes, maxResponseBodyBytes: DefaultMaxBodyBytes, - memResponseBodyBytes: DefaultMemBodyBytes, } for _, s := range setters { if err := s(strm); err != nil { return nil, err } } - if strm.errHandler == nil { - strm.errHandler = errHandler - } - - if strm.log == nil { - strm.log = utils.NullLogger - } - return strm, nil } -type optSetter func(s *Streamer) error - -// Retry provides a predicate that allows stream middleware to replay the request -// if it matches certain condition, e.g. returns special error code. Available functions are: -// -// Attempts() - limits the amount of retry attempts -// ResponseCode() - returns http response code -// IsNetworkError() - tests if response code is related to networking error -// -// Example of the predicate: -// -// `Attempts() <= 2 && ResponseCode() == 502` -// -func Retry(predicate string) optSetter { - return func(s *Streamer) error { - p, err := parseExpression(predicate) - if err != nil { - return err - } - s.retryPredicate = p - return nil - } -} - -// Logger sets the logger that will be used by this middleware. -func Logger(l utils.Logger) optSetter { - return func(s *Streamer) error { - s.log = l - return nil - } -} - -// ErrorHandler sets error handler of the server -func ErrorHandler(h utils.ErrorHandler) optSetter { - return func(s *Streamer) error { - s.errHandler = h - return nil - } -} - -// MaxRequestBodyBytes sets the maximum request body size in bytes -func MaxRequestBodyBytes(m int64) optSetter { - return func(s *Streamer) error { - if m < 0 { - return fmt.Errorf("max bytes should be >= 0 got %d", m) - } - s.maxRequestBodyBytes = m - return nil - } -} - -// MaxRequestBody bytes sets the maximum request body to be stored in memory -// stream middleware will serialize the excess to disk. -func MemRequestBodyBytes(m int64) optSetter { - return func(s *Streamer) error { - if m < 0 { - return fmt.Errorf("mem bytes should be >= 0 got %d", m) - } - s.memRequestBodyBytes = m - return nil - } -} - -// MaxResponseBodyBytes sets the maximum request body size in bytes -func MaxResponseBodyBytes(m int64) optSetter { - return func(s *Streamer) error { - if m < 0 { - return fmt.Errorf("max bytes should be >= 0 got %d", m) - } - s.maxResponseBodyBytes = m - return nil - } -} - -// MemResponseBodyBytes sets the maximum request body to be stored in memory -// stream middleware will serialize the excess to disk. -func MemResponseBodyBytes(m int64) optSetter { - return func(s *Streamer) error { - if m < 0 { - return fmt.Errorf("mem bytes should be >= 0 got %d", m) - } - s.memResponseBodyBytes = m - return nil - } -} +type optSetter func(s *Stream) error // Wrap sets the next handler to be called by stream handler. -func (s *Streamer) Wrap(next http.Handler) error { +func (s *Stream) Wrap(next http.Handler) error { s.next = next return nil } -func (s *Streamer) ServeHTTP(w http.ResponseWriter, req *http.Request) { - if err := s.checkLimit(req); err != nil { - s.log.Infof("request body over limit: %v", err) - s.errHandler.ServeHTTP(w, req, err) - return +func (s *Stream) ServeHTTP(w http.ResponseWriter, req *http.Request) { + if log.GetLevel() >= log.DebugLevel { + logEntry := log.WithField("Request", utils.DumpHttpRequest(req)) + logEntry.Debug("vulcand/oxy/stream: begin ServeHttp on request") + defer logEntry.Debug("vulcand/oxy/stream: competed ServeHttp on request") } - // Read the body while keeping limits in mind. This reader controls the maximum bytes - // to read into memory and disk. This reader returns an error if the total request size exceeds the - // prefefined MaxSizeBytes. This can occur if we got chunked request, in this case ContentLength would be set to -1 - // and the reader would be unbounded bufio in the http.Server - body, err := multibuf.New(req.Body, multibuf.MaxBytes(s.maxRequestBodyBytes), multibuf.MemBytes(s.memRequestBodyBytes)) - if err != nil || body == nil { - s.errHandler.ServeHTTP(w, req, err) - return - } - - // Set request body to buffered reader that can replay the read and execute Seek - // Note that we don't change the original request body as it's handled by the http server - // and we don'w want to mess with standard library - defer body.Close() - - // We need to set ContentLength based on known request size. The incoming request may have been - // set without content length or using chunked TransferEncoding - totalSize, err := body.Size() - if err != nil { - s.log.Errorf("failed to get size, err %v", err) - s.errHandler.ServeHTTP(w, req, err) - return - } - - outreq := s.copyRequest(req, body, totalSize) - - attempt := 1 - for { - // We create a special writer that will limit the response size, buffer it to disk if necessary - writer, err := multibuf.NewWriterOnce(multibuf.MaxBytes(s.maxResponseBodyBytes), multibuf.MemBytes(s.memResponseBodyBytes)) - if err != nil { - s.errHandler.ServeHTTP(w, req, err) - return - } - - // We are mimicking http.ResponseWriter to replace writer with our special writer - b := &bufferWriter{ - header: make(http.Header), - buffer: writer, - } - defer b.Close() - - s.next.ServeHTTP(b, outreq) - - var reader multibuf.MultiReader - if b.expectBody(outreq) { - rdr, err := writer.Reader() - if err != nil { - s.log.Errorf("failed to read response, err %v", err) - s.errHandler.ServeHTTP(w, req, err) - return - } - defer rdr.Close() - reader = rdr - } - - if (s.retryPredicate == nil || attempt > DefaultMaxRetryAttempts) || - !s.retryPredicate(&context{r: req, attempt: attempt, responseCode: b.code, log: s.log}) { - utils.CopyHeaders(w.Header(), b.Header()) - w.WriteHeader(b.code) - if reader != nil { - io.Copy(w, reader) - } - return - } - - attempt += 1 - if _, err := body.Seek(0, 0); err != nil { - s.log.Errorf("Failed to rewind: error: %v", err) - s.errHandler.ServeHTTP(w, req, err) - return - } - outreq = s.copyRequest(req, body, totalSize) - s.log.Infof("retry Request(%v %v) attempt %v", req.Method, req.URL, attempt) - } -} - -func (s *Streamer) copyRequest(req *http.Request, body io.ReadCloser, bodySize int64) *http.Request { - o := *req - o.URL = utils.CopyURL(req.URL) - o.Header = make(http.Header) - utils.CopyHeaders(o.Header, req.Header) - o.ContentLength = bodySize - // remove TransferEncoding that could have been previously set because we have transformed the request from chunked encoding - o.TransferEncoding = []string{} - // http.Transport will close the request body on any error, we are controlling the close process ourselves, so we override the closer here - o.Body = ioutil.NopCloser(body) - return &o -} - -func (s *Streamer) checkLimit(req *http.Request) error { - if s.maxRequestBodyBytes <= 0 { - return nil - } - if req.ContentLength > s.maxRequestBodyBytes { - return &multibuf.MaxSizeReachedError{MaxSize: s.maxRequestBodyBytes} - } - return nil -} - -type bufferWriter struct { - header http.Header - code int - buffer multibuf.WriterOnce -} - -// RFC2616 #4.4 -func (b *bufferWriter) expectBody(r *http.Request) bool { - if r.Method == "HEAD" { - return false - } - if (b.code >= 100 && b.code < 200) || b.code == 204 || b.code == 304 { - return false - } - if b.header.Get("Content-Length") == "" && b.header.Get("Transfer-Encoding") == "" { - return false - } - if b.header.Get("Content-Length") == "0" { - return false - } - return true -} - -func (b *bufferWriter) Close() error { - return b.buffer.Close() -} - -func (b *bufferWriter) Header() http.Header { - return b.header -} - -func (b *bufferWriter) Write(buf []byte) (int, error) { - return b.buffer.Write(buf) -} - -// WriteHeader sets rw.Code. -func (b *bufferWriter) WriteHeader(code int) { - b.code = code -} - -type SizeErrHandler struct { -} - -func (e *SizeErrHandler) ServeHTTP(w http.ResponseWriter, req *http.Request, err error) { - if _, ok := err.(*multibuf.MaxSizeReachedError); ok { - w.WriteHeader(http.StatusRequestEntityTooLarge) - w.Write([]byte(http.StatusText(http.StatusRequestEntityTooLarge))) - return - } - utils.DefaultHandler.ServeHTTP(w, req, err) + s.next.ServeHTTP(w, req) } diff --git a/vendor/github.com/vulcand/oxy/stream/threshold.go b/vendor/github.com/vulcand/oxy/stream/threshold.go index ce1408e27..08d725415 100644 --- a/vendor/github.com/vulcand/oxy/stream/threshold.go +++ b/vendor/github.com/vulcand/oxy/stream/threshold.go @@ -4,7 +4,6 @@ import ( "fmt" "net/http" - "github.com/vulcand/oxy/utils" "github.com/vulcand/predicate" ) @@ -17,7 +16,6 @@ type context struct { r *http.Request attempt int responseCode int - log utils.Logger } type hpredicate func(*context) bool diff --git a/vendor/github.com/vulcand/oxy/utils/auth.go b/vendor/github.com/vulcand/oxy/utils/auth.go index 9017d8ec9..b80b91685 100644 --- a/vendor/github.com/vulcand/oxy/utils/auth.go +++ b/vendor/github.com/vulcand/oxy/utils/auth.go @@ -22,18 +22,18 @@ func ParseAuthHeader(header string) (*BasicAuth, error) { return nil, fmt.Errorf(fmt.Sprintf("Failed to parse header '%s'", header)) } - auth_type := strings.ToLower(values[0]) - if auth_type != "basic" { - return nil, fmt.Errorf("Expected basic auth type, got '%s'", auth_type) + authType := strings.ToLower(values[0]) + if authType != "basic" { + return nil, fmt.Errorf("Expected basic auth type, got '%s'", authType) } - encoded_string := values[1] - decoded_string, err := base64.StdEncoding.DecodeString(encoded_string) + encodedString := values[1] + decodedString, err := base64.StdEncoding.DecodeString(encodedString) if err != nil { return nil, fmt.Errorf("Failed to parse header '%s', base64 failed: %s", header, err) } - values = strings.SplitN(string(decoded_string), ":", 2) + values = strings.SplitN(string(decodedString), ":", 2) if len(values) != 2 { return nil, fmt.Errorf("Failed to parse header '%s', expected separator ':'", header) } diff --git a/vendor/github.com/vulcand/oxy/utils/dumpreq.go b/vendor/github.com/vulcand/oxy/utils/dumpreq.go new file mode 100644 index 000000000..ef34d38f6 --- /dev/null +++ b/vendor/github.com/vulcand/oxy/utils/dumpreq.go @@ -0,0 +1,60 @@ +package utils + +import ( + "crypto/tls" + "encoding/json" + "fmt" + "mime/multipart" + "net/http" + "net/url" +) + +type SerializableHttpRequest struct { + Method string + URL *url.URL + Proto string // "HTTP/1.0" + ProtoMajor int // 1 + ProtoMinor int // 0 + Header http.Header + ContentLength int64 + TransferEncoding []string + Host string + Form url.Values + PostForm url.Values + MultipartForm *multipart.Form + Trailer http.Header + RemoteAddr string + RequestURI string + TLS *tls.ConnectionState +} + +func Clone(r *http.Request) *SerializableHttpRequest { + if r == nil { + return nil + } + + rc := new(SerializableHttpRequest) + rc.Method = r.Method + rc.URL = r.URL + rc.Proto = r.Proto + rc.ProtoMajor = r.ProtoMajor + rc.ProtoMinor = r.ProtoMinor + rc.Header = r.Header + rc.ContentLength = r.ContentLength + rc.Host = r.Host + rc.RemoteAddr = r.RemoteAddr + rc.RequestURI = r.RequestURI + return rc +} + +func (s *SerializableHttpRequest) ToJson() string { + if jsonVal, err := json.Marshal(s); err != nil || jsonVal == nil { + return fmt.Sprintf("Error marshalling SerializableHttpRequest to json: %s", err.Error()) + } else { + return string(jsonVal) + } +} + +func DumpHttpRequest(req *http.Request) string { + return fmt.Sprintf("%v", Clone(req).ToJson()) +} diff --git a/vendor/github.com/vulcand/oxy/utils/logging.go b/vendor/github.com/vulcand/oxy/utils/logging.go deleted file mode 100644 index 7b036b7cd..000000000 --- a/vendor/github.com/vulcand/oxy/utils/logging.go +++ /dev/null @@ -1,86 +0,0 @@ -package utils - -import ( - "io" - "log" -) - -var NullLogger Logger = &NOPLogger{} - -// Logger defines a simple logging interface -type Logger interface { - Infof(format string, args ...interface{}) - Warningf(format string, args ...interface{}) - Errorf(format string, args ...interface{}) -} - -type FileLogger struct { - info *log.Logger - warn *log.Logger - error *log.Logger -} - -func NewFileLogger(w io.Writer, lvl LogLevel) *FileLogger { - l := &FileLogger{} - flag := log.Ldate | log.Ltime | log.Lmicroseconds - if lvl <= INFO { - l.info = log.New(w, "INFO: ", flag) - } - if lvl <= WARN { - l.warn = log.New(w, "WARN: ", flag) - } - if lvl <= ERROR { - l.error = log.New(w, "ERR: ", flag) - } - return l -} - -func (f *FileLogger) Infof(format string, args ...interface{}) { - if f.info == nil { - return - } - f.info.Printf(format, args...) -} - -func (f *FileLogger) Warningf(format string, args ...interface{}) { - if f.warn == nil { - return - } - f.warn.Printf(format, args...) -} - -func (f *FileLogger) Errorf(format string, args ...interface{}) { - if f.error == nil { - return - } - f.error.Printf(format, args...) -} - -type NOPLogger struct { -} - -func (*NOPLogger) Infof(format string, args ...interface{}) { - -} -func (*NOPLogger) Warningf(format string, args ...interface{}) { -} - -func (*NOPLogger) Errorf(format string, args ...interface{}) { -} - -func (*NOPLogger) Info(string) { - -} -func (*NOPLogger) Warning(string) { -} - -func (*NOPLogger) Error(string) { -} - -type LogLevel int - -const ( - INFO = iota - WARN - ERROR -) diff --git a/vendor/github.com/vulcand/oxy/utils/netutils.go b/vendor/github.com/vulcand/oxy/utils/netutils.go index 236ffdd34..9c247e4cc 100644 --- a/vendor/github.com/vulcand/oxy/utils/netutils.go +++ b/vendor/github.com/vulcand/oxy/utils/netutils.go @@ -2,19 +2,23 @@ package utils import ( "bufio" + "fmt" "io" - "mime" "net" "net/http" "net/url" + "reflect" + + log "github.com/Sirupsen/logrus" ) // ProxyWriter helps to capture response headers and status code // from the ServeHTTP. It can be safely passed to ServeHTTP handler, // wrapping the real response writer. type ProxyWriter struct { - W http.ResponseWriter - Code int + W http.ResponseWriter + Code int + Length int64 } func (p *ProxyWriter) StatusCode() int { @@ -31,6 +35,7 @@ func (p *ProxyWriter) Header() http.Header { } func (p *ProxyWriter) Write(buf []byte) (int, error) { + p.Length = p.Length + int64(len(buf)) return p.W.Write(buf) } @@ -45,8 +50,20 @@ func (p *ProxyWriter) Flush() { } } +func (p *ProxyWriter) CloseNotify() <-chan bool { + if cn, ok := p.W.(http.CloseNotifier); ok { + return cn.CloseNotify() + } + log.Warningf("Upstream ResponseWriter of type %v does not implement http.CloseNotifier. Returning dummy channel.", reflect.TypeOf(p.W)) + return make(<-chan bool) +} + func (p *ProxyWriter) Hijack() (net.Conn, *bufio.ReadWriter, error) { - return p.W.(http.Hijacker).Hijack() + if hi, ok := p.W.(http.Hijacker); ok { + return hi.Hijack() + } + log.Warningf("Upstream ResponseWriter of type %v does not implement http.Hijacker. Returning dummy channel.", reflect.TypeOf(p.W)) + return nil, nil, fmt.Errorf("The response writer that was wrapped in this proxy, does not implement http.Hijacker. It is of type: %v", reflect.TypeOf(p.W)) } func NewBufferWriter(w io.WriteCloser) *BufferWriter { @@ -79,8 +96,20 @@ func (b *BufferWriter) WriteHeader(code int) { b.Code = code } +func (b *BufferWriter) CloseNotify() <-chan bool { + if cn, ok := b.W.(http.CloseNotifier); ok { + return cn.CloseNotify() + } + log.Warningf("Upstream ResponseWriter of type %v does not implement http.CloseNotifier. Returning dummy channel.", reflect.TypeOf(b.W)) + return make(<-chan bool) +} + func (b *BufferWriter) Hijack() (net.Conn, *bufio.ReadWriter, error) { - return b.W.(http.Hijacker).Hijack() + if hi, ok := b.W.(http.Hijacker); ok { + return hi.Hijack() + } + log.Warningf("Upstream ResponseWriter of type %v does not implement http.Hijacker. Returning dummy channel.", reflect.TypeOf(b.W)) + return nil, nil, fmt.Errorf("The response writer that was wrapped in this proxy, does not implement http.Hijacker. It is of type: %v", reflect.TypeOf(b.W)) } type nopWriteCloser struct { @@ -106,11 +135,9 @@ func CopyURL(i *url.URL) *url.URL { // CopyHeaders copies http headers from source to destination, it // does not overide, but adds multiple headers -func CopyHeaders(dst, src http.Header) { +func CopyHeaders(dst http.Header, src http.Header) { for k, vv := range src { - for _, v := range vv { - dst.Add(k, v) - } + dst[k] = append(dst[k], vv...) } } @@ -130,9 +157,3 @@ func RemoveHeaders(headers http.Header, names ...string) { headers.Del(h) } } - -// Parse the MIME media type value of a header. -func GetHeaderMediaType(headers http.Header, name string) (string, error) { - mediatype, _, err := mime.ParseMediaType(headers.Get(name)) - return mediatype, err -}