2017-02-07 22:33:23 +01:00
|
|
|
package utils
|
|
|
|
|
|
|
|
import (
|
|
|
|
"fmt"
|
|
|
|
"net/http"
|
|
|
|
"strings"
|
|
|
|
)
|
|
|
|
|
2018-07-11 10:08:03 +02:00
|
|
|
// SourceExtractor extracts the source from the request, e.g. that may be client ip, or particular header that
|
2017-02-07 22:33:23 +01:00
|
|
|
// identifies the source. amount stands for amount of connections the source consumes, usually 1 for connection limiters
|
|
|
|
// error should be returned when source can not be identified
|
|
|
|
type SourceExtractor interface {
|
|
|
|
Extract(req *http.Request) (token string, amount int64, err error)
|
|
|
|
}
|
|
|
|
|
2018-07-11 10:08:03 +02:00
|
|
|
// ExtractorFunc extractor function type
|
2017-02-07 22:33:23 +01:00
|
|
|
type ExtractorFunc func(req *http.Request) (token string, amount int64, err error)
|
|
|
|
|
2018-07-11 10:08:03 +02:00
|
|
|
// Extract extract from request
|
2017-02-07 22:33:23 +01:00
|
|
|
func (f ExtractorFunc) Extract(req *http.Request) (string, int64, error) {
|
|
|
|
return f(req)
|
|
|
|
}
|
|
|
|
|
2018-07-11 10:08:03 +02:00
|
|
|
// ExtractSource extract source function type
|
2017-02-07 22:33:23 +01:00
|
|
|
type ExtractSource func(req *http.Request)
|
|
|
|
|
2018-07-11 10:08:03 +02:00
|
|
|
// NewExtractor creates a new SourceExtractor
|
2017-02-07 22:33:23 +01:00
|
|
|
func NewExtractor(variable string) (SourceExtractor, error) {
|
|
|
|
if variable == "client.ip" {
|
|
|
|
return ExtractorFunc(extractClientIP), nil
|
|
|
|
}
|
|
|
|
if variable == "request.host" {
|
|
|
|
return ExtractorFunc(extractHost), nil
|
|
|
|
}
|
|
|
|
if strings.HasPrefix(variable, "request.header.") {
|
|
|
|
header := strings.TrimPrefix(variable, "request.header.")
|
|
|
|
if len(header) == 0 {
|
2018-07-11 10:08:03 +02:00
|
|
|
return nil, fmt.Errorf("wrong header: %s", header)
|
2017-02-07 22:33:23 +01:00
|
|
|
}
|
|
|
|
return makeHeaderExtractor(header), nil
|
|
|
|
}
|
2018-07-11 10:08:03 +02:00
|
|
|
return nil, fmt.Errorf("unsupported limiting variable: '%s'", variable)
|
2017-02-07 22:33:23 +01:00
|
|
|
}
|
|
|
|
|
|
|
|
func extractClientIP(req *http.Request) (string, int64, error) {
|
|
|
|
vals := strings.SplitN(req.RemoteAddr, ":", 2)
|
|
|
|
if len(vals[0]) == 0 {
|
2018-07-11 10:08:03 +02:00
|
|
|
return "", 0, fmt.Errorf("failed to parse client IP: %v", req.RemoteAddr)
|
2017-02-07 22:33:23 +01:00
|
|
|
}
|
|
|
|
return vals[0], 1, nil
|
|
|
|
}
|
|
|
|
|
|
|
|
func extractHost(req *http.Request) (string, int64, error) {
|
|
|
|
return req.Host, 1, nil
|
|
|
|
}
|
|
|
|
|
|
|
|
func makeHeaderExtractor(header string) SourceExtractor {
|
|
|
|
return ExtractorFunc(func(req *http.Request) (string, int64, error) {
|
|
|
|
return req.Header.Get(header), 1, nil
|
|
|
|
})
|
|
|
|
}
|