Send 'Retry-After' to comply with RFC6585.

This commit is contained in:
Ludovic Fernandez 2018-07-11 10:08:03 +02:00 committed by Traefiker Bot
parent 027093a5a5
commit 8d75aba7eb
29 changed files with 435 additions and 172 deletions

2
Gopkg.lock generated
View file

@ -1267,7 +1267,7 @@
"roundrobin", "roundrobin",
"utils" "utils"
] ]
revision = "adbef6bedf021985587c3c18c9d4b84b2d78f67c" revision = "f0cbb9d6b797d92d168b95b5c443a31dfa67ccd0"
[[projects]] [[projects]]
name = "github.com/vulcand/predicate" name = "github.com/vulcand/predicate"

View file

@ -36,13 +36,12 @@ Examples of a buffering middleware:
package buffer package buffer
import ( import (
"bufio"
"fmt" "fmt"
"io" "io"
"io/ioutil" "io/ioutil"
"net/http"
"bufio"
"net" "net"
"net/http"
"reflect" "reflect"
"github.com/mailgun/multibuf" "github.com/mailgun/multibuf"
@ -74,6 +73,8 @@ type Buffer struct {
next http.Handler next http.Handler
errHandler utils.ErrorHandler errHandler utils.ErrorHandler
log *log.Logger
} }
// New returns a new buffer middleware. New() function supports optional functional arguments // New returns a new buffer middleware. New() function supports optional functional arguments
@ -86,6 +87,8 @@ func New(next http.Handler, setters ...optSetter) (*Buffer, error) {
maxResponseBodyBytes: DefaultMaxBodyBytes, maxResponseBodyBytes: DefaultMaxBodyBytes,
memResponseBodyBytes: DefaultMemBodyBytes, memResponseBodyBytes: DefaultMemBodyBytes,
log: log.StandardLogger(),
} }
for _, s := range setters { for _, s := range setters {
if err := s(strm); err != nil { if err := s(strm); err != nil {
@ -99,6 +102,16 @@ func New(next http.Handler, setters ...optSetter) (*Buffer, error) {
return strm, nil return strm, nil
} }
// Logger defines the logger the buffer will use.
//
// It defaults to logrus.StandardLogger(), the global logger used by logrus.
func Logger(l *log.Logger) optSetter {
return func(b *Buffer) error {
b.log = l
return nil
}
}
type optSetter func(b *Buffer) error type optSetter func(b *Buffer) error
// CondSetter Conditional setter. // CondSetter Conditional setter.
@ -154,7 +167,7 @@ func MaxRequestBodyBytes(m int64) optSetter {
} }
} }
// MaxRequestBody bytes sets the maximum request body to be stored in memory // MemRequestBodyBytes bytes sets the maximum request body to be stored in memory
// buffer middleware will serialize the excess to disk. // buffer middleware will serialize the excess to disk.
func MemRequestBodyBytes(m int64) optSetter { func MemRequestBodyBytes(m int64) optSetter {
return func(b *Buffer) error { return func(b *Buffer) error {
@ -196,8 +209,8 @@ func (b *Buffer) Wrap(next http.Handler) error {
} }
func (b *Buffer) ServeHTTP(w http.ResponseWriter, req *http.Request) { func (b *Buffer) ServeHTTP(w http.ResponseWriter, req *http.Request) {
if log.GetLevel() >= log.DebugLevel { if b.log.Level >= log.DebugLevel {
logEntry := log.WithField("Request", utils.DumpHttpRequest(req)) logEntry := b.log.WithField("Request", utils.DumpHttpRequest(req))
logEntry.Debug("vulcand/oxy/buffer: begin ServeHttp on request") logEntry.Debug("vulcand/oxy/buffer: begin ServeHttp on request")
defer logEntry.Debug("vulcand/oxy/buffer: completed ServeHttp on request") defer logEntry.Debug("vulcand/oxy/buffer: completed ServeHttp on request")
} }
@ -210,11 +223,11 @@ func (b *Buffer) ServeHTTP(w http.ResponseWriter, req *http.Request) {
// Read the body while keeping limits in mind. This reader controls the maximum bytes // Read the body while keeping limits in mind. This reader controls the maximum bytes
// to read into memory and disk. This reader returns an error if the total request size exceeds the // to read into memory and disk. This reader returns an error if the total request size exceeds the
// prefefined MaxSizeBytes. This can occur if we got chunked request, in this case ContentLength would be set to -1 // predefined MaxSizeBytes. This can occur if we got chunked request, in this case ContentLength would be set to -1
// and the reader would be unbounded bufio in the http.Server // and the reader would be unbounded bufio in the http.Server
body, err := multibuf.New(req.Body, multibuf.MaxBytes(b.maxRequestBodyBytes), multibuf.MemBytes(b.memRequestBodyBytes)) body, err := multibuf.New(req.Body, multibuf.MaxBytes(b.maxRequestBodyBytes), multibuf.MemBytes(b.memRequestBodyBytes))
if err != nil || body == nil { if err != nil || body == nil {
log.Errorf("vulcand/oxy/buffer: error when reading request body, err: %v", err) b.log.Errorf("vulcand/oxy/buffer: error when reading request body, err: %v", err)
b.errHandler.ServeHTTP(w, req, err) b.errHandler.ServeHTTP(w, req, err)
return return
} }
@ -235,7 +248,7 @@ func (b *Buffer) ServeHTTP(w http.ResponseWriter, req *http.Request) {
// set without content length or using chunked TransferEncoding // set without content length or using chunked TransferEncoding
totalSize, err := body.Size() totalSize, err := body.Size()
if err != nil { if err != nil {
log.Errorf("vulcand/oxy/buffer: failed to get request size, err: %v", err) b.log.Errorf("vulcand/oxy/buffer: failed to get request size, err: %v", err)
b.errHandler.ServeHTTP(w, req, err) b.errHandler.ServeHTTP(w, req, err)
return return
} }
@ -251,7 +264,7 @@ func (b *Buffer) ServeHTTP(w http.ResponseWriter, req *http.Request) {
// We create a special writer that will limit the response size, buffer it to disk if necessary // We create a special writer that will limit the response size, buffer it to disk if necessary
writer, err := multibuf.NewWriterOnce(multibuf.MaxBytes(b.maxResponseBodyBytes), multibuf.MemBytes(b.memResponseBodyBytes)) writer, err := multibuf.NewWriterOnce(multibuf.MaxBytes(b.maxResponseBodyBytes), multibuf.MemBytes(b.memResponseBodyBytes))
if err != nil { if err != nil {
log.Errorf("vulcand/oxy/buffer: failed create response writer, err: %v", err) b.log.Errorf("vulcand/oxy/buffer: failed create response writer, err: %v", err)
b.errHandler.ServeHTTP(w, req, err) b.errHandler.ServeHTTP(w, req, err)
return return
} }
@ -261,12 +274,13 @@ func (b *Buffer) ServeHTTP(w http.ResponseWriter, req *http.Request) {
header: make(http.Header), header: make(http.Header),
buffer: writer, buffer: writer,
responseWriter: w, responseWriter: w,
log: b.log,
} }
defer bw.Close() defer bw.Close()
b.next.ServeHTTP(bw, outreq) b.next.ServeHTTP(bw, outreq)
if bw.hijacked { if bw.hijacked {
log.Debugf("vulcand/oxy/buffer: connection was hijacked downstream. Not taking any action in buffer.") b.log.Debugf("vulcand/oxy/buffer: connection was hijacked downstream. Not taking any action in buffer.")
return return
} }
@ -274,7 +288,7 @@ func (b *Buffer) ServeHTTP(w http.ResponseWriter, req *http.Request) {
if bw.expectBody(outreq) { if bw.expectBody(outreq) {
rdr, err := writer.Reader() rdr, err := writer.Reader()
if err != nil { if err != nil {
log.Errorf("vulcand/oxy/buffer: failed to read response, err: %v", err) b.log.Errorf("vulcand/oxy/buffer: failed to read response, err: %v", err)
b.errHandler.ServeHTTP(w, req, err) b.errHandler.ServeHTTP(w, req, err)
return return
} }
@ -292,17 +306,17 @@ func (b *Buffer) ServeHTTP(w http.ResponseWriter, req *http.Request) {
return return
} }
attempt += 1 attempt++
if body != nil { if body != nil {
if _, err := body.Seek(0, 0); err != nil { if _, err := body.Seek(0, 0); err != nil {
log.Errorf("vulcand/oxy/buffer: failed to rewind response body, err: %v", err) b.log.Errorf("vulcand/oxy/buffer: failed to rewind response body, err: %v", err)
b.errHandler.ServeHTTP(w, req, err) b.errHandler.ServeHTTP(w, req, err)
return return
} }
} }
outreq = b.copyRequest(req, body, totalSize) outreq = b.copyRequest(req, body, totalSize)
log.Debugf("vulcand/oxy/buffer: retry Request(%v %v) attempt %v", req.Method, req.URL, attempt) b.log.Debugf("vulcand/oxy/buffer: retry Request(%v %v) attempt %v", req.Method, req.URL, attempt)
} }
} }
@ -339,6 +353,7 @@ type bufferWriter struct {
buffer multibuf.WriterOnce buffer multibuf.WriterOnce
responseWriter http.ResponseWriter responseWriter http.ResponseWriter
hijacked bool hijacked bool
log *log.Logger
} }
// RFC2616 #4.4 // RFC2616 #4.4
@ -376,16 +391,16 @@ func (b *bufferWriter) WriteHeader(code int) {
b.code = code b.code = code
} }
//CloseNotifier interface - this allows downstream connections to be terminated when the client terminates. // CloseNotifier interface - this allows downstream connections to be terminated when the client terminates.
func (b *bufferWriter) CloseNotify() <-chan bool { func (b *bufferWriter) CloseNotify() <-chan bool {
if cn, ok := b.responseWriter.(http.CloseNotifier); ok { if cn, ok := b.responseWriter.(http.CloseNotifier); ok {
return cn.CloseNotify() return cn.CloseNotify()
} }
log.Warningf("Upstream ResponseWriter of type %v does not implement http.CloseNotifier. Returning dummy channel.", reflect.TypeOf(b.responseWriter)) b.log.Warningf("Upstream ResponseWriter of type %v does not implement http.CloseNotifier. Returning dummy channel.", reflect.TypeOf(b.responseWriter))
return make(<-chan bool) return make(<-chan bool)
} }
//This allows connections to be hijacked for websockets for instance. // Hijack This allows connections to be hijacked for websockets for instance.
func (b *bufferWriter) Hijack() (net.Conn, *bufio.ReadWriter, error) { func (b *bufferWriter) Hijack() (net.Conn, *bufio.ReadWriter, error) {
if hi, ok := b.responseWriter.(http.Hijacker); ok { if hi, ok := b.responseWriter.(http.Hijacker); ok {
conn, rw, err := hi.Hijack() conn, rw, err := hi.Hijack()
@ -394,12 +409,12 @@ func (b *bufferWriter) Hijack() (net.Conn, *bufio.ReadWriter, error) {
} }
return conn, rw, err return conn, rw, err
} }
log.Warningf("Upstream ResponseWriter of type %v does not implement http.Hijacker. Returning dummy channel.", reflect.TypeOf(b.responseWriter)) b.log.Warningf("Upstream ResponseWriter of type %v does not implement http.Hijacker. Returning dummy channel.", reflect.TypeOf(b.responseWriter))
return nil, nil, fmt.Errorf("The response writer that was wrapped in this proxy, does not implement http.Hijacker. It is of type: %v", reflect.TypeOf(b.responseWriter)) return nil, nil, fmt.Errorf("The response writer that was wrapped in this proxy, does not implement http.Hijacker. It is of type: %v", reflect.TypeOf(b.responseWriter))
} }
type SizeErrHandler struct { // SizeErrHandler Size error handler
} type SizeErrHandler struct{}
func (e *SizeErrHandler) ServeHTTP(w http.ResponseWriter, req *http.Request, err error) { func (e *SizeErrHandler) ServeHTTP(w http.ResponseWriter, req *http.Request, err error) {
if _, ok := err.(*multibuf.MaxSizeReachedError); ok { if _, ok := err.(*multibuf.MaxSizeReachedError); ok {

View file

@ -7,6 +7,7 @@ import (
"github.com/vulcand/predicate" "github.com/vulcand/predicate"
) )
// IsValidExpression check if it's a valid expression
func IsValidExpression(expr string) bool { func IsValidExpression(expr string) bool {
_, err := parseExpression(expr) _, err := parseExpression(expr)
return err == nil return err == nil

View file

@ -3,7 +3,7 @@
// Vulcan circuit breaker watches the error condtion to match // Vulcan circuit breaker watches the error condtion to match
// after which it activates the fallback scenario, e.g. returns the response code // after which it activates the fallback scenario, e.g. returns the response code
// or redirects the request to another location // or redirects the request to another location
//
// Circuit breakers start in the Standby state first, observing responses and watching location metrics. // Circuit breakers start in the Standby state first, observing responses and watching location metrics.
// //
// Once the Circuit breaker condition is met, it enters the "Tripped" state, where it activates fallback scenario // Once the Circuit breaker condition is met, it enters the "Tripped" state, where it activates fallback scenario
@ -31,9 +31,8 @@ import (
"sync" "sync"
"time" "time"
log "github.com/sirupsen/logrus"
"github.com/mailgun/timetools" "github.com/mailgun/timetools"
log "github.com/sirupsen/logrus"
"github.com/vulcand/oxy/memmetrics" "github.com/vulcand/oxy/memmetrics"
"github.com/vulcand/oxy/utils" "github.com/vulcand/oxy/utils"
) )
@ -63,6 +62,8 @@ type CircuitBreaker struct {
next http.Handler next http.Handler
clock timetools.TimeProvider clock timetools.TimeProvider
log *log.Logger
} }
// New creates a new CircuitBreaker middleware // New creates a new CircuitBreaker middleware
@ -76,6 +77,7 @@ func New(next http.Handler, expression string, options ...CircuitBreakerOption)
fallbackDuration: defaultFallbackDuration, fallbackDuration: defaultFallbackDuration,
recoveryDuration: defaultRecoveryDuration, recoveryDuration: defaultRecoveryDuration,
fallback: defaultFallback, fallback: defaultFallback,
log: log.StandardLogger(),
} }
for _, s := range options { for _, s := range options {
@ -99,9 +101,19 @@ func New(next http.Handler, expression string, options ...CircuitBreakerOption)
return cb, nil return cb, nil
} }
// Logger defines the logger the circuit breaker will use.
//
// It defaults to logrus.StandardLogger(), the global logger used by logrus.
func Logger(l *log.Logger) CircuitBreakerOption {
return func(c *CircuitBreaker) error {
c.log = l
return nil
}
}
func (c *CircuitBreaker) ServeHTTP(w http.ResponseWriter, req *http.Request) { func (c *CircuitBreaker) ServeHTTP(w http.ResponseWriter, req *http.Request) {
if log.GetLevel() >= log.DebugLevel { if c.log.Level >= log.DebugLevel {
logEntry := log.WithField("Request", utils.DumpHttpRequest(req)) logEntry := c.log.WithField("Request", utils.DumpHttpRequest(req))
logEntry.Debug("vulcand/oxy/circuitbreaker: begin ServeHttp on request") logEntry.Debug("vulcand/oxy/circuitbreaker: begin ServeHttp on request")
defer logEntry.Debug("vulcand/oxy/circuitbreaker: completed ServeHttp on request") defer logEntry.Debug("vulcand/oxy/circuitbreaker: completed ServeHttp on request")
} }
@ -112,6 +124,7 @@ func (c *CircuitBreaker) ServeHTTP(w http.ResponseWriter, req *http.Request) {
c.serve(w, req) c.serve(w, req)
} }
// Wrap sets the next handler to be called by circuit breaker handler.
func (c *CircuitBreaker) Wrap(next http.Handler) { func (c *CircuitBreaker) Wrap(next http.Handler) {
c.next = next c.next = next
} }
@ -126,7 +139,7 @@ func (c *CircuitBreaker) activateFallback(w http.ResponseWriter, req *http.Reque
c.m.Lock() c.m.Lock()
defer c.m.Unlock() defer c.m.Unlock()
log.Warnf("%v is in error state", c) c.log.Warnf("%v is in error state", c)
switch c.state { switch c.state {
case stateStandby: case stateStandby:
@ -156,7 +169,7 @@ func (c *CircuitBreaker) activateFallback(w http.ResponseWriter, req *http.Reque
func (c *CircuitBreaker) serve(w http.ResponseWriter, req *http.Request) { func (c *CircuitBreaker) serve(w http.ResponseWriter, req *http.Request) {
start := c.clock.UtcNow() start := c.clock.UtcNow()
p := utils.NewProxyWriter(w) p := utils.NewProxyWriterWithLogger(w, c.log)
c.next.ServeHTTP(p, req) c.next.ServeHTTP(p, req)
@ -191,13 +204,13 @@ func (c *CircuitBreaker) exec(s SideEffect) {
} }
go func() { go func() {
if err := s.Exec(); err != nil { if err := s.Exec(); err != nil {
log.Errorf("%v side effect failure: %v", c, err) c.log.Errorf("%v side effect failure: %v", c, err)
} }
}() }()
} }
func (c *CircuitBreaker) setState(new cbState, until time.Time) { func (c *CircuitBreaker) setState(new cbState, until time.Time) {
log.Debugf("%v setting state to %v, until %v", c, new, until) c.log.Debugf("%v setting state to %v, until %v", c, new, until)
c.state = new c.state = new
c.until = until c.until = until
switch new { switch new {
@ -230,7 +243,7 @@ func (c *CircuitBreaker) checkAndSet() {
c.lastCheck = c.clock.UtcNow().Add(c.checkPeriod) c.lastCheck = c.clock.UtcNow().Add(c.checkPeriod)
if c.state == stateTripped { if c.state == stateTripped {
log.Debugf("%v skip set tripped", c) c.log.Debugf("%v skip set tripped", c)
return return
} }
@ -244,7 +257,7 @@ func (c *CircuitBreaker) checkAndSet() {
func (c *CircuitBreaker) setRecovering() { func (c *CircuitBreaker) setRecovering() {
c.setState(stateRecovering, c.clock.UtcNow().Add(c.recoveryDuration)) c.setState(stateRecovering, c.clock.UtcNow().Add(c.recoveryDuration))
c.rc = newRatioController(c.clock, c.recoveryDuration) c.rc = newRatioController(c.clock, c.recoveryDuration, c.log)
} }
// CircuitBreakerOption represents an option you can pass to New. // CircuitBreakerOption represents an option you can pass to New.
@ -296,7 +309,7 @@ func OnTripped(s SideEffect) CircuitBreakerOption {
} }
} }
// OnTripped sets a SideEffect to run when entering the Standby state. // OnStandby sets a SideEffect to run when entering the Standby state.
// Only one SideEffect can be set for this hook. // Only one SideEffect can be set for this hook.
func OnStandby(s SideEffect) CircuitBreakerOption { func OnStandby(s SideEffect) CircuitBreakerOption {
return func(c *CircuitBreaker) error { return func(c *CircuitBreaker) error {
@ -346,8 +359,7 @@ const (
var defaultFallback = &fallback{} var defaultFallback = &fallback{}
type fallback struct { type fallback struct{}
}
func (f *fallback) ServeHTTP(w http.ResponseWriter, req *http.Request) { func (f *fallback) ServeHTTP(w http.ResponseWriter, req *http.Request) {
w.WriteHeader(http.StatusServiceUnavailable) w.WriteHeader(http.StatusServiceUnavailable)

View file

@ -13,10 +13,12 @@ import (
"github.com/vulcand/oxy/utils" "github.com/vulcand/oxy/utils"
) )
// SideEffect a side effect
type SideEffect interface { type SideEffect interface {
Exec() error Exec() error
} }
// Webhook Web hook
type Webhook struct { type Webhook struct {
URL string URL string
Method string Method string
@ -25,11 +27,15 @@ type Webhook struct {
Body []byte Body []byte
} }
// WebhookSideEffect a web hook side effect
type WebhookSideEffect struct { type WebhookSideEffect struct {
w Webhook w Webhook
log *log.Logger
} }
func NewWebhookSideEffect(w Webhook) (*WebhookSideEffect, error) { // NewWebhookSideEffectsWithLogger creates a new WebhookSideEffect
func NewWebhookSideEffectsWithLogger(w Webhook, l *log.Logger) (*WebhookSideEffect, error) {
if w.Method == "" { if w.Method == "" {
return nil, fmt.Errorf("Supply method") return nil, fmt.Errorf("Supply method")
} }
@ -38,7 +44,12 @@ func NewWebhookSideEffect(w Webhook) (*WebhookSideEffect, error) {
return nil, err return nil, err
} }
return &WebhookSideEffect{w: w}, nil return &WebhookSideEffect{w: w, log: l}, nil
}
// NewWebhookSideEffect creates a new WebhookSideEffect
func NewWebhookSideEffect(w Webhook) (*WebhookSideEffect, error) {
return NewWebhookSideEffectsWithLogger(w, log.StandardLogger())
} }
func (w *WebhookSideEffect) getBody() io.Reader { func (w *WebhookSideEffect) getBody() io.Reader {
@ -51,6 +62,7 @@ func (w *WebhookSideEffect) getBody() io.Reader {
return nil return nil
} }
// Exec execute the side effect
func (w *WebhookSideEffect) Exec() error { func (w *WebhookSideEffect) Exec() error {
r, err := http.NewRequest(w.w.Method, w.w.URL, w.getBody()) r, err := http.NewRequest(w.w.Method, w.w.URL, w.getBody())
if err != nil { if err != nil {
@ -73,6 +85,6 @@ func (w *WebhookSideEffect) Exec() error {
if err != nil { if err != nil {
return err return err
} }
log.Debugf("%v got response: (%s): %s", w, re.Status, string(body)) w.log.Debugf("%v got response: (%s): %s", w, re.Status, string(body))
return nil return nil
} }

View file

@ -10,26 +10,36 @@ import (
"github.com/vulcand/oxy/utils" "github.com/vulcand/oxy/utils"
) )
// Response response model
type Response struct { type Response struct {
StatusCode int StatusCode int
ContentType string ContentType string
Body []byte Body []byte
} }
// ResponseFallback fallback response handler
type ResponseFallback struct { type ResponseFallback struct {
r Response r Response
log *log.Logger
} }
func NewResponseFallback(r Response) (*ResponseFallback, error) { // NewResponseFallbackWithLogger creates a new ResponseFallback
func NewResponseFallbackWithLogger(r Response, l *log.Logger) (*ResponseFallback, error) {
if r.StatusCode == 0 { if r.StatusCode == 0 {
return nil, fmt.Errorf("response code should not be 0") return nil, fmt.Errorf("response code should not be 0")
} }
return &ResponseFallback{r: r}, nil return &ResponseFallback{r: r, log: l}, nil
}
// NewResponseFallback creates a new ResponseFallback
func NewResponseFallback(r Response) (*ResponseFallback, error) {
return NewResponseFallbackWithLogger(r, log.StandardLogger())
} }
func (f *ResponseFallback) ServeHTTP(w http.ResponseWriter, req *http.Request) { func (f *ResponseFallback) ServeHTTP(w http.ResponseWriter, req *http.Request) {
if log.GetLevel() >= log.DebugLevel { if f.log.Level >= log.DebugLevel {
logEntry := log.WithField("Request", utils.DumpHttpRequest(req)) logEntry := f.log.WithField("Request", utils.DumpHttpRequest(req))
logEntry.Debug("vulcand/oxy/fallback/response: begin ServeHttp on request") logEntry.Debug("vulcand/oxy/fallback/response: begin ServeHttp on request")
defer logEntry.Debug("vulcand/oxy/fallback/response: completed ServeHttp on request") defer logEntry.Debug("vulcand/oxy/fallback/response: completed ServeHttp on request")
} }
@ -45,27 +55,38 @@ func (f *ResponseFallback) ServeHTTP(w http.ResponseWriter, req *http.Request) {
} }
} }
// Redirect redirect model
type Redirect struct { type Redirect struct {
URL string URL string
PreservePath bool PreservePath bool
} }
// RedirectFallback fallback redirect handler
type RedirectFallback struct { type RedirectFallback struct {
u *url.URL
r Redirect r Redirect
u *url.URL
log *log.Logger
} }
func NewRedirectFallback(r Redirect) (*RedirectFallback, error) { // NewRedirectFallbackWithLogger creates a new RedirectFallback
func NewRedirectFallbackWithLogger(r Redirect, l *log.Logger) (*RedirectFallback, error) {
u, err := url.ParseRequestURI(r.URL) u, err := url.ParseRequestURI(r.URL)
if err != nil { if err != nil {
return nil, err return nil, err
} }
return &RedirectFallback{u: u, r: r}, nil return &RedirectFallback{r: r, u: u, log: l}, nil
}
// NewRedirectFallback creates a new RedirectFallback
func NewRedirectFallback(r Redirect) (*RedirectFallback, error) {
return NewRedirectFallbackWithLogger(r, log.StandardLogger())
} }
func (f *RedirectFallback) ServeHTTP(w http.ResponseWriter, req *http.Request) { func (f *RedirectFallback) ServeHTTP(w http.ResponseWriter, req *http.Request) {
if log.GetLevel() >= log.DebugLevel { if f.log.Level >= log.DebugLevel {
logEntry := log.WithField("Request", utils.DumpHttpRequest(req)) logEntry := f.log.WithField("Request", utils.DumpHttpRequest(req))
logEntry.Debug("vulcand/oxy/fallback/redirect: begin ServeHttp on request") logEntry.Debug("vulcand/oxy/fallback/redirect: begin ServeHttp on request")
defer logEntry.Debug("vulcand/oxy/fallback/redirect: completed ServeHttp on request") defer logEntry.Debug("vulcand/oxy/fallback/redirect: completed ServeHttp on request")
} }

View file

@ -4,7 +4,6 @@ import (
"fmt" "fmt"
"time" "time"
log "github.com/sirupsen/logrus"
"github.com/vulcand/predicate" "github.com/vulcand/predicate"
) )
@ -50,7 +49,7 @@ func latencyAtQuantile(quantile float64) toInt {
return func(c *CircuitBreaker) int { return func(c *CircuitBreaker) int {
h, err := c.metrics.LatencyHistogram() h, err := c.metrics.LatencyHistogram()
if err != nil { if err != nil {
log.Errorf("Failed to get latency histogram, for %v error: %v", c, err) c.log.Errorf("Failed to get latency histogram, for %v error: %v", c, err)
return 0 return 0
} }
return int(h.LatencyAtQuantile(quantile) / time.Millisecond) return int(h.LatencyAtQuantile(quantile) / time.Millisecond)

View file

@ -19,13 +19,17 @@ type ratioController struct {
tm timetools.TimeProvider tm timetools.TimeProvider
allowed int allowed int
denied int denied int
log *log.Logger
} }
func newRatioController(tm timetools.TimeProvider, rampUp time.Duration) *ratioController { func newRatioController(tm timetools.TimeProvider, rampUp time.Duration, log *log.Logger) *ratioController {
return &ratioController{ return &ratioController{
duration: rampUp, duration: rampUp,
tm: tm, tm: tm,
start: tm.UtcNow(), start: tm.UtcNow(),
log: log,
} }
} }
@ -34,17 +38,17 @@ func (r *ratioController) String() string {
} }
func (r *ratioController) allowRequest() bool { func (r *ratioController) allowRequest() bool {
log.Debugf("%v", r) r.log.Debugf("%v", r)
t := r.targetRatio() t := r.targetRatio()
// This condition answers the question - would we satisfy the target ratio if we allow this request? // This condition answers the question - would we satisfy the target ratio if we allow this request?
e := r.computeRatio(r.allowed+1, r.denied) e := r.computeRatio(r.allowed+1, r.denied)
if e < t { if e < t {
r.allowed++ r.allowed++
log.Debugf("%v allowed", r) r.log.Debugf("%v allowed", r)
return true return true
} }
r.denied++ r.denied++
log.Debugf("%v denied", r) r.log.Debugf("%v denied", r)
return false return false
} }

View file

@ -1,4 +1,4 @@
// package connlimit provides control over simultaneous connections coming from the same source // Package connlimit provides control over simultaneous connections coming from the same source
package connlimit package connlimit
import ( import (
@ -10,7 +10,7 @@ import (
"github.com/vulcand/oxy/utils" "github.com/vulcand/oxy/utils"
) )
// Limiter tracks concurrent connection per token // ConnLimiter tracks concurrent connection per token
// and is capable of rejecting connections if they are failed // and is capable of rejecting connections if they are failed
type ConnLimiter struct { type ConnLimiter struct {
mutex *sync.Mutex mutex *sync.Mutex
@ -21,8 +21,10 @@ type ConnLimiter struct {
next http.Handler next http.Handler
errHandler utils.ErrorHandler errHandler utils.ErrorHandler
log *log.Logger
} }
// New creates a new ConnLimiter
func New(next http.Handler, extract utils.SourceExtractor, maxConnections int64, options ...ConnLimitOption) (*ConnLimiter, error) { func New(next http.Handler, extract utils.SourceExtractor, maxConnections int64, options ...ConnLimitOption) (*ConnLimiter, error) {
if extract == nil { if extract == nil {
return nil, fmt.Errorf("Extract function can not be nil") return nil, fmt.Errorf("Extract function can not be nil")
@ -33,6 +35,7 @@ func New(next http.Handler, extract utils.SourceExtractor, maxConnections int64,
maxConnections: maxConnections, maxConnections: maxConnections,
connections: make(map[string]int64), connections: make(map[string]int64),
next: next, next: next,
log: log.StandardLogger(),
} }
for _, o := range options { for _, o := range options {
@ -41,11 +44,24 @@ func New(next http.Handler, extract utils.SourceExtractor, maxConnections int64,
} }
} }
if cl.errHandler == nil { if cl.errHandler == nil {
cl.errHandler = defaultErrHandler cl.errHandler = &ConnErrHandler{
log: cl.log,
}
} }
return cl, nil return cl, nil
} }
// Logger defines the logger the connection limiter will use.
//
// It defaults to logrus.StandardLogger(), the global logger used by logrus.
func Logger(l *log.Logger) ConnLimitOption {
return func(cl *ConnLimiter) error {
cl.log = l
return nil
}
}
// Wrap sets the next handler to be called by connexion limiter handler.
func (cl *ConnLimiter) Wrap(h http.Handler) { func (cl *ConnLimiter) Wrap(h http.Handler) {
cl.next = h cl.next = h
} }
@ -53,12 +69,12 @@ func (cl *ConnLimiter) Wrap(h http.Handler) {
func (cl *ConnLimiter) ServeHTTP(w http.ResponseWriter, r *http.Request) { func (cl *ConnLimiter) ServeHTTP(w http.ResponseWriter, r *http.Request) {
token, amount, err := cl.extract.Extract(r) token, amount, err := cl.extract.Extract(r)
if err != nil { if err != nil {
log.Errorf("failed to extract source of the connection: %v", err) cl.log.Errorf("failed to extract source of the connection: %v", err)
cl.errHandler.ServeHTTP(w, r, err) cl.errHandler.ServeHTTP(w, r, err)
return return
} }
if err := cl.acquire(token, amount); err != nil { if err := cl.acquire(token, amount); err != nil {
log.Debugf("limiting request source %s: %v", token, err) cl.log.Debugf("limiting request source %s: %v", token, err)
cl.errHandler.ServeHTTP(w, r, err) cl.errHandler.ServeHTTP(w, r, err)
return return
} }
@ -95,6 +111,7 @@ func (cl *ConnLimiter) release(token string, amount int64) {
} }
} }
// MaxConnError maximum connections reached error
type MaxConnError struct { type MaxConnError struct {
max int64 max int64
} }
@ -103,12 +120,14 @@ func (m *MaxConnError) Error() string {
return fmt.Sprintf("max connections reached: %d", m.max) return fmt.Sprintf("max connections reached: %d", m.max)
} }
// ConnErrHandler connection limiter error handler
type ConnErrHandler struct { type ConnErrHandler struct {
log *log.Logger
} }
func (e *ConnErrHandler) ServeHTTP(w http.ResponseWriter, req *http.Request, err error) { func (e *ConnErrHandler) ServeHTTP(w http.ResponseWriter, req *http.Request, err error) {
if log.GetLevel() >= log.DebugLevel { if e.log.Level >= log.DebugLevel {
logEntry := log.WithField("Request", utils.DumpHttpRequest(req)) logEntry := e.log.WithField("Request", utils.DumpHttpRequest(req))
logEntry.Debug("vulcand/oxy/connlimit: begin ServeHttp on request") logEntry.Debug("vulcand/oxy/connlimit: begin ServeHttp on request")
defer logEntry.Debug("vulcand/oxy/connlimit: completed ServeHttp on request") defer logEntry.Debug("vulcand/oxy/connlimit: completed ServeHttp on request")
} }
@ -121,6 +140,7 @@ func (e *ConnErrHandler) ServeHTTP(w http.ResponseWriter, req *http.Request, err
utils.DefaultHandler.ServeHTTP(w, req, err) utils.DefaultHandler.ServeHTTP(w, req, err)
} }
// ConnLimitOption connection limit option type
type ConnLimitOption func(l *ConnLimiter) error type ConnLimitOption func(l *ConnLimiter) error
// ErrorHandler sets error handler of the server // ErrorHandler sets error handler of the server
@ -130,5 +150,3 @@ func ErrorHandler(h utils.ErrorHandler) ConnLimitOption {
return nil return nil
} }
} }
var defaultErrHandler = &ConnErrHandler{}

View file

@ -1,4 +1,4 @@
// package forwarder implements http handler that forwards requests to remote server // Package forward implements http handler that forwards requests to remote server
// and serves back the response // and serves back the response
// websocket proxying support based on https://github.com/yhat/wsutil // websocket proxying support based on https://github.com/yhat/wsutil
package forward package forward
@ -21,7 +21,7 @@ import (
"github.com/vulcand/oxy/utils" "github.com/vulcand/oxy/utils"
) )
// Oxy Logger interface of the internal // OxyLogger interface of the internal
type OxyLogger interface { type OxyLogger interface {
log.FieldLogger log.FieldLogger
GetLevel() log.Level GetLevel() log.Level
@ -42,8 +42,7 @@ type ReqRewriter interface {
type optSetter func(f *Forwarder) error type optSetter func(f *Forwarder) error
// PassHostHeader specifies if a client's Host header field should // PassHostHeader specifies if a client's Host header field should be delegated
// be delegated
func PassHostHeader(b bool) optSetter { func PassHostHeader(b bool) optSetter {
return func(f *Forwarder) error { return func(f *Forwarder) error {
f.httpForwarder.passHost = b f.httpForwarder.passHost = b
@ -68,8 +67,7 @@ func Rewriter(r ReqRewriter) optSetter {
} }
} }
// PassHostHeader specifies if a client's Host header field should // WebsocketTLSClientConfig define the websocker client TLS configuration
// be delegated
func WebsocketTLSClientConfig(tcc *tls.Config) optSetter { func WebsocketTLSClientConfig(tcc *tls.Config) optSetter {
return func(f *Forwarder) error { return func(f *Forwarder) error {
f.httpForwarder.tlsClientConfig = tcc f.httpForwarder.tlsClientConfig = tcc
@ -120,6 +118,7 @@ func Logger(l log.FieldLogger) optSetter {
} }
} }
// StateListener defines a state listener for the HTTP forwarder
func StateListener(stateListener UrlForwardingStateListener) optSetter { func StateListener(stateListener UrlForwardingStateListener) optSetter {
return func(f *Forwarder) error { return func(f *Forwarder) error {
f.stateListener = stateListener f.stateListener = stateListener
@ -127,6 +126,7 @@ func StateListener(stateListener UrlForwardingStateListener) optSetter {
} }
} }
// ResponseModifier defines a response modifier for the HTTP forwarder
func ResponseModifier(responseModifier func(*http.Response) error) optSetter { func ResponseModifier(responseModifier func(*http.Response) error) optSetter {
return func(f *Forwarder) error { return func(f *Forwarder) error {
f.httpForwarder.modifyResponse = responseModifier f.httpForwarder.modifyResponse = responseModifier
@ -134,6 +134,7 @@ func ResponseModifier(responseModifier func(*http.Response) error) optSetter {
} }
} }
// StreamingFlushInterval defines a streaming flush interval for the HTTP forwarder
func StreamingFlushInterval(flushInterval time.Duration) optSetter { func StreamingFlushInterval(flushInterval time.Duration) optSetter {
return func(f *Forwarder) error { return func(f *Forwarder) error {
f.httpForwarder.flushInterval = flushInterval f.httpForwarder.flushInterval = flushInterval
@ -141,11 +142,13 @@ func StreamingFlushInterval(flushInterval time.Duration) optSetter {
} }
} }
// ErrorHandlingRoundTripper a error handling round tripper
type ErrorHandlingRoundTripper struct { type ErrorHandlingRoundTripper struct {
http.RoundTripper http.RoundTripper
errorHandler utils.ErrorHandler errorHandler utils.ErrorHandler
} }
// RoundTrip executes the round trip
func (rt ErrorHandlingRoundTripper) RoundTrip(req *http.Request) (*http.Response, error) { func (rt ErrorHandlingRoundTripper) RoundTrip(req *http.Request) (*http.Response, error) {
res, err := rt.RoundTripper.RoundTrip(req) res, err := rt.RoundTripper.RoundTrip(req)
if err != nil { if err != nil {
@ -188,12 +191,15 @@ type httpForwarder struct {
bufferPool httputil.BufferPool bufferPool httputil.BufferPool
} }
const defaultFlushInterval = time.Duration(100) * time.Millisecond
// Connection states
const ( const (
defaultFlushInterval = time.Duration(100) * time.Millisecond StateConnected = iota
StateConnected = iota
StateDisconnected StateDisconnected
) )
// UrlForwardingStateListener URL forwarding state listener
type UrlForwardingStateListener func(*url.URL, int) type UrlForwardingStateListener func(*url.URL, int)
// New creates an instance of Forwarder based on the provided list of configuration options // New creates an instance of Forwarder based on the provided list of configuration options
@ -501,7 +507,7 @@ func (f *httpForwarder) serveHTTP(w http.ResponseWriter, inReq *http.Request, ct
} }
} }
// isWebsocketRequest determines if the specified HTTP request is a // IsWebsocketRequest determines if the specified HTTP request is a
// websocket handshake request // websocket handshake request
func IsWebsocketRequest(req *http.Request) bool { func IsWebsocketRequest(req *http.Request) bool {
containsHeader := func(name, value string) bool { containsHeader := func(name, value string) bool {

View file

@ -1,5 +1,6 @@
package forward package forward
// Headers
const ( const (
XForwardedProto = "X-Forwarded-Proto" XForwardedProto = "X-Forwarded-Proto"
XForwardedFor = "X-Forwarded-For" XForwardedFor = "X-Forwarded-For"
@ -22,7 +23,7 @@ const (
SecWebsocketAccept = "Sec-Websocket-Accept" SecWebsocketAccept = "Sec-Websocket-Accept"
) )
// Hop-by-hop headers. These are removed when sent to the backend. // HopHeaders Hop-by-hop headers. These are removed when sent to the backend.
// http://www.w3.org/Protocols/rfc2616/rfc2616-sec13.html // http://www.w3.org/Protocols/rfc2616/rfc2616-sec13.html
// Copied from reverseproxy.go, too bad // Copied from reverseproxy.go, too bad
var HopHeaders = []string{ var HopHeaders = []string{
@ -36,6 +37,7 @@ var HopHeaders = []string{
Upgrade, Upgrade,
} }
// WebsocketDialHeaders Websocket dial headers
var WebsocketDialHeaders = []string{ var WebsocketDialHeaders = []string{
Upgrade, Upgrade,
Connection, Connection,
@ -45,6 +47,7 @@ var WebsocketDialHeaders = []string{
SecWebsocketAccept, SecWebsocketAccept,
} }
// WebsocketUpgradeHeaders Websocket upgrade headers
var WebsocketUpgradeHeaders = []string{ var WebsocketUpgradeHeaders = []string{
Upgrade, Upgrade,
Connection, Connection,
@ -52,6 +55,7 @@ var WebsocketUpgradeHeaders = []string{
SecWebsocketExtensions, SecWebsocketExtensions,
} }
// XHeaders X-* headers
var XHeaders = []string{ var XHeaders = []string{
XForwardedProto, XForwardedProto,
XForwardedFor, XForwardedFor,

View file

@ -8,7 +8,7 @@ import (
"github.com/vulcand/oxy/utils" "github.com/vulcand/oxy/utils"
) )
// Rewriter is responsible for removing hop-by-hop headers and setting forwarding headers // HeaderRewriter is responsible for removing hop-by-hop headers and setting forwarding headers
type HeaderRewriter struct { type HeaderRewriter struct {
TrustForwardHeader bool TrustForwardHeader bool
Hostname string Hostname string
@ -19,6 +19,7 @@ func ipv6fix(clientIP string) string {
return strings.Split(clientIP, "%")[0] return strings.Split(clientIP, "%")[0]
} }
// Rewrite rewrite request headers
func (rw *HeaderRewriter) Rewrite(req *http.Request) { func (rw *HeaderRewriter) Rewrite(req *http.Request) {
if !rw.TrustForwardHeader { if !rw.TrustForwardHeader {
utils.RemoveHeaders(req.Header, XHeaders...) utils.RemoveHeaders(req.Header, XHeaders...)

View file

@ -6,7 +6,7 @@ import (
"time" "time"
) )
// SplitRatios provides simple anomaly detection for requests latencies. // SplitLatencies provides simple anomaly detection for requests latencies.
// it splits values into good or bad category based on the threshold and the median value. // it splits values into good or bad category based on the threshold and the median value.
// If all values are not far from the median, it will return all values in 'good' set. // If all values are not far from the median, it will return all values in 'good' set.
// Precision is the smallest value to consider, e.g. if set to millisecond, microseconds will be ignored. // Precision is the smallest value to consider, e.g. if set to millisecond, microseconds will be ignored.
@ -23,10 +23,10 @@ func SplitLatencies(values []time.Duration, precision time.Duration) (good map[t
good, bad = make(map[time.Duration]bool), make(map[time.Duration]bool) good, bad = make(map[time.Duration]bool), make(map[time.Duration]bool)
// Note that multiplier makes this function way less sensitive than ratios detector, this is to avoid noise. // Note that multiplier makes this function way less sensitive than ratios detector, this is to avoid noise.
vgood, vbad := SplitFloat64(2, 0, ratios) vgood, vbad := SplitFloat64(2, 0, ratios)
for r, _ := range vgood { for r := range vgood {
good[v2r[r]] = true good[v2r[r]] = true
} }
for r, _ := range vbad { for r := range vbad {
bad[v2r[r]] = true bad[v2r[r]] = true
} }
return good, bad return good, bad

View file

@ -9,6 +9,7 @@ import (
type rcOptSetter func(*RollingCounter) error type rcOptSetter func(*RollingCounter) error
// CounterClock defines a counter clock
func CounterClock(c timetools.TimeProvider) rcOptSetter { func CounterClock(c timetools.TimeProvider) rcOptSetter {
return func(r *RollingCounter) error { return func(r *RollingCounter) error {
r.clock = c r.clock = c
@ -16,7 +17,7 @@ func CounterClock(c timetools.TimeProvider) rcOptSetter {
} }
} }
// Calculates in memory failure rate of an endpoint using rolling window of a predefined size // RollingCounter Calculates in memory failure rate of an endpoint using rolling window of a predefined size
type RollingCounter struct { type RollingCounter struct {
clock timetools.TimeProvider clock timetools.TimeProvider
resolution time.Duration resolution time.Duration
@ -57,11 +58,13 @@ func NewCounter(buckets int, resolution time.Duration, options ...rcOptSetter) (
return rc, nil return rc, nil
} }
// Append append a counter
func (c *RollingCounter) Append(o *RollingCounter) error { func (c *RollingCounter) Append(o *RollingCounter) error {
c.Inc(int(o.Count())) c.Inc(int(o.Count()))
return nil return nil
} }
// Clone clone a counter
func (c *RollingCounter) Clone() *RollingCounter { func (c *RollingCounter) Clone() *RollingCounter {
c.cleanup() c.cleanup()
other := &RollingCounter{ other := &RollingCounter{
@ -75,6 +78,7 @@ func (c *RollingCounter) Clone() *RollingCounter {
return other return other
} }
// Reset reset a counter
func (c *RollingCounter) Reset() { func (c *RollingCounter) Reset() {
c.lastBucket = -1 c.lastBucket = -1
c.countedBuckets = 0 c.countedBuckets = 0
@ -84,27 +88,33 @@ func (c *RollingCounter) Reset() {
} }
} }
// CountedBuckets gets counted buckets
func (c *RollingCounter) CountedBuckets() int { func (c *RollingCounter) CountedBuckets() int {
return c.countedBuckets return c.countedBuckets
} }
// Count counts
func (c *RollingCounter) Count() int64 { func (c *RollingCounter) Count() int64 {
c.cleanup() c.cleanup()
return c.sum() return c.sum()
} }
// Resolution gets resolution
func (c *RollingCounter) Resolution() time.Duration { func (c *RollingCounter) Resolution() time.Duration {
return c.resolution return c.resolution
} }
// Buckets gets buckets
func (c *RollingCounter) Buckets() int { func (c *RollingCounter) Buckets() int {
return len(c.values) return len(c.values)
} }
// WindowSize gets windows size
func (c *RollingCounter) WindowSize() time.Duration { func (c *RollingCounter) WindowSize() time.Duration {
return time.Duration(len(c.values)) * c.resolution return time.Duration(len(c.values)) * c.resolution
} }
// Inc increment counter
func (c *RollingCounter) Inc(v int) { func (c *RollingCounter) Inc(v int) {
c.cleanup() c.cleanup()
c.incBucketValue(v) c.incBucketValue(v)

View file

@ -20,6 +20,7 @@ type HDRHistogram struct {
h *hdrhistogram.Histogram h *hdrhistogram.Histogram
} }
// NewHDRHistogram creates a new HDRHistogram
func NewHDRHistogram(low, high int64, sigfigs int) (h *HDRHistogram, err error) { func NewHDRHistogram(low, high int64, sigfigs int) (h *HDRHistogram, err error) {
defer func() { defer func() {
if msg := recover(); msg != nil { if msg := recover(); msg != nil {
@ -34,37 +35,42 @@ func NewHDRHistogram(low, high int64, sigfigs int) (h *HDRHistogram, err error)
}, nil }, nil
} }
func (r *HDRHistogram) Export() *HDRHistogram { // Export export a HDRHistogram
var hist *hdrhistogram.Histogram = nil func (h *HDRHistogram) Export() *HDRHistogram {
if r.h != nil { var hist *hdrhistogram.Histogram
snapshot := r.h.Export() if h.h != nil {
snapshot := h.h.Export()
hist = hdrhistogram.Import(snapshot) hist = hdrhistogram.Import(snapshot)
} }
return &HDRHistogram{low: r.low, high: r.high, sigfigs: r.sigfigs, h: hist} return &HDRHistogram{low: h.low, high: h.high, sigfigs: h.sigfigs, h: hist}
} }
// Returns latency at quantile with microsecond precision // LatencyAtQuantile sets latency at quantile with microsecond precision
func (h *HDRHistogram) LatencyAtQuantile(q float64) time.Duration { func (h *HDRHistogram) LatencyAtQuantile(q float64) time.Duration {
return time.Duration(h.ValueAtQuantile(q)) * time.Microsecond return time.Duration(h.ValueAtQuantile(q)) * time.Microsecond
} }
// Records latencies with microsecond precision // RecordLatencies Records latencies with microsecond precision
func (h *HDRHistogram) RecordLatencies(d time.Duration, n int64) error { func (h *HDRHistogram) RecordLatencies(d time.Duration, n int64) error {
return h.RecordValues(int64(d/time.Microsecond), n) return h.RecordValues(int64(d/time.Microsecond), n)
} }
// Reset reset a HDRHistogram
func (h *HDRHistogram) Reset() { func (h *HDRHistogram) Reset() {
h.h.Reset() h.h.Reset()
} }
// ValueAtQuantile sets value at quantile
func (h *HDRHistogram) ValueAtQuantile(q float64) int64 { func (h *HDRHistogram) ValueAtQuantile(q float64) int64 {
return h.h.ValueAtQuantile(q) return h.h.ValueAtQuantile(q)
} }
// RecordValues sets record values
func (h *HDRHistogram) RecordValues(v, n int64) error { func (h *HDRHistogram) RecordValues(v, n int64) error {
return h.h.RecordValues(v, n) return h.h.RecordValues(v, n)
} }
// Merge merge a HDRHistogram
func (h *HDRHistogram) Merge(other *HDRHistogram) error { func (h *HDRHistogram) Merge(other *HDRHistogram) error {
if other == nil { if other == nil {
return fmt.Errorf("other is nil") return fmt.Errorf("other is nil")
@ -75,6 +81,7 @@ func (h *HDRHistogram) Merge(other *HDRHistogram) error {
type rhOptSetter func(r *RollingHDRHistogram) error type rhOptSetter func(r *RollingHDRHistogram) error
// RollingClock sets a clock
func RollingClock(clock timetools.TimeProvider) rhOptSetter { func RollingClock(clock timetools.TimeProvider) rhOptSetter {
return func(r *RollingHDRHistogram) error { return func(r *RollingHDRHistogram) error {
r.clock = clock r.clock = clock
@ -82,7 +89,7 @@ func RollingClock(clock timetools.TimeProvider) rhOptSetter {
} }
} }
// RollingHistogram holds multiple histograms and rotates every period. // RollingHDRHistogram holds multiple histograms and rotates every period.
// It provides resulting histogram as a result of a call of 'Merged' function. // It provides resulting histogram as a result of a call of 'Merged' function.
type RollingHDRHistogram struct { type RollingHDRHistogram struct {
idx int idx int
@ -96,6 +103,7 @@ type RollingHDRHistogram struct {
clock timetools.TimeProvider clock timetools.TimeProvider
} }
// NewRollingHDRHistogram created a new RollingHDRHistogram
func NewRollingHDRHistogram(low, high int64, sigfigs int, period time.Duration, bucketCount int, options ...rhOptSetter) (*RollingHDRHistogram, error) { func NewRollingHDRHistogram(low, high int64, sigfigs int, period time.Duration, bucketCount int, options ...rhOptSetter) (*RollingHDRHistogram, error) {
rh := &RollingHDRHistogram{ rh := &RollingHDRHistogram{
bucketCount: bucketCount, bucketCount: bucketCount,
@ -127,6 +135,7 @@ func NewRollingHDRHistogram(low, high int64, sigfigs int, period time.Duration,
return rh, nil return rh, nil
} }
// Export export a RollingHDRHistogram
func (r *RollingHDRHistogram) Export() *RollingHDRHistogram { func (r *RollingHDRHistogram) Export() *RollingHDRHistogram {
export := &RollingHDRHistogram{} export := &RollingHDRHistogram{}
export.idx = r.idx export.idx = r.idx
@ -147,6 +156,7 @@ func (r *RollingHDRHistogram) Export() *RollingHDRHistogram {
return export return export
} }
// Append append a RollingHDRHistogram
func (r *RollingHDRHistogram) Append(o *RollingHDRHistogram) error { func (r *RollingHDRHistogram) Append(o *RollingHDRHistogram) error {
if r.bucketCount != o.bucketCount || r.period != o.period || r.low != o.low || r.high != o.high || r.sigfigs != o.sigfigs { if r.bucketCount != o.bucketCount || r.period != o.period || r.low != o.low || r.high != o.high || r.sigfigs != o.sigfigs {
return fmt.Errorf("can't merge") return fmt.Errorf("can't merge")
@ -160,6 +170,7 @@ func (r *RollingHDRHistogram) Append(o *RollingHDRHistogram) error {
return nil return nil
} }
// Reset reset a RollingHDRHistogram
func (r *RollingHDRHistogram) Reset() { func (r *RollingHDRHistogram) Reset() {
r.idx = 0 r.idx = 0
r.lastRoll = r.clock.UtcNow() r.lastRoll = r.clock.UtcNow()
@ -173,6 +184,7 @@ func (r *RollingHDRHistogram) rotate() {
r.buckets[r.idx].Reset() r.buckets[r.idx].Reset()
} }
// Merged gets merged histogram
func (r *RollingHDRHistogram) Merged() (*HDRHistogram, error) { func (r *RollingHDRHistogram) Merged() (*HDRHistogram, error) {
m, err := NewHDRHistogram(r.low, r.high, r.sigfigs) m, err := NewHDRHistogram(r.low, r.high, r.sigfigs)
if err != nil { if err != nil {
@ -194,10 +206,12 @@ func (r *RollingHDRHistogram) getHist() *HDRHistogram {
return r.buckets[r.idx] return r.buckets[r.idx]
} }
// RecordLatencies sets records latencies
func (r *RollingHDRHistogram) RecordLatencies(v time.Duration, n int64) error { func (r *RollingHDRHistogram) RecordLatencies(v time.Duration, n int64) error {
return r.getHist().RecordLatencies(v, n) return r.getHist().RecordLatencies(v, n)
} }
// RecordValues set record values
func (r *RollingHDRHistogram) RecordValues(v, n int64) error { func (r *RollingHDRHistogram) RecordValues(v, n int64) error {
return r.getHist().RecordValues(v, n) return r.getHist().RecordValues(v, n)
} }

View file

@ -8,6 +8,7 @@ import (
type ratioOptSetter func(r *RatioCounter) error type ratioOptSetter func(r *RatioCounter) error
// RatioClock sets a clock
func RatioClock(clock timetools.TimeProvider) ratioOptSetter { func RatioClock(clock timetools.TimeProvider) ratioOptSetter {
return func(r *RatioCounter) error { return func(r *RatioCounter) error {
r.clock = clock r.clock = clock
@ -22,6 +23,7 @@ type RatioCounter struct {
b *RollingCounter b *RollingCounter
} }
// NewRatioCounter creates a new RatioCounter
func NewRatioCounter(buckets int, resolution time.Duration, options ...ratioOptSetter) (*RatioCounter, error) { func NewRatioCounter(buckets int, resolution time.Duration, options ...ratioOptSetter) (*RatioCounter, error) {
rc := &RatioCounter{} rc := &RatioCounter{}
@ -50,39 +52,48 @@ func NewRatioCounter(buckets int, resolution time.Duration, options ...ratioOptS
return rc, nil return rc, nil
} }
// Reset reset the counter
func (r *RatioCounter) Reset() { func (r *RatioCounter) Reset() {
r.a.Reset() r.a.Reset()
r.b.Reset() r.b.Reset()
} }
// IsReady returns true if the counter is ready
func (r *RatioCounter) IsReady() bool { func (r *RatioCounter) IsReady() bool {
return r.a.countedBuckets+r.b.countedBuckets >= len(r.a.values) return r.a.countedBuckets+r.b.countedBuckets >= len(r.a.values)
} }
// CountA gets count A
func (r *RatioCounter) CountA() int64 { func (r *RatioCounter) CountA() int64 {
return r.a.Count() return r.a.Count()
} }
// CountB gets count B
func (r *RatioCounter) CountB() int64 { func (r *RatioCounter) CountB() int64 {
return r.b.Count() return r.b.Count()
} }
// Resolution gets resolution
func (r *RatioCounter) Resolution() time.Duration { func (r *RatioCounter) Resolution() time.Duration {
return r.a.Resolution() return r.a.Resolution()
} }
// Buckets gets buckets
func (r *RatioCounter) Buckets() int { func (r *RatioCounter) Buckets() int {
return r.a.Buckets() return r.a.Buckets()
} }
// WindowSize gets windows size
func (r *RatioCounter) WindowSize() time.Duration { func (r *RatioCounter) WindowSize() time.Duration {
return r.a.WindowSize() return r.a.WindowSize()
} }
// ProcessedCount gets processed count
func (r *RatioCounter) ProcessedCount() int64 { func (r *RatioCounter) ProcessedCount() int64 {
return r.CountA() + r.CountB() return r.CountA() + r.CountB()
} }
// Ratio gets ratio
func (r *RatioCounter) Ratio() float64 { func (r *RatioCounter) Ratio() float64 {
a := r.a.Count() a := r.a.Count()
b := r.b.Count() b := r.b.Count()
@ -93,28 +104,34 @@ func (r *RatioCounter) Ratio() float64 {
return float64(a) / float64(a+b) return float64(a) / float64(a+b)
} }
// IncA increment counter A
func (r *RatioCounter) IncA(v int) { func (r *RatioCounter) IncA(v int) {
r.a.Inc(v) r.a.Inc(v)
} }
// IncB increment counter B
func (r *RatioCounter) IncB(v int) { func (r *RatioCounter) IncB(v int) {
r.b.Inc(v) r.b.Inc(v)
} }
// TestMeter a test meter
type TestMeter struct { type TestMeter struct {
Rate float64 Rate float64
NotReady bool NotReady bool
WindowSize time.Duration WindowSize time.Duration
} }
// GetWindowSize gets windows size
func (tm *TestMeter) GetWindowSize() time.Duration { func (tm *TestMeter) GetWindowSize() time.Duration {
return tm.WindowSize return tm.WindowSize
} }
// IsReady returns true if the meter is ready
func (tm *TestMeter) IsReady() bool { func (tm *TestMeter) IsReady() bool {
return !tm.NotReady return !tm.NotReady
} }
// GetRate gets rate
func (tm *TestMeter) GetRate() float64 { func (tm *TestMeter) GetRate() float64 {
return tm.Rate return tm.Rate
} }

View file

@ -29,10 +29,16 @@ type RTMetrics struct {
type rrOptSetter func(r *RTMetrics) error type rrOptSetter func(r *RTMetrics) error
// NewRTMetricsFn builder function type
type NewRTMetricsFn func() (*RTMetrics, error) type NewRTMetricsFn func() (*RTMetrics, error)
// NewCounterFn builder function type
type NewCounterFn func() (*RollingCounter, error) type NewCounterFn func() (*RollingCounter, error)
// NewRollingHistogramFn builder function type
type NewRollingHistogramFn func() (*RollingHDRHistogram, error) type NewRollingHistogramFn func() (*RollingHDRHistogram, error)
// RTCounter set a builder function for Counter
func RTCounter(new NewCounterFn) rrOptSetter { func RTCounter(new NewCounterFn) rrOptSetter {
return func(r *RTMetrics) error { return func(r *RTMetrics) error {
r.newCounter = new r.newCounter = new
@ -40,13 +46,15 @@ func RTCounter(new NewCounterFn) rrOptSetter {
} }
} }
func RTHistogram(new NewRollingHistogramFn) rrOptSetter { // RTHistogram set a builder function for RollingHistogram
func RTHistogram(fn NewRollingHistogramFn) rrOptSetter {
return func(r *RTMetrics) error { return func(r *RTMetrics) error {
r.newHist = new r.newHist = fn
return nil return nil
} }
} }
// RTClock sets a clock
func RTClock(clock timetools.TimeProvider) rrOptSetter { func RTClock(clock timetools.TimeProvider) rrOptSetter {
return func(r *RTMetrics) error { return func(r *RTMetrics) error {
r.clock = clock r.clock = clock
@ -103,7 +111,7 @@ func NewRTMetrics(settings ...rrOptSetter) (*RTMetrics, error) {
return m, nil return m, nil
} }
// Returns a new RTMetrics which is a copy of the current one // Export Returns a new RTMetrics which is a copy of the current one
func (m *RTMetrics) Export() *RTMetrics { func (m *RTMetrics) Export() *RTMetrics {
m.statusCodesLock.RLock() m.statusCodesLock.RLock()
defer m.statusCodesLock.RUnlock() defer m.statusCodesLock.RUnlock()
@ -130,11 +138,12 @@ func (m *RTMetrics) Export() *RTMetrics {
return export return export
} }
// CounterWindowSize gets total windows size
func (m *RTMetrics) CounterWindowSize() time.Duration { func (m *RTMetrics) CounterWindowSize() time.Duration {
return m.total.WindowSize() return m.total.WindowSize()
} }
// GetNetworkErrorRatio calculates the amont of network errors such as time outs and dropped connection // NetworkErrorRatio calculates the amont of network errors such as time outs and dropped connection
// that occurred in the given time window compared to the total requests count. // that occurred in the given time window compared to the total requests count.
func (m *RTMetrics) NetworkErrorRatio() float64 { func (m *RTMetrics) NetworkErrorRatio() float64 {
if m.total.Count() == 0 { if m.total.Count() == 0 {
@ -143,7 +152,7 @@ func (m *RTMetrics) NetworkErrorRatio() float64 {
return float64(m.netErrors.Count()) / float64(m.total.Count()) return float64(m.netErrors.Count()) / float64(m.total.Count())
} }
// GetResponseCodeRatio calculates ratio of count(startA to endA) / count(startB to endB) // ResponseCodeRatio calculates ratio of count(startA to endA) / count(startB to endB)
func (m *RTMetrics) ResponseCodeRatio(startA, endA, startB, endB int) float64 { func (m *RTMetrics) ResponseCodeRatio(startA, endA, startB, endB int) float64 {
a := int64(0) a := int64(0)
b := int64(0) b := int64(0)
@ -163,6 +172,7 @@ func (m *RTMetrics) ResponseCodeRatio(startA, endA, startB, endB int) float64 {
return 0 return 0
} }
// Append append a metric
func (m *RTMetrics) Append(other *RTMetrics) error { func (m *RTMetrics) Append(other *RTMetrics) error {
if m == other { if m == other {
return errors.New("RTMetrics cannot append to self") return errors.New("RTMetrics cannot append to self")
@ -196,6 +206,7 @@ func (m *RTMetrics) Append(other *RTMetrics) error {
return m.histogram.Append(copied.histogram) return m.histogram.Append(copied.histogram)
} }
// Record records a metric
func (m *RTMetrics) Record(code int, duration time.Duration) { func (m *RTMetrics) Record(code int, duration time.Duration) {
m.total.Inc(1) m.total.Inc(1)
if code == http.StatusGatewayTimeout || code == http.StatusBadGateway { if code == http.StatusGatewayTimeout || code == http.StatusBadGateway {
@ -205,17 +216,17 @@ func (m *RTMetrics) Record(code int, duration time.Duration) {
m.recordLatency(duration) m.recordLatency(duration)
} }
// GetTotalCount returns total count of processed requests collected. // TotalCount returns total count of processed requests collected.
func (m *RTMetrics) TotalCount() int64 { func (m *RTMetrics) TotalCount() int64 {
return m.total.Count() return m.total.Count()
} }
// GetNetworkErrorCount returns total count of processed requests observed // NetworkErrorCount returns total count of processed requests observed
func (m *RTMetrics) NetworkErrorCount() int64 { func (m *RTMetrics) NetworkErrorCount() int64 {
return m.netErrors.Count() return m.netErrors.Count()
} }
// GetStatusCodesCounts returns map with counts of the response codes // StatusCodesCounts returns map with counts of the response codes
func (m *RTMetrics) StatusCodesCounts() map[int]int64 { func (m *RTMetrics) StatusCodesCounts() map[int]int64 {
sc := make(map[int]int64) sc := make(map[int]int64)
m.statusCodesLock.RLock() m.statusCodesLock.RLock()
@ -228,13 +239,14 @@ func (m *RTMetrics) StatusCodesCounts() map[int]int64 {
return sc return sc
} }
// GetLatencyHistogram computes and returns resulting histogram with latencies observed. // LatencyHistogram computes and returns resulting histogram with latencies observed.
func (m *RTMetrics) LatencyHistogram() (*HDRHistogram, error) { func (m *RTMetrics) LatencyHistogram() (*HDRHistogram, error) {
m.histogramLock.Lock() m.histogramLock.Lock()
defer m.histogramLock.Unlock() defer m.histogramLock.Unlock()
return m.histogram.Merged() return m.histogram.Merged()
} }
// Reset reset metrics
func (m *RTMetrics) Reset() { func (m *RTMetrics) Reset() {
m.statusCodesLock.Lock() m.statusCodesLock.Lock()
defer m.statusCodesLock.Unlock() defer m.statusCodesLock.Unlock()
@ -284,7 +296,7 @@ const (
counterResolution = time.Second counterResolution = time.Second
histMin = 1 histMin = 1
histMax = 3600000000 // 1 hour in microseconds histMax = 3600000000 // 1 hour in microseconds
histSignificantFigures = 2 // signigicant figures (1% precision) histSignificantFigures = 2 // significant figures (1% precision)
histBuckets = 6 // number of sub-histograms in a rolling histogram histBuckets = 6 // number of sub-histograms in a rolling histogram
histPeriod = 10 * time.Second // roll time histPeriod = 10 * time.Second // roll time
) )

View file

@ -7,6 +7,7 @@ import (
"github.com/mailgun/timetools" "github.com/mailgun/timetools"
) )
// UndefinedDelay default delay
const UndefinedDelay = -1 const UndefinedDelay = -1
// rate defines token bucket parameters. // rate defines token bucket parameters.
@ -20,7 +21,7 @@ func (r *rate) String() string {
return fmt.Sprintf("rate(%v/%v, burst=%v)", r.average, r.period, r.burst) return fmt.Sprintf("rate(%v/%v, burst=%v)", r.average, r.period, r.burst)
} }
// Implements token bucket algorithm (http://en.wikipedia.org/wiki/Token_bucket) // tokenBucket Implements token bucket algorithm (http://en.wikipedia.org/wiki/Token_bucket)
type tokenBucket struct { type tokenBucket struct {
// The time period controlled by the bucket in nanoseconds. // The time period controlled by the bucket in nanoseconds.
period time.Duration period time.Duration
@ -63,7 +64,7 @@ func (tb *tokenBucket) consume(tokens int64) (time.Duration, error) {
tb.updateAvailableTokens() tb.updateAvailableTokens()
tb.lastConsumed = 0 tb.lastConsumed = 0
if tokens > tb.burst { if tokens > tb.burst {
return UndefinedDelay, fmt.Errorf("Requested tokens larger than max tokens") return UndefinedDelay, fmt.Errorf("requested tokens larger than max tokens")
} }
if tb.availableTokens < tokens { if tb.availableTokens < tokens {
return tb.timeTillAvailable(tokens), nil return tb.timeTillAvailable(tokens), nil
@ -83,11 +84,11 @@ func (tb *tokenBucket) rollback() {
tb.lastConsumed = 0 tb.lastConsumed = 0
} }
// Update modifies `average` and `burst` fields of the token bucket according // update modifies `average` and `burst` fields of the token bucket according
// to the provided `Rate` // to the provided `Rate`
func (tb *tokenBucket) update(rate *rate) error { func (tb *tokenBucket) update(rate *rate) error {
if rate.period != tb.period { if rate.period != tb.period {
return fmt.Errorf("Period mismatch: %v != %v", tb.period, rate.period) return fmt.Errorf("period mismatch: %v != %v", tb.period, rate.period)
} }
tb.timePerToken = time.Duration(int64(tb.period) / rate.average) tb.timePerToken = time.Duration(int64(tb.period) / rate.average)
tb.burst = rate.burst tb.burst = rate.burst

View file

@ -2,11 +2,11 @@ package ratelimit
import ( import (
"fmt" "fmt"
"sort"
"strings" "strings"
"time" "time"
"github.com/mailgun/timetools" "github.com/mailgun/timetools"
"sort"
) )
// TokenBucketSet represents a set of TokenBucket covering different time periods. // TokenBucketSet represents a set of TokenBucket covering different time periods.
@ -16,7 +16,7 @@ type TokenBucketSet struct {
clock timetools.TimeProvider clock timetools.TimeProvider
} }
// newTokenBucketSet creates a `TokenBucketSet` from the specified `rates`. // NewTokenBucketSet creates a `TokenBucketSet` from the specified `rates`.
func NewTokenBucketSet(rates *RateSet, clock timetools.TimeProvider) *TokenBucketSet { func NewTokenBucketSet(rates *RateSet, clock timetools.TimeProvider) *TokenBucketSet {
tbs := new(TokenBucketSet) tbs := new(TokenBucketSet)
tbs.clock = clock tbs.clock = clock
@ -54,9 +54,10 @@ func (tbs *TokenBucketSet) Update(rates *RateSet) {
} }
} }
// Consume consume tokens
func (tbs *TokenBucketSet) Consume(tokens int64) (time.Duration, error) { func (tbs *TokenBucketSet) Consume(tokens int64) (time.Duration, error) {
var maxDelay time.Duration = UndefinedDelay var maxDelay time.Duration = UndefinedDelay
var firstErr error = nil var firstErr error
for _, tokenBucket := range tbs.buckets { for _, tokenBucket := range tbs.buckets {
// We keep calling `Consume` even after a error is returned for one of // We keep calling `Consume` even after a error is returned for one of
// buckets because that allows us to simplify the rollback procedure, // buckets because that allows us to simplify the rollback procedure,
@ -80,6 +81,7 @@ func (tbs *TokenBucketSet) Consume(tokens int64) (time.Duration, error) {
return maxDelay, firstErr return maxDelay, firstErr
} }
// GetMaxPeriod returns the max period
func (tbs *TokenBucketSet) GetMaxPeriod() time.Duration { func (tbs *TokenBucketSet) GetMaxPeriod() time.Duration {
return tbs.maxPeriod return tbs.maxPeriod
} }

View file

@ -1,4 +1,4 @@
// Tokenbucket based request rate limiter // Package ratelimit Tokenbucket based request rate limiter
package ratelimit package ratelimit
import ( import (
@ -13,6 +13,7 @@ import (
"github.com/vulcand/oxy/utils" "github.com/vulcand/oxy/utils"
) )
// DefaultCapacity default capacity
const DefaultCapacity = 65536 const DefaultCapacity = 65536
// RateSet maintains a set of rates. It can contain only one rate per period at a time. // RateSet maintains a set of rates. It can contain only one rate per period at a time.
@ -31,15 +32,15 @@ func NewRateSet() *RateSet {
// set then the new rate overrides the old one. // set then the new rate overrides the old one.
func (rs *RateSet) Add(period time.Duration, average int64, burst int64) error { func (rs *RateSet) Add(period time.Duration, average int64, burst int64) error {
if period <= 0 { if period <= 0 {
return fmt.Errorf("Invalid period: %v", period) return fmt.Errorf("invalid period: %v", period)
} }
if average <= 0 { if average <= 0 {
return fmt.Errorf("Invalid average: %v", average) return fmt.Errorf("invalid average: %v", average)
} }
if burst <= 0 { if burst <= 0 {
return fmt.Errorf("Invalid burst: %v", burst) return fmt.Errorf("invalid burst: %v", burst)
} }
rs.m[period] = &rate{period, average, burst} rs.m[period] = &rate{period: period, average: average, burst: burst}
return nil return nil
} }
@ -47,12 +48,15 @@ func (rs *RateSet) String() string {
return fmt.Sprint(rs.m) return fmt.Sprint(rs.m)
} }
// RateExtractor rate extractor
type RateExtractor interface { type RateExtractor interface {
Extract(r *http.Request) (*RateSet, error) Extract(r *http.Request) (*RateSet, error)
} }
// RateExtractorFunc rate extractor function type
type RateExtractorFunc func(r *http.Request) (*RateSet, error) type RateExtractorFunc func(r *http.Request) (*RateSet, error)
// Extract extract from request
func (e RateExtractorFunc) Extract(r *http.Request) (*RateSet, error) { func (e RateExtractorFunc) Extract(r *http.Request) (*RateSet, error) {
return e(r) return e(r)
} }
@ -68,20 +72,24 @@ type TokenLimiter struct {
errHandler utils.ErrorHandler errHandler utils.ErrorHandler
capacity int capacity int
next http.Handler next http.Handler
log *log.Logger
} }
// New constructs a `TokenLimiter` middleware instance. // New constructs a `TokenLimiter` middleware instance.
func New(next http.Handler, extract utils.SourceExtractor, defaultRates *RateSet, opts ...TokenLimiterOption) (*TokenLimiter, error) { func New(next http.Handler, extract utils.SourceExtractor, defaultRates *RateSet, opts ...TokenLimiterOption) (*TokenLimiter, error) {
if defaultRates == nil || len(defaultRates.m) == 0 { if defaultRates == nil || len(defaultRates.m) == 0 {
return nil, fmt.Errorf("Provide default rates") return nil, fmt.Errorf("provide default rates")
} }
if extract == nil { if extract == nil {
return nil, fmt.Errorf("Provide extract function") return nil, fmt.Errorf("provide extract function")
} }
tl := &TokenLimiter{ tl := &TokenLimiter{
next: next, next: next,
defaultRates: defaultRates, defaultRates: defaultRates,
extract: extract, extract: extract,
log: log.StandardLogger(),
} }
for _, o := range opts { for _, o := range opts {
@ -98,6 +106,17 @@ func New(next http.Handler, extract utils.SourceExtractor, defaultRates *RateSet
return tl, nil return tl, nil
} }
// Logger defines the logger the token limiter will use.
//
// It defaults to logrus.StandardLogger(), the global logger used by logrus.
func Logger(l *log.Logger) TokenLimiterOption {
return func(tl *TokenLimiter) error {
tl.log = l
return nil
}
}
// Wrap sets the next handler to be called by token limiter handler.
func (tl *TokenLimiter) Wrap(next http.Handler) { func (tl *TokenLimiter) Wrap(next http.Handler) {
tl.next = next tl.next = next
} }
@ -110,7 +129,7 @@ func (tl *TokenLimiter) ServeHTTP(w http.ResponseWriter, req *http.Request) {
} }
if err := tl.consumeRates(req, source, amount); err != nil { if err := tl.consumeRates(req, source, amount); err != nil {
log.Warnf("limiting request %v %v, limit: %v", req.Method, req.URL, err) tl.log.Warnf("limiting request %v %v, limit: %v", req.Method, req.URL, err)
tl.errHandler.ServeHTTP(w, req, err) tl.errHandler.ServeHTTP(w, req, err)
return return
} }
@ -155,7 +174,7 @@ func (tl *TokenLimiter) resolveRates(req *http.Request) *RateSet {
rates, err := tl.extractRates.Extract(req) rates, err := tl.extractRates.Extract(req)
if err != nil { if err != nil {
log.Errorf("Failed to retrieve rates: %v", err) tl.log.Errorf("Failed to retrieve rates: %v", err)
return tl.defaultRates return tl.defaultRates
} }
@ -167,6 +186,7 @@ func (tl *TokenLimiter) resolveRates(req *http.Request) *RateSet {
return rates return rates
} }
// MaxRateError max rate error
type MaxRateError struct { type MaxRateError struct {
delay time.Duration delay time.Duration
} }
@ -175,19 +195,21 @@ func (m *MaxRateError) Error() string {
return fmt.Sprintf("max rate reached: retry-in %v", m.delay) return fmt.Sprintf("max rate reached: retry-in %v", m.delay)
} }
type RateErrHandler struct { // RateErrHandler error handler
} type RateErrHandler struct{}
func (e *RateErrHandler) ServeHTTP(w http.ResponseWriter, req *http.Request, err error) { func (e *RateErrHandler) ServeHTTP(w http.ResponseWriter, req *http.Request, err error) {
if rerr, ok := err.(*MaxRateError); ok { if rerr, ok := err.(*MaxRateError); ok {
w.Header().Set("Retry-After", fmt.Sprintf("%.0f", rerr.delay.Seconds()))
w.Header().Set("X-Retry-In", rerr.delay.String()) w.Header().Set("X-Retry-In", rerr.delay.String())
w.WriteHeader(429) w.WriteHeader(http.StatusTooManyRequests)
w.Write([]byte(err.Error())) w.Write([]byte(err.Error()))
return return
} }
utils.DefaultHandler.ServeHTTP(w, req, err) utils.DefaultHandler.ServeHTTP(w, req, err)
} }
// TokenLimiterOption token limiter option type
type TokenLimiterOption func(l *TokenLimiter) error type TokenLimiterOption func(l *TokenLimiter) error
// ErrorHandler sets error handler of the server // ErrorHandler sets error handler of the server
@ -198,6 +220,7 @@ func ErrorHandler(h utils.ErrorHandler) TokenLimiterOption {
} }
} }
// ExtractRates sets the rate extractor
func ExtractRates(e RateExtractor) TokenLimiterOption { func ExtractRates(e RateExtractor) TokenLimiterOption {
return func(cl *TokenLimiter) error { return func(cl *TokenLimiter) error {
cl.extractRates = e cl.extractRates = e
@ -205,6 +228,7 @@ func ExtractRates(e RateExtractor) TokenLimiterOption {
} }
} }
// Clock sets the clock
func Clock(clock timetools.TimeProvider) TokenLimiterOption { func Clock(clock timetools.TimeProvider) TokenLimiterOption {
return func(cl *TokenLimiter) error { return func(cl *TokenLimiter) error {
cl.clock = clock cl.clock = clock
@ -212,6 +236,7 @@ func Clock(clock timetools.TimeProvider) TokenLimiterOption {
} }
} }
// Capacity sets the capacity
func Capacity(cap int) TokenLimiterOption { func Capacity(cap int) TokenLimiterOption {
return func(cl *TokenLimiter) error { return func(cl *TokenLimiter) error {
if cap <= 0 { if cap <= 0 {

View file

@ -2,4 +2,5 @@ package roundrobin
import "net/http" import "net/http"
// RequestRewriteListener function to rewrite request
type RequestRewriteListener func(oldReq *http.Request, newReq *http.Request) type RequestRewriteListener func(oldReq *http.Request, newReq *http.Request)

View file

@ -16,13 +16,14 @@ import (
// RebalancerOption - functional option setter for rebalancer // RebalancerOption - functional option setter for rebalancer
type RebalancerOption func(*Rebalancer) error type RebalancerOption func(*Rebalancer) error
// Meter measures server peformance and returns it's relative value via rating // Meter measures server performance and returns it's relative value via rating
type Meter interface { type Meter interface {
Rating() float64 Rating() float64
Record(int, time.Duration) Record(int, time.Duration)
IsReady() bool IsReady() bool
} }
// NewMeterFn type of functions to create new Meter
type NewMeterFn func() (Meter, error) type NewMeterFn func() (Meter, error)
// Rebalancer increases weights on servers that perform better than others. It also rolls back to original weights // Rebalancer increases weights on servers that perform better than others. It also rolls back to original weights
@ -52,8 +53,11 @@ type Rebalancer struct {
stickySession *StickySession stickySession *StickySession
requestRewriteListener RequestRewriteListener requestRewriteListener RequestRewriteListener
log *log.Logger
} }
// RebalancerClock sets a clock
func RebalancerClock(clock timetools.TimeProvider) RebalancerOption { func RebalancerClock(clock timetools.TimeProvider) RebalancerOption {
return func(r *Rebalancer) error { return func(r *Rebalancer) error {
r.clock = clock r.clock = clock
@ -61,6 +65,7 @@ func RebalancerClock(clock timetools.TimeProvider) RebalancerOption {
} }
} }
// RebalancerBackoff sets a beck off duration
func RebalancerBackoff(d time.Duration) RebalancerOption { func RebalancerBackoff(d time.Duration) RebalancerOption {
return func(r *Rebalancer) error { return func(r *Rebalancer) error {
r.backoffDuration = d r.backoffDuration = d
@ -68,6 +73,7 @@ func RebalancerBackoff(d time.Duration) RebalancerOption {
} }
} }
// RebalancerMeter sets a Meter builder function
func RebalancerMeter(newMeter NewMeterFn) RebalancerOption { func RebalancerMeter(newMeter NewMeterFn) RebalancerOption {
return func(r *Rebalancer) error { return func(r *Rebalancer) error {
r.newMeter = newMeter r.newMeter = newMeter
@ -83,6 +89,7 @@ func RebalancerErrorHandler(h utils.ErrorHandler) RebalancerOption {
} }
} }
// RebalancerStickySession sets a sticky session
func RebalancerStickySession(stickySession *StickySession) RebalancerOption { func RebalancerStickySession(stickySession *StickySession) RebalancerOption {
return func(r *Rebalancer) error { return func(r *Rebalancer) error {
r.stickySession = stickySession r.stickySession = stickySession
@ -90,7 +97,7 @@ func RebalancerStickySession(stickySession *StickySession) RebalancerOption {
} }
} }
// RebalancerErrorHandler is a functional argument that sets error handler of the server // RebalancerRequestRewriteListener is a functional argument that sets error handler of the server
func RebalancerRequestRewriteListener(rrl RequestRewriteListener) RebalancerOption { func RebalancerRequestRewriteListener(rrl RequestRewriteListener) RebalancerOption {
return func(r *Rebalancer) error { return func(r *Rebalancer) error {
r.requestRewriteListener = rrl r.requestRewriteListener = rrl
@ -98,11 +105,14 @@ func RebalancerRequestRewriteListener(rrl RequestRewriteListener) RebalancerOpti
} }
} }
// NewRebalancer creates a new Rebalancer
func NewRebalancer(handler balancerHandler, opts ...RebalancerOption) (*Rebalancer, error) { func NewRebalancer(handler balancerHandler, opts ...RebalancerOption) (*Rebalancer, error) {
rb := &Rebalancer{ rb := &Rebalancer{
mtx: &sync.Mutex{}, mtx: &sync.Mutex{},
next: handler, next: handler,
stickySession: nil, stickySession: nil,
log: log.StandardLogger(),
} }
for _, o := range opts { for _, o := range opts {
if err := o(rb); err != nil { if err := o(rb); err != nil {
@ -134,6 +144,17 @@ func NewRebalancer(handler balancerHandler, opts ...RebalancerOption) (*Rebalanc
return rb, nil return rb, nil
} }
// RebalancerLogger defines the logger the rebalancer will use.
//
// It defaults to logrus.StandardLogger(), the global logger used by logrus.
func RebalancerLogger(l *log.Logger) RebalancerOption {
return func(rb *Rebalancer) error {
rb.log = l
return nil
}
}
// Servers gets all servers
func (rb *Rebalancer) Servers() []*url.URL { func (rb *Rebalancer) Servers() []*url.URL {
rb.mtx.Lock() rb.mtx.Lock()
defer rb.mtx.Unlock() defer rb.mtx.Unlock()
@ -142,8 +163,8 @@ func (rb *Rebalancer) Servers() []*url.URL {
} }
func (rb *Rebalancer) ServeHTTP(w http.ResponseWriter, req *http.Request) { func (rb *Rebalancer) ServeHTTP(w http.ResponseWriter, req *http.Request) {
if log.GetLevel() >= log.DebugLevel { if rb.log.Level >= log.DebugLevel {
logEntry := log.WithField("Request", utils.DumpHttpRequest(req)) logEntry := rb.log.WithField("Request", utils.DumpHttpRequest(req))
logEntry.Debug("vulcand/oxy/roundrobin/rebalancer: begin ServeHttp on request") logEntry.Debug("vulcand/oxy/roundrobin/rebalancer: begin ServeHttp on request")
defer logEntry.Debug("vulcand/oxy/roundrobin/rebalancer: completed ServeHttp on request") defer logEntry.Debug("vulcand/oxy/roundrobin/rebalancer: completed ServeHttp on request")
} }
@ -169,25 +190,25 @@ func (rb *Rebalancer) ServeHTTP(w http.ResponseWriter, req *http.Request) {
} }
if !stuck { if !stuck {
url, err := rb.next.NextServer() fwdURL, err := rb.next.NextServer()
if err != nil { if err != nil {
rb.errHandler.ServeHTTP(w, req, err) rb.errHandler.ServeHTTP(w, req, err)
return return
} }
if log.GetLevel() >= log.DebugLevel { if log.GetLevel() >= log.DebugLevel {
//log which backend URL we're sending this request to // log which backend URL we're sending this request to
log.WithFields(log.Fields{"Request": utils.DumpHttpRequest(req), "ForwardURL": url}).Debugf("vulcand/oxy/roundrobin/rebalancer: Forwarding this request to URL") log.WithFields(log.Fields{"Request": utils.DumpHttpRequest(req), "ForwardURL": fwdURL}).Debugf("vulcand/oxy/roundrobin/rebalancer: Forwarding this request to URL")
} }
if rb.stickySession != nil { if rb.stickySession != nil {
rb.stickySession.StickBackend(url, &w) rb.stickySession.StickBackend(fwdURL, &w)
} }
newReq.URL = url newReq.URL = fwdURL
} }
//Emit event to a listener if one exists // Emit event to a listener if one exists
if rb.requestRewriteListener != nil { if rb.requestRewriteListener != nil {
rb.requestRewriteListener(req, &newReq) rb.requestRewriteListener(req, &newReq)
} }
@ -215,6 +236,7 @@ func (rb *Rebalancer) reset() {
rb.ratings = make([]float64, len(rb.servers)) rb.ratings = make([]float64, len(rb.servers))
} }
// Wrap sets the next handler to be called by rebalancer handler.
func (rb *Rebalancer) Wrap(next balancerHandler) error { func (rb *Rebalancer) Wrap(next balancerHandler) error {
if rb.next != nil { if rb.next != nil {
return fmt.Errorf("already bound to %T", rb.next) return fmt.Errorf("already bound to %T", rb.next)
@ -223,6 +245,7 @@ func (rb *Rebalancer) Wrap(next balancerHandler) error {
return nil return nil
} }
// UpsertServer upsert a server
func (rb *Rebalancer) UpsertServer(u *url.URL, options ...ServerOption) error { func (rb *Rebalancer) UpsertServer(u *url.URL, options ...ServerOption) error {
rb.mtx.Lock() rb.mtx.Lock()
defer rb.mtx.Unlock() defer rb.mtx.Unlock()
@ -239,6 +262,7 @@ func (rb *Rebalancer) UpsertServer(u *url.URL, options ...ServerOption) error {
return nil return nil
} }
// RemoveServer remove a server
func (rb *Rebalancer) RemoveServer(u *url.URL) error { func (rb *Rebalancer) RemoveServer(u *url.URL) error {
rb.mtx.Lock() rb.mtx.Lock()
defer rb.mtx.Unlock() defer rb.mtx.Unlock()
@ -289,7 +313,7 @@ func (rb *Rebalancer) findServer(u *url.URL) (*rbServer, int) {
return nil, -1 return nil, -1
} }
// Called on every load balancer ServeHTTP call, returns the suggested weights // adjustWeights Called on every load balancer ServeHTTP call, returns the suggested weights
// on every call, can adjust weights if needed. // on every call, can adjust weights if needed.
func (rb *Rebalancer) adjustWeights() { func (rb *Rebalancer) adjustWeights() {
rb.mtx.Lock() rb.mtx.Lock()
@ -319,7 +343,7 @@ func (rb *Rebalancer) adjustWeights() {
func (rb *Rebalancer) applyWeights() { func (rb *Rebalancer) applyWeights() {
for _, srv := range rb.servers { for _, srv := range rb.servers {
log.Debugf("upsert server %v, weight %v", srv.url, srv.curWeight) rb.log.Debugf("upsert server %v, weight %v", srv.url, srv.curWeight)
rb.next.UpsertServer(srv.url, Weight(srv.curWeight)) rb.next.UpsertServer(srv.url, Weight(srv.curWeight))
} }
} }
@ -331,7 +355,7 @@ func (rb *Rebalancer) setMarkedWeights() bool {
if srv.good { if srv.good {
weight := increase(srv.curWeight) weight := increase(srv.curWeight)
if weight <= FSMMaxWeight { if weight <= FSMMaxWeight {
log.Debugf("increasing weight of %v from %v to %v", srv.url, srv.curWeight, weight) rb.log.Debugf("increasing weight of %v from %v to %v", srv.url, srv.curWeight, weight)
srv.curWeight = weight srv.curWeight = weight
changed = true changed = true
} }
@ -378,7 +402,7 @@ func (rb *Rebalancer) markServers() bool {
} }
} }
if len(g) != 0 && len(b) != 0 { if len(g) != 0 && len(b) != 0 {
log.Debugf("bad: %v good: %v, ratings: %v", b, g, rb.ratings) rb.log.Debugf("bad: %v good: %v, ratings: %v", b, g, rb.ratings)
} }
return len(g) != 0 && len(b) != 0 return len(g) != 0 && len(b) != 0
} }
@ -433,9 +457,8 @@ func decrease(target, current int) int {
adjusted := current / FSMGrowFactor adjusted := current / FSMGrowFactor
if adjusted < target { if adjusted < target {
return target return target
} else {
return adjusted
} }
return adjusted
} }
// rebalancer server record that keeps track of the original weight supplied by user // rebalancer server record that keeps track of the original weight supplied by user
@ -448,9 +471,9 @@ type rbServer struct {
} }
const ( const (
// This is the maximum weight that handler will set for the server // FSMMaxWeight is the maximum weight that handler will set for the server
FSMMaxWeight = 4096 FSMMaxWeight = 4096
// Multiplier for the server weight // FSMGrowFactor Multiplier for the server weight
FSMGrowFactor = 4 FSMGrowFactor = 4
) )
@ -460,10 +483,12 @@ type codeMeter struct {
codeE int codeE int
} }
// Rating gets ratio
func (n *codeMeter) Rating() float64 { func (n *codeMeter) Rating() float64 {
return n.r.Ratio() return n.r.Ratio()
} }
// Record records a meter
func (n *codeMeter) Record(code int, d time.Duration) { func (n *codeMeter) Record(code int, d time.Duration) {
if code >= n.codeS && code < n.codeE { if code >= n.codeS && code < n.codeE {
n.r.IncA(1) n.r.IncA(1)
@ -472,6 +497,7 @@ func (n *codeMeter) Record(code int, d time.Duration) {
} }
} }
// IsReady returns true if the counter is ready
func (n *codeMeter) IsReady() bool { func (n *codeMeter) IsReady() bool {
return n.r.IsReady() return n.r.IsReady()
} }

View file

@ -1,4 +1,4 @@
// package roundrobin implements dynamic weighted round robin load balancer http handler // Package roundrobin implements dynamic weighted round robin load balancer http handler
package roundrobin package roundrobin
import ( import (
@ -30,6 +30,7 @@ func ErrorHandler(h utils.ErrorHandler) LBOption {
} }
} }
// EnableStickySession enable sticky session
func EnableStickySession(stickySession *StickySession) LBOption { func EnableStickySession(stickySession *StickySession) LBOption {
return func(s *RoundRobin) error { return func(s *RoundRobin) error {
s.stickySession = stickySession s.stickySession = stickySession
@ -37,7 +38,7 @@ func EnableStickySession(stickySession *StickySession) LBOption {
} }
} }
// ErrorHandler is a functional argument that sets error handler of the server // RoundRobinRequestRewriteListener is a functional argument that sets error handler of the server
func RoundRobinRequestRewriteListener(rrl RequestRewriteListener) LBOption { func RoundRobinRequestRewriteListener(rrl RequestRewriteListener) LBOption {
return func(s *RoundRobin) error { return func(s *RoundRobin) error {
s.requestRewriteListener = rrl s.requestRewriteListener = rrl
@ -45,6 +46,7 @@ func RoundRobinRequestRewriteListener(rrl RequestRewriteListener) LBOption {
} }
} }
// RoundRobin implements dynamic weighted round robin load balancer http handler
type RoundRobin struct { type RoundRobin struct {
mutex *sync.Mutex mutex *sync.Mutex
next http.Handler next http.Handler
@ -55,8 +57,11 @@ type RoundRobin struct {
currentWeight int currentWeight int
stickySession *StickySession stickySession *StickySession
requestRewriteListener RequestRewriteListener requestRewriteListener RequestRewriteListener
log *log.Logger
} }
// New created a new RoundRobin
func New(next http.Handler, opts ...LBOption) (*RoundRobin, error) { func New(next http.Handler, opts ...LBOption) (*RoundRobin, error) {
rr := &RoundRobin{ rr := &RoundRobin{
next: next, next: next,
@ -64,6 +69,8 @@ func New(next http.Handler, opts ...LBOption) (*RoundRobin, error) {
mutex: &sync.Mutex{}, mutex: &sync.Mutex{},
servers: []*server{}, servers: []*server{},
stickySession: nil, stickySession: nil,
log: log.StandardLogger(),
} }
for _, o := range opts { for _, o := range opts {
if err := o(rr); err != nil { if err := o(rr); err != nil {
@ -76,13 +83,24 @@ func New(next http.Handler, opts ...LBOption) (*RoundRobin, error) {
return rr, nil return rr, nil
} }
// RoundRobinLogger defines the logger the round robin load balancer will use.
//
// It defaults to logrus.StandardLogger(), the global logger used by logrus.
func RoundRobinLogger(l *log.Logger) LBOption {
return func(r *RoundRobin) error {
r.log = l
return nil
}
}
// Next returns the next handler
func (r *RoundRobin) Next() http.Handler { func (r *RoundRobin) Next() http.Handler {
return r.next return r.next
} }
func (r *RoundRobin) ServeHTTP(w http.ResponseWriter, req *http.Request) { func (r *RoundRobin) ServeHTTP(w http.ResponseWriter, req *http.Request) {
if log.GetLevel() >= log.DebugLevel { if r.log.Level >= log.DebugLevel {
logEntry := log.WithField("Request", utils.DumpHttpRequest(req)) logEntry := r.log.WithField("Request", utils.DumpHttpRequest(req))
logEntry.Debug("vulcand/oxy/roundrobin/rr: begin ServeHttp on request") logEntry.Debug("vulcand/oxy/roundrobin/rr: begin ServeHttp on request")
defer logEntry.Debug("vulcand/oxy/roundrobin/rr: completed ServeHttp on request") defer logEntry.Debug("vulcand/oxy/roundrobin/rr: completed ServeHttp on request")
} }
@ -116,12 +134,12 @@ func (r *RoundRobin) ServeHTTP(w http.ResponseWriter, req *http.Request) {
newReq.URL = url newReq.URL = url
} }
if log.GetLevel() >= log.DebugLevel { if r.log.Level >= log.DebugLevel {
//log which backend URL we're sending this request to // log which backend URL we're sending this request to
log.WithFields(log.Fields{"Request": utils.DumpHttpRequest(req), "ForwardURL": newReq.URL}).Debugf("vulcand/oxy/roundrobin/rr: Forwarding this request to URL") r.log.WithFields(log.Fields{"Request": utils.DumpHttpRequest(req), "ForwardURL": newReq.URL}).Debugf("vulcand/oxy/roundrobin/rr: Forwarding this request to URL")
} }
//Emit event to a listener if one exists // Emit event to a listener if one exists
if r.requestRewriteListener != nil { if r.requestRewriteListener != nil {
r.requestRewriteListener(req, &newReq) r.requestRewriteListener(req, &newReq)
} }
@ -129,6 +147,7 @@ func (r *RoundRobin) ServeHTTP(w http.ResponseWriter, req *http.Request) {
r.next.ServeHTTP(w, &newReq) r.next.ServeHTTP(w, &newReq)
} }
// NextServer gets the next server
func (r *RoundRobin) NextServer() (*url.URL, error) { func (r *RoundRobin) NextServer() (*url.URL, error) {
srv, err := r.nextServer() srv, err := r.nextServer()
if err != nil { if err != nil {
@ -172,6 +191,7 @@ func (r *RoundRobin) nextServer() (*server, error) {
} }
} }
// RemoveServer remove a server
func (r *RoundRobin) RemoveServer(u *url.URL) error { func (r *RoundRobin) RemoveServer(u *url.URL) error {
r.mutex.Lock() r.mutex.Lock()
defer r.mutex.Unlock() defer r.mutex.Unlock()
@ -185,6 +205,7 @@ func (r *RoundRobin) RemoveServer(u *url.URL) error {
return nil return nil
} }
// Servers gets servers URL
func (r *RoundRobin) Servers() []*url.URL { func (r *RoundRobin) Servers() []*url.URL {
r.mutex.Lock() r.mutex.Lock()
defer r.mutex.Unlock() defer r.mutex.Unlock()
@ -196,6 +217,7 @@ func (r *RoundRobin) Servers() []*url.URL {
return out return out
} }
// ServerWeight gets the server weight
func (r *RoundRobin) ServerWeight(u *url.URL) (int, bool) { func (r *RoundRobin) ServerWeight(u *url.URL) (int, bool) {
r.mutex.Lock() r.mutex.Lock()
defer r.mutex.Unlock() defer r.mutex.Unlock()
@ -206,7 +228,7 @@ func (r *RoundRobin) ServerWeight(u *url.URL) (int, bool) {
return -1, false return -1, false
} }
// In case if server is already present in the load balancer, returns error // UpsertServer In case if server is already present in the load balancer, returns error
func (r *RoundRobin) UpsertServer(u *url.URL, options ...ServerOption) error { func (r *RoundRobin) UpsertServer(u *url.URL, options ...ServerOption) error {
r.mutex.Lock() r.mutex.Lock()
defer r.mutex.Unlock() defer r.mutex.Unlock()
@ -306,6 +328,7 @@ type server struct {
var defaultWeight = 1 var defaultWeight = 1
// SetDefaultWeight sets the default server weight
func SetDefaultWeight(weight int) error { func SetDefaultWeight(weight int) error {
if weight < 0 { if weight < 0 {
return fmt.Errorf("default weight should be >= 0") return fmt.Errorf("default weight should be >= 0")

View file

@ -1,4 +1,3 @@
// package stickysession is a mixin for load balancers that implements layer 7 (http cookie) session affinity
package roundrobin package roundrobin
import ( import (
@ -6,12 +5,14 @@ import (
"net/url" "net/url"
) )
// StickySession is a mixin for load balancers that implements layer 7 (http cookie) session affinity
type StickySession struct { type StickySession struct {
cookieName string cookieName string
} }
// NewStickySession creates a new StickySession
func NewStickySession(cookieName string) *StickySession { func NewStickySession(cookieName string) *StickySession {
return &StickySession{cookieName} return &StickySession{cookieName: cookieName}
} }
// GetBackend returns the backend URL stored in the sticky cookie, iff the backend is still in the valid list of servers. // GetBackend returns the backend URL stored in the sticky cookie, iff the backend is still in the valid list of servers.
@ -32,11 +33,11 @@ func (s *StickySession) GetBackend(req *http.Request, servers []*url.URL) (*url.
if s.isBackendAlive(serverURL, servers) { if s.isBackendAlive(serverURL, servers) {
return serverURL, true, nil return serverURL, true, nil
} else {
return nil, false, nil
} }
return nil, false, nil
} }
// StickBackend creates and sets the cookie
func (s *StickySession) StickBackend(backend *url.URL, w *http.ResponseWriter) { func (s *StickySession) StickBackend(backend *url.URL, w *http.ResponseWriter) {
cookie := &http.Cookie{Name: s.cookieName, Value: backend.String(), Path: "/"} cookie := &http.Cookie{Name: s.cookieName, Value: backend.String(), Path: "/"}
http.SetCookie(*w, cookie) http.SetCookie(*w, cookie)

View file

@ -6,6 +6,7 @@ import (
"strings" "strings"
) )
// BasicAuth basic auth information
type BasicAuth struct { type BasicAuth struct {
Username string Username string
Password string Password string
@ -16,6 +17,7 @@ func (ba *BasicAuth) String() string {
return fmt.Sprintf("Basic %s", encoded) return fmt.Sprintf("Basic %s", encoded)
} }
// ParseAuthHeader creates a new BasicAuth from header values
func ParseAuthHeader(header string) (*BasicAuth, error) { func ParseAuthHeader(header string) (*BasicAuth, error) {
values := strings.Fields(header) values := strings.Fields(header)
if len(values) != 2 { if len(values) != 2 {

View file

@ -9,6 +9,7 @@ import (
"net/url" "net/url"
) )
// SerializableHttpRequest serializable HTTP request
type SerializableHttpRequest struct { type SerializableHttpRequest struct {
Method string Method string
URL *url.URL URL *url.URL
@ -28,6 +29,7 @@ type SerializableHttpRequest struct {
TLS *tls.ConnectionState TLS *tls.ConnectionState
} }
// Clone clone a request
func Clone(r *http.Request) *SerializableHttpRequest { func Clone(r *http.Request) *SerializableHttpRequest {
if r == nil { if r == nil {
return nil return nil
@ -47,14 +49,16 @@ func Clone(r *http.Request) *SerializableHttpRequest {
return rc return rc
} }
// ToJson serializes to JSON
func (s *SerializableHttpRequest) ToJson() string { func (s *SerializableHttpRequest) ToJson() string {
if jsonVal, err := json.Marshal(s); err != nil || jsonVal == nil { jsonVal, err := json.Marshal(s)
return fmt.Sprintf("Error marshalling SerializableHttpRequest to json: %s", err.Error()) if err != nil || jsonVal == nil {
} else { return fmt.Sprintf("Error marshalling SerializableHttpRequest to json: %s", err)
return string(jsonVal)
} }
return string(jsonVal)
} }
// DumpHttpRequest dump a HTTP request to JSON
func DumpHttpRequest(req *http.Request) string { func DumpHttpRequest(req *http.Request) string {
return fmt.Sprintf("%v", Clone(req).ToJson()) return Clone(req).ToJson()
} }

View file

@ -8,14 +8,16 @@ import (
log "github.com/sirupsen/logrus" log "github.com/sirupsen/logrus"
) )
// ErrorHandler error handler
type ErrorHandler interface { type ErrorHandler interface {
ServeHTTP(w http.ResponseWriter, req *http.Request, err error) ServeHTTP(w http.ResponseWriter, req *http.Request, err error)
} }
// DefaultHandler default error handler
var DefaultHandler ErrorHandler = &StdHandler{} var DefaultHandler ErrorHandler = &StdHandler{}
type StdHandler struct { // StdHandler Standard error handler
} type StdHandler struct{}
func (e *StdHandler) ServeHTTP(w http.ResponseWriter, req *http.Request, err error) { func (e *StdHandler) ServeHTTP(w http.ResponseWriter, req *http.Request, err error) {
statusCode := http.StatusInternalServerError statusCode := http.StatusInternalServerError
@ -33,6 +35,7 @@ func (e *StdHandler) ServeHTTP(w http.ResponseWriter, req *http.Request, err err
log.Debugf("'%d %s' caused by: %v", statusCode, http.StatusText(statusCode), err) log.Debugf("'%d %s' caused by: %v", statusCode, http.StatusText(statusCode), err)
} }
// ErrorHandlerFunc error handler function type
type ErrorHandlerFunc func(http.ResponseWriter, *http.Request, error) type ErrorHandlerFunc func(http.ResponseWriter, *http.Request, error)
// ServeHTTP calls f(w, r). // ServeHTTP calls f(w, r).

View file

@ -12,18 +12,29 @@ import (
log "github.com/sirupsen/logrus" log "github.com/sirupsen/logrus"
) )
// ProxyWriter calls recorder, used to debug logs
type ProxyWriter struct { type ProxyWriter struct {
W http.ResponseWriter w http.ResponseWriter
code int code int
length int64 length int64
log *log.Logger
} }
func NewProxyWriter(writer http.ResponseWriter) *ProxyWriter { // NewProxyWriter creates a new ProxyWriter
func NewProxyWriter(w http.ResponseWriter) *ProxyWriter {
return NewProxyWriterWithLogger(w, log.StandardLogger())
}
// NewProxyWriterWithLogger creates a new ProxyWriter
func NewProxyWriterWithLogger(w http.ResponseWriter, l *log.Logger) *ProxyWriter {
return &ProxyWriter{ return &ProxyWriter{
W: writer, w: w,
log: l,
} }
} }
// StatusCode gets status code
func (p *ProxyWriter) StatusCode() int { func (p *ProxyWriter) StatusCode() int {
if p.code == 0 { if p.code == 0 {
// per contract standard lib will set this to http.StatusOK if not set // per contract standard lib will set this to http.StatusOK if not set
@ -33,46 +44,54 @@ func (p *ProxyWriter) StatusCode() int {
return p.code return p.code
} }
// GetLength gets content length
func (p *ProxyWriter) GetLength() int64 { func (p *ProxyWriter) GetLength() int64 {
return p.length return p.length
} }
// Header gets response header
func (p *ProxyWriter) Header() http.Header { func (p *ProxyWriter) Header() http.Header {
return p.W.Header() return p.w.Header()
} }
func (p *ProxyWriter) Write(buf []byte) (int, error) { func (p *ProxyWriter) Write(buf []byte) (int, error) {
p.length = p.length + int64(len(buf)) p.length = p.length + int64(len(buf))
return p.W.Write(buf) return p.w.Write(buf)
} }
// WriteHeader writes status code
func (p *ProxyWriter) WriteHeader(code int) { func (p *ProxyWriter) WriteHeader(code int) {
p.code = code p.code = code
p.W.WriteHeader(code) p.w.WriteHeader(code)
} }
// Flush flush the writer
func (p *ProxyWriter) Flush() { func (p *ProxyWriter) Flush() {
if f, ok := p.W.(http.Flusher); ok { if f, ok := p.w.(http.Flusher); ok {
f.Flush() f.Flush()
} }
} }
// CloseNotify returns a channel that receives at most a single value (true)
// when the client connection has gone away.
func (p *ProxyWriter) CloseNotify() <-chan bool { func (p *ProxyWriter) CloseNotify() <-chan bool {
if cn, ok := p.W.(http.CloseNotifier); ok { if cn, ok := p.w.(http.CloseNotifier); ok {
return cn.CloseNotify() return cn.CloseNotify()
} }
log.Debugf("Upstream ResponseWriter of type %v does not implement http.CloseNotifier. Returning dummy channel.", reflect.TypeOf(p.W)) p.log.Debugf("Upstream ResponseWriter of type %v does not implement http.CloseNotifier. Returning dummy channel.", reflect.TypeOf(p.w))
return make(<-chan bool) return make(<-chan bool)
} }
// Hijack lets the caller take over the connection.
func (p *ProxyWriter) Hijack() (net.Conn, *bufio.ReadWriter, error) { func (p *ProxyWriter) Hijack() (net.Conn, *bufio.ReadWriter, error) {
if hi, ok := p.W.(http.Hijacker); ok { if hi, ok := p.w.(http.Hijacker); ok {
return hi.Hijack() return hi.Hijack()
} }
log.Debugf("Upstream ResponseWriter of type %v does not implement http.Hijacker. Returning dummy channel.", reflect.TypeOf(p.W)) p.log.Debugf("Upstream ResponseWriter of type %v does not implement http.Hijacker. Returning dummy channel.", reflect.TypeOf(p.w))
return nil, nil, fmt.Errorf("the response writer that was wrapped in this proxy, does not implement http.Hijacker. It is of type: %v", reflect.TypeOf(p.W)) return nil, nil, fmt.Errorf("the response writer that was wrapped in this proxy, does not implement http.Hijacker. It is of type: %v", reflect.TypeOf(p.w))
} }
// NewBufferWriter creates a new BufferWriter
func NewBufferWriter(w io.WriteCloser) *BufferWriter { func NewBufferWriter(w io.WriteCloser) *BufferWriter {
return &BufferWriter{ return &BufferWriter{
W: w, W: w,
@ -80,16 +99,19 @@ func NewBufferWriter(w io.WriteCloser) *BufferWriter {
} }
} }
// BufferWriter buffer writer
type BufferWriter struct { type BufferWriter struct {
H http.Header H http.Header
Code int Code int
W io.WriteCloser W io.WriteCloser
} }
// Close close the writer
func (b *BufferWriter) Close() error { func (b *BufferWriter) Close() error {
return b.W.Close() return b.W.Close()
} }
// Header gets response header
func (b *BufferWriter) Header() http.Header { func (b *BufferWriter) Header() http.Header {
return b.H return b.H
} }
@ -98,11 +120,13 @@ func (b *BufferWriter) Write(buf []byte) (int, error) {
return b.W.Write(buf) return b.W.Write(buf)
} }
// WriteHeader sets rw.Code. // WriteHeader writes status code
func (b *BufferWriter) WriteHeader(code int) { func (b *BufferWriter) WriteHeader(code int) {
b.Code = code b.Code = code
} }
// CloseNotify returns a channel that receives at most a single value (true)
// when the client connection has gone away.
func (b *BufferWriter) CloseNotify() <-chan bool { func (b *BufferWriter) CloseNotify() <-chan bool {
if cn, ok := b.W.(http.CloseNotifier); ok { if cn, ok := b.W.(http.CloseNotifier); ok {
return cn.CloseNotify() return cn.CloseNotify()
@ -111,6 +135,7 @@ func (b *BufferWriter) CloseNotify() <-chan bool {
return make(<-chan bool) return make(<-chan bool)
} }
// Hijack lets the caller take over the connection.
func (b *BufferWriter) Hijack() (net.Conn, *bufio.ReadWriter, error) { func (b *BufferWriter) Hijack() (net.Conn, *bufio.ReadWriter, error) {
if hi, ok := b.W.(http.Hijacker); ok { if hi, ok := b.W.(http.Hijacker); ok {
return hi.Hijack() return hi.Hijack()
@ -125,10 +150,10 @@ type nopWriteCloser struct {
func (*nopWriteCloser) Close() error { return nil } func (*nopWriteCloser) Close() error { return nil }
// NopCloser returns a WriteCloser with a no-op Close method wrapping // NopWriteCloser returns a WriteCloser with a no-op Close method wrapping
// the provided Writer w. // the provided Writer w.
func NopWriteCloser(w io.Writer) io.WriteCloser { func NopWriteCloser(w io.Writer) io.WriteCloser {
return &nopWriteCloser{w} return &nopWriteCloser{Writer: w}
} }
// CopyURL provides update safe copy by avoiding shallow copying User field // CopyURL provides update safe copy by avoiding shallow copying User field

View file

@ -6,21 +6,25 @@ import (
"strings" "strings"
) )
// ExtractSource extracts the source from the request, e.g. that may be client ip, or particular header that // SourceExtractor extracts the source from the request, e.g. that may be client ip, or particular header that
// identifies the source. amount stands for amount of connections the source consumes, usually 1 for connection limiters // identifies the source. amount stands for amount of connections the source consumes, usually 1 for connection limiters
// error should be returned when source can not be identified // error should be returned when source can not be identified
type SourceExtractor interface { type SourceExtractor interface {
Extract(req *http.Request) (token string, amount int64, err error) Extract(req *http.Request) (token string, amount int64, err error)
} }
// ExtractorFunc extractor function type
type ExtractorFunc func(req *http.Request) (token string, amount int64, err error) type ExtractorFunc func(req *http.Request) (token string, amount int64, err error)
// Extract extract from request
func (f ExtractorFunc) Extract(req *http.Request) (string, int64, error) { func (f ExtractorFunc) Extract(req *http.Request) (string, int64, error) {
return f(req) return f(req)
} }
// ExtractSource extract source function type
type ExtractSource func(req *http.Request) type ExtractSource func(req *http.Request)
// NewExtractor creates a new SourceExtractor
func NewExtractor(variable string) (SourceExtractor, error) { func NewExtractor(variable string) (SourceExtractor, error) {
if variable == "client.ip" { if variable == "client.ip" {
return ExtractorFunc(extractClientIP), nil return ExtractorFunc(extractClientIP), nil
@ -31,17 +35,17 @@ func NewExtractor(variable string) (SourceExtractor, error) {
if strings.HasPrefix(variable, "request.header.") { if strings.HasPrefix(variable, "request.header.") {
header := strings.TrimPrefix(variable, "request.header.") header := strings.TrimPrefix(variable, "request.header.")
if len(header) == 0 { if len(header) == 0 {
return nil, fmt.Errorf("Wrong header: %s", header) return nil, fmt.Errorf("wrong header: %s", header)
} }
return makeHeaderExtractor(header), nil return makeHeaderExtractor(header), nil
} }
return nil, fmt.Errorf("Unsupported limiting variable: '%s'", variable) return nil, fmt.Errorf("unsupported limiting variable: '%s'", variable)
} }
func extractClientIP(req *http.Request) (string, int64, error) { func extractClientIP(req *http.Request) (string, int64, error) {
vals := strings.SplitN(req.RemoteAddr, ":", 2) vals := strings.SplitN(req.RemoteAddr, ":", 2)
if len(vals[0]) == 0 { if len(vals[0]) == 0 {
return "", 0, fmt.Errorf("Failed to parse client IP: %v", req.RemoteAddr) return "", 0, fmt.Errorf("failed to parse client IP: %v", req.RemoteAddr)
} }
return vals[0], 1, nil return vals[0], 1, nil
} }