134 lines
3.2 KiB
Go
134 lines
3.2 KiB
Go
// package connlimit provides control over simultaneous connections coming from the same source
|
|
package connlimit
|
|
|
|
import (
|
|
"fmt"
|
|
"net/http"
|
|
"sync"
|
|
|
|
log "github.com/Sirupsen/logrus"
|
|
"github.com/vulcand/oxy/utils"
|
|
)
|
|
|
|
// Limiter tracks concurrent connection per token
|
|
// and is capable of rejecting connections if they are failed
|
|
type ConnLimiter struct {
|
|
mutex *sync.Mutex
|
|
extract utils.SourceExtractor
|
|
connections map[string]int64
|
|
maxConnections int64
|
|
totalConnections int64
|
|
next http.Handler
|
|
|
|
errHandler utils.ErrorHandler
|
|
}
|
|
|
|
func New(next http.Handler, extract utils.SourceExtractor, maxConnections int64, options ...ConnLimitOption) (*ConnLimiter, error) {
|
|
if extract == nil {
|
|
return nil, fmt.Errorf("Extract function can not be nil")
|
|
}
|
|
cl := &ConnLimiter{
|
|
mutex: &sync.Mutex{},
|
|
extract: extract,
|
|
maxConnections: maxConnections,
|
|
connections: make(map[string]int64),
|
|
next: next,
|
|
}
|
|
|
|
for _, o := range options {
|
|
if err := o(cl); err != nil {
|
|
return nil, err
|
|
}
|
|
}
|
|
if cl.errHandler == nil {
|
|
cl.errHandler = defaultErrHandler
|
|
}
|
|
return cl, nil
|
|
}
|
|
|
|
func (cl *ConnLimiter) Wrap(h http.Handler) {
|
|
cl.next = h
|
|
}
|
|
|
|
func (cl *ConnLimiter) ServeHTTP(w http.ResponseWriter, r *http.Request) {
|
|
token, amount, err := cl.extract.Extract(r)
|
|
if err != nil {
|
|
log.Errorf("failed to extract source of the connection: %v", err)
|
|
cl.errHandler.ServeHTTP(w, r, err)
|
|
return
|
|
}
|
|
if err := cl.acquire(token, amount); err != nil {
|
|
log.Infof("limiting request source %s: %v", token, err)
|
|
cl.errHandler.ServeHTTP(w, r, err)
|
|
return
|
|
}
|
|
|
|
defer cl.release(token, amount)
|
|
|
|
cl.next.ServeHTTP(w, r)
|
|
}
|
|
|
|
func (cl *ConnLimiter) acquire(token string, amount int64) error {
|
|
cl.mutex.Lock()
|
|
defer cl.mutex.Unlock()
|
|
|
|
connections := cl.connections[token]
|
|
if connections >= cl.maxConnections {
|
|
return &MaxConnError{max: cl.maxConnections}
|
|
}
|
|
|
|
cl.connections[token] += amount
|
|
cl.totalConnections += amount
|
|
return nil
|
|
}
|
|
|
|
func (cl *ConnLimiter) release(token string, amount int64) {
|
|
cl.mutex.Lock()
|
|
defer cl.mutex.Unlock()
|
|
|
|
cl.connections[token] -= amount
|
|
cl.totalConnections -= amount
|
|
|
|
// Otherwise it would grow forever
|
|
if cl.connections[token] == 0 {
|
|
delete(cl.connections, token)
|
|
}
|
|
}
|
|
|
|
type MaxConnError struct {
|
|
max int64
|
|
}
|
|
|
|
func (m *MaxConnError) Error() string {
|
|
return fmt.Sprintf("max connections reached: %d", m.max)
|
|
}
|
|
|
|
type ConnErrHandler struct {
|
|
}
|
|
|
|
func (e *ConnErrHandler) ServeHTTP(w http.ResponseWriter, req *http.Request, err error) {
|
|
if log.GetLevel() >= log.DebugLevel {
|
|
logEntry := log.WithField("Request", utils.DumpHttpRequest(req))
|
|
logEntry.Debug("vulcand/oxy/connlimit: begin ServeHttp on request")
|
|
defer logEntry.Debug("vulcand/oxy/connlimit: competed ServeHttp on request")
|
|
}
|
|
|
|
if _, ok := err.(*MaxConnError); ok {
|
|
w.WriteHeader(429)
|
|
w.Write([]byte(err.Error()))
|
|
return
|
|
}
|
|
utils.DefaultHandler.ServeHTTP(w, req, err)
|
|
}
|
|
|
|
type ConnLimitOption func(l *ConnLimiter) error
|
|
|
|
// ErrorHandler sets error handler of the server
|
|
func ErrorHandler(h utils.ErrorHandler) ConnLimitOption {
|
|
return func(cl *ConnLimiter) error {
|
|
cl.errHandler = h
|
|
return nil
|
|
}
|
|
}
|
|
|
|
var defaultErrHandler = &ConnErrHandler{}
|