From c38d405cfd8bc652374ba38229a5cf26ad1cfa8d Mon Sep 17 00:00:00 2001 From: Tom Moulard Date: Thu, 22 Dec 2022 17:16:04 +0100 Subject: [PATCH] Remove containous/mux from HTTP muxer Co-authored-by: Simon Delicata --- go.mod | 1 - go.sum | 7 +- pkg/muxer/http/matcher.go | 106 ++++++++++------ pkg/muxer/http/matcher_test.go | 35 +++++- pkg/muxer/http/mux.go | 221 +++++++++++++++++++-------------- pkg/muxer/http/mux_test.go | 2 - pkg/muxer/tcp/mux.go | 134 ++++++++++---------- pkg/server/router/router.go | 2 - 8 files changed, 299 insertions(+), 209 deletions(-) diff --git a/go.mod b/go.mod index bca864e5b..c2aa3ba61 100644 --- a/go.mod +++ b/go.mod @@ -385,7 +385,6 @@ require ( replace ( github.com/abbot/go-http-auth => github.com/containous/go-http-auth v0.4.1-0.20200324110947-a37a7636d23e github.com/go-check/check => github.com/containous/check v0.0.0-20170915194414-ca0bf163426a - github.com/gorilla/mux => github.com/containous/mux v0.0.0-20220627093034-b2dd784e613f github.com/mailgun/minheap => github.com/containous/minheap v0.0.0-20190809180810-6e71eb837595 ) diff --git a/go.sum b/go.sum index 8d94db2f0..d94e2248a 100644 --- a/go.sum +++ b/go.sum @@ -457,8 +457,6 @@ github.com/containous/go-http-auth v0.4.1-0.20200324110947-a37a7636d23e h1:D+uTE github.com/containous/go-http-auth v0.4.1-0.20200324110947-a37a7636d23e/go.mod h1:s8kLgBQolDbsJOPVIGCEEv9zGAKUUf/685Gi0Qqg8z8= github.com/containous/minheap v0.0.0-20190809180810-6e71eb837595 h1:aPspFRO6b94To3gl4yTDOEtpjFwXI7V2W+z0JcNljQ4= github.com/containous/minheap v0.0.0-20190809180810-6e71eb837595/go.mod h1:+lHFbEasIiQVGzhVDVw/cn0ZaOzde2OwNncp1NhXV4c= -github.com/containous/mux v0.0.0-20220627093034-b2dd784e613f h1:1uEtynq2C0ljy3630jt7EAxg8jZY2gy6YHdGwdqEpWw= -github.com/containous/mux v0.0.0-20220627093034-b2dd784e613f/go.mod h1:z8WW7n06n8/1xF9Jl9WmuDeZuHAhfL+bwarNjsciwwg= github.com/coredns/coredns v1.1.2/go.mod h1:zASH/MVDgR6XZTbxvOnsZfffS+31vg6Ackf/wo1+AM0= github.com/coreos/bbolt v1.3.2/go.mod h1:iRUV2dpdMOn7Bo10OQBFzIJO9kkE559Wcmn+qkEiiKk= github.com/coreos/etcd v3.3.10+incompatible/go.mod h1:uF7uidLiAD3TWHmW31ZFd/JWoc32PjwdhPthX9715RE= @@ -936,6 +934,11 @@ github.com/gopherjs/gopherjs v0.0.0-20181017120253-0766667cb4d1/go.mod h1:wJfORR github.com/gorilla/context v1.1.1/go.mod h1:kBGZzfjB9CEq2AlWe17Uuf7NDRt0dE0s8S51q0aT7Yg= github.com/gorilla/handlers v0.0.0-20150720190736-60c7bfde3e33/go.mod h1:Qkdc/uu4tH4g6mTK6auzZ766c4CA0Ng8+o/OAirnOIQ= github.com/gorilla/handlers v1.5.1/go.mod h1:t8XrUpc4KVXb7HGyJ4/cEnwQiaxrX/hz1Zv/4g96P1Q= +github.com/gorilla/mux v1.6.2/go.mod h1:1lud6UwP+6orDFRuTfBEV8e9/aOM/c4fVVCaMa2zaAs= +github.com/gorilla/mux v1.7.2/go.mod h1:1lud6UwP+6orDFRuTfBEV8e9/aOM/c4fVVCaMa2zaAs= +github.com/gorilla/mux v1.7.3/go.mod h1:1lud6UwP+6orDFRuTfBEV8e9/aOM/c4fVVCaMa2zaAs= +github.com/gorilla/mux v1.8.0 h1:i40aqfkR1h2SlN9hojwV5ZA91wcXFOvkdNIeFDP5koI= +github.com/gorilla/mux v1.8.0/go.mod h1:DVbg23sWSpFRCP0SfiEN6jmj59UnW/n46BH5rLB71So= github.com/gorilla/websocket v0.0.0-20170926233335-4201258b820c/go.mod h1:E7qHFY5m1UJ88s3WnNqhKjPHQ0heANvMoAMk2YaljkQ= github.com/gorilla/websocket v1.4.0/go.mod h1:E7qHFY5m1UJ88s3WnNqhKjPHQ0heANvMoAMk2YaljkQ= github.com/gorilla/websocket v1.4.1/go.mod h1:YR8l580nyteQvAITg2hZ9XVh4b55+EU/adAjf1fMHhE= diff --git a/pkg/muxer/http/matcher.go b/pkg/muxer/http/matcher.go index d30e6df61..124761fd2 100644 --- a/pkg/muxer/http/matcher.go +++ b/pkg/muxer/http/matcher.go @@ -7,14 +7,13 @@ import ( "strings" "unicode/utf8" - "github.com/gorilla/mux" "github.com/rs/zerolog/log" "github.com/traefik/traefik/v2/pkg/ip" "github.com/traefik/traefik/v2/pkg/middlewares/requestdecorator" "golang.org/x/exp/slices" ) -var httpFuncs = map[string]func(*mux.Route, ...string) error{ +var httpFuncs = map[string]func(*matchersTree, ...string) error{ "ClientIP": expectNParameters(clientIP, 1), "Method": expectNParameters(method, 1), "Host": expectNParameters(host, 1), @@ -28,17 +27,17 @@ var httpFuncs = map[string]func(*mux.Route, ...string) error{ "QueryRegexp": expectNParameters(queryRegexp, 1, 2), } -func expectNParameters(fn func(*mux.Route, ...string) error, n ...int) func(*mux.Route, ...string) error { - return func(route *mux.Route, s ...string) error { +func expectNParameters(fn func(*matchersTree, ...string) error, n ...int) func(*matchersTree, ...string) error { + return func(tree *matchersTree, s ...string) error { if !slices.Contains(n, len(s)) { return fmt.Errorf("unexpected number of parameters; got %d, expected one of %v", len(s), n) } - return fn(route, s...) + return fn(tree, s...) } } -func clientIP(route *mux.Route, clientIP ...string) error { +func clientIP(tree *matchersTree, clientIP ...string) error { checker, err := ip.NewChecker(clientIP) if err != nil { return fmt.Errorf("initializing IP checker for ClientIP matcher: %w", err) @@ -46,7 +45,7 @@ func clientIP(route *mux.Route, clientIP ...string) error { strategy := ip.RemoteAddrStrategy{} - route.MatcherFunc(func(req *http.Request, _ *mux.RouteMatch) bool { + tree.matcher = func(req *http.Request) bool { ok, err := checker.Contains(strategy.GetIP(req)) if err != nil { log.Ctx(req.Context()).Warn().Err(err).Msg("ClientIP matcher: could not match remote address") @@ -54,16 +53,22 @@ func clientIP(route *mux.Route, clientIP ...string) error { } return ok - }) + } return nil } -func method(route *mux.Route, methods ...string) error { - return route.Methods(methods...).GetError() +func method(tree *matchersTree, methods ...string) error { + method := strings.ToUpper(methods[0]) + + tree.matcher = func(req *http.Request) bool { + return method == req.Method + } + + return nil } -func host(route *mux.Route, hosts ...string) error { +func host(tree *matchersTree, hosts ...string) error { host := hosts[0] if !IsASCII(host) { @@ -72,7 +77,7 @@ func host(route *mux.Route, hosts ...string) error { host = strings.ToLower(host) - route.MatcherFunc(func(req *http.Request, _ *mux.RouteMatch) bool { + tree.matcher = func(req *http.Request) bool { reqHost := requestdecorator.GetCanonizedHost(req.Context()) if len(reqHost) == 0 { return false @@ -104,12 +109,12 @@ func host(route *mux.Route, hosts ...string) error { } return false - }) + } return nil } -func hostRegexp(route *mux.Route, hosts ...string) error { +func hostRegexp(tree *matchersTree, hosts ...string) error { host := hosts[0] if !IsASCII(host) { @@ -121,29 +126,29 @@ func hostRegexp(route *mux.Route, hosts ...string) error { return fmt.Errorf("compiling HostRegexp matcher: %w", err) } - route.MatcherFunc(func(req *http.Request, _ *mux.RouteMatch) bool { + tree.matcher = func(req *http.Request) bool { return re.MatchString(requestdecorator.GetCanonizedHost(req.Context())) || re.MatchString(requestdecorator.GetCNAMEFlatten(req.Context())) - }) + } return nil } -func path(route *mux.Route, paths ...string) error { +func path(tree *matchersTree, paths ...string) error { path := paths[0] if !strings.HasPrefix(path, "/") { return fmt.Errorf("path %q does not start with a '/'", path) } - route.MatcherFunc(func(req *http.Request, _ *mux.RouteMatch) bool { + tree.matcher = func(req *http.Request) bool { return req.URL.Path == path - }) + } return nil } -func pathRegexp(route *mux.Route, paths ...string) error { +func pathRegexp(tree *matchersTree, paths ...string) error { path := paths[0] re, err := regexp.Compile(path) @@ -151,36 +156,65 @@ func pathRegexp(route *mux.Route, paths ...string) error { return fmt.Errorf("compiling PathPrefix matcher: %w", err) } - route.MatcherFunc(func(req *http.Request, _ *mux.RouteMatch) bool { + tree.matcher = func(req *http.Request) bool { return re.MatchString(req.URL.Path) - }) + } return nil } -func pathPrefix(route *mux.Route, paths ...string) error { +func pathPrefix(tree *matchersTree, paths ...string) error { path := paths[0] if !strings.HasPrefix(path, "/") { return fmt.Errorf("path %q does not start with a '/'", path) } - route.MatcherFunc(func(req *http.Request, _ *mux.RouteMatch) bool { + tree.matcher = func(req *http.Request) bool { return strings.HasPrefix(req.URL.Path, path) - }) + } return nil } -func header(route *mux.Route, headers ...string) error { - return route.Headers(headers...).GetError() +func header(tree *matchersTree, headers ...string) error { + key, value := http.CanonicalHeaderKey(headers[0]), headers[1] + + tree.matcher = func(req *http.Request) bool { + for _, headerValue := range req.Header[key] { + if headerValue == value { + return true + } + } + + return false + } + + return nil } -func headerRegexp(route *mux.Route, headers ...string) error { - return route.HeadersRegexp(headers...).GetError() +func headerRegexp(tree *matchersTree, headers ...string) error { + key, value := http.CanonicalHeaderKey(headers[0]), headers[1] + + re, err := regexp.Compile(value) + if err != nil { + return fmt.Errorf("compiling HeaderRegexp matcher: %w", err) + } + + tree.matcher = func(req *http.Request) bool { + for _, headerValue := range req.Header[key] { + if re.MatchString(headerValue) { + return true + } + } + + return false + } + + return nil } -func query(route *mux.Route, queries ...string) error { +func query(tree *matchersTree, queries ...string) error { key := queries[0] var value string @@ -188,21 +222,21 @@ func query(route *mux.Route, queries ...string) error { value = queries[1] } - route.MatcherFunc(func(req *http.Request, _ *mux.RouteMatch) bool { + tree.matcher = func(req *http.Request) bool { values, ok := req.URL.Query()[key] if !ok { return false } return slices.Contains(values, value) - }) + } return nil } -func queryRegexp(route *mux.Route, queries ...string) error { +func queryRegexp(tree *matchersTree, queries ...string) error { if len(queries) == 1 { - return query(route, queries...) + return query(tree, queries...) } key, value := queries[0], queries[1] @@ -212,7 +246,7 @@ func queryRegexp(route *mux.Route, queries ...string) error { return fmt.Errorf("compiling QueryRegexp matcher: %w", err) } - route.MatcherFunc(func(req *http.Request, _ *mux.RouteMatch) bool { + tree.matcher = func(req *http.Request) bool { values, ok := req.URL.Query()[key] if !ok { return false @@ -223,7 +257,7 @@ func queryRegexp(route *mux.Route, queries ...string) error { }) return idx >= 0 - }) + } return nil } diff --git a/pkg/muxer/http/matcher_test.go b/pkg/muxer/http/matcher_test.go index 2d31d8dc8..0c3d6dc9f 100644 --- a/pkg/muxer/http/matcher_test.go +++ b/pkg/muxer/http/matcher_test.go @@ -3,6 +3,7 @@ package http import ( "net/http" "net/http/httptest" + "strings" "testing" "github.com/stretchr/testify/assert" @@ -121,16 +122,18 @@ func TestMethodMatcher(t *testing.T) { desc: "valid Method matcher", rule: "Method(`GET`)", expected: map[string]int{ - http.MethodGet: http.StatusOK, - http.MethodPost: http.StatusMethodNotAllowed, + http.MethodGet: http.StatusOK, + http.MethodPost: http.StatusNotFound, + strings.ToLower(http.MethodGet): http.StatusNotFound, }, }, { desc: "valid Method matcher (lower case)", rule: "Method(`get`)", expected: map[string]int{ - http.MethodGet: http.StatusOK, - http.MethodPost: http.StatusMethodNotAllowed, + http.MethodGet: http.StatusOK, + http.MethodPost: http.StatusNotFound, + strings.ToLower(http.MethodGet): http.StatusNotFound, }, }, } @@ -200,6 +203,7 @@ func TestHostMatcher(t *testing.T) { "https://example.com": http.StatusOK, "https://example.com:8080": http.StatusOK, "https://example.com/path": http.StatusOK, + "https://EXAMPLE.COM/path": http.StatusOK, "https://example.org": http.StatusNotFound, "https://example.org/path": http.StatusNotFound, }, @@ -665,6 +669,17 @@ func TestHeaderMatcher(t *testing.T) { {"X-Forwarded-Host": []string{"example.com"}}: http.StatusNotFound, }, }, + { + desc: "valid Header matcher (non-canonical form)", + rule: "Header(`x-forwarded-proto`, `https`)", + expected: map[*http.Header]int{ + {"X-Forwarded-Proto": []string{"https"}}: http.StatusOK, + {"x-forwarded-proto": []string{"https"}}: http.StatusNotFound, + {"X-Forwarded-Proto": []string{"http", "https"}}: http.StatusOK, + {"X-Forwarded-Proto": []string{"https", "http"}}: http.StatusOK, + {"X-Forwarded-Host": []string{"example.com"}}: http.StatusNotFound, + }, + }, } for _, test := range testCases { @@ -747,6 +762,18 @@ func TestHeaderRegexpMatcher(t *testing.T) { {"X-Forwarded-Host": []string{"example.com"}}: http.StatusNotFound, }, }, + { + desc: "valid HeaderRegexp matcher (non-canonical form)", + rule: "HeaderRegexp(`x-forwarded-proto`, `^https?$`)", + expected: map[*http.Header]int{ + {"X-Forwarded-Proto": []string{"http"}}: http.StatusOK, + {"x-forwarded-proto": []string{"http"}}: http.StatusNotFound, + {"X-Forwarded-Proto": []string{"https"}}: http.StatusOK, + {"X-Forwarded-Proto": []string{"HTTPS"}}: http.StatusNotFound, + {"X-Forwarded-Proto": []string{"ws", "https"}}: http.StatusOK, + {"X-Forwarded-Host": []string{"example.com"}}: http.StatusNotFound, + }, + }, { desc: "valid HeaderRegexp matcher with Traefik v2 syntax", rule: "HeaderRegexp(`X-Forwarded-Proto`, `http{secure:s?}`)", diff --git a/pkg/muxer/http/mux.go b/pkg/muxer/http/mux.go index 237977044..3039075ee 100644 --- a/pkg/muxer/http/mux.go +++ b/pkg/muxer/http/mux.go @@ -3,15 +3,16 @@ package http import ( "fmt" "net/http" + "sort" - "github.com/gorilla/mux" + "github.com/rs/zerolog/log" "github.com/traefik/traefik/v2/pkg/rules" "github.com/vulcand/predicate" ) // Muxer handles routing with rules. type Muxer struct { - *mux.Router + routes routes parser predicate.Parser } @@ -24,18 +25,30 @@ func NewMuxer() (*Muxer, error) { parser, err := rules.NewParser(matchers) if err != nil { - return nil, err + return nil, fmt.Errorf("error while creating parser: %w", err) } return &Muxer{ - Router: mux.NewRouter().SkipClean(true), parser: parser, }, nil } +// ServeHTTP forwards the connection to the matching HTTP handler. +// Serves 404 if no handler is found. +func (m *Muxer) ServeHTTP(rw http.ResponseWriter, req *http.Request) { + for _, route := range m.routes { + if route.matchers.match(req) { + route.handler.ServeHTTP(rw, req) + return + } + } + + http.NotFoundHandler().ServeHTTP(rw, req) +} + // AddRoute add a new route to the router. -func (r *Muxer) AddRoute(rule string, priority int, handler http.Handler) error { - parse, err := r.parser.Parse(rule) +func (m *Muxer) AddRoute(rule string, priority int, handler http.Handler) error { + parse, err := m.parser.Parse(rule) if err != nil { return fmt.Errorf("error while parsing rule %s: %w", rule, err) } @@ -45,101 +58,27 @@ func (r *Muxer) AddRoute(rule string, priority int, handler http.Handler) error return fmt.Errorf("error while parsing rule %s", rule) } + var matchers matchersTree + err = matchers.addRule(buildTree()) + if err != nil { + return fmt.Errorf("error while adding rule %s: %w", rule, err) + } + if priority == 0 { priority = len(rule) } - route := r.NewRoute().Handler(handler).Priority(priority) + m.routes = append(m.routes, &route{ + handler: handler, + matchers: matchers, + priority: priority, + }) - err = addRuleOnRoute(route, buildTree()) - if err != nil { - route.BuildOnly() - return err - } + sort.Sort(m.routes) return nil } -func addRuleOnRouter(router *mux.Router, rule *rules.Tree) error { - switch rule.Matcher { - case "and": - route := router.NewRoute() - err := addRuleOnRoute(route, rule.RuleLeft) - if err != nil { - return err - } - - return addRuleOnRoute(route, rule.RuleRight) - case "or": - err := addRuleOnRouter(router, rule.RuleLeft) - if err != nil { - return err - } - - return addRuleOnRouter(router, rule.RuleRight) - default: - err := rules.CheckRule(rule) - if err != nil { - return err - } - - if rule.Not { - return not(httpFuncs[rule.Matcher])(router.NewRoute(), rule.Value...) - } - - return httpFuncs[rule.Matcher](router.NewRoute(), rule.Value...) - } -} - -func addRuleOnRoute(route *mux.Route, rule *rules.Tree) error { - switch rule.Matcher { - case "and": - err := addRuleOnRoute(route, rule.RuleLeft) - if err != nil { - return err - } - - return addRuleOnRoute(route, rule.RuleRight) - case "or": - subRouter := route.Subrouter() - - err := addRuleOnRouter(subRouter, rule.RuleLeft) - if err != nil { - return err - } - - return addRuleOnRouter(subRouter, rule.RuleRight) - default: - err := rules.CheckRule(rule) - if err != nil { - return err - } - - if rule.Not { - return not(httpFuncs[rule.Matcher])(route, rule.Value...) - } - - return httpFuncs[rule.Matcher](route, rule.Value...) - } -} - -func not(m func(*mux.Route, ...string) error) func(*mux.Route, ...string) error { - return func(r *mux.Route, v ...string) error { - router := mux.NewRouter() - - err := m(router.NewRoute(), v...) - if err != nil { - return err - } - - r.MatcherFunc(func(req *http.Request, ma *mux.RouteMatch) bool { - return !router.Match(req, ma) - }) - - return nil - } -} - // ParseDomains extract domains from rule. func ParseDomains(rule string) ([]string, error) { var matchers []string @@ -149,12 +88,12 @@ func ParseDomains(rule string) ([]string, error) { parser, err := rules.NewParser(matchers) if err != nil { - return nil, err + return nil, fmt.Errorf("error while creating parser: %w", err) } parse, err := parser.Parse(rule) if err != nil { - return nil, err + return nil, fmt.Errorf("error while parsing rule %s: %w", rule, err) } buildTree, ok := parse.(rules.TreeBuilder) @@ -164,3 +103,97 @@ func ParseDomains(rule string) ([]string, error) { return buildTree().ParseMatchers([]string{"Host"}), nil } + +// routes implements sort.Interface. +type routes []*route + +// Len implements sort.Interface. +func (r routes) Len() int { return len(r) } + +// Swap implements sort.Interface. +func (r routes) Swap(i, j int) { r[i], r[j] = r[j], r[i] } + +// Less implements sort.Interface. +func (r routes) Less(i, j int) bool { return r[i].priority > r[j].priority } + +// route holds the matchers to match HTTP route, +// and the handler that will serve the request. +type route struct { + // matchers tree structure reflecting the rule. + matchers matchersTree + // handler responsible for handling the route. + handler http.Handler + // priority is used to disambiguate between two (or more) rules that would all match for a given request. + // Computed from the matching rule length, if not user-set. + priority int +} + +// matchersTree represents the matchers tree structure. +type matchersTree struct { + // matcher is a matcher func used to match HTTP request properties. + // If matcher is not nil, it means that this matcherTree is a leaf of the tree. + // It is therefore mutually exclusive with left and right. + matcher func(*http.Request) bool + // operator to combine the evaluation of left and right leaves. + operator string + // Mutually exclusive with matcher. + left *matchersTree + right *matchersTree +} + +func (m *matchersTree) match(req *http.Request) bool { + if m == nil { + // This should never happen as it should have been detected during parsing. + log.Warn().Msg("Rule matcher is nil") + return false + } + + if m.matcher != nil { + return m.matcher(req) + } + + switch m.operator { + case "or": + return m.left.match(req) || m.right.match(req) + case "and": + return m.left.match(req) && m.right.match(req) + default: + // This should never happen as it should have been detected during parsing. + log.Warn().Str("operator", m.operator).Msg("Invalid rule operator") + return false + } +} + +func (m *matchersTree) addRule(rule *rules.Tree) error { + switch rule.Matcher { + case "and", "or": + m.operator = rule.Matcher + m.left = &matchersTree{} + err := m.left.addRule(rule.RuleLeft) + if err != nil { + return fmt.Errorf("error while adding rule %s: %w", rule.Matcher, err) + } + + m.right = &matchersTree{} + return m.right.addRule(rule.RuleRight) + default: + err := rules.CheckRule(rule) + if err != nil { + return fmt.Errorf("error while checking rule %s: %w", rule.Matcher, err) + } + + err = httpFuncs[rule.Matcher](m, rule.Value...) + if err != nil { + return fmt.Errorf("error while adding rule %s: %w", rule.Matcher, err) + } + + if rule.Not { + matcherFunc := m.matcher + m.matcher = func(req *http.Request) bool { + return !matcherFunc(req) + } + } + } + + return nil +} diff --git a/pkg/muxer/http/mux_test.go b/pkg/muxer/http/mux_test.go index bdc98df37..0c1903de9 100644 --- a/pkg/muxer/http/mux_test.go +++ b/pkg/muxer/http/mux_test.go @@ -380,8 +380,6 @@ func Test_addRoutePriority(t *testing.T) { require.NoError(t, err, route.rule) } - muxer.SortRoutes() - w := httptest.NewRecorder() req := testhelpers.MustNewRequest(http.MethodGet, test.path, http.NoBody) diff --git a/pkg/muxer/tcp/mux.go b/pkg/muxer/tcp/mux.go index 3c5abffaf..00ec577af 100644 --- a/pkg/muxer/tcp/mux.go +++ b/pkg/muxer/tcp/mux.go @@ -13,32 +13,6 @@ import ( "github.com/vulcand/predicate" ) -// ParseHostSNI extracts the HostSNIs declared in a rule. -// This is a first naive implementation used in TCP routing. -func ParseHostSNI(rule string) ([]string, error) { - var matchers []string - for matcher := range tcpFuncs { - matchers = append(matchers, matcher) - } - - parser, err := rules.NewParser(matchers) - if err != nil { - return nil, err - } - - parse, err := parser.Parse(rule) - if err != nil { - return nil, err - } - - buildTree, ok := parse.(rules.TreeBuilder) - if !ok { - return nil, fmt.Errorf("error while parsing rule %s", rule) - } - - return buildTree().ParseMatchers([]string{"HostSNI"}), nil -} - // ConnData contains TCP connection metadata. type ConnData struct { serverName string @@ -67,7 +41,7 @@ func NewConnData(serverName string, conn tcp.WriteCloser, alpnProtos []string) ( // Muxer defines a muxer that handles TCP routing with rules. type Muxer struct { - routes []*route + routes routes parser predicate.Parser } @@ -114,9 +88,9 @@ func (m *Muxer) AddRoute(rule string, priority int, handler tcp.Handler) error { ruleTree := buildTree() var matchers matchersTree - err = addRule(&matchers, ruleTree) + err = matchers.addRule(ruleTree) if err != nil { - return err + return fmt.Errorf("error while adding rule %s: %w", rule, err) } var catchAll bool @@ -144,41 +118,7 @@ func (m *Muxer) AddRoute(rule string, priority int, handler tcp.Handler) error { } m.routes = append(m.routes, newRoute) - sort.Sort(routes(m.routes)) - - return nil -} - -func addRule(tree *matchersTree, rule *rules.Tree) error { - switch rule.Matcher { - case "and", "or": - tree.operator = rule.Matcher - tree.left = &matchersTree{} - err := addRule(tree.left, rule.RuleLeft) - if err != nil { - return err - } - - tree.right = &matchersTree{} - return addRule(tree.right, rule.RuleRight) - default: - err := rules.CheckRule(rule) - if err != nil { - return err - } - - err = tcpFuncs[rule.Matcher](tree, rule.Value...) - if err != nil { - return err - } - - if rule.Not { - matcherFunc := tree.matcher - tree.matcher = func(meta ConnData) bool { - return !matcherFunc(meta) - } - } - } + sort.Sort(m.routes) return nil } @@ -188,6 +128,32 @@ func (m *Muxer) HasRoutes() bool { return len(m.routes) > 0 } +// ParseHostSNI extracts the HostSNIs declared in a rule. +// This is a first naive implementation used in TCP routing. +func ParseHostSNI(rule string) ([]string, error) { + var matchers []string + for matcher := range tcpFuncs { + matchers = append(matchers, matcher) + } + + parser, err := rules.NewParser(matchers) + if err != nil { + return nil, err + } + + parse, err := parser.Parse(rule) + if err != nil { + return nil, err + } + + buildTree, ok := parse.(rules.TreeBuilder) + if !ok { + return nil, fmt.Errorf("error while parsing rule %s", rule) + } + + return buildTree().ParseMatchers([]string{"HostSNI"}), nil +} + // routes implements sort.Interface. type routes []*route @@ -215,14 +181,12 @@ type route struct { priority int } -// matcher is a matcher func used to match connection properties. -type matcher func(meta ConnData) bool - // matchersTree represents the matchers tree structure. type matchersTree struct { + // matcher is a matcher func used to match connection properties. // If matcher is not nil, it means that this matcherTree is a leaf of the tree. // It is therefore mutually exclusive with left and right. - matcher matcher + matcher func(ConnData) bool // operator to combine the evaluation of left and right leaves. operator string // Mutually exclusive with matcher. @@ -252,3 +216,37 @@ func (m *matchersTree) match(meta ConnData) bool { return false } } + +func (m *matchersTree) addRule(rule *rules.Tree) error { + switch rule.Matcher { + case "and", "or": + m.operator = rule.Matcher + m.left = &matchersTree{} + err := m.left.addRule(rule.RuleLeft) + if err != nil { + return err + } + + m.right = &matchersTree{} + return m.right.addRule(rule.RuleRight) + default: + err := rules.CheckRule(rule) + if err != nil { + return err + } + + err = tcpFuncs[rule.Matcher](m, rule.Value...) + if err != nil { + return err + } + + if rule.Not { + matcherFunc := m.matcher + m.matcher = func(meta ConnData) bool { + return !matcherFunc(meta) + } + } + } + + return nil +} diff --git a/pkg/server/router/router.go b/pkg/server/router/router.go index 388e792d3..f19a46eb5 100644 --- a/pkg/server/router/router.go +++ b/pkg/server/router/router.go @@ -134,8 +134,6 @@ func (m *Manager) buildEntryPointHandler(ctx context.Context, configs map[string } } - muxer.SortRoutes() - chain := alice.New() chain = chain.Append(func(next http.Handler) (http.Handler, error) { return recovery.New(ctx, next)