Resync oxy with original repository

This commit is contained in:
SALLEYRON Julien 2017-11-22 18:20:03 +01:00 committed by Traefiker
parent da5e4a13bf
commit bee8ebb00b
31 changed files with 650 additions and 808 deletions

4
glide.lock generated
View file

@ -1,4 +1,4 @@
hash: 2a604d8b74e8659df7db72d063432fa0822fdee6c81bc6657efa3c1bf0d9cd8a
hash: fec4fec4363272870c49e10cea64cc51095ecd0987b9c020c9714d950cf38784
updated: 2017-11-17T14:21:55.148450413+01:00
imports:
- name: cloud.google.com/go
@ -516,7 +516,7 @@ imports:
- name: github.com/VividCortex/gohistogram
version: 51564d9861991fb0ad0f531c99ef602d0f9866e6
- name: github.com/vulcand/oxy
version: 7e9763c4dc71b9758379da3581e6495c145caaab
version: 7b6e758ab449705195df638765c4ca472248908a
repo: https://github.com/containous/oxy.git
vcs: git
subpackages:

View file

@ -12,7 +12,7 @@ import:
- package: github.com/cenk/backoff
- package: github.com/containous/flaeg
- package: github.com/vulcand/oxy
version: 7e9763c4dc71b9758379da3581e6495c145caaab
version: 7b6e758ab449705195df638765c4ca472248908a
repo: https://github.com/containous/oxy.git
vcs: git
subpackages:

View file

@ -224,10 +224,12 @@ func (s *GRPCSuite) TestGRPCBuffer(c *check.C) {
var client helloworld.Greeter_StreamExampleClient
client, closer, err := callStreamExampleClientGRPC()
defer closer()
c.Assert(err, check.IsNil)
received := make(chan bool)
go func() {
tr, _ := client.Recv()
tr, err := client.Recv()
c.Assert(err, check.IsNil)
c.Assert(len(tr.Data), check.Equals, 512)
received <- true
}()

View file

@ -395,7 +395,7 @@ func (s *WebsocketSuite) TestURLWithURLEncodedChar(c *check.C) {
var upgrader = gorillawebsocket.Upgrader{} // use default options
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
c.Assert(r.URL.Path, check.Equals, "/ws/http%3A%2F%2Ftest")
c.Assert(r.URL.EscapedPath(), check.Equals, "/ws/http%3A%2F%2Ftest")
conn, err := upgrader.Upgrade(w, r, nil)
if err != nil {
return

View file

@ -2,29 +2,8 @@ package server
import (
"net/http"
"github.com/containous/traefik/log"
)
// OxyLogger implements oxy Logger interface with logrus.
type OxyLogger struct {
}
// Infof logs specified string as Debug level in logrus.
func (oxylogger *OxyLogger) Infof(format string, args ...interface{}) {
log.Debugf(format, args...)
}
// Warningf logs specified string as Warning level in logrus.
func (oxylogger *OxyLogger) Warningf(format string, args ...interface{}) {
log.Warningf(format, args...)
}
// Errorf logs specified string as Warningf level in logrus.
func (oxylogger *OxyLogger) Errorf(format string, args ...interface{}) {
log.Warningf(format, args...)
}
func notFoundHandler(w http.ResponseWriter, r *http.Request) {
http.NotFound(w, r)
}

View file

@ -39,7 +39,6 @@ import (
"github.com/eapache/channels"
thoas_stats "github.com/thoas/stats"
"github.com/urfave/negroni"
"github.com/vulcand/oxy/cbreaker"
"github.com/vulcand/oxy/connlimit"
"github.com/vulcand/oxy/forward"
"github.com/vulcand/oxy/ratelimit"
@ -48,10 +47,6 @@ import (
"golang.org/x/net/http2"
)
var (
oxyLogger = &OxyLogger{}
)
// Server is the reverse-proxy/load-balancer engine
type Server struct {
serverEntryPoints serverEntryPoints
@ -959,7 +954,7 @@ func (server *Server) loadConfig(configurations types.Configurations, globalConf
}
fwd, err := forward.New(
forward.Logger(oxyLogger),
forward.Stream(true),
forward.PassHostHeader(frontend.PassHostHeader),
forward.RoundTripper(roundTripper),
forward.ErrorHandler(errorHandler),
@ -1006,10 +1001,10 @@ func (server *Server) loadConfig(configurations types.Configurations, globalConf
switch lbMethod {
case types.Drr:
log.Debugf("Creating load-balancer drr")
rebalancer, _ := roundrobin.NewRebalancer(rr, roundrobin.RebalancerLogger(oxyLogger))
rebalancer, _ := roundrobin.NewRebalancer(rr)
if sticky != nil {
log.Debugf("Sticky session with cookie %v", cookieName)
rebalancer, _ = roundrobin.NewRebalancer(rr, roundrobin.RebalancerLogger(oxyLogger), roundrobin.RebalancerStickySession(sticky))
rebalancer, _ = roundrobin.NewRebalancer(rr, roundrobin.RebalancerStickySession(sticky))
}
lb = rebalancer
if err := configureLBServers(rebalancer, config, frontend); err != nil {
@ -1080,7 +1075,7 @@ func (server *Server) loadConfig(configurations types.Configurations, globalConf
continue frontend
}
log.Debugf("Creating load-balancer connlimit")
lb, err = connlimit.New(lb, extractFunc, maxConns.Amount, connlimit.Logger(oxyLogger))
lb, err = connlimit.New(lb, extractFunc, maxConns.Amount)
if err != nil {
log.Errorf("Error creating connlimit: %v", err)
log.Errorf("Skipping frontend %s...", frontendName)
@ -1151,7 +1146,7 @@ func (server *Server) loadConfig(configurations types.Configurations, globalConf
if config.Backends[frontend.Backend].CircuitBreaker != nil {
log.Debugf("Creating circuit breaker %s", config.Backends[frontend.Backend].CircuitBreaker.Expression)
circuitBreaker, err := middlewares.NewCircuitBreaker(lb, config.Backends[frontend.Backend].CircuitBreaker.Expression, cbreaker.Logger(oxyLogger))
circuitBreaker, err := middlewares.NewCircuitBreaker(lb, config.Backends[frontend.Backend].CircuitBreaker.Expression)
if err != nil {
log.Errorf("Error creating circuit breaker: %v", err)
log.Errorf("Skipping frontend %s...", frontendName)
@ -1445,7 +1440,7 @@ func (server *Server) buildRateLimiter(handler http.Handler, rlConfig *types.Rat
return nil, err
}
}
return ratelimit.New(handler, extractFunc, rateSet, ratelimit.Logger(oxyLogger))
return ratelimit.New(handler, extractFunc, rateSet)
}
func (server *Server) buildRetryMiddleware(handler http.Handler, globalConfig configuration.GlobalConfiguration, countServers int, backendName string) http.Handler {

View file

@ -1,4 +1,4 @@
// package cbreaker implements circuit breaker similar to https://github.com/Netflix/Hystrix/wiki/How-it-Works
// Package cbreaker implements circuit breaker similar to https://github.com/Netflix/Hystrix/wiki/How-it-Works
//
// Vulcan circuit breaker watches the error condtion to match
// after which it activates the fallback scenario, e.g. returns the response code
@ -31,6 +31,8 @@ import (
"sync"
"time"
log "github.com/Sirupsen/logrus"
"github.com/mailgun/timetools"
"github.com/vulcand/oxy/memmetrics"
"github.com/vulcand/oxy/utils"
@ -60,7 +62,6 @@ type CircuitBreaker struct {
fallback http.Handler
next http.Handler
log utils.Logger
clock timetools.TimeProvider
}
@ -75,7 +76,6 @@ func New(next http.Handler, expression string, options ...CircuitBreakerOption)
fallbackDuration: defaultFallbackDuration,
recoveryDuration: defaultRecoveryDuration,
fallback: defaultFallback,
log: utils.NullLogger,
}
for _, s := range options {
@ -100,6 +100,11 @@ func New(next http.Handler, expression string, options ...CircuitBreakerOption)
}
func (c *CircuitBreaker) ServeHTTP(w http.ResponseWriter, req *http.Request) {
if log.GetLevel() >= log.DebugLevel {
logEntry := log.WithField("Request", utils.DumpHttpRequest(req))
logEntry.Debug("vulcand/oxy/circuitbreaker: begin ServeHttp on request")
defer logEntry.Debug("vulcand/oxy/circuitbreaker: competed ServeHttp on request")
}
if c.activateFallback(w, req) {
c.fallback.ServeHTTP(w, req)
return
@ -121,7 +126,7 @@ func (c *CircuitBreaker) activateFallback(w http.ResponseWriter, req *http.Reque
c.m.Lock()
defer c.m.Unlock()
c.log.Infof("%v is in error state", c)
log.Infof("%v is in error state", c)
switch c.state {
case stateStandby:
@ -186,13 +191,13 @@ func (c *CircuitBreaker) exec(s SideEffect) {
}
go func() {
if err := s.Exec(); err != nil {
c.log.Errorf("%v side effect failure: %v", c, err)
log.Errorf("%v side effect failure: %v", c, err)
}
}()
}
func (c *CircuitBreaker) setState(new cbState, until time.Time) {
c.log.Infof("%v setting state to %v, until %v", c, new, until)
log.Infof("%v setting state to %v, until %v", c, new, until)
c.state = new
c.until = until
switch new {
@ -225,7 +230,7 @@ func (c *CircuitBreaker) checkAndSet() {
c.lastCheck = c.clock.UtcNow().Add(c.checkPeriod)
if c.state == stateTripped {
c.log.Infof("%v skip set tripped", c)
log.Infof("%v skip set tripped", c)
return
}
@ -309,14 +314,6 @@ func Fallback(h http.Handler) CircuitBreakerOption {
}
}
// Logger adds logging for the CircuitBreaker.
func Logger(l utils.Logger) CircuitBreakerOption {
return func(c *CircuitBreaker) error {
c.log = l
return nil
}
}
// cbState is the state of the circuit breaker
type cbState int

View file

@ -9,6 +9,7 @@ import (
"net/url"
"strings"
log "github.com/Sirupsen/logrus"
"github.com/vulcand/oxy/utils"
)
@ -68,9 +69,10 @@ func (w *WebhookSideEffect) Exec() error {
if re.Body != nil {
defer re.Body.Close()
}
_, err = ioutil.ReadAll(re.Body)
body, err := ioutil.ReadAll(re.Body)
if err != nil {
return err
}
log.Infof("%v got response: (%s): %s", w, re.Status, string(body))
return nil
}

View file

@ -5,6 +5,9 @@ import (
"net/http"
"net/url"
"strconv"
log "github.com/Sirupsen/logrus"
"github.com/vulcand/oxy/utils"
)
type Response struct {
@ -25,20 +28,31 @@ func NewResponseFallback(r Response) (*ResponseFallback, error) {
}
func (f *ResponseFallback) ServeHTTP(w http.ResponseWriter, req *http.Request) {
if log.GetLevel() >= log.DebugLevel {
logEntry := log.WithField("Request", utils.DumpHttpRequest(req))
logEntry.Debug("vulcand/oxy/fallback/response: begin ServeHttp on request")
defer logEntry.Debug("vulcand/oxy/fallback/response: competed ServeHttp on request")
}
if f.r.ContentType != "" {
w.Header().Set("Content-Type", f.r.ContentType)
}
w.Header().Set("Content-Length", strconv.Itoa(len(f.r.Body)))
w.WriteHeader(f.r.StatusCode)
w.Write(f.r.Body)
_, err := w.Write(f.r.Body)
if err != nil {
log.Errorf("vulcand/oxy/fallback/response: failed to write response, err: %v", err)
}
}
type Redirect struct {
URL string
URL string
PreservePath bool
}
type RedirectFallback struct {
u *url.URL
r Redirect
}
func NewRedirectFallback(r Redirect) (*RedirectFallback, error) {
@ -46,11 +60,25 @@ func NewRedirectFallback(r Redirect) (*RedirectFallback, error) {
if err != nil {
return nil, err
}
return &RedirectFallback{u: u}, nil
return &RedirectFallback{u: u, r: r}, nil
}
func (f *RedirectFallback) ServeHTTP(w http.ResponseWriter, req *http.Request) {
w.Header().Set("Location", f.u.String())
if log.GetLevel() >= log.DebugLevel {
logEntry := log.WithField("Request", utils.DumpHttpRequest(req))
logEntry.Debug("vulcand/oxy/fallback/redirect: begin ServeHttp on request")
defer logEntry.Debug("vulcand/oxy/fallback/redirect: competed ServeHttp on request")
}
location := f.u.String()
if f.r.PreservePath {
location += req.URL.Path
}
w.Header().Set("Location", location)
w.WriteHeader(http.StatusFound)
w.Write([]byte(http.StatusText(http.StatusFound)))
_, err := w.Write([]byte(http.StatusText(http.StatusFound)))
if err != nil {
log.Errorf("vulcand/oxy/fallback/redirect: failed to write response, err: %v", err)
}
}

View file

@ -4,6 +4,7 @@ import (
"fmt"
"time"
log "github.com/Sirupsen/logrus"
"github.com/vulcand/predicate"
)
@ -49,7 +50,7 @@ func latencyAtQuantile(quantile float64) toInt {
return func(c *CircuitBreaker) int {
h, err := c.metrics.LatencyHistogram()
if err != nil {
c.log.Errorf("Failed to get latency histogram, for %v error: %v", c, err)
log.Errorf("Failed to get latency histogram, for %v error: %v", c, err)
return 0
}
return int(h.LatencyAtQuantile(quantile) / time.Millisecond)

View file

@ -4,6 +4,7 @@ import (
"fmt"
"time"
log "github.com/Sirupsen/logrus"
"github.com/mailgun/timetools"
)
@ -33,14 +34,17 @@ func (r *ratioController) String() string {
}
func (r *ratioController) allowRequest() bool {
log.Infof("%v", r)
t := r.targetRatio()
// This condition answers the question - would we satisfy the target ratio if we allow this request?
e := r.computeRatio(r.allowed+1, r.denied)
if e < t {
r.allowed++
log.Infof("%v allowed", r)
return true
}
r.denied++
log.Infof("%v denied", r)
return false
}

View file

@ -6,6 +6,7 @@ import (
"net/http"
"sync"
log "github.com/Sirupsen/logrus"
"github.com/vulcand/oxy/utils"
)
@ -20,7 +21,6 @@ type ConnLimiter struct {
next http.Handler
errHandler utils.ErrorHandler
log utils.Logger
}
func New(next http.Handler, extract utils.SourceExtractor, maxConnections int64, options ...ConnLimitOption) (*ConnLimiter, error) {
@ -40,9 +40,6 @@ func New(next http.Handler, extract utils.SourceExtractor, maxConnections int64,
return nil, err
}
}
if cl.log == nil {
cl.log = utils.NullLogger
}
if cl.errHandler == nil {
cl.errHandler = defaultErrHandler
}
@ -56,12 +53,12 @@ func (cl *ConnLimiter) Wrap(h http.Handler) {
func (cl *ConnLimiter) ServeHTTP(w http.ResponseWriter, r *http.Request) {
token, amount, err := cl.extract.Extract(r)
if err != nil {
cl.log.Errorf("failed to extract source of the connection: %v", err)
log.Errorf("failed to extract source of the connection: %v", err)
cl.errHandler.ServeHTTP(w, r, err)
return
}
if err := cl.acquire(token, amount); err != nil {
cl.log.Infof("limiting request source %s: %v", token, err)
log.Infof("limiting request source %s: %v", token, err)
cl.errHandler.ServeHTTP(w, r, err)
return
}
@ -81,7 +78,7 @@ func (cl *ConnLimiter) acquire(token string, amount int64) error {
}
cl.connections[token] += amount
cl.totalConnections += int64(amount)
cl.totalConnections += amount
return nil
}
@ -90,7 +87,7 @@ func (cl *ConnLimiter) release(token string, amount int64) {
defer cl.mutex.Unlock()
cl.connections[token] -= amount
cl.totalConnections -= int64(amount)
cl.totalConnections -= amount
// Otherwise it would grow forever
if cl.connections[token] == 0 {
@ -110,6 +107,12 @@ type ConnErrHandler struct {
}
func (e *ConnErrHandler) ServeHTTP(w http.ResponseWriter, req *http.Request, err error) {
if log.GetLevel() >= log.DebugLevel {
logEntry := log.WithField("Request", utils.DumpHttpRequest(req))
logEntry.Debug("vulcand/oxy/connlimit: begin ServeHttp on request")
defer logEntry.Debug("vulcand/oxy/connlimit: competed ServeHttp on request")
}
if _, ok := err.(*MaxConnError); ok {
w.WriteHeader(429)
w.Write([]byte(err.Error()))
@ -120,14 +123,6 @@ func (e *ConnErrHandler) ServeHTTP(w http.ResponseWriter, req *http.Request, err
type ConnLimitOption func(l *ConnLimiter) error
// Logger sets the logger that will be used by this middleware.
func Logger(l utils.Logger) ConnLimitOption {
return func(cl *ConnLimiter) error {
cl.log = l
return nil
}
}
// ErrorHandler sets error handler of the server
func ErrorHandler(h utils.ErrorHandler) ConnLimitOption {
return func(cl *ConnLimiter) error {

View file

@ -7,13 +7,15 @@ import (
"crypto/tls"
"io"
"net/http"
"net/http/httptest"
"net/http/httputil"
"net/url"
"os"
"reflect"
"strconv"
"strings"
"time"
log "github.com/Sirupsen/logrus"
"github.com/gorilla/websocket"
"github.com/vulcand/oxy/utils"
)
@ -29,15 +31,7 @@ type optSetter func(f *Forwarder) error
// be delegated
func PassHostHeader(b bool) optSetter {
return func(f *Forwarder) error {
f.passHost = b
return nil
}
}
// StreamResponse forces streaming body (flushes response directly to client)
func StreamResponse(b bool) optSetter {
return func(f *Forwarder) error {
f.httpForwarder.streamResponse = b
f.httpForwarder.passHost = b
return nil
}
}
@ -46,7 +40,7 @@ func StreamResponse(b bool) optSetter {
// Forwarder will use http.DefaultTransport as a default round tripper
func RoundTripper(r http.RoundTripper) optSetter {
return func(f *Forwarder) error {
f.roundTripper = r
f.httpForwarder.roundTripper = r
return nil
}
}
@ -59,10 +53,11 @@ func Rewriter(r ReqRewriter) optSetter {
}
}
// WebsocketRewriter defines a request rewriter for the websocket forwarder
func WebsocketRewriter(r ReqRewriter) optSetter {
// PassHostHeader specifies if a client's Host header field should
// be delegated
func WebsocketTLSClientConfig(tcc *tls.Config) optSetter {
return func(f *Forwarder) error {
f.websocketForwarder.rewriter = r
f.httpForwarder.tlsClientConfig = tcc
return nil
}
}
@ -75,27 +70,74 @@ func ErrorHandler(h utils.ErrorHandler) optSetter {
}
}
// Logger specifies the logger to use.
// Forwarder will default to oxyutils.NullLogger if no logger has been specified
func Logger(l utils.Logger) optSetter {
// Stream specifies if HTTP responses should be streamed.
func Stream(stream bool) optSetter {
return func(f *Forwarder) error {
f.stream = stream
return nil
}
}
// Logger defines the logger the forwarder will use.
//
// It defaults to logrus.StandardLogger(), the global logger used by logrus.
func Logger(l *log.Logger) optSetter {
return func(f *Forwarder) error {
f.log = l
return nil
}
}
func StateListener(stateListener UrlForwardingStateListener) optSetter {
return func(f *Forwarder) error {
f.stateListener = stateListener
return nil
}
}
func ResponseModifier(responseModifier func(*http.Response) error) optSetter {
return func(f *Forwarder) error {
f.httpForwarder.modifyResponse = responseModifier
return nil
}
}
func StreamingFlushInterval(flushInterval time.Duration) optSetter {
return func(f *Forwarder) error {
f.httpForwarder.flushInterval = flushInterval
return nil
}
}
type ErrorHandlingRoundTripper struct {
http.RoundTripper
errorHandler utils.ErrorHandler
}
func (rt ErrorHandlingRoundTripper) RoundTrip(req *http.Request) (*http.Response, error) {
res, err := rt.RoundTripper.RoundTrip(req)
if err != nil {
// We use the recorder from httptest because there isn't another `public` implementation of a recorder.
recorder := httptest.NewRecorder()
rt.errorHandler.ServeHTTP(recorder, req, err)
res = recorder.Result()
err = nil
}
return res, err
}
// Forwarder wraps two traffic forwarding implementations: HTTP and websockets.
// It decides based on the specified request which implementation to use
type Forwarder struct {
*httpForwarder
*websocketForwarder
*handlerContext
stateListener UrlForwardingStateListener
stream bool
}
// handlerContext defines a handler context for error reporting and logging
type handlerContext struct {
errHandler utils.ErrorHandler
log utils.Logger
}
// httpForwarder is a handler that can reverse proxy
@ -104,32 +146,40 @@ type httpForwarder struct {
roundTripper http.RoundTripper
rewriter ReqRewriter
passHost bool
streamResponse bool
flushInterval time.Duration
modifyResponse func(*http.Response) error
tlsClientConfig *tls.Config
log *log.Logger
}
// websocketForwarder is a handler that can reverse proxy
// websocket traffic
type websocketForwarder struct {
rewriter ReqRewriter
TLSClientConfig *tls.Config
}
const (
defaultFlushInterval = time.Duration(100) * time.Millisecond
StateConnected = iota
StateDisconnected
)
type UrlForwardingStateListener func(*url.URL, int)
// New creates an instance of Forwarder based on the provided list of configuration options
func New(setters ...optSetter) (*Forwarder, error) {
f := &Forwarder{
httpForwarder: &httpForwarder{},
websocketForwarder: &websocketForwarder{},
handlerContext: &handlerContext{},
httpForwarder: &httpForwarder{log: log.StandardLogger()},
handlerContext: &handlerContext{},
}
for _, s := range setters {
if err := s(f); err != nil {
return nil, err
}
}
if f.httpForwarder.roundTripper == nil {
f.httpForwarder.roundTripper = http.DefaultTransport
if !f.stream {
f.flushInterval = 0
} else if f.flushInterval == 0 {
f.flushInterval = defaultFlushInterval
}
f.websocketForwarder.TLSClientConfig = f.httpForwarder.roundTripper.(*http.Transport).TLSClientConfig
if f.httpForwarder.rewriter == nil {
h, err := os.Hostname()
if err != nil {
@ -137,136 +187,104 @@ func New(setters ...optSetter) (*Forwarder, error) {
}
f.httpForwarder.rewriter = &HeaderRewriter{TrustForwardHeader: true, Hostname: h}
}
if f.log == nil {
f.log = utils.NullLogger
if f.httpForwarder.roundTripper == nil {
f.httpForwarder.roundTripper = http.DefaultTransport
}
if f.errHandler == nil {
f.errHandler = utils.DefaultHandler
}
if f.tlsClientConfig == nil {
f.tlsClientConfig = f.httpForwarder.roundTripper.(*http.Transport).TLSClientConfig
}
f.httpForwarder.roundTripper = ErrorHandlingRoundTripper{
RoundTripper: f.httpForwarder.roundTripper,
errorHandler: f.errHandler,
}
return f, nil
}
// ServeHTTP decides which forwarder to use based on the specified
// request and delegates to the proper implementation
func (f *Forwarder) ServeHTTP(w http.ResponseWriter, req *http.Request) {
if isWebsocketRequest(req) {
f.websocketForwarder.serveHTTP(w, req, f.handlerContext)
if f.log.Level >= log.DebugLevel {
logEntry := f.log.WithField("Request", utils.DumpHttpRequest(req))
logEntry.Debug("vulcand/oxy/forward: begin ServeHttp on request")
defer logEntry.Debug("vulcand/oxy/forward: competed ServeHttp on request")
}
if f.stateListener != nil {
f.stateListener(req.URL, StateConnected)
defer f.stateListener(req.URL, StateDisconnected)
}
if IsWebsocketRequest(req) {
f.httpForwarder.serveWebSocket(w, req, f.handlerContext)
} else {
f.httpForwarder.serveHTTP(w, req, f.handlerContext)
}
}
// serveHTTP forwards HTTP traffic using the configured transport
func (f *httpForwarder) serveHTTP(w http.ResponseWriter, req *http.Request, ctx *handlerContext) {
start := time.Now().UTC()
response, err := f.roundTripper.RoundTrip(f.copyRequest(req, req.URL))
if err != nil {
ctx.log.Errorf("Error forwarding to %v, err: %v", req.URL, err)
ctx.errHandler.ServeHTTP(w, req, err)
return
}
utils.CopyHeaders(w.Header(), response.Header)
// Remove hop-by-hop headers.
utils.RemoveHeaders(w.Header(), HopHeaders...)
announcedTrailerKeyCount := len(response.Trailer)
if announcedTrailerKeyCount > 0 {
trailerKeys := make([]string, 0, announcedTrailerKeyCount)
for k := range response.Trailer {
trailerKeys = append(trailerKeys, k)
}
w.Header().Add("Trailer", strings.Join(trailerKeys, ", "))
}
w.WriteHeader(response.StatusCode)
stream := f.streamResponse
if !stream {
contentType, err := utils.GetHeaderMediaType(response.Header, ContentType)
func (f *httpForwarder) getUrlFromRequest(req *http.Request) *url.URL {
// If the Request was created by Go via a real HTTP request, RequestURI will
// contain the original query string. If the Request was created in code, RequestURI
// will be empty, and we will use the URL object instead
u := req.URL
if req.RequestURI != "" {
parsedURL, err := url.ParseRequestURI(req.RequestURI)
if err == nil {
stream = contentType == "text/event-stream"
u = parsedURL
} else {
f.log.Warnf("vulcand/oxy/forward: error when parsing RequestURI: %s", err)
}
}
flush := stream || req.ProtoMajor == 2
written, err := io.Copy(newResponseFlusher(w, flush), response.Body)
if err != nil {
ctx.log.Errorf("Error copying upstream response body: %v", err)
ctx.errHandler.ServeHTTP(w, req, err)
return
}
defer response.Body.Close()
forceSetTrailers := len(response.Trailer) != announcedTrailerKeyCount
shallowCopyTrailers(w.Header(), response.Trailer, forceSetTrailers)
if written != 0 {
w.Header().Set(ContentLength, strconv.FormatInt(written, 10))
}
if req.TLS != nil {
ctx.log.Infof("Round trip: %v, code: %v, duration: %v tls:version: %x, tls:resume:%t, tls:csuite:%x, tls:server:%v",
req.URL, response.StatusCode, time.Now().UTC().Sub(start),
req.TLS.Version,
req.TLS.DidResume,
req.TLS.CipherSuite,
req.TLS.ServerName)
} else {
ctx.log.Infof("Round trip: %v, code: %v, duration: %v",
req.URL, response.StatusCode, time.Now().UTC().Sub(start))
}
return u
}
// copyRequest makes a copy of the specified request to be sent using the configured
// transport
func (f *httpForwarder) copyRequest(req *http.Request, u *url.URL) *http.Request {
outReq := new(http.Request)
*outReq = *req // includes shallow copies of maps, but we handle this below
// Modify the request to handle the target URL
func (f *httpForwarder) modifyRequest(outReq *http.Request, target *url.URL) {
outReq.URL = utils.CopyURL(outReq.URL)
outReq.URL.Scheme = target.Scheme
outReq.URL.Host = target.Host
u := f.getUrlFromRequest(outReq)
outReq.URL.Path = u.Path
outReq.URL.RawPath = u.RawPath
outReq.URL.RawQuery = u.RawQuery
outReq.RequestURI = "" // Outgoing request should not have RequestURI
outReq.URL = utils.CopyURL(req.URL)
outReq.URL.Scheme = u.Scheme
outReq.URL.Host = u.Host
outReq.URL.Opaque = req.RequestURI
// raw query is already included in RequestURI, so ignore it to avoid dupes
outReq.URL.RawQuery = ""
// Do not pass client Host header unless optsetter PassHostHeader is set.
if !f.passHost {
outReq.Host = u.Host
outReq.Host = target.Host
}
outReq.Proto = "HTTP/1.1"
outReq.ProtoMajor = 1
outReq.ProtoMinor = 1
// Overwrite close flag so we can keep persistent connection for the backend servers
outReq.Close = false
outReq.Header = make(http.Header)
utils.CopyHeaders(outReq.Header, req.Header)
if f.rewriter != nil {
f.rewriter.Rewrite(outReq)
}
if req.ContentLength == 0 {
// https://github.com/golang/go/issues/16036: nil Body for http.Transport retries
outReq.Body = nil
}
return outReq
}
// serveHTTP forwards websocket traffic
func (f *websocketForwarder) serveHTTP(w http.ResponseWriter, req *http.Request, ctx *handlerContext) {
outReq := f.copyRequest(req, req.URL)
func (f *httpForwarder) serveWebSocket(w http.ResponseWriter, req *http.Request, ctx *handlerContext) {
if f.log.Level >= log.DebugLevel {
logEntry := f.log.WithField("Request", utils.DumpHttpRequest(req))
logEntry.Debug("vulcand/oxy/forward/websocket: begin ServeHttp on request")
defer logEntry.Debug("vulcand/oxy/forward/websocket: competed ServeHttp on request")
}
outReq := f.copyWebSocketRequest(req)
dialer := websocket.DefaultDialer
if outReq.URL.Scheme == "wss" && f.TLSClientConfig != nil {
dialer.TLSClientConfig = f.TLSClientConfig.Clone()
if outReq.URL.Scheme == "wss" && f.tlsClientConfig != nil {
dialer.TLSClientConfig = f.tlsClientConfig.Clone()
// WebSocket is only in http/1.1
dialer.TLSClientConfig.NextProtos = []string{"http/1.1"}
}
targetConn, resp, err := dialer.Dial(outReq.URL.String(), outReq.Header)
@ -274,33 +292,33 @@ func (f *websocketForwarder) serveHTTP(w http.ResponseWriter, req *http.Request,
if resp == nil {
ctx.errHandler.ServeHTTP(w, req, err)
} else {
ctx.log.Errorf("Error dialing %q: %v with resp: %d %s", outReq.Host, err, resp.StatusCode, resp.Status)
log.Errorf("vulcand/oxy/forward/websocket: Error dialing %q: %v with resp: %d %s", outReq.Host, err, resp.StatusCode, resp.Status)
hijacker, ok := w.(http.Hijacker)
if !ok {
ctx.log.Errorf("%s can not be hijack", reflect.TypeOf(w))
log.Errorf("vulcand/oxy/forward/websocket: %s can not be hijack", reflect.TypeOf(w))
ctx.errHandler.ServeHTTP(w, req, err)
return
}
conn, _, err := hijacker.Hijack()
if err != nil {
ctx.log.Errorf("Failed to hijack responseWriter")
ctx.errHandler.ServeHTTP(w, req, err)
conn, _, errHijack := hijacker.Hijack()
if errHijack != nil {
log.Errorf("vulcand/oxy/forward/websocket: Failed to hijack responseWriter")
ctx.errHandler.ServeHTTP(w, req, errHijack)
return
}
defer conn.Close()
err = resp.Write(conn)
if err != nil {
ctx.log.Errorf("Failed to forward response")
ctx.errHandler.ServeHTTP(w, req, err)
errWrite := resp.Write(conn)
if errWrite != nil {
log.Errorf("vulcand/oxy/forward/websocket: Failed to forward response")
ctx.errHandler.ServeHTTP(w, req, errWrite)
return
}
}
return
}
//Only the targetConn choose to CheckOrigin or not
// Only the targetConn choose to CheckOrigin or not
upgrader := websocket.Upgrader{CheckOrigin: func(r *http.Request) bool {
return true
}}
@ -308,62 +326,64 @@ func (f *websocketForwarder) serveHTTP(w http.ResponseWriter, req *http.Request,
utils.RemoveHeaders(resp.Header, WebsocketUpgradeHeaders...)
underlyingConn, err := upgrader.Upgrade(w, req, resp.Header)
if err != nil {
ctx.log.Errorf("Error while upgrading connection : %v", err)
log.Errorf("vulcand/oxy/forward/websocket: Error while upgrading connection : %v", err)
return
}
defer underlyingConn.Close()
defer targetConn.Close()
errc := make(chan error, 2)
replicate := func(dst io.Writer, src io.Reader) {
_, err := io.Copy(dst, src)
errc <- err
replicate := func(dst io.Writer, src io.Reader, dstName string, srcName string) {
_, errCopy := io.Copy(dst, src)
if errCopy != nil {
f.log.Errorf("vulcand/oxy/forward/websocket: Error when copying from %s to %s using io.Copy: %v", srcName, dstName, errCopy)
} else {
f.log.Infof("vulcand/oxy/forward/websocket: Copying from %s to %s using io.Copy completed without error.", srcName, dstName)
}
errc <- errCopy
}
go replicate(targetConn.UnderlyingConn(), underlyingConn.UnderlyingConn())
go replicate(targetConn.UnderlyingConn(), underlyingConn.UnderlyingConn(), "backend", "client")
// Try to read the first message
t, msg, err := targetConn.ReadMessage()
msgType, msg, err := targetConn.ReadMessage()
if err != nil {
ctx.log.Errorf("Couldn't read first message : %v", err)
log.Errorf("vulcand/oxy/forward/websocket: Couldn't read first message : %v", err)
} else {
underlyingConn.WriteMessage(t, msg)
underlyingConn.WriteMessage(msgType, msg)
}
go replicate(underlyingConn.UnderlyingConn(), targetConn.UnderlyingConn())
go replicate(underlyingConn.UnderlyingConn(), targetConn.UnderlyingConn(), "client", "backend")
<-errc
}
// copyRequest makes a copy of the specified request.
func (f *websocketForwarder) copyRequest(req *http.Request, u *url.URL) (outReq *http.Request) {
// copyWebsocketRequest makes a copy of the specified request.
func (f *httpForwarder) copyWebSocketRequest(req *http.Request) (outReq *http.Request) {
outReq = new(http.Request)
*outReq = *req // includes shallow copies of maps, but we handle this below
outReq.URL = utils.CopyURL(req.URL)
outReq.URL.Scheme = u.Scheme
outReq.URL.Scheme = req.URL.Scheme
//sometimes backends might be registered as HTTP/HTTPS servers so translate URLs to websocket URLs.
switch u.Scheme {
// sometimes backends might be registered as HTTP/HTTPS servers so translate URLs to websocket URLs.
switch req.URL.Scheme {
case "https":
outReq.URL.Scheme = "wss"
case "http":
outReq.URL.Scheme = "ws"
}
if requestURI, err := url.ParseRequestURI(outReq.RequestURI); err == nil {
if requestURI.RawPath != "" {
outReq.URL.Path = requestURI.RawPath
} else {
outReq.URL.Path = requestURI.Path
}
outReq.URL.RawQuery = requestURI.RawQuery
}
u := f.getUrlFromRequest(outReq)
outReq.URL.Host = u.Host
outReq.URL.Path = u.Path
outReq.URL.RawPath = u.RawPath
outReq.URL.RawQuery = u.RawQuery
outReq.RequestURI = "" // Outgoing request should not have RequestURI
outReq.URL.Host = req.URL.Host
outReq.Header = make(http.Header)
//gorilla websocket use this header to set the request.Host tested in checkSameOrigin
// gorilla websocket use this header to set the request.Host tested in checkSameOrigin
outReq.Header.Set("Host", outReq.Host)
utils.CopyHeaders(outReq.Header, req.Header)
utils.RemoveHeaders(outReq.Header, WebsocketDialHeaders...)
@ -374,9 +394,48 @@ func (f *websocketForwarder) copyRequest(req *http.Request, u *url.URL) (outReq
return outReq
}
// serveHTTP forwards HTTP traffic using the configured transport
func (f *httpForwarder) serveHTTP(w http.ResponseWriter, inReq *http.Request, ctx *handlerContext) {
if f.log.Level >= log.DebugLevel {
logEntry := f.log.WithField("Request", utils.DumpHttpRequest(inReq))
logEntry.Debug("vulcand/oxy/forward/http: begin ServeHttp on request")
defer logEntry.Debug("vulcand/oxy/forward/http: completed ServeHttp on request")
}
pw := &utils.ProxyWriter{
W: w,
}
start := time.Now().UTC()
outReq := new(http.Request)
*outReq = *inReq // includes shallow copies of maps, but we handle this in Director
revproxy := httputil.ReverseProxy{
Director: func(req *http.Request) {
f.modifyRequest(req, inReq.URL)
},
Transport: f.roundTripper,
FlushInterval: f.flushInterval,
ModifyResponse: f.modifyResponse,
}
revproxy.ServeHTTP(pw, outReq)
if inReq.TLS != nil {
f.log.Infof("vulcand/oxy/forward/http: Round trip: %v, code: %v, Length: %v, duration: %v tls:version: %x, tls:resume:%t, tls:csuite:%x, tls:server:%v",
inReq.URL, pw.Code, pw.Length, time.Now().UTC().Sub(start),
inReq.TLS.Version,
inReq.TLS.DidResume,
inReq.TLS.CipherSuite,
inReq.TLS.ServerName)
} else {
f.log.Infof("vulcand/oxy/forward/http: Round trip: %v, code: %v, Length: %v, duration: %v",
inReq.URL, pw.Code, pw.Length, time.Now().UTC().Sub(start))
}
}
// isWebsocketRequest determines if the specified HTTP request is a
// websocket handshake request
func isWebsocketRequest(req *http.Request) bool {
func IsWebsocketRequest(req *http.Request) bool {
containsHeader := func(name, value string) bool {
items := strings.Split(req.Header.Get(name), ",")
for _, item := range items {
@ -388,12 +447,3 @@ func isWebsocketRequest(req *http.Request) bool {
}
return containsHeader(Connection, "upgrade") && containsHeader(Upgrade, "websocket")
}
func shallowCopyTrailers(dstHeader, srcTrailer http.Header, forceSetTrailers bool) {
for k, vv := range srcTrailer {
if forceSetTrailers {
k = http.TrailerPrefix + k
}
dstHeader[k] = vv
}
}

View file

@ -16,7 +16,6 @@ const (
TransferEncoding = "Transfer-Encoding"
Upgrade = "Upgrade"
ContentLength = "Content-Length"
ContentType = "Content-Type"
SecWebsocketKey = "Sec-Websocket-Key"
SecWebsocketVersion = "Sec-Websocket-Version"
SecWebsocketExtensions = "Sec-Websocket-Extensions"

View file

@ -1,53 +0,0 @@
package forward
import (
"bufio"
"fmt"
"net"
"net/http"
)
var (
_ http.Hijacker = &responseFlusher{}
_ http.Flusher = &responseFlusher{}
_ http.CloseNotifier = &responseFlusher{}
)
type responseFlusher struct {
http.ResponseWriter
flush bool
}
func newResponseFlusher(rw http.ResponseWriter, flush bool) *responseFlusher {
return &responseFlusher{
ResponseWriter: rw,
flush: flush,
}
}
func (wf *responseFlusher) Write(p []byte) (int, error) {
written, err := wf.ResponseWriter.Write(p)
if wf.flush {
wf.Flush()
}
return written, err
}
func (wf *responseFlusher) Hijack() (net.Conn, *bufio.ReadWriter, error) {
hijacker, ok := wf.ResponseWriter.(http.Hijacker)
if !ok {
return nil, nil, fmt.Errorf("the ResponseWriter doesn't support the Hijacker interface")
}
return hijacker.Hijack()
}
func (wf *responseFlusher) CloseNotify() <-chan bool {
return wf.ResponseWriter.(http.CloseNotifier).CloseNotify()
}
func (wf *responseFlusher) Flush() {
flusher, ok := wf.ResponseWriter.(http.Flusher)
if ok {
flusher.Flush()
}
}

View file

@ -14,16 +14,25 @@ type HeaderRewriter struct {
Hostname string
}
// clean up IP in case if it is ipv6 address and it has {zone} information in it, like "[fe80::d806:a55d:eb1b:49cc%vEthernet (vmxnet3 Ethernet Adapter - Virtual Switch)]:64692"
func ipv6fix(clientIP string) string {
return strings.Split(clientIP, "%")[0]
}
func (rw *HeaderRewriter) Rewrite(req *http.Request) {
if !rw.TrustForwardHeader {
utils.RemoveHeaders(req.Header, XHeaders...)
}
if clientIP, _, err := net.SplitHostPort(req.RemoteAddr); err == nil {
if prior, ok := req.Header[XForwardedFor]; ok {
req.Header.Set(XForwardedFor, strings.Join(prior, ", ")+", "+clientIP)
} else {
req.Header.Set(XForwardedFor, clientIP)
clientIP = ipv6fix(clientIP)
// If not websocket, done in http.ReverseProxy
if IsWebsocketRequest(req) {
if prior, ok := req.Header[XForwardedFor]; ok {
req.Header.Set(XForwardedFor, strings.Join(prior, ", ")+", "+clientIP)
} else {
req.Header.Set(XForwardedFor, clientIP)
}
}
if req.Header.Get(XRealIp) == "" {
@ -40,7 +49,15 @@ func (rw *HeaderRewriter) Rewrite(req *http.Request) {
}
}
if xfp := req.Header.Get(XForwardedPort); xfp == "" {
if IsWebsocketRequest(req) {
if req.Header.Get(XForwardedProto) == "https" {
req.Header.Set(XForwardedProto, "wss")
} else {
req.Header.Set(XForwardedProto, "ws")
}
}
if xfPort := req.Header.Get(XForwardedPort); xfPort == "" {
req.Header.Set(XForwardedPort, forwardedPort(req))
}
@ -52,9 +69,11 @@ func (rw *HeaderRewriter) Rewrite(req *http.Request) {
req.Header.Set(XForwardedServer, rw.Hostname)
}
// Remove hop-by-hop headers to the backend. Especially important is "Connection" because we want a persistent
// connection, regardless of what the client sent to us.
utils.RemoveHeaders(req.Header, HopHeaders...)
if !IsWebsocketRequest(req) {
// Remove hop-by-hop headers to the backend. Especially important is "Connection" because we want a persistent
// connection, regardless of what the client sent to us.
utils.RemoveHeaders(req.Header, HopHeaders...)
}
}
func forwardedPort(req *http.Request) string {

View file

@ -40,7 +40,7 @@ func SplitRatios(values []float64) (good map[float64]bool, bad map[float64]bool)
}
// SplitFloat64 provides simple anomaly detection for skewed data sets with no particular distribution.
// In essense it applies the formula if(v > median(values) + threshold * medianAbsoluteDeviation) -> anomaly
// In essence it applies the formula if(v > median(values) + threshold * medianAbsoluteDeviation) -> anomaly
// There's a corner case where there are just 2 values, so by definition there's no value that exceeds the threshold.
// This case is solved by introducing additional value that we know is good, e.g. 0. That helps to improve the detection results
// on such data sets.

View file

@ -71,9 +71,7 @@ func (c *RollingCounter) Clone() *RollingCounter {
lastBucket: c.lastBucket,
lastUpdated: c.lastUpdated,
}
for i, v := range c.values {
other.values[i] = v
}
copy(other.values, c.values)
return other
}

View file

@ -34,6 +34,15 @@ func NewHDRHistogram(low, high int64, sigfigs int) (h *HDRHistogram, err error)
}, nil
}
func (r *HDRHistogram) Export() *HDRHistogram {
var hist *hdrhistogram.Histogram = nil
if r.h != nil {
snapshot := r.h.Export()
hist = hdrhistogram.Import(snapshot)
}
return &HDRHistogram{low: r.low, high: r.high, sigfigs: r.sigfigs, h: hist}
}
// Returns latency at quantile with microsecond precision
func (h *HDRHistogram) LatencyAtQuantile(q float64) time.Duration {
return time.Duration(h.ValueAtQuantile(q)) * time.Microsecond
@ -118,6 +127,26 @@ func NewRollingHDRHistogram(low, high int64, sigfigs int, period time.Duration,
return rh, nil
}
func (r *RollingHDRHistogram) Export() *RollingHDRHistogram {
export := &RollingHDRHistogram{}
export.idx = r.idx
export.lastRoll = r.lastRoll
export.period = r.period
export.bucketCount = r.bucketCount
export.low = r.low
export.high = r.high
export.sigfigs = r.sigfigs
export.clock = r.clock
exportBuckets := make([]*HDRHistogram, len(r.buckets))
for i, hist := range r.buckets {
exportBuckets[i] = hist.Export()
}
export.buckets = exportBuckets
return export
}
func (r *RollingHDRHistogram) Append(o *RollingHDRHistogram) error {
if r.bucketCount != o.bucketCount || r.period != o.period || r.low != o.low || r.high != o.high || r.sigfigs != o.sigfigs {
return fmt.Errorf("can't merge")
@ -150,8 +179,8 @@ func (r *RollingHDRHistogram) Merged() (*HDRHistogram, error) {
return m, err
}
for _, h := range r.buckets {
if m.Merge(h); err != nil {
return nil, err
if errMerge := m.Merge(h); errMerge != nil {
return nil, errMerge
}
}
return m, nil

View file

@ -20,6 +20,7 @@ type RTMetrics struct {
statusCodes map[int]*RollingCounter
statusCodesLock sync.RWMutex
histogram *RollingHDRHistogram
histogramLock sync.RWMutex
newCounter NewCounterFn
newHist NewRollingHistogramFn
@ -102,12 +103,39 @@ func NewRTMetrics(settings ...rrOptSetter) (*RTMetrics, error) {
return m, nil
}
// Returns a new RTMetrics which is a copy of the current one
func (m *RTMetrics) Export() *RTMetrics {
m.statusCodesLock.RLock()
defer m.statusCodesLock.RUnlock()
m.histogramLock.RLock()
defer m.histogramLock.RUnlock()
export := &RTMetrics{}
export.statusCodesLock = sync.RWMutex{}
export.histogramLock = sync.RWMutex{}
export.total = m.total.Clone()
export.netErrors = m.netErrors.Clone()
exportStatusCodes := map[int]*RollingCounter{}
for code, rollingCounter := range m.statusCodes {
exportStatusCodes[code] = rollingCounter.Clone()
}
export.statusCodes = exportStatusCodes
if m.histogram != nil {
export.histogram = m.histogram.Export()
}
export.newCounter = m.newCounter
export.newHist = m.newHist
export.clock = m.clock
return export
}
func (m *RTMetrics) CounterWindowSize() time.Duration {
return m.total.WindowSize()
}
// GetNetworkErrorRatio calculates the amont of network errors such as time outs and dropped connection
// that occured in the given time window compared to the total requests count.
// that occurred in the given time window compared to the total requests count.
func (m *RTMetrics) NetworkErrorRatio() float64 {
if m.total.Count() == 0 {
return 0
@ -148,11 +176,13 @@ func (m *RTMetrics) Append(other *RTMetrics) error {
return err
}
copied := other.Export()
m.statusCodesLock.Lock()
defer m.statusCodesLock.Unlock()
other.statusCodesLock.RLock()
defer other.statusCodesLock.RUnlock()
for code, c := range other.statusCodes {
m.histogramLock.Lock()
defer m.histogramLock.Unlock()
for code, c := range copied.statusCodes {
o, ok := m.statusCodes[code]
if ok {
if err := o.Append(c); err != nil {
@ -163,7 +193,7 @@ func (m *RTMetrics) Append(other *RTMetrics) error {
}
}
return m.histogram.Append(other.histogram)
return m.histogram.Append(copied.histogram)
}
func (m *RTMetrics) Record(code int, duration time.Duration) {
@ -200,35 +230,36 @@ func (m *RTMetrics) StatusCodesCounts() map[int]int64 {
// GetLatencyHistogram computes and returns resulting histogram with latencies observed.
func (m *RTMetrics) LatencyHistogram() (*HDRHistogram, error) {
m.histogramLock.Lock()
defer m.histogramLock.Unlock()
return m.histogram.Merged()
}
func (m *RTMetrics) Reset() {
m.statusCodesLock.Lock()
defer m.statusCodesLock.Unlock()
m.histogramLock.Lock()
defer m.histogramLock.Unlock()
m.histogram.Reset()
m.total.Reset()
m.netErrors.Reset()
m.statusCodesLock.Lock()
defer m.statusCodesLock.Unlock()
m.statusCodes = make(map[int]*RollingCounter)
}
func (m *RTMetrics) recordNetError() error {
m.netErrors.Inc(1)
return nil
}
func (m *RTMetrics) recordLatency(d time.Duration) error {
m.histogramLock.Lock()
defer m.histogramLock.Unlock()
return m.histogram.RecordLatencies(d, 1)
}
func (m *RTMetrics) recordStatusCode(statusCode int) error {
m.statusCodesLock.RLock()
m.statusCodesLock.Lock()
if c, ok := m.statusCodes[statusCode]; ok {
c.Inc(1)
m.statusCodesLock.RUnlock()
m.statusCodesLock.Unlock()
return nil
}
m.statusCodesLock.RUnlock()
m.statusCodesLock.Unlock()
m.statusCodesLock.Lock()
defer m.statusCodesLock.Unlock()

View file

@ -7,6 +7,7 @@ import (
"sync"
"time"
log "github.com/Sirupsen/logrus"
"github.com/mailgun/timetools"
"github.com/mailgun/ttlmap"
"github.com/vulcand/oxy/utils"
@ -65,7 +66,6 @@ type TokenLimiter struct {
mutex sync.Mutex
bucketSets *ttlmap.TtlMap
errHandler utils.ErrorHandler
log utils.Logger
capacity int
next http.Handler
}
@ -110,7 +110,7 @@ func (tl *TokenLimiter) ServeHTTP(w http.ResponseWriter, req *http.Request) {
}
if err := tl.consumeRates(req, source, amount); err != nil {
tl.log.Infof("limiting request %v %v, limit: %v", req.Method, req.URL, err)
log.Infof("limiting request %v %v, limit: %v", req.Method, req.URL, err)
tl.errHandler.ServeHTTP(w, req, err)
return
}
@ -155,7 +155,7 @@ func (tl *TokenLimiter) resolveRates(req *http.Request) *RateSet {
rates, err := tl.extractRates.Extract(req)
if err != nil {
tl.log.Errorf("Failed to retrieve rates: %v", err)
log.Errorf("Failed to retrieve rates: %v", err)
return tl.defaultRates
}
@ -190,14 +190,6 @@ func (e *RateErrHandler) ServeHTTP(w http.ResponseWriter, req *http.Request, err
type TokenLimiterOption func(l *TokenLimiter) error
// Logger sets the logger that will be used by this middleware.
func Logger(l utils.Logger) TokenLimiterOption {
return func(cl *TokenLimiter) error {
cl.log = l
return nil
}
}
// ErrorHandler sets error handler of the server
func ErrorHandler(h utils.ErrorHandler) TokenLimiterOption {
return func(cl *TokenLimiter) error {
@ -233,9 +225,6 @@ func Capacity(cap int) TokenLimiterOption {
var defaultErrHandler = &RateErrHandler{}
func setDefaults(tl *TokenLimiter) {
if tl.log == nil {
tl.log = utils.NullLogger
}
if tl.capacity <= 0 {
tl.capacity = DefaultCapacity
}

View file

@ -0,0 +1,5 @@
package roundrobin
import "net/http"
type RequestRewriteListener func(oldReq *http.Request, newReq *http.Request)

View file

@ -7,6 +7,7 @@ import (
"sync"
"time"
log "github.com/Sirupsen/logrus"
"github.com/mailgun/timetools"
"github.com/vulcand/oxy/memmetrics"
"github.com/vulcand/oxy/utils"
@ -42,22 +43,15 @@ type Rebalancer struct {
// errHandler is HTTP handler called in case of errors
errHandler utils.ErrorHandler
log utils.Logger
ratings []float64
// creates new meters
newMeter NewMeterFn
// sticky session object
ss *StickySession
}
stickySession *StickySession
func RebalancerLogger(log utils.Logger) RebalancerOption {
return func(r *Rebalancer) error {
r.log = log
return nil
}
requestRewriteListener RequestRewriteListener
}
func RebalancerClock(clock timetools.TimeProvider) RebalancerOption {
@ -89,18 +83,26 @@ func RebalancerErrorHandler(h utils.ErrorHandler) RebalancerOption {
}
}
func RebalancerStickySession(ss *StickySession) RebalancerOption {
func RebalancerStickySession(stickySession *StickySession) RebalancerOption {
return func(r *Rebalancer) error {
r.ss = ss
r.stickySession = stickySession
return nil
}
}
// RebalancerErrorHandler is a functional argument that sets error handler of the server
func RebalancerRequestRewriteListener(rrl RequestRewriteListener) RebalancerOption {
return func(r *Rebalancer) error {
r.requestRewriteListener = rrl
return nil
}
}
func NewRebalancer(handler balancerHandler, opts ...RebalancerOption) (*Rebalancer, error) {
rb := &Rebalancer{
mtx: &sync.Mutex{},
next: handler,
ss: nil,
mtx: &sync.Mutex{},
next: handler,
stickySession: nil,
}
for _, o := range opts {
if err := o(rb); err != nil {
@ -113,9 +115,6 @@ func NewRebalancer(handler balancerHandler, opts ...RebalancerOption) (*Rebalanc
if rb.backoffDuration == 0 {
rb.backoffDuration = 10 * time.Second
}
if rb.log == nil {
rb.log = &utils.NOPLogger{}
}
if rb.newMeter == nil {
rb.newMeter = func() (Meter, error) {
rc, err := memmetrics.NewRatioCounter(10, time.Second, memmetrics.RatioClock(rb.clock))
@ -143,6 +142,12 @@ func (rb *Rebalancer) Servers() []*url.URL {
}
func (rb *Rebalancer) ServeHTTP(w http.ResponseWriter, req *http.Request) {
if log.GetLevel() >= log.DebugLevel {
logEntry := log.WithField("Request", utils.DumpHttpRequest(req))
logEntry.Debug("vulcand/oxy/roundrobin/rebalancer: begin ServeHttp on request")
defer logEntry.Debug("vulcand/oxy/roundrobin/rebalancer: competed ServeHttp on request")
}
pw := &utils.ProxyWriter{W: w}
start := rb.clock.UtcNow()
@ -150,16 +155,15 @@ func (rb *Rebalancer) ServeHTTP(w http.ResponseWriter, req *http.Request) {
newReq := *req
stuck := false
if rb.ss != nil {
cookie_url, present, err := rb.ss.GetBackend(&newReq, rb.Servers())
if rb.stickySession != nil {
cookieUrl, present, err := rb.stickySession.GetBackend(&newReq, rb.Servers())
if err != nil {
rb.errHandler.ServeHTTP(w, req, err)
return
log.Infof("vulcand/oxy/roundrobin/rebalancer: error using server from cookie: %v", err)
}
if present {
newReq.URL = cookie_url
newReq.URL = cookieUrl
stuck = true
}
}
@ -171,12 +175,23 @@ func (rb *Rebalancer) ServeHTTP(w http.ResponseWriter, req *http.Request) {
return
}
if rb.ss != nil {
rb.ss.StickBackend(url, &w)
if log.GetLevel() >= log.DebugLevel {
//log which backend URL we're sending this request to
log.WithFields(log.Fields{"Request": utils.DumpHttpRequest(req), "ForwardURL": url}).Debugf("vulcand/oxy/roundrobin/rebalancer: Forwarding this request to URL")
}
if rb.stickySession != nil {
rb.stickySession.StickBackend(url, &w)
}
newReq.URL = url
}
//Emit event to a listener if one exists
if rb.requestRewriteListener != nil {
rb.requestRewriteListener(req, &newReq)
}
rb.next.Next().ServeHTTP(pw, &newReq)
rb.recordMetrics(newReq.URL, pw.Code, rb.clock.UtcNow().Sub(start))
@ -262,11 +277,11 @@ func (rb *Rebalancer) upsertServer(u *url.URL, weight int) error {
return nil
}
func (r *Rebalancer) findServer(u *url.URL) (*rbServer, int) {
if len(r.servers) == 0 {
func (rb *Rebalancer) findServer(u *url.URL) (*rbServer, int) {
if len(rb.servers) == 0 {
return nil, -1
}
for i, s := range r.servers {
for i, s := range rb.servers {
if sameURL(u, s.url) {
return s, i
}
@ -304,7 +319,7 @@ func (rb *Rebalancer) adjustWeights() {
func (rb *Rebalancer) applyWeights() {
for _, srv := range rb.servers {
rb.log.Infof("upsert server %v, weight %v", srv.url, srv.curWeight)
log.Infof("upsert server %v, weight %v", srv.url, srv.curWeight)
rb.next.UpsertServer(srv.url, Weight(srv.curWeight))
}
}
@ -316,7 +331,7 @@ func (rb *Rebalancer) setMarkedWeights() bool {
if srv.good {
weight := increase(srv.curWeight)
if weight <= FSMMaxWeight {
rb.log.Infof("increasing weight of %v from %v to %v", srv.url, srv.curWeight, weight)
log.Infof("increasing weight of %v from %v to %v", srv.url, srv.curWeight, weight)
srv.curWeight = weight
changed = true
}
@ -363,13 +378,13 @@ func (rb *Rebalancer) markServers() bool {
}
}
if len(g) != 0 && len(b) != 0 {
rb.log.Infof("bad: %v good: %v, ratings: %v", b, g, rb.ratings)
log.Infof("bad: %v good: %v, ratings: %v", b, g, rb.ratings)
}
return len(g) != 0 && len(b) != 0
}
func (rb *Rebalancer) convergeWeights() bool {
// If we have previoulsy changed servers try to restore weights to the original state
// If we have previously changed servers try to restore weights to the original state
changed := false
for _, s := range rb.servers {
if s.origWeight == s.curWeight {
@ -377,7 +392,7 @@ func (rb *Rebalancer) convergeWeights() bool {
}
changed = true
newWeight := decrease(s.origWeight, s.curWeight)
rb.log.Infof("decreasing weight of %v from %v to %v", s.url, s.curWeight, newWeight)
log.Infof("decreasing weight of %v from %v to %v", s.url, s.curWeight, newWeight)
s.curWeight = newWeight
}
if !changed {

View file

@ -7,6 +7,7 @@ import (
"net/url"
"sync"
log "github.com/Sirupsen/logrus"
"github.com/vulcand/oxy/utils"
)
@ -29,9 +30,17 @@ func ErrorHandler(h utils.ErrorHandler) LBOption {
}
}
func EnableStickySession(ss *StickySession) LBOption {
func EnableStickySession(stickySession *StickySession) LBOption {
return func(s *RoundRobin) error {
s.ss = ss
s.stickySession = stickySession
return nil
}
}
// ErrorHandler is a functional argument that sets error handler of the server
func RoundRobinRequestRewriteListener(rrl RequestRewriteListener) LBOption {
return func(s *RoundRobin) error {
s.requestRewriteListener = rrl
return nil
}
}
@ -41,19 +50,20 @@ type RoundRobin struct {
next http.Handler
errHandler utils.ErrorHandler
// Current index (starts from -1)
index int
servers []*server
currentWeight int
ss *StickySession
index int
servers []*server
currentWeight int
stickySession *StickySession
requestRewriteListener RequestRewriteListener
}
func New(next http.Handler, opts ...LBOption) (*RoundRobin, error) {
rr := &RoundRobin{
next: next,
index: -1,
mutex: &sync.Mutex{},
servers: []*server{},
ss: nil,
next: next,
index: -1,
mutex: &sync.Mutex{},
servers: []*server{},
stickySession: nil,
}
for _, o := range opts {
if err := o(rr); err != nil {
@ -71,19 +81,24 @@ func (r *RoundRobin) Next() http.Handler {
}
func (r *RoundRobin) ServeHTTP(w http.ResponseWriter, req *http.Request) {
if log.GetLevel() >= log.DebugLevel {
logEntry := log.WithField("Request", utils.DumpHttpRequest(req))
logEntry.Debug("vulcand/oxy/roundrobin/rr: begin ServeHttp on request")
defer logEntry.Debug("vulcand/oxy/roundrobin/rr: competed ServeHttp on request")
}
// make shallow copy of request before chaning anything to avoid side effects
newReq := *req
stuck := false
if r.ss != nil {
cookie_url, present, err := r.ss.GetBackend(&newReq, r.Servers())
if r.stickySession != nil {
cookieURL, present, err := r.stickySession.GetBackend(&newReq, r.Servers())
if err != nil {
r.errHandler.ServeHTTP(w, req, err)
return
log.Infof("vulcand/oxy/roundrobin/rr: error using server from cookie: %v", err)
}
if present {
newReq.URL = cookie_url
newReq.URL = cookieURL
stuck = true
}
}
@ -95,11 +110,22 @@ func (r *RoundRobin) ServeHTTP(w http.ResponseWriter, req *http.Request) {
return
}
if r.ss != nil {
r.ss.StickBackend(url, &w)
if r.stickySession != nil {
r.stickySession.StickBackend(url, &w)
}
newReq.URL = url
}
if log.GetLevel() >= log.DebugLevel {
//log which backend URL we're sending this request to
log.WithFields(log.Fields{"Request": utils.DumpHttpRequest(req), "ForwardURL": newReq.URL}).Debugf("vulcand/oxy/roundrobin/rr: Forwarding this request to URL")
}
//Emit event to a listener if one exists
if r.requestRewriteListener != nil {
r.requestRewriteListener(req, &newReq)
}
r.next.ServeHTTP(w, &newReq)
}
@ -144,8 +170,6 @@ func (r *RoundRobin) nextServer() (*server, error) {
return srv, nil
}
}
// We did full circle and found no available servers
return nil, fmt.Errorf("no available servers")
}
func (r *RoundRobin) RemoveServer(u *url.URL) error {

View file

@ -7,16 +7,16 @@ import (
)
type StickySession struct {
cookiename string
cookieName string
}
func NewStickySession(c string) *StickySession {
return &StickySession{c}
func NewStickySession(cookieName string) *StickySession {
return &StickySession{cookieName}
}
// GetBackend returns the backend URL stored in the sticky cookie, iff the backend is still in the valid list of servers.
func (s *StickySession) GetBackend(req *http.Request, servers []*url.URL) (*url.URL, bool, error) {
cookie, err := req.Cookie(s.cookiename)
cookie, err := req.Cookie(s.cookieName)
switch err {
case nil:
case http.ErrNoCookie:
@ -25,22 +25,21 @@ func (s *StickySession) GetBackend(req *http.Request, servers []*url.URL) (*url.
return nil, false, err
}
s_url, err := url.Parse(cookie.Value)
serverURL, err := url.Parse(cookie.Value)
if err != nil {
return nil, false, err
}
if s.isBackendAlive(s_url, servers) {
return s_url, true, nil
if s.isBackendAlive(serverURL, servers) {
return serverURL, true, nil
} else {
return nil, false, nil
}
}
func (s *StickySession) StickBackend(backend *url.URL, w *http.ResponseWriter) {
c := &http.Cookie{Name: s.cookiename, Value: backend.String(), Path: "/"}
http.SetCookie(*w, c)
return
cookie := &http.Cookie{Name: s.cookieName, Value: backend.String(), Path: "/"}
http.SetCookie(*w, cookie)
}
func (s *StickySession) isBackendAlive(needle *url.URL, haystack []*url.URL) bool {
@ -48,8 +47,8 @@ func (s *StickySession) isBackendAlive(needle *url.URL, haystack []*url.URL) boo
return false
}
for _, s := range haystack {
if sameURL(needle, s) {
for _, serverURL := range haystack {
if sameURL(needle, serverURL) {
return true
}
}

View file

@ -1,9 +1,21 @@
/*
package stream provides http.Handler middleware that solves several problems when dealing with http requests:
package stream provides http.Handler middleware that passes-through the entire request
Reads the entire request and response into buffer, optionally buffering it to disk for large requests.
Checks the limits for the requests and responses, rejecting in case if the limit was exceeded.
Changes request content-transfer-encoding from chunked and provides total size to the handlers.
Stream works around several limitations caused by buffering implementations, but
also introduces certain risks.
Workarounds for buffering limitations:
1. Streaming really large chunks of data (large file transfers, or streaming videos,
etc.)
2. Streaming (chunking) sparse data. For example, an implementation might
send a health check or a heart beat over a long-lived connection. This
does not play well with buffering.
Risks:
1. Connections could survive for very long periods of time.
2. There is no easy way to enforce limits on size/time of a connection.
Examples of a streaming middleware:
@ -12,340 +24,69 @@ Examples of a streaming middleware:
w.Write([]byte("hello"))
})
// Stream will read the body in buffer before passing the request to the handler
// calculate total size of the request and transform it from chunked encoding
// before passing to the server
// Stream will literally pass through to the next handler without ANY buffering
// or validation of the data.
stream.New(handler)
// This version will buffer up to 2MB in memory and will serialize any extra
// to a temporary file, if the request size exceeds 10MB it will reject the request
stream.New(handler,
stream.MemRequestBodyBytes(2 * 1024 * 1024),
stream.MaxRequestBodyBytes(10 * 1024 * 1024))
// Will do the same as above, but with responses
stream.New(handler,
stream.MemResponseBodyBytes(2 * 1024 * 1024),
stream.MaxResponseBodyBytes(10 * 1024 * 1024))
// Stream will replay the request if the handler returns error at least 3 times
// before returning the response
stream.New(handler, stream.Retry(`IsNetworkError() && Attempts() <= 2`))
*/
package stream
import (
"fmt"
"io"
"io/ioutil"
"net/http"
"github.com/mailgun/multibuf"
log "github.com/Sirupsen/logrus"
"github.com/vulcand/oxy/utils"
)
const (
// Store up to 1MB in RAM
DefaultMemBodyBytes = 1048576
// No limit by default
DefaultMaxBodyBytes = -1
// Maximum retry attempts
DefaultMaxRetryAttempts = 10
)
var errHandler utils.ErrorHandler = &SizeErrHandler{}
// Streamer is responsible for streaming requests and responses
// It buffers large reqeuests and responses to disk,
type Streamer struct {
// Stream is responsible for buffering requests and responses
// It buffers large requests and responses to disk,
type Stream struct {
maxRequestBodyBytes int64
memRequestBodyBytes int64
maxResponseBodyBytes int64
memResponseBodyBytes int64
retryPredicate hpredicate
next http.Handler
errHandler utils.ErrorHandler
log utils.Logger
}
// New returns a new streamer middleware. New() function supports optional functional arguments
func New(next http.Handler, setters ...optSetter) (*Streamer, error) {
strm := &Streamer{
func New(next http.Handler, setters ...optSetter) (*Stream, error) {
strm := &Stream{
next: next,
maxRequestBodyBytes: DefaultMaxBodyBytes,
memRequestBodyBytes: DefaultMemBodyBytes,
maxResponseBodyBytes: DefaultMaxBodyBytes,
memResponseBodyBytes: DefaultMemBodyBytes,
}
for _, s := range setters {
if err := s(strm); err != nil {
return nil, err
}
}
if strm.errHandler == nil {
strm.errHandler = errHandler
}
if strm.log == nil {
strm.log = utils.NullLogger
}
return strm, nil
}
type optSetter func(s *Streamer) error
// Retry provides a predicate that allows stream middleware to replay the request
// if it matches certain condition, e.g. returns special error code. Available functions are:
//
// Attempts() - limits the amount of retry attempts
// ResponseCode() - returns http response code
// IsNetworkError() - tests if response code is related to networking error
//
// Example of the predicate:
//
// `Attempts() <= 2 && ResponseCode() == 502`
//
func Retry(predicate string) optSetter {
return func(s *Streamer) error {
p, err := parseExpression(predicate)
if err != nil {
return err
}
s.retryPredicate = p
return nil
}
}
// Logger sets the logger that will be used by this middleware.
func Logger(l utils.Logger) optSetter {
return func(s *Streamer) error {
s.log = l
return nil
}
}
// ErrorHandler sets error handler of the server
func ErrorHandler(h utils.ErrorHandler) optSetter {
return func(s *Streamer) error {
s.errHandler = h
return nil
}
}
// MaxRequestBodyBytes sets the maximum request body size in bytes
func MaxRequestBodyBytes(m int64) optSetter {
return func(s *Streamer) error {
if m < 0 {
return fmt.Errorf("max bytes should be >= 0 got %d", m)
}
s.maxRequestBodyBytes = m
return nil
}
}
// MaxRequestBody bytes sets the maximum request body to be stored in memory
// stream middleware will serialize the excess to disk.
func MemRequestBodyBytes(m int64) optSetter {
return func(s *Streamer) error {
if m < 0 {
return fmt.Errorf("mem bytes should be >= 0 got %d", m)
}
s.memRequestBodyBytes = m
return nil
}
}
// MaxResponseBodyBytes sets the maximum request body size in bytes
func MaxResponseBodyBytes(m int64) optSetter {
return func(s *Streamer) error {
if m < 0 {
return fmt.Errorf("max bytes should be >= 0 got %d", m)
}
s.maxResponseBodyBytes = m
return nil
}
}
// MemResponseBodyBytes sets the maximum request body to be stored in memory
// stream middleware will serialize the excess to disk.
func MemResponseBodyBytes(m int64) optSetter {
return func(s *Streamer) error {
if m < 0 {
return fmt.Errorf("mem bytes should be >= 0 got %d", m)
}
s.memResponseBodyBytes = m
return nil
}
}
type optSetter func(s *Stream) error
// Wrap sets the next handler to be called by stream handler.
func (s *Streamer) Wrap(next http.Handler) error {
func (s *Stream) Wrap(next http.Handler) error {
s.next = next
return nil
}
func (s *Streamer) ServeHTTP(w http.ResponseWriter, req *http.Request) {
if err := s.checkLimit(req); err != nil {
s.log.Infof("request body over limit: %v", err)
s.errHandler.ServeHTTP(w, req, err)
return
func (s *Stream) ServeHTTP(w http.ResponseWriter, req *http.Request) {
if log.GetLevel() >= log.DebugLevel {
logEntry := log.WithField("Request", utils.DumpHttpRequest(req))
logEntry.Debug("vulcand/oxy/stream: begin ServeHttp on request")
defer logEntry.Debug("vulcand/oxy/stream: competed ServeHttp on request")
}
// Read the body while keeping limits in mind. This reader controls the maximum bytes
// to read into memory and disk. This reader returns an error if the total request size exceeds the
// prefefined MaxSizeBytes. This can occur if we got chunked request, in this case ContentLength would be set to -1
// and the reader would be unbounded bufio in the http.Server
body, err := multibuf.New(req.Body, multibuf.MaxBytes(s.maxRequestBodyBytes), multibuf.MemBytes(s.memRequestBodyBytes))
if err != nil || body == nil {
s.errHandler.ServeHTTP(w, req, err)
return
}
// Set request body to buffered reader that can replay the read and execute Seek
// Note that we don't change the original request body as it's handled by the http server
// and we don'w want to mess with standard library
defer body.Close()
// We need to set ContentLength based on known request size. The incoming request may have been
// set without content length or using chunked TransferEncoding
totalSize, err := body.Size()
if err != nil {
s.log.Errorf("failed to get size, err %v", err)
s.errHandler.ServeHTTP(w, req, err)
return
}
outreq := s.copyRequest(req, body, totalSize)
attempt := 1
for {
// We create a special writer that will limit the response size, buffer it to disk if necessary
writer, err := multibuf.NewWriterOnce(multibuf.MaxBytes(s.maxResponseBodyBytes), multibuf.MemBytes(s.memResponseBodyBytes))
if err != nil {
s.errHandler.ServeHTTP(w, req, err)
return
}
// We are mimicking http.ResponseWriter to replace writer with our special writer
b := &bufferWriter{
header: make(http.Header),
buffer: writer,
}
defer b.Close()
s.next.ServeHTTP(b, outreq)
var reader multibuf.MultiReader
if b.expectBody(outreq) {
rdr, err := writer.Reader()
if err != nil {
s.log.Errorf("failed to read response, err %v", err)
s.errHandler.ServeHTTP(w, req, err)
return
}
defer rdr.Close()
reader = rdr
}
if (s.retryPredicate == nil || attempt > DefaultMaxRetryAttempts) ||
!s.retryPredicate(&context{r: req, attempt: attempt, responseCode: b.code, log: s.log}) {
utils.CopyHeaders(w.Header(), b.Header())
w.WriteHeader(b.code)
if reader != nil {
io.Copy(w, reader)
}
return
}
attempt += 1
if _, err := body.Seek(0, 0); err != nil {
s.log.Errorf("Failed to rewind: error: %v", err)
s.errHandler.ServeHTTP(w, req, err)
return
}
outreq = s.copyRequest(req, body, totalSize)
s.log.Infof("retry Request(%v %v) attempt %v", req.Method, req.URL, attempt)
}
}
func (s *Streamer) copyRequest(req *http.Request, body io.ReadCloser, bodySize int64) *http.Request {
o := *req
o.URL = utils.CopyURL(req.URL)
o.Header = make(http.Header)
utils.CopyHeaders(o.Header, req.Header)
o.ContentLength = bodySize
// remove TransferEncoding that could have been previously set because we have transformed the request from chunked encoding
o.TransferEncoding = []string{}
// http.Transport will close the request body on any error, we are controlling the close process ourselves, so we override the closer here
o.Body = ioutil.NopCloser(body)
return &o
}
func (s *Streamer) checkLimit(req *http.Request) error {
if s.maxRequestBodyBytes <= 0 {
return nil
}
if req.ContentLength > s.maxRequestBodyBytes {
return &multibuf.MaxSizeReachedError{MaxSize: s.maxRequestBodyBytes}
}
return nil
}
type bufferWriter struct {
header http.Header
code int
buffer multibuf.WriterOnce
}
// RFC2616 #4.4
func (b *bufferWriter) expectBody(r *http.Request) bool {
if r.Method == "HEAD" {
return false
}
if (b.code >= 100 && b.code < 200) || b.code == 204 || b.code == 304 {
return false
}
if b.header.Get("Content-Length") == "" && b.header.Get("Transfer-Encoding") == "" {
return false
}
if b.header.Get("Content-Length") == "0" {
return false
}
return true
}
func (b *bufferWriter) Close() error {
return b.buffer.Close()
}
func (b *bufferWriter) Header() http.Header {
return b.header
}
func (b *bufferWriter) Write(buf []byte) (int, error) {
return b.buffer.Write(buf)
}
// WriteHeader sets rw.Code.
func (b *bufferWriter) WriteHeader(code int) {
b.code = code
}
type SizeErrHandler struct {
}
func (e *SizeErrHandler) ServeHTTP(w http.ResponseWriter, req *http.Request, err error) {
if _, ok := err.(*multibuf.MaxSizeReachedError); ok {
w.WriteHeader(http.StatusRequestEntityTooLarge)
w.Write([]byte(http.StatusText(http.StatusRequestEntityTooLarge)))
return
}
utils.DefaultHandler.ServeHTTP(w, req, err)
s.next.ServeHTTP(w, req)
}

View file

@ -4,7 +4,6 @@ import (
"fmt"
"net/http"
"github.com/vulcand/oxy/utils"
"github.com/vulcand/predicate"
)
@ -17,7 +16,6 @@ type context struct {
r *http.Request
attempt int
responseCode int
log utils.Logger
}
type hpredicate func(*context) bool

View file

@ -22,18 +22,18 @@ func ParseAuthHeader(header string) (*BasicAuth, error) {
return nil, fmt.Errorf(fmt.Sprintf("Failed to parse header '%s'", header))
}
auth_type := strings.ToLower(values[0])
if auth_type != "basic" {
return nil, fmt.Errorf("Expected basic auth type, got '%s'", auth_type)
authType := strings.ToLower(values[0])
if authType != "basic" {
return nil, fmt.Errorf("Expected basic auth type, got '%s'", authType)
}
encoded_string := values[1]
decoded_string, err := base64.StdEncoding.DecodeString(encoded_string)
encodedString := values[1]
decodedString, err := base64.StdEncoding.DecodeString(encodedString)
if err != nil {
return nil, fmt.Errorf("Failed to parse header '%s', base64 failed: %s", header, err)
}
values = strings.SplitN(string(decoded_string), ":", 2)
values = strings.SplitN(string(decodedString), ":", 2)
if len(values) != 2 {
return nil, fmt.Errorf("Failed to parse header '%s', expected separator ':'", header)
}

60
vendor/github.com/vulcand/oxy/utils/dumpreq.go generated vendored Normal file
View file

@ -0,0 +1,60 @@
package utils
import (
"crypto/tls"
"encoding/json"
"fmt"
"mime/multipart"
"net/http"
"net/url"
)
type SerializableHttpRequest struct {
Method string
URL *url.URL
Proto string // "HTTP/1.0"
ProtoMajor int // 1
ProtoMinor int // 0
Header http.Header
ContentLength int64
TransferEncoding []string
Host string
Form url.Values
PostForm url.Values
MultipartForm *multipart.Form
Trailer http.Header
RemoteAddr string
RequestURI string
TLS *tls.ConnectionState
}
func Clone(r *http.Request) *SerializableHttpRequest {
if r == nil {
return nil
}
rc := new(SerializableHttpRequest)
rc.Method = r.Method
rc.URL = r.URL
rc.Proto = r.Proto
rc.ProtoMajor = r.ProtoMajor
rc.ProtoMinor = r.ProtoMinor
rc.Header = r.Header
rc.ContentLength = r.ContentLength
rc.Host = r.Host
rc.RemoteAddr = r.RemoteAddr
rc.RequestURI = r.RequestURI
return rc
}
func (s *SerializableHttpRequest) ToJson() string {
if jsonVal, err := json.Marshal(s); err != nil || jsonVal == nil {
return fmt.Sprintf("Error marshalling SerializableHttpRequest to json: %s", err.Error())
} else {
return string(jsonVal)
}
}
func DumpHttpRequest(req *http.Request) string {
return fmt.Sprintf("%v", Clone(req).ToJson())
}

View file

@ -1,86 +0,0 @@
package utils
import (
"io"
"log"
)
var NullLogger Logger = &NOPLogger{}
// Logger defines a simple logging interface
type Logger interface {
Infof(format string, args ...interface{})
Warningf(format string, args ...interface{})
Errorf(format string, args ...interface{})
}
type FileLogger struct {
info *log.Logger
warn *log.Logger
error *log.Logger
}
func NewFileLogger(w io.Writer, lvl LogLevel) *FileLogger {
l := &FileLogger{}
flag := log.Ldate | log.Ltime | log.Lmicroseconds
if lvl <= INFO {
l.info = log.New(w, "INFO: ", flag)
}
if lvl <= WARN {
l.warn = log.New(w, "WARN: ", flag)
}
if lvl <= ERROR {
l.error = log.New(w, "ERR: ", flag)
}
return l
}
func (f *FileLogger) Infof(format string, args ...interface{}) {
if f.info == nil {
return
}
f.info.Printf(format, args...)
}
func (f *FileLogger) Warningf(format string, args ...interface{}) {
if f.warn == nil {
return
}
f.warn.Printf(format, args...)
}
func (f *FileLogger) Errorf(format string, args ...interface{}) {
if f.error == nil {
return
}
f.error.Printf(format, args...)
}
type NOPLogger struct {
}
func (*NOPLogger) Infof(format string, args ...interface{}) {
}
func (*NOPLogger) Warningf(format string, args ...interface{}) {
}
func (*NOPLogger) Errorf(format string, args ...interface{}) {
}
func (*NOPLogger) Info(string) {
}
func (*NOPLogger) Warning(string) {
}
func (*NOPLogger) Error(string) {
}
type LogLevel int
const (
INFO = iota
WARN
ERROR
)

View file

@ -2,19 +2,23 @@ package utils
import (
"bufio"
"fmt"
"io"
"mime"
"net"
"net/http"
"net/url"
"reflect"
log "github.com/Sirupsen/logrus"
)
// ProxyWriter helps to capture response headers and status code
// from the ServeHTTP. It can be safely passed to ServeHTTP handler,
// wrapping the real response writer.
type ProxyWriter struct {
W http.ResponseWriter
Code int
W http.ResponseWriter
Code int
Length int64
}
func (p *ProxyWriter) StatusCode() int {
@ -31,6 +35,7 @@ func (p *ProxyWriter) Header() http.Header {
}
func (p *ProxyWriter) Write(buf []byte) (int, error) {
p.Length = p.Length + int64(len(buf))
return p.W.Write(buf)
}
@ -45,8 +50,20 @@ func (p *ProxyWriter) Flush() {
}
}
func (p *ProxyWriter) CloseNotify() <-chan bool {
if cn, ok := p.W.(http.CloseNotifier); ok {
return cn.CloseNotify()
}
log.Warningf("Upstream ResponseWriter of type %v does not implement http.CloseNotifier. Returning dummy channel.", reflect.TypeOf(p.W))
return make(<-chan bool)
}
func (p *ProxyWriter) Hijack() (net.Conn, *bufio.ReadWriter, error) {
return p.W.(http.Hijacker).Hijack()
if hi, ok := p.W.(http.Hijacker); ok {
return hi.Hijack()
}
log.Warningf("Upstream ResponseWriter of type %v does not implement http.Hijacker. Returning dummy channel.", reflect.TypeOf(p.W))
return nil, nil, fmt.Errorf("The response writer that was wrapped in this proxy, does not implement http.Hijacker. It is of type: %v", reflect.TypeOf(p.W))
}
func NewBufferWriter(w io.WriteCloser) *BufferWriter {
@ -79,8 +96,20 @@ func (b *BufferWriter) WriteHeader(code int) {
b.Code = code
}
func (b *BufferWriter) CloseNotify() <-chan bool {
if cn, ok := b.W.(http.CloseNotifier); ok {
return cn.CloseNotify()
}
log.Warningf("Upstream ResponseWriter of type %v does not implement http.CloseNotifier. Returning dummy channel.", reflect.TypeOf(b.W))
return make(<-chan bool)
}
func (b *BufferWriter) Hijack() (net.Conn, *bufio.ReadWriter, error) {
return b.W.(http.Hijacker).Hijack()
if hi, ok := b.W.(http.Hijacker); ok {
return hi.Hijack()
}
log.Warningf("Upstream ResponseWriter of type %v does not implement http.Hijacker. Returning dummy channel.", reflect.TypeOf(b.W))
return nil, nil, fmt.Errorf("The response writer that was wrapped in this proxy, does not implement http.Hijacker. It is of type: %v", reflect.TypeOf(b.W))
}
type nopWriteCloser struct {
@ -106,11 +135,9 @@ func CopyURL(i *url.URL) *url.URL {
// CopyHeaders copies http headers from source to destination, it
// does not overide, but adds multiple headers
func CopyHeaders(dst, src http.Header) {
func CopyHeaders(dst http.Header, src http.Header) {
for k, vv := range src {
for _, v := range vv {
dst.Add(k, v)
}
dst[k] = append(dst[k], vv...)
}
}
@ -130,9 +157,3 @@ func RemoveHeaders(headers http.Header, names ...string) {
headers.Del(h)
}
}
// Parse the MIME media type value of a header.
func GetHeaderMediaType(headers http.Header, name string) (string, error) {
mediatype, _, err := mime.ParseMediaType(headers.Get(name))
return mediatype, err
}