// package connlimit provides control over simultaneous connections coming from the same source package connlimit import ( "fmt" "net/http" "sync" log "github.com/Sirupsen/logrus" "github.com/vulcand/oxy/utils" ) // Limiter tracks concurrent connection per token // and is capable of rejecting connections if they are failed type ConnLimiter struct { mutex *sync.Mutex extract utils.SourceExtractor connections map[string]int64 maxConnections int64 totalConnections int64 next http.Handler errHandler utils.ErrorHandler } func New(next http.Handler, extract utils.SourceExtractor, maxConnections int64, options ...ConnLimitOption) (*ConnLimiter, error) { if extract == nil { return nil, fmt.Errorf("Extract function can not be nil") } cl := &ConnLimiter{ mutex: &sync.Mutex{}, extract: extract, maxConnections: maxConnections, connections: make(map[string]int64), next: next, } for _, o := range options { if err := o(cl); err != nil { return nil, err } } if cl.errHandler == nil { cl.errHandler = defaultErrHandler } return cl, nil } func (cl *ConnLimiter) Wrap(h http.Handler) { cl.next = h } func (cl *ConnLimiter) ServeHTTP(w http.ResponseWriter, r *http.Request) { token, amount, err := cl.extract.Extract(r) if err != nil { log.Errorf("failed to extract source of the connection: %v", err) cl.errHandler.ServeHTTP(w, r, err) return } if err := cl.acquire(token, amount); err != nil { log.Infof("limiting request source %s: %v", token, err) cl.errHandler.ServeHTTP(w, r, err) return } defer cl.release(token, amount) cl.next.ServeHTTP(w, r) } func (cl *ConnLimiter) acquire(token string, amount int64) error { cl.mutex.Lock() defer cl.mutex.Unlock() connections := cl.connections[token] if connections >= cl.maxConnections { return &MaxConnError{max: cl.maxConnections} } cl.connections[token] += amount cl.totalConnections += amount return nil } func (cl *ConnLimiter) release(token string, amount int64) { cl.mutex.Lock() defer cl.mutex.Unlock() cl.connections[token] -= amount cl.totalConnections -= amount // Otherwise it would grow forever if cl.connections[token] == 0 { delete(cl.connections, token) } } type MaxConnError struct { max int64 } func (m *MaxConnError) Error() string { return fmt.Sprintf("max connections reached: %d", m.max) } type ConnErrHandler struct { } func (e *ConnErrHandler) ServeHTTP(w http.ResponseWriter, req *http.Request, err error) { if log.GetLevel() >= log.DebugLevel { logEntry := log.WithField("Request", utils.DumpHttpRequest(req)) logEntry.Debug("vulcand/oxy/connlimit: begin ServeHttp on request") defer logEntry.Debug("vulcand/oxy/connlimit: competed ServeHttp on request") } if _, ok := err.(*MaxConnError); ok { w.WriteHeader(429) w.Write([]byte(err.Error())) return } utils.DefaultHandler.ServeHTTP(w, req, err) } type ConnLimitOption func(l *ConnLimiter) error // ErrorHandler sets error handler of the server func ErrorHandler(h utils.ErrorHandler) ConnLimitOption { return func(cl *ConnLimiter) error { cl.errHandler = h return nil } } var defaultErrHandler = &ConnErrHandler{}