traefik/vendor/github.com/vulcand/oxy/connlimit/connlimit.go

140 lines
3.1 KiB
Go
Raw Normal View History

2017-02-07 22:33:23 +01:00
// package connlimit provides control over simultaneous connections coming from the same source
package connlimit
import (
"fmt"
"net/http"
"sync"
"github.com/vulcand/oxy/utils"
)
// Limiter tracks concurrent connection per token
// and is capable of rejecting connections if they are failed
type ConnLimiter struct {
mutex *sync.Mutex
extract utils.SourceExtractor
connections map[string]int64
maxConnections int64
totalConnections int64
next http.Handler
errHandler utils.ErrorHandler
log utils.Logger
}
func New(next http.Handler, extract utils.SourceExtractor, maxConnections int64, options ...ConnLimitOption) (*ConnLimiter, error) {
if extract == nil {
return nil, fmt.Errorf("Extract function can not be nil")
}
cl := &ConnLimiter{
mutex: &sync.Mutex{},
extract: extract,
maxConnections: maxConnections,
connections: make(map[string]int64),
next: next,
}
for _, o := range options {
if err := o(cl); err != nil {
return nil, err
}
}
if cl.log == nil {
cl.log = utils.NullLogger
}
if cl.errHandler == nil {
cl.errHandler = defaultErrHandler
}
return cl, nil
}
func (cl *ConnLimiter) Wrap(h http.Handler) {
cl.next = h
}
func (cl *ConnLimiter) ServeHTTP(w http.ResponseWriter, r *http.Request) {
token, amount, err := cl.extract.Extract(r)
if err != nil {
cl.log.Errorf("failed to extract source of the connection: %v", err)
cl.errHandler.ServeHTTP(w, r, err)
return
}
if err := cl.acquire(token, amount); err != nil {
cl.log.Infof("limiting request source %s: %v", token, err)
cl.errHandler.ServeHTTP(w, r, err)
return
}
defer cl.release(token, amount)
cl.next.ServeHTTP(w, r)
}
func (cl *ConnLimiter) acquire(token string, amount int64) error {
cl.mutex.Lock()
defer cl.mutex.Unlock()
connections := cl.connections[token]
if connections >= cl.maxConnections {
return &MaxConnError{max: cl.maxConnections}
}
cl.connections[token] += amount
cl.totalConnections += int64(amount)
return nil
}
func (cl *ConnLimiter) release(token string, amount int64) {
cl.mutex.Lock()
defer cl.mutex.Unlock()
cl.connections[token] -= amount
cl.totalConnections -= int64(amount)
// Otherwise it would grow forever
if cl.connections[token] == 0 {
delete(cl.connections, token)
}
}
type MaxConnError struct {
max int64
}
func (m *MaxConnError) Error() string {
return fmt.Sprintf("max connections reached: %d", m.max)
}
type ConnErrHandler struct {
}
func (e *ConnErrHandler) ServeHTTP(w http.ResponseWriter, req *http.Request, err error) {
if _, ok := err.(*MaxConnError); ok {
w.WriteHeader(429)
w.Write([]byte(err.Error()))
return
}
utils.DefaultHandler.ServeHTTP(w, req, err)
}
type ConnLimitOption func(l *ConnLimiter) error
// Logger sets the logger that will be used by this middleware.
func Logger(l utils.Logger) ConnLimitOption {
return func(cl *ConnLimiter) error {
cl.log = l
return nil
}
}
// ErrorHandler sets error handler of the server
func ErrorHandler(h utils.ErrorHandler) ConnLimitOption {
return func(cl *ConnLimiter) error {
cl.errHandler = h
return nil
}
}
var defaultErrHandler = &ConnErrHandler{}