237 lines
5.6 KiB
Go
237 lines
5.6 KiB
Go
// Tokenbucket based request rate limiter
|
|
package ratelimit
|
|
|
|
import (
|
|
"fmt"
|
|
"net/http"
|
|
"sync"
|
|
"time"
|
|
|
|
"github.com/mailgun/timetools"
|
|
"github.com/mailgun/ttlmap"
|
|
log "github.com/sirupsen/logrus"
|
|
"github.com/vulcand/oxy/utils"
|
|
)
|
|
|
|
const DefaultCapacity = 65536
|
|
|
|
// RateSet maintains a set of rates. It can contain only one rate per period at a time.
|
|
type RateSet struct {
|
|
m map[time.Duration]*rate
|
|
}
|
|
|
|
// NewRateSet crates an empty `RateSet` instance.
|
|
func NewRateSet() *RateSet {
|
|
rs := new(RateSet)
|
|
rs.m = make(map[time.Duration]*rate)
|
|
return rs
|
|
}
|
|
|
|
// Add adds a rate to the set. If there is a rate with the same period in the
|
|
// set then the new rate overrides the old one.
|
|
func (rs *RateSet) Add(period time.Duration, average int64, burst int64) error {
|
|
if period <= 0 {
|
|
return fmt.Errorf("Invalid period: %v", period)
|
|
}
|
|
if average <= 0 {
|
|
return fmt.Errorf("Invalid average: %v", average)
|
|
}
|
|
if burst <= 0 {
|
|
return fmt.Errorf("Invalid burst: %v", burst)
|
|
}
|
|
rs.m[period] = &rate{period, average, burst}
|
|
return nil
|
|
}
|
|
|
|
func (rs *RateSet) String() string {
|
|
return fmt.Sprint(rs.m)
|
|
}
|
|
|
|
type RateExtractor interface {
|
|
Extract(r *http.Request) (*RateSet, error)
|
|
}
|
|
|
|
type RateExtractorFunc func(r *http.Request) (*RateSet, error)
|
|
|
|
func (e RateExtractorFunc) Extract(r *http.Request) (*RateSet, error) {
|
|
return e(r)
|
|
}
|
|
|
|
// TokenLimiter implements rate limiting middleware.
|
|
type TokenLimiter struct {
|
|
defaultRates *RateSet
|
|
extract utils.SourceExtractor
|
|
extractRates RateExtractor
|
|
clock timetools.TimeProvider
|
|
mutex sync.Mutex
|
|
bucketSets *ttlmap.TtlMap
|
|
errHandler utils.ErrorHandler
|
|
capacity int
|
|
next http.Handler
|
|
}
|
|
|
|
// New constructs a `TokenLimiter` middleware instance.
|
|
func New(next http.Handler, extract utils.SourceExtractor, defaultRates *RateSet, opts ...TokenLimiterOption) (*TokenLimiter, error) {
|
|
if defaultRates == nil || len(defaultRates.m) == 0 {
|
|
return nil, fmt.Errorf("Provide default rates")
|
|
}
|
|
if extract == nil {
|
|
return nil, fmt.Errorf("Provide extract function")
|
|
}
|
|
tl := &TokenLimiter{
|
|
next: next,
|
|
defaultRates: defaultRates,
|
|
extract: extract,
|
|
}
|
|
|
|
for _, o := range opts {
|
|
if err := o(tl); err != nil {
|
|
return nil, err
|
|
}
|
|
}
|
|
setDefaults(tl)
|
|
bucketSets, err := ttlmap.NewMapWithProvider(tl.capacity, tl.clock)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
tl.bucketSets = bucketSets
|
|
return tl, nil
|
|
}
|
|
|
|
func (tl *TokenLimiter) Wrap(next http.Handler) {
|
|
tl.next = next
|
|
}
|
|
|
|
func (tl *TokenLimiter) ServeHTTP(w http.ResponseWriter, req *http.Request) {
|
|
source, amount, err := tl.extract.Extract(req)
|
|
if err != nil {
|
|
tl.errHandler.ServeHTTP(w, req, err)
|
|
return
|
|
}
|
|
|
|
if err := tl.consumeRates(req, source, amount); err != nil {
|
|
log.Warnf("limiting request %v %v, limit: %v", req.Method, req.URL, err)
|
|
tl.errHandler.ServeHTTP(w, req, err)
|
|
return
|
|
}
|
|
|
|
tl.next.ServeHTTP(w, req)
|
|
}
|
|
|
|
func (tl *TokenLimiter) consumeRates(req *http.Request, source string, amount int64) error {
|
|
tl.mutex.Lock()
|
|
defer tl.mutex.Unlock()
|
|
|
|
effectiveRates := tl.resolveRates(req)
|
|
bucketSetI, exists := tl.bucketSets.Get(source)
|
|
var bucketSet *TokenBucketSet
|
|
|
|
if exists {
|
|
bucketSet = bucketSetI.(*TokenBucketSet)
|
|
bucketSet.Update(effectiveRates)
|
|
} else {
|
|
bucketSet = NewTokenBucketSet(effectiveRates, tl.clock)
|
|
// We set ttl as 10 times rate period. E.g. if rate is 100 requests/second per client ip
|
|
// the counters for this ip will expire after 10 seconds of inactivity
|
|
tl.bucketSets.Set(source, bucketSet, int(bucketSet.maxPeriod/time.Second)*10+1)
|
|
}
|
|
delay, err := bucketSet.Consume(amount)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
if delay > 0 {
|
|
return &MaxRateError{delay: delay}
|
|
}
|
|
return nil
|
|
}
|
|
|
|
// effectiveRates retrieves rates to be applied to the request.
|
|
func (tl *TokenLimiter) resolveRates(req *http.Request) *RateSet {
|
|
// If configuration mapper is not specified for this instance, then return
|
|
// the default bucket specs.
|
|
if tl.extractRates == nil {
|
|
return tl.defaultRates
|
|
}
|
|
|
|
rates, err := tl.extractRates.Extract(req)
|
|
if err != nil {
|
|
log.Errorf("Failed to retrieve rates: %v", err)
|
|
return tl.defaultRates
|
|
}
|
|
|
|
// If the returned rate set is empty then used the default one.
|
|
if len(rates.m) == 0 {
|
|
return tl.defaultRates
|
|
}
|
|
|
|
return rates
|
|
}
|
|
|
|
type MaxRateError struct {
|
|
delay time.Duration
|
|
}
|
|
|
|
func (m *MaxRateError) Error() string {
|
|
return fmt.Sprintf("max rate reached: retry-in %v", m.delay)
|
|
}
|
|
|
|
type RateErrHandler struct {
|
|
}
|
|
|
|
func (e *RateErrHandler) ServeHTTP(w http.ResponseWriter, req *http.Request, err error) {
|
|
if rerr, ok := err.(*MaxRateError); ok {
|
|
w.Header().Set("X-Retry-In", rerr.delay.String())
|
|
w.WriteHeader(429)
|
|
w.Write([]byte(err.Error()))
|
|
return
|
|
}
|
|
utils.DefaultHandler.ServeHTTP(w, req, err)
|
|
}
|
|
|
|
type TokenLimiterOption func(l *TokenLimiter) error
|
|
|
|
// ErrorHandler sets error handler of the server
|
|
func ErrorHandler(h utils.ErrorHandler) TokenLimiterOption {
|
|
return func(cl *TokenLimiter) error {
|
|
cl.errHandler = h
|
|
return nil
|
|
}
|
|
}
|
|
|
|
func ExtractRates(e RateExtractor) TokenLimiterOption {
|
|
return func(cl *TokenLimiter) error {
|
|
cl.extractRates = e
|
|
return nil
|
|
}
|
|
}
|
|
|
|
func Clock(clock timetools.TimeProvider) TokenLimiterOption {
|
|
return func(cl *TokenLimiter) error {
|
|
cl.clock = clock
|
|
return nil
|
|
}
|
|
}
|
|
|
|
func Capacity(cap int) TokenLimiterOption {
|
|
return func(cl *TokenLimiter) error {
|
|
if cap <= 0 {
|
|
return fmt.Errorf("bad capacity: %v", cap)
|
|
}
|
|
cl.capacity = cap
|
|
return nil
|
|
}
|
|
}
|
|
|
|
var defaultErrHandler = &RateErrHandler{}
|
|
|
|
func setDefaults(tl *TokenLimiter) {
|
|
if tl.capacity <= 0 {
|
|
tl.capacity = DefaultCapacity
|
|
}
|
|
if tl.clock == nil {
|
|
tl.clock = &timetools.RealTime{}
|
|
}
|
|
if tl.errHandler == nil {
|
|
tl.errHandler = defaultErrHandler
|
|
}
|
|
}
|