Remove containous/mux from HTTP muxer

Co-authored-by: Simon Delicata <simon.delicata@traefik.io>
This commit is contained in:
Tom Moulard 2022-12-22 17:16:04 +01:00 committed by GitHub
parent 8c98234c07
commit c38d405cfd
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
8 changed files with 299 additions and 209 deletions

1
go.mod
View file

@ -385,7 +385,6 @@ require (
replace ( replace (
github.com/abbot/go-http-auth => github.com/containous/go-http-auth v0.4.1-0.20200324110947-a37a7636d23e github.com/abbot/go-http-auth => github.com/containous/go-http-auth v0.4.1-0.20200324110947-a37a7636d23e
github.com/go-check/check => github.com/containous/check v0.0.0-20170915194414-ca0bf163426a github.com/go-check/check => github.com/containous/check v0.0.0-20170915194414-ca0bf163426a
github.com/gorilla/mux => github.com/containous/mux v0.0.0-20220627093034-b2dd784e613f
github.com/mailgun/minheap => github.com/containous/minheap v0.0.0-20190809180810-6e71eb837595 github.com/mailgun/minheap => github.com/containous/minheap v0.0.0-20190809180810-6e71eb837595
) )

7
go.sum
View file

@ -457,8 +457,6 @@ github.com/containous/go-http-auth v0.4.1-0.20200324110947-a37a7636d23e h1:D+uTE
github.com/containous/go-http-auth v0.4.1-0.20200324110947-a37a7636d23e/go.mod h1:s8kLgBQolDbsJOPVIGCEEv9zGAKUUf/685Gi0Qqg8z8= github.com/containous/go-http-auth v0.4.1-0.20200324110947-a37a7636d23e/go.mod h1:s8kLgBQolDbsJOPVIGCEEv9zGAKUUf/685Gi0Qqg8z8=
github.com/containous/minheap v0.0.0-20190809180810-6e71eb837595 h1:aPspFRO6b94To3gl4yTDOEtpjFwXI7V2W+z0JcNljQ4= github.com/containous/minheap v0.0.0-20190809180810-6e71eb837595 h1:aPspFRO6b94To3gl4yTDOEtpjFwXI7V2W+z0JcNljQ4=
github.com/containous/minheap v0.0.0-20190809180810-6e71eb837595/go.mod h1:+lHFbEasIiQVGzhVDVw/cn0ZaOzde2OwNncp1NhXV4c= github.com/containous/minheap v0.0.0-20190809180810-6e71eb837595/go.mod h1:+lHFbEasIiQVGzhVDVw/cn0ZaOzde2OwNncp1NhXV4c=
github.com/containous/mux v0.0.0-20220627093034-b2dd784e613f h1:1uEtynq2C0ljy3630jt7EAxg8jZY2gy6YHdGwdqEpWw=
github.com/containous/mux v0.0.0-20220627093034-b2dd784e613f/go.mod h1:z8WW7n06n8/1xF9Jl9WmuDeZuHAhfL+bwarNjsciwwg=
github.com/coredns/coredns v1.1.2/go.mod h1:zASH/MVDgR6XZTbxvOnsZfffS+31vg6Ackf/wo1+AM0= github.com/coredns/coredns v1.1.2/go.mod h1:zASH/MVDgR6XZTbxvOnsZfffS+31vg6Ackf/wo1+AM0=
github.com/coreos/bbolt v1.3.2/go.mod h1:iRUV2dpdMOn7Bo10OQBFzIJO9kkE559Wcmn+qkEiiKk= github.com/coreos/bbolt v1.3.2/go.mod h1:iRUV2dpdMOn7Bo10OQBFzIJO9kkE559Wcmn+qkEiiKk=
github.com/coreos/etcd v3.3.10+incompatible/go.mod h1:uF7uidLiAD3TWHmW31ZFd/JWoc32PjwdhPthX9715RE= github.com/coreos/etcd v3.3.10+incompatible/go.mod h1:uF7uidLiAD3TWHmW31ZFd/JWoc32PjwdhPthX9715RE=
@ -936,6 +934,11 @@ github.com/gopherjs/gopherjs v0.0.0-20181017120253-0766667cb4d1/go.mod h1:wJfORR
github.com/gorilla/context v1.1.1/go.mod h1:kBGZzfjB9CEq2AlWe17Uuf7NDRt0dE0s8S51q0aT7Yg= github.com/gorilla/context v1.1.1/go.mod h1:kBGZzfjB9CEq2AlWe17Uuf7NDRt0dE0s8S51q0aT7Yg=
github.com/gorilla/handlers v0.0.0-20150720190736-60c7bfde3e33/go.mod h1:Qkdc/uu4tH4g6mTK6auzZ766c4CA0Ng8+o/OAirnOIQ= github.com/gorilla/handlers v0.0.0-20150720190736-60c7bfde3e33/go.mod h1:Qkdc/uu4tH4g6mTK6auzZ766c4CA0Ng8+o/OAirnOIQ=
github.com/gorilla/handlers v1.5.1/go.mod h1:t8XrUpc4KVXb7HGyJ4/cEnwQiaxrX/hz1Zv/4g96P1Q= github.com/gorilla/handlers v1.5.1/go.mod h1:t8XrUpc4KVXb7HGyJ4/cEnwQiaxrX/hz1Zv/4g96P1Q=
github.com/gorilla/mux v1.6.2/go.mod h1:1lud6UwP+6orDFRuTfBEV8e9/aOM/c4fVVCaMa2zaAs=
github.com/gorilla/mux v1.7.2/go.mod h1:1lud6UwP+6orDFRuTfBEV8e9/aOM/c4fVVCaMa2zaAs=
github.com/gorilla/mux v1.7.3/go.mod h1:1lud6UwP+6orDFRuTfBEV8e9/aOM/c4fVVCaMa2zaAs=
github.com/gorilla/mux v1.8.0 h1:i40aqfkR1h2SlN9hojwV5ZA91wcXFOvkdNIeFDP5koI=
github.com/gorilla/mux v1.8.0/go.mod h1:DVbg23sWSpFRCP0SfiEN6jmj59UnW/n46BH5rLB71So=
github.com/gorilla/websocket v0.0.0-20170926233335-4201258b820c/go.mod h1:E7qHFY5m1UJ88s3WnNqhKjPHQ0heANvMoAMk2YaljkQ= github.com/gorilla/websocket v0.0.0-20170926233335-4201258b820c/go.mod h1:E7qHFY5m1UJ88s3WnNqhKjPHQ0heANvMoAMk2YaljkQ=
github.com/gorilla/websocket v1.4.0/go.mod h1:E7qHFY5m1UJ88s3WnNqhKjPHQ0heANvMoAMk2YaljkQ= github.com/gorilla/websocket v1.4.0/go.mod h1:E7qHFY5m1UJ88s3WnNqhKjPHQ0heANvMoAMk2YaljkQ=
github.com/gorilla/websocket v1.4.1/go.mod h1:YR8l580nyteQvAITg2hZ9XVh4b55+EU/adAjf1fMHhE= github.com/gorilla/websocket v1.4.1/go.mod h1:YR8l580nyteQvAITg2hZ9XVh4b55+EU/adAjf1fMHhE=

View file

@ -7,14 +7,13 @@ import (
"strings" "strings"
"unicode/utf8" "unicode/utf8"
"github.com/gorilla/mux"
"github.com/rs/zerolog/log" "github.com/rs/zerolog/log"
"github.com/traefik/traefik/v2/pkg/ip" "github.com/traefik/traefik/v2/pkg/ip"
"github.com/traefik/traefik/v2/pkg/middlewares/requestdecorator" "github.com/traefik/traefik/v2/pkg/middlewares/requestdecorator"
"golang.org/x/exp/slices" "golang.org/x/exp/slices"
) )
var httpFuncs = map[string]func(*mux.Route, ...string) error{ var httpFuncs = map[string]func(*matchersTree, ...string) error{
"ClientIP": expectNParameters(clientIP, 1), "ClientIP": expectNParameters(clientIP, 1),
"Method": expectNParameters(method, 1), "Method": expectNParameters(method, 1),
"Host": expectNParameters(host, 1), "Host": expectNParameters(host, 1),
@ -28,17 +27,17 @@ var httpFuncs = map[string]func(*mux.Route, ...string) error{
"QueryRegexp": expectNParameters(queryRegexp, 1, 2), "QueryRegexp": expectNParameters(queryRegexp, 1, 2),
} }
func expectNParameters(fn func(*mux.Route, ...string) error, n ...int) func(*mux.Route, ...string) error { func expectNParameters(fn func(*matchersTree, ...string) error, n ...int) func(*matchersTree, ...string) error {
return func(route *mux.Route, s ...string) error { return func(tree *matchersTree, s ...string) error {
if !slices.Contains(n, len(s)) { if !slices.Contains(n, len(s)) {
return fmt.Errorf("unexpected number of parameters; got %d, expected one of %v", len(s), n) return fmt.Errorf("unexpected number of parameters; got %d, expected one of %v", len(s), n)
} }
return fn(route, s...) return fn(tree, s...)
} }
} }
func clientIP(route *mux.Route, clientIP ...string) error { func clientIP(tree *matchersTree, clientIP ...string) error {
checker, err := ip.NewChecker(clientIP) checker, err := ip.NewChecker(clientIP)
if err != nil { if err != nil {
return fmt.Errorf("initializing IP checker for ClientIP matcher: %w", err) return fmt.Errorf("initializing IP checker for ClientIP matcher: %w", err)
@ -46,7 +45,7 @@ func clientIP(route *mux.Route, clientIP ...string) error {
strategy := ip.RemoteAddrStrategy{} strategy := ip.RemoteAddrStrategy{}
route.MatcherFunc(func(req *http.Request, _ *mux.RouteMatch) bool { tree.matcher = func(req *http.Request) bool {
ok, err := checker.Contains(strategy.GetIP(req)) ok, err := checker.Contains(strategy.GetIP(req))
if err != nil { if err != nil {
log.Ctx(req.Context()).Warn().Err(err).Msg("ClientIP matcher: could not match remote address") log.Ctx(req.Context()).Warn().Err(err).Msg("ClientIP matcher: could not match remote address")
@ -54,16 +53,22 @@ func clientIP(route *mux.Route, clientIP ...string) error {
} }
return ok return ok
}) }
return nil return nil
} }
func method(route *mux.Route, methods ...string) error { func method(tree *matchersTree, methods ...string) error {
return route.Methods(methods...).GetError() method := strings.ToUpper(methods[0])
tree.matcher = func(req *http.Request) bool {
return method == req.Method
}
return nil
} }
func host(route *mux.Route, hosts ...string) error { func host(tree *matchersTree, hosts ...string) error {
host := hosts[0] host := hosts[0]
if !IsASCII(host) { if !IsASCII(host) {
@ -72,7 +77,7 @@ func host(route *mux.Route, hosts ...string) error {
host = strings.ToLower(host) host = strings.ToLower(host)
route.MatcherFunc(func(req *http.Request, _ *mux.RouteMatch) bool { tree.matcher = func(req *http.Request) bool {
reqHost := requestdecorator.GetCanonizedHost(req.Context()) reqHost := requestdecorator.GetCanonizedHost(req.Context())
if len(reqHost) == 0 { if len(reqHost) == 0 {
return false return false
@ -104,12 +109,12 @@ func host(route *mux.Route, hosts ...string) error {
} }
return false return false
}) }
return nil return nil
} }
func hostRegexp(route *mux.Route, hosts ...string) error { func hostRegexp(tree *matchersTree, hosts ...string) error {
host := hosts[0] host := hosts[0]
if !IsASCII(host) { if !IsASCII(host) {
@ -121,29 +126,29 @@ func hostRegexp(route *mux.Route, hosts ...string) error {
return fmt.Errorf("compiling HostRegexp matcher: %w", err) return fmt.Errorf("compiling HostRegexp matcher: %w", err)
} }
route.MatcherFunc(func(req *http.Request, _ *mux.RouteMatch) bool { tree.matcher = func(req *http.Request) bool {
return re.MatchString(requestdecorator.GetCanonizedHost(req.Context())) || return re.MatchString(requestdecorator.GetCanonizedHost(req.Context())) ||
re.MatchString(requestdecorator.GetCNAMEFlatten(req.Context())) re.MatchString(requestdecorator.GetCNAMEFlatten(req.Context()))
}) }
return nil return nil
} }
func path(route *mux.Route, paths ...string) error { func path(tree *matchersTree, paths ...string) error {
path := paths[0] path := paths[0]
if !strings.HasPrefix(path, "/") { if !strings.HasPrefix(path, "/") {
return fmt.Errorf("path %q does not start with a '/'", path) return fmt.Errorf("path %q does not start with a '/'", path)
} }
route.MatcherFunc(func(req *http.Request, _ *mux.RouteMatch) bool { tree.matcher = func(req *http.Request) bool {
return req.URL.Path == path return req.URL.Path == path
}) }
return nil return nil
} }
func pathRegexp(route *mux.Route, paths ...string) error { func pathRegexp(tree *matchersTree, paths ...string) error {
path := paths[0] path := paths[0]
re, err := regexp.Compile(path) re, err := regexp.Compile(path)
@ -151,36 +156,65 @@ func pathRegexp(route *mux.Route, paths ...string) error {
return fmt.Errorf("compiling PathPrefix matcher: %w", err) return fmt.Errorf("compiling PathPrefix matcher: %w", err)
} }
route.MatcherFunc(func(req *http.Request, _ *mux.RouteMatch) bool { tree.matcher = func(req *http.Request) bool {
return re.MatchString(req.URL.Path) return re.MatchString(req.URL.Path)
}) }
return nil return nil
} }
func pathPrefix(route *mux.Route, paths ...string) error { func pathPrefix(tree *matchersTree, paths ...string) error {
path := paths[0] path := paths[0]
if !strings.HasPrefix(path, "/") { if !strings.HasPrefix(path, "/") {
return fmt.Errorf("path %q does not start with a '/'", path) return fmt.Errorf("path %q does not start with a '/'", path)
} }
route.MatcherFunc(func(req *http.Request, _ *mux.RouteMatch) bool { tree.matcher = func(req *http.Request) bool {
return strings.HasPrefix(req.URL.Path, path) return strings.HasPrefix(req.URL.Path, path)
}) }
return nil return nil
} }
func header(route *mux.Route, headers ...string) error { func header(tree *matchersTree, headers ...string) error {
return route.Headers(headers...).GetError() key, value := http.CanonicalHeaderKey(headers[0]), headers[1]
tree.matcher = func(req *http.Request) bool {
for _, headerValue := range req.Header[key] {
if headerValue == value {
return true
}
}
return false
}
return nil
} }
func headerRegexp(route *mux.Route, headers ...string) error { func headerRegexp(tree *matchersTree, headers ...string) error {
return route.HeadersRegexp(headers...).GetError() key, value := http.CanonicalHeaderKey(headers[0]), headers[1]
re, err := regexp.Compile(value)
if err != nil {
return fmt.Errorf("compiling HeaderRegexp matcher: %w", err)
}
tree.matcher = func(req *http.Request) bool {
for _, headerValue := range req.Header[key] {
if re.MatchString(headerValue) {
return true
}
}
return false
}
return nil
} }
func query(route *mux.Route, queries ...string) error { func query(tree *matchersTree, queries ...string) error {
key := queries[0] key := queries[0]
var value string var value string
@ -188,21 +222,21 @@ func query(route *mux.Route, queries ...string) error {
value = queries[1] value = queries[1]
} }
route.MatcherFunc(func(req *http.Request, _ *mux.RouteMatch) bool { tree.matcher = func(req *http.Request) bool {
values, ok := req.URL.Query()[key] values, ok := req.URL.Query()[key]
if !ok { if !ok {
return false return false
} }
return slices.Contains(values, value) return slices.Contains(values, value)
}) }
return nil return nil
} }
func queryRegexp(route *mux.Route, queries ...string) error { func queryRegexp(tree *matchersTree, queries ...string) error {
if len(queries) == 1 { if len(queries) == 1 {
return query(route, queries...) return query(tree, queries...)
} }
key, value := queries[0], queries[1] key, value := queries[0], queries[1]
@ -212,7 +246,7 @@ func queryRegexp(route *mux.Route, queries ...string) error {
return fmt.Errorf("compiling QueryRegexp matcher: %w", err) return fmt.Errorf("compiling QueryRegexp matcher: %w", err)
} }
route.MatcherFunc(func(req *http.Request, _ *mux.RouteMatch) bool { tree.matcher = func(req *http.Request) bool {
values, ok := req.URL.Query()[key] values, ok := req.URL.Query()[key]
if !ok { if !ok {
return false return false
@ -223,7 +257,7 @@ func queryRegexp(route *mux.Route, queries ...string) error {
}) })
return idx >= 0 return idx >= 0
}) }
return nil return nil
} }

View file

@ -3,6 +3,7 @@ package http
import ( import (
"net/http" "net/http"
"net/http/httptest" "net/http/httptest"
"strings"
"testing" "testing"
"github.com/stretchr/testify/assert" "github.com/stretchr/testify/assert"
@ -122,7 +123,8 @@ func TestMethodMatcher(t *testing.T) {
rule: "Method(`GET`)", rule: "Method(`GET`)",
expected: map[string]int{ expected: map[string]int{
http.MethodGet: http.StatusOK, http.MethodGet: http.StatusOK,
http.MethodPost: http.StatusMethodNotAllowed, http.MethodPost: http.StatusNotFound,
strings.ToLower(http.MethodGet): http.StatusNotFound,
}, },
}, },
{ {
@ -130,7 +132,8 @@ func TestMethodMatcher(t *testing.T) {
rule: "Method(`get`)", rule: "Method(`get`)",
expected: map[string]int{ expected: map[string]int{
http.MethodGet: http.StatusOK, http.MethodGet: http.StatusOK,
http.MethodPost: http.StatusMethodNotAllowed, http.MethodPost: http.StatusNotFound,
strings.ToLower(http.MethodGet): http.StatusNotFound,
}, },
}, },
} }
@ -200,6 +203,7 @@ func TestHostMatcher(t *testing.T) {
"https://example.com": http.StatusOK, "https://example.com": http.StatusOK,
"https://example.com:8080": http.StatusOK, "https://example.com:8080": http.StatusOK,
"https://example.com/path": http.StatusOK, "https://example.com/path": http.StatusOK,
"https://EXAMPLE.COM/path": http.StatusOK,
"https://example.org": http.StatusNotFound, "https://example.org": http.StatusNotFound,
"https://example.org/path": http.StatusNotFound, "https://example.org/path": http.StatusNotFound,
}, },
@ -665,6 +669,17 @@ func TestHeaderMatcher(t *testing.T) {
{"X-Forwarded-Host": []string{"example.com"}}: http.StatusNotFound, {"X-Forwarded-Host": []string{"example.com"}}: http.StatusNotFound,
}, },
}, },
{
desc: "valid Header matcher (non-canonical form)",
rule: "Header(`x-forwarded-proto`, `https`)",
expected: map[*http.Header]int{
{"X-Forwarded-Proto": []string{"https"}}: http.StatusOK,
{"x-forwarded-proto": []string{"https"}}: http.StatusNotFound,
{"X-Forwarded-Proto": []string{"http", "https"}}: http.StatusOK,
{"X-Forwarded-Proto": []string{"https", "http"}}: http.StatusOK,
{"X-Forwarded-Host": []string{"example.com"}}: http.StatusNotFound,
},
},
} }
for _, test := range testCases { for _, test := range testCases {
@ -747,6 +762,18 @@ func TestHeaderRegexpMatcher(t *testing.T) {
{"X-Forwarded-Host": []string{"example.com"}}: http.StatusNotFound, {"X-Forwarded-Host": []string{"example.com"}}: http.StatusNotFound,
}, },
}, },
{
desc: "valid HeaderRegexp matcher (non-canonical form)",
rule: "HeaderRegexp(`x-forwarded-proto`, `^https?$`)",
expected: map[*http.Header]int{
{"X-Forwarded-Proto": []string{"http"}}: http.StatusOK,
{"x-forwarded-proto": []string{"http"}}: http.StatusNotFound,
{"X-Forwarded-Proto": []string{"https"}}: http.StatusOK,
{"X-Forwarded-Proto": []string{"HTTPS"}}: http.StatusNotFound,
{"X-Forwarded-Proto": []string{"ws", "https"}}: http.StatusOK,
{"X-Forwarded-Host": []string{"example.com"}}: http.StatusNotFound,
},
},
{ {
desc: "valid HeaderRegexp matcher with Traefik v2 syntax", desc: "valid HeaderRegexp matcher with Traefik v2 syntax",
rule: "HeaderRegexp(`X-Forwarded-Proto`, `http{secure:s?}`)", rule: "HeaderRegexp(`X-Forwarded-Proto`, `http{secure:s?}`)",

View file

@ -3,15 +3,16 @@ package http
import ( import (
"fmt" "fmt"
"net/http" "net/http"
"sort"
"github.com/gorilla/mux" "github.com/rs/zerolog/log"
"github.com/traefik/traefik/v2/pkg/rules" "github.com/traefik/traefik/v2/pkg/rules"
"github.com/vulcand/predicate" "github.com/vulcand/predicate"
) )
// Muxer handles routing with rules. // Muxer handles routing with rules.
type Muxer struct { type Muxer struct {
*mux.Router routes routes
parser predicate.Parser parser predicate.Parser
} }
@ -24,18 +25,30 @@ func NewMuxer() (*Muxer, error) {
parser, err := rules.NewParser(matchers) parser, err := rules.NewParser(matchers)
if err != nil { if err != nil {
return nil, err return nil, fmt.Errorf("error while creating parser: %w", err)
} }
return &Muxer{ return &Muxer{
Router: mux.NewRouter().SkipClean(true),
parser: parser, parser: parser,
}, nil }, nil
} }
// ServeHTTP forwards the connection to the matching HTTP handler.
// Serves 404 if no handler is found.
func (m *Muxer) ServeHTTP(rw http.ResponseWriter, req *http.Request) {
for _, route := range m.routes {
if route.matchers.match(req) {
route.handler.ServeHTTP(rw, req)
return
}
}
http.NotFoundHandler().ServeHTTP(rw, req)
}
// AddRoute add a new route to the router. // AddRoute add a new route to the router.
func (r *Muxer) AddRoute(rule string, priority int, handler http.Handler) error { func (m *Muxer) AddRoute(rule string, priority int, handler http.Handler) error {
parse, err := r.parser.Parse(rule) parse, err := m.parser.Parse(rule)
if err != nil { if err != nil {
return fmt.Errorf("error while parsing rule %s: %w", rule, err) return fmt.Errorf("error while parsing rule %s: %w", rule, err)
} }
@ -45,99 +58,25 @@ func (r *Muxer) AddRoute(rule string, priority int, handler http.Handler) error
return fmt.Errorf("error while parsing rule %s", rule) return fmt.Errorf("error while parsing rule %s", rule)
} }
var matchers matchersTree
err = matchers.addRule(buildTree())
if err != nil {
return fmt.Errorf("error while adding rule %s: %w", rule, err)
}
if priority == 0 { if priority == 0 {
priority = len(rule) priority = len(rule)
} }
route := r.NewRoute().Handler(handler).Priority(priority) m.routes = append(m.routes, &route{
handler: handler,
err = addRuleOnRoute(route, buildTree()) matchers: matchers,
if err != nil { priority: priority,
route.BuildOnly()
return err
}
return nil
}
func addRuleOnRouter(router *mux.Router, rule *rules.Tree) error {
switch rule.Matcher {
case "and":
route := router.NewRoute()
err := addRuleOnRoute(route, rule.RuleLeft)
if err != nil {
return err
}
return addRuleOnRoute(route, rule.RuleRight)
case "or":
err := addRuleOnRouter(router, rule.RuleLeft)
if err != nil {
return err
}
return addRuleOnRouter(router, rule.RuleRight)
default:
err := rules.CheckRule(rule)
if err != nil {
return err
}
if rule.Not {
return not(httpFuncs[rule.Matcher])(router.NewRoute(), rule.Value...)
}
return httpFuncs[rule.Matcher](router.NewRoute(), rule.Value...)
}
}
func addRuleOnRoute(route *mux.Route, rule *rules.Tree) error {
switch rule.Matcher {
case "and":
err := addRuleOnRoute(route, rule.RuleLeft)
if err != nil {
return err
}
return addRuleOnRoute(route, rule.RuleRight)
case "or":
subRouter := route.Subrouter()
err := addRuleOnRouter(subRouter, rule.RuleLeft)
if err != nil {
return err
}
return addRuleOnRouter(subRouter, rule.RuleRight)
default:
err := rules.CheckRule(rule)
if err != nil {
return err
}
if rule.Not {
return not(httpFuncs[rule.Matcher])(route, rule.Value...)
}
return httpFuncs[rule.Matcher](route, rule.Value...)
}
}
func not(m func(*mux.Route, ...string) error) func(*mux.Route, ...string) error {
return func(r *mux.Route, v ...string) error {
router := mux.NewRouter()
err := m(router.NewRoute(), v...)
if err != nil {
return err
}
r.MatcherFunc(func(req *http.Request, ma *mux.RouteMatch) bool {
return !router.Match(req, ma)
}) })
sort.Sort(m.routes)
return nil return nil
}
} }
// ParseDomains extract domains from rule. // ParseDomains extract domains from rule.
@ -149,12 +88,12 @@ func ParseDomains(rule string) ([]string, error) {
parser, err := rules.NewParser(matchers) parser, err := rules.NewParser(matchers)
if err != nil { if err != nil {
return nil, err return nil, fmt.Errorf("error while creating parser: %w", err)
} }
parse, err := parser.Parse(rule) parse, err := parser.Parse(rule)
if err != nil { if err != nil {
return nil, err return nil, fmt.Errorf("error while parsing rule %s: %w", rule, err)
} }
buildTree, ok := parse.(rules.TreeBuilder) buildTree, ok := parse.(rules.TreeBuilder)
@ -164,3 +103,97 @@ func ParseDomains(rule string) ([]string, error) {
return buildTree().ParseMatchers([]string{"Host"}), nil return buildTree().ParseMatchers([]string{"Host"}), nil
} }
// routes implements sort.Interface.
type routes []*route
// Len implements sort.Interface.
func (r routes) Len() int { return len(r) }
// Swap implements sort.Interface.
func (r routes) Swap(i, j int) { r[i], r[j] = r[j], r[i] }
// Less implements sort.Interface.
func (r routes) Less(i, j int) bool { return r[i].priority > r[j].priority }
// route holds the matchers to match HTTP route,
// and the handler that will serve the request.
type route struct {
// matchers tree structure reflecting the rule.
matchers matchersTree
// handler responsible for handling the route.
handler http.Handler
// priority is used to disambiguate between two (or more) rules that would all match for a given request.
// Computed from the matching rule length, if not user-set.
priority int
}
// matchersTree represents the matchers tree structure.
type matchersTree struct {
// matcher is a matcher func used to match HTTP request properties.
// If matcher is not nil, it means that this matcherTree is a leaf of the tree.
// It is therefore mutually exclusive with left and right.
matcher func(*http.Request) bool
// operator to combine the evaluation of left and right leaves.
operator string
// Mutually exclusive with matcher.
left *matchersTree
right *matchersTree
}
func (m *matchersTree) match(req *http.Request) bool {
if m == nil {
// This should never happen as it should have been detected during parsing.
log.Warn().Msg("Rule matcher is nil")
return false
}
if m.matcher != nil {
return m.matcher(req)
}
switch m.operator {
case "or":
return m.left.match(req) || m.right.match(req)
case "and":
return m.left.match(req) && m.right.match(req)
default:
// This should never happen as it should have been detected during parsing.
log.Warn().Str("operator", m.operator).Msg("Invalid rule operator")
return false
}
}
func (m *matchersTree) addRule(rule *rules.Tree) error {
switch rule.Matcher {
case "and", "or":
m.operator = rule.Matcher
m.left = &matchersTree{}
err := m.left.addRule(rule.RuleLeft)
if err != nil {
return fmt.Errorf("error while adding rule %s: %w", rule.Matcher, err)
}
m.right = &matchersTree{}
return m.right.addRule(rule.RuleRight)
default:
err := rules.CheckRule(rule)
if err != nil {
return fmt.Errorf("error while checking rule %s: %w", rule.Matcher, err)
}
err = httpFuncs[rule.Matcher](m, rule.Value...)
if err != nil {
return fmt.Errorf("error while adding rule %s: %w", rule.Matcher, err)
}
if rule.Not {
matcherFunc := m.matcher
m.matcher = func(req *http.Request) bool {
return !matcherFunc(req)
}
}
}
return nil
}

View file

@ -380,8 +380,6 @@ func Test_addRoutePriority(t *testing.T) {
require.NoError(t, err, route.rule) require.NoError(t, err, route.rule)
} }
muxer.SortRoutes()
w := httptest.NewRecorder() w := httptest.NewRecorder()
req := testhelpers.MustNewRequest(http.MethodGet, test.path, http.NoBody) req := testhelpers.MustNewRequest(http.MethodGet, test.path, http.NoBody)

View file

@ -13,32 +13,6 @@ import (
"github.com/vulcand/predicate" "github.com/vulcand/predicate"
) )
// ParseHostSNI extracts the HostSNIs declared in a rule.
// This is a first naive implementation used in TCP routing.
func ParseHostSNI(rule string) ([]string, error) {
var matchers []string
for matcher := range tcpFuncs {
matchers = append(matchers, matcher)
}
parser, err := rules.NewParser(matchers)
if err != nil {
return nil, err
}
parse, err := parser.Parse(rule)
if err != nil {
return nil, err
}
buildTree, ok := parse.(rules.TreeBuilder)
if !ok {
return nil, fmt.Errorf("error while parsing rule %s", rule)
}
return buildTree().ParseMatchers([]string{"HostSNI"}), nil
}
// ConnData contains TCP connection metadata. // ConnData contains TCP connection metadata.
type ConnData struct { type ConnData struct {
serverName string serverName string
@ -67,7 +41,7 @@ func NewConnData(serverName string, conn tcp.WriteCloser, alpnProtos []string) (
// Muxer defines a muxer that handles TCP routing with rules. // Muxer defines a muxer that handles TCP routing with rules.
type Muxer struct { type Muxer struct {
routes []*route routes routes
parser predicate.Parser parser predicate.Parser
} }
@ -114,9 +88,9 @@ func (m *Muxer) AddRoute(rule string, priority int, handler tcp.Handler) error {
ruleTree := buildTree() ruleTree := buildTree()
var matchers matchersTree var matchers matchersTree
err = addRule(&matchers, ruleTree) err = matchers.addRule(ruleTree)
if err != nil { if err != nil {
return err return fmt.Errorf("error while adding rule %s: %w", rule, err)
} }
var catchAll bool var catchAll bool
@ -144,41 +118,7 @@ func (m *Muxer) AddRoute(rule string, priority int, handler tcp.Handler) error {
} }
m.routes = append(m.routes, newRoute) m.routes = append(m.routes, newRoute)
sort.Sort(routes(m.routes)) sort.Sort(m.routes)
return nil
}
func addRule(tree *matchersTree, rule *rules.Tree) error {
switch rule.Matcher {
case "and", "or":
tree.operator = rule.Matcher
tree.left = &matchersTree{}
err := addRule(tree.left, rule.RuleLeft)
if err != nil {
return err
}
tree.right = &matchersTree{}
return addRule(tree.right, rule.RuleRight)
default:
err := rules.CheckRule(rule)
if err != nil {
return err
}
err = tcpFuncs[rule.Matcher](tree, rule.Value...)
if err != nil {
return err
}
if rule.Not {
matcherFunc := tree.matcher
tree.matcher = func(meta ConnData) bool {
return !matcherFunc(meta)
}
}
}
return nil return nil
} }
@ -188,6 +128,32 @@ func (m *Muxer) HasRoutes() bool {
return len(m.routes) > 0 return len(m.routes) > 0
} }
// ParseHostSNI extracts the HostSNIs declared in a rule.
// This is a first naive implementation used in TCP routing.
func ParseHostSNI(rule string) ([]string, error) {
var matchers []string
for matcher := range tcpFuncs {
matchers = append(matchers, matcher)
}
parser, err := rules.NewParser(matchers)
if err != nil {
return nil, err
}
parse, err := parser.Parse(rule)
if err != nil {
return nil, err
}
buildTree, ok := parse.(rules.TreeBuilder)
if !ok {
return nil, fmt.Errorf("error while parsing rule %s", rule)
}
return buildTree().ParseMatchers([]string{"HostSNI"}), nil
}
// routes implements sort.Interface. // routes implements sort.Interface.
type routes []*route type routes []*route
@ -215,14 +181,12 @@ type route struct {
priority int priority int
} }
// matcher is a matcher func used to match connection properties.
type matcher func(meta ConnData) bool
// matchersTree represents the matchers tree structure. // matchersTree represents the matchers tree structure.
type matchersTree struct { type matchersTree struct {
// matcher is a matcher func used to match connection properties.
// If matcher is not nil, it means that this matcherTree is a leaf of the tree. // If matcher is not nil, it means that this matcherTree is a leaf of the tree.
// It is therefore mutually exclusive with left and right. // It is therefore mutually exclusive with left and right.
matcher matcher matcher func(ConnData) bool
// operator to combine the evaluation of left and right leaves. // operator to combine the evaluation of left and right leaves.
operator string operator string
// Mutually exclusive with matcher. // Mutually exclusive with matcher.
@ -252,3 +216,37 @@ func (m *matchersTree) match(meta ConnData) bool {
return false return false
} }
} }
func (m *matchersTree) addRule(rule *rules.Tree) error {
switch rule.Matcher {
case "and", "or":
m.operator = rule.Matcher
m.left = &matchersTree{}
err := m.left.addRule(rule.RuleLeft)
if err != nil {
return err
}
m.right = &matchersTree{}
return m.right.addRule(rule.RuleRight)
default:
err := rules.CheckRule(rule)
if err != nil {
return err
}
err = tcpFuncs[rule.Matcher](m, rule.Value...)
if err != nil {
return err
}
if rule.Not {
matcherFunc := m.matcher
m.matcher = func(meta ConnData) bool {
return !matcherFunc(meta)
}
}
}
return nil
}

View file

@ -134,8 +134,6 @@ func (m *Manager) buildEntryPointHandler(ctx context.Context, configs map[string
} }
} }
muxer.SortRoutes()
chain := alice.New() chain := alice.New()
chain = chain.Append(func(next http.Handler) (http.Handler, error) { chain = chain.Append(func(next http.Handler) (http.Handler, error) {
return recovery.New(ctx, next) return recovery.New(ctx, next)