From dad76e047831c9b083ea68a24a14a5a5ccec9eff Mon Sep 17 00:00:00 2001 From: Daniel Tomcej Date: Thu, 17 Mar 2022 11:02:08 -0600 Subject: [PATCH] Add muxer for TCP Routers --- .golangci.toml | 3 + .../dynamic-configuration/docker-labels.yml | 2 + .../reference/dynamic-configuration/file.toml | 2 + .../reference/dynamic-configuration/file.yaml | 2 + .../reference/dynamic-configuration/kv-ref.md | 2 + .../marathon-labels.json | 2 + .../traefik.containo.us_ingressroutetcps.yaml | 2 + .../routing/providers/consul-catalog.md | 3 +- docs/content/routing/providers/docker.md | 8 + docs/content/routing/providers/ecs.md | 8 + .../routing/providers/kubernetes-crd.md | 121 +-- docs/content/routing/providers/kv.md | 9 +- docs/content/routing/providers/marathon.md | 8 + docs/content/routing/providers/rancher.md | 8 + docs/content/routing/routers/index.md | 135 ++- integration/fixtures/k8s/01-traefik-crd.yml | 2 + integration/simple_test.go | 2 +- pkg/config/dynamic/tcp_config.go | 1 + pkg/config/label/label_test.go | 24 +- pkg/{rules/rules.go => muxer/http/mux.go} | 124 +-- .../rules_test.go => muxer/http/mux_test.go} | 66 +- pkg/muxer/tcp/mux.go | 328 ++++++++ pkg/muxer/tcp/mux_test.go | 776 ++++++++++++++++++ pkg/provider/acme/provider.go | 7 +- pkg/provider/kubernetes/crd/kubernetes_tcp.go | 1 + .../crd/traefik/v1alpha1/ingressroutetcp.go | 1 + pkg/rules/parser.go | 229 +++--- pkg/rules/parser_test.go | 301 +++++++ pkg/server/router/router.go | 10 +- pkg/server/router/tcp/manager.go | 365 ++++++++ .../tcp/{router_test.go => manager_test.go} | 29 + pkg/server/router/tcp/router.go | 617 ++++++++------ pkg/server/routerfactory.go | 17 +- pkg/server/server_entrypoint_tcp.go | 17 +- pkg/server/server_entrypoint_tcp_http3.go | 4 +- .../server_entrypoint_tcp_http3_test.go | 10 +- pkg/server/server_entrypoint_tcp_test.go | 28 +- .../service/loadbalancer/mirror/mirror.go | 2 +- pkg/tcp/router.go | 286 ------- 39 files changed, 2661 insertions(+), 901 deletions(-) rename pkg/{rules/rules.go => muxer/http/mux.go} (71%) rename pkg/{rules/rules_test.go => muxer/http/mux_test.go} (94%) create mode 100644 pkg/muxer/tcp/mux.go create mode 100644 pkg/muxer/tcp/mux_test.go create mode 100644 pkg/rules/parser_test.go create mode 100644 pkg/server/router/tcp/manager.go rename pkg/server/router/tcp/{router_test.go => manager_test.go} (95%) delete mode 100644 pkg/tcp/router.go diff --git a/.golangci.toml b/.golangci.toml index ae9b717c8..8e3e3d246 100644 --- a/.golangci.toml +++ b/.golangci.toml @@ -214,3 +214,6 @@ [[issues.exclude-rules]] path = "(.+)\\.go" text = "struct-tag: unknown option 'inline' in JSON tag" + [[issues.exclude-rules]] + path = "pkg/server/router/tcp/manager.go" + text = "Function 'buildEntryPointHandler' is too long (.+)" diff --git a/docs/content/reference/dynamic-configuration/docker-labels.yml b/docs/content/reference/dynamic-configuration/docker-labels.yml index ee2f33122..a28358558 100644 --- a/docs/content/reference/dynamic-configuration/docker-labels.yml +++ b/docs/content/reference/dynamic-configuration/docker-labels.yml @@ -165,6 +165,7 @@ - "traefik.tcp.routers.tcprouter0.entrypoints=foobar, foobar" - "traefik.tcp.routers.tcprouter0.middlewares=foobar, foobar" - "traefik.tcp.routers.tcprouter0.rule=foobar" +- "traefik.tcp.routers.tcprouter0.priority=42" - "traefik.tcp.routers.tcprouter0.service=foobar" - "traefik.tcp.routers.tcprouter0.tls=true" - "traefik.tcp.routers.tcprouter0.tls.certresolver=foobar" @@ -177,6 +178,7 @@ - "traefik.tcp.routers.tcprouter1.entrypoints=foobar, foobar" - "traefik.tcp.routers.tcprouter1.middlewares=foobar, foobar" - "traefik.tcp.routers.tcprouter1.rule=foobar" +- "traefik.tcp.routers.tcprouter1.priority=42" - "traefik.tcp.routers.tcprouter1.service=foobar" - "traefik.tcp.routers.tcprouter1.tls=true" - "traefik.tcp.routers.tcprouter1.tls.certresolver=foobar" diff --git a/docs/content/reference/dynamic-configuration/file.toml b/docs/content/reference/dynamic-configuration/file.toml index 08c45f0b7..87ce97c27 100644 --- a/docs/content/reference/dynamic-configuration/file.toml +++ b/docs/content/reference/dynamic-configuration/file.toml @@ -326,6 +326,7 @@ middlewares = ["foobar", "foobar"] service = "foobar" rule = "foobar" + priority = 42 [tcp.routers.TCPRouter0.tls] passthrough = true options = "foobar" @@ -343,6 +344,7 @@ middlewares = ["foobar", "foobar"] service = "foobar" rule = "foobar" + priority = 42 [tcp.routers.TCPRouter1.tls] passthrough = true options = "foobar" diff --git a/docs/content/reference/dynamic-configuration/file.yaml b/docs/content/reference/dynamic-configuration/file.yaml index 0fa87fa6c..b5345b4b2 100644 --- a/docs/content/reference/dynamic-configuration/file.yaml +++ b/docs/content/reference/dynamic-configuration/file.yaml @@ -366,6 +366,7 @@ tcp: - foobar service: foobar rule: foobar + priority: 42 tls: passthrough: true options: foobar @@ -388,6 +389,7 @@ tcp: - foobar service: foobar rule: foobar + priority: 42 tls: passthrough: true options: foobar diff --git a/docs/content/reference/dynamic-configuration/kv-ref.md b/docs/content/reference/dynamic-configuration/kv-ref.md index 591052b6b..959980b9f 100644 --- a/docs/content/reference/dynamic-configuration/kv-ref.md +++ b/docs/content/reference/dynamic-configuration/kv-ref.md @@ -237,6 +237,7 @@ | `traefik/tcp/routers/TCPRouter0/entryPoints/1` | `foobar` | | `traefik/tcp/routers/TCPRouter0/middlewares/0` | `foobar` | | `traefik/tcp/routers/TCPRouter0/middlewares/1` | `foobar` | +| `traefik/tcp/routers/TCPRouter0/priority` | `42` | | `traefik/tcp/routers/TCPRouter0/rule` | `foobar` | | `traefik/tcp/routers/TCPRouter0/service` | `foobar` | | `traefik/tcp/routers/TCPRouter0/tls/certResolver` | `foobar` | @@ -252,6 +253,7 @@ | `traefik/tcp/routers/TCPRouter1/entryPoints/1` | `foobar` | | `traefik/tcp/routers/TCPRouter1/middlewares/0` | `foobar` | | `traefik/tcp/routers/TCPRouter1/middlewares/1` | `foobar` | +| `traefik/tcp/routers/TCPRouter1/priority` | `42` | | `traefik/tcp/routers/TCPRouter1/rule` | `foobar` | | `traefik/tcp/routers/TCPRouter1/service` | `foobar` | | `traefik/tcp/routers/TCPRouter1/tls/certResolver` | `foobar` | diff --git a/docs/content/reference/dynamic-configuration/marathon-labels.json b/docs/content/reference/dynamic-configuration/marathon-labels.json index 58128731d..900c50727 100644 --- a/docs/content/reference/dynamic-configuration/marathon-labels.json +++ b/docs/content/reference/dynamic-configuration/marathon-labels.json @@ -163,6 +163,7 @@ "traefik.http.services.service01.loadbalancer.serverstransport": "foobar", "traefik.tcp.routers.tcprouter0.entrypoints": "foobar, foobar", "traefik.tcp.routers.tcprouter0.rule": "foobar", +"traefik.tcp.routers.tcprouter0.priority": "42", "traefik.tcp.routers.tcprouter0.service": "foobar", "traefik.tcp.routers.tcprouter0.tls": "true", "traefik.tcp.routers.tcprouter0.tls.certresolver": "foobar", @@ -174,6 +175,7 @@ "traefik.tcp.routers.tcprouter0.tls.passthrough": "true", "traefik.tcp.routers.tcprouter1.entrypoints": "foobar, foobar", "traefik.tcp.routers.tcprouter1.rule": "foobar", +"traefik.tcp.routers.tcprouter1.priority": "42", "traefik.tcp.routers.tcprouter1.service": "foobar", "traefik.tcp.routers.tcprouter1.tls": "true", "traefik.tcp.routers.tcprouter1.tls.certresolver": "foobar", diff --git a/docs/content/reference/dynamic-configuration/traefik.containo.us_ingressroutetcps.yaml b/docs/content/reference/dynamic-configuration/traefik.containo.us_ingressroutetcps.yaml index 6432396dd..e20d612a7 100644 --- a/docs/content/reference/dynamic-configuration/traefik.containo.us_ingressroutetcps.yaml +++ b/docs/content/reference/dynamic-configuration/traefik.containo.us_ingressroutetcps.yaml @@ -62,6 +62,8 @@ spec: - name type: object type: array + priority: + type: integer services: items: description: ServiceTCP defines an upstream to proxy traffic. diff --git a/docs/content/routing/providers/consul-catalog.md b/docs/content/routing/providers/consul-catalog.md index 1431ad71c..1d452370a 100644 --- a/docs/content/routing/providers/consul-catalog.md +++ b/docs/content/routing/providers/consul-catalog.md @@ -99,7 +99,8 @@ For example, to change the rule, you could add the tag ```traefik.http.routers.m ``` ??? info "`traefik.http.routers..priority`" - + + See [priority](../routers/index.md#priority) for more information. ```yaml traefik.http.routers.myrouter.priority=42 diff --git a/docs/content/routing/providers/docker.md b/docs/content/routing/providers/docker.md index 923c21237..77cdce49a 100644 --- a/docs/content/routing/providers/docker.md +++ b/docs/content/routing/providers/docker.md @@ -538,6 +538,14 @@ You can declare TCP Routers and/or Services using labels. - "traefik.tcp.routers.mytcprouter.tls.passthrough=true" ``` +??? info "`traefik.tcp.routers..priority`" + + See [priority](../routers/index.md#priority_1) for more information. + + ```yaml + - "traefik.tcp.routers.myrouter.priority=42" + ``` + #### TCP Services ??? info "`traefik.tcp.services..loadbalancer.server.port`" diff --git a/docs/content/routing/providers/ecs.md b/docs/content/routing/providers/ecs.md index cfcb1af8c..5e60fd239 100644 --- a/docs/content/routing/providers/ecs.md +++ b/docs/content/routing/providers/ecs.md @@ -379,6 +379,14 @@ You can declare TCP Routers and/or Services using labels. traefik.tcp.routers.mytcprouter.tls.passthrough=true ``` +??? info "`traefik.tcp.routers..priority`" + + See [priority](../routers/index.md#priority_1) for more information. + + ```yaml + traefik.tcp.routers.myrouter.priority=42 + ``` + #### TCP Services ??? info "`traefik.tcp.services..loadbalancer.server.port`" diff --git a/docs/content/routing/providers/kubernetes-crd.md b/docs/content/routing/providers/kubernetes-crd.md index 0eb1552b7..a23cd86d6 100644 --- a/docs/content/routing/providers/kubernetes-crd.md +++ b/docs/content/routing/providers/kubernetes-crd.md @@ -357,27 +357,27 @@ Register the `IngressRoute` [kind](../../reference/dynamic-configuration/kuberne - b.example.net ``` -| Ref | Attribute | Purpose | -|------|--------------------------------|--------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------| -| [1] | `entryPoints` | List of [entry points](../routers/index.md#entrypoints) names | -| [2] | `routes` | List of routes | -| [3] | `routes[n].match` | Defines the [rule](../routers/index.md#rule) corresponding to an underlying router. | -| [4] | `routes[n].priority` | [Disambiguate](../routers/index.md#priority) rules of the same length, for route matching | -| [5] | `routes[n].middlewares` | List of reference to [Middleware](#kind-middleware) | -| [6] | `middlewares[n].name` | Defines the [Middleware](#kind-middleware) name | -| [7] | `middlewares[n].namespace` | Defines the [Middleware](#kind-middleware) namespace | -| [8] | `routes[n].services` | List of any combination of [TraefikService](#kind-traefikservice) and reference to a [Kubernetes service](https://kubernetes.io/docs/concepts/services-networking/service/) (See below for `ExternalName Service` setup) | -| [9] | `services[n].port` | Defines the port of a [Kubernetes service](https://kubernetes.io/docs/concepts/services-networking/service/). This can be a reference to a named port. | +| Ref | Attribute | Purpose | +|------|--------------------------------|----------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------| +| [1] | `entryPoints` | List of [entry points](../routers/index.md#entrypoints) names | +| [2] | `routes` | List of routes | +| [3] | `routes[n].match` | Defines the [rule](../routers/index.md#rule) corresponding to an underlying router. | +| [4] | `routes[n].priority` | Defines the [priority](../routers/index.md#priority) to disambiguate rules of the same length, for route matching | +| [5] | `routes[n].middlewares` | List of reference to [Middleware](#kind-middleware) | +| [6] | `middlewares[n].name` | Defines the [Middleware](#kind-middleware) name | +| [7] | `middlewares[n].namespace` | Defines the [Middleware](#kind-middleware) namespace | +| [8] | `routes[n].services` | List of any combination of [TraefikService](#kind-traefikservice) and reference to a [Kubernetes service](https://kubernetes.io/docs/concepts/services-networking/service/) (See below for `ExternalName Service` setup) | +| [9] | `services[n].port` | Defines the port of a [Kubernetes service](https://kubernetes.io/docs/concepts/services-networking/service/). This can be a reference to a named port. | | [10] | `services[n].serversTransport` | Defines the reference to a [ServersTransport](#kind-serverstransport). The ServersTransport namespace is assumed to be the [Kubernetes service](https://kubernetes.io/docs/concepts/services-networking/service/) namespace (see [ServersTransport reference](#serverstransport-reference)). | -| [11] | `tls` | Defines [TLS](../routers/index.md#tls) certificate configuration | -| [12] | `tls.secretName` | Defines the [secret](https://kubernetes.io/docs/concepts/configuration/secret/) name used to store the certificate (in the `IngressRoute` namespace) | -| [13] | `tls.options` | Defines the reference to a [TLSOption](#kind-tlsoption) | -| [14] | `options.name` | Defines the [TLSOption](#kind-tlsoption) name | -| [15] | `options.namespace` | Defines the [TLSOption](#kind-tlsoption) namespace | -| [16] | `tls.certResolver` | Defines the reference to a [CertResolver](../routers/index.md#certresolver) | -| [17] | `tls.domains` | List of [domains](../routers/index.md#domains) | -| [18] | `domains[n].main` | Defines the main domain name | -| [19] | `domains[n].sans` | List of SANs (alternative domains) | +| [11] | `tls` | Defines [TLS](../routers/index.md#tls) certificate configuration | +| [12] | `tls.secretName` | Defines the [secret](https://kubernetes.io/docs/concepts/configuration/secret/) name used to store the certificate (in the `IngressRoute` namespace) | +| [13] | `tls.options` | Defines the reference to a [TLSOption](#kind-tlsoption) | +| [14] | `options.name` | Defines the [TLSOption](#kind-tlsoption) name | +| [15] | `options.namespace` | Defines the [TLSOption](#kind-tlsoption) namespace | +| [16] | `tls.certResolver` | Defines the reference to a [CertResolver](../routers/index.md#certresolver) | +| [17] | `tls.domains` | List of [domains](../routers/index.md#domains) | +| [18] | `domains[n].main` | Defines the main domain name | +| [19] | `domains[n].sans` | List of SANs (alternative domains) | ??? example "Declaring an IngressRoute" @@ -1088,54 +1088,56 @@ Register the `IngressRouteTCP` [kind](../../reference/dynamic-configuration/kube - footcp routes: # [2] - match: HostSNI(`*`) # [3] + priority: 10 # [4] middlewares: - - name: middleware1 # [4] - namespace: default # [5] - services: # [6] - - name: foo # [7] - port: 8080 # [8] - weight: 10 # [9] - terminationDelay: 400 # [10] - proxyProtocol: # [11] - version: 1 # [12] - tls: # [13] - secretName: supersecret # [14] - options: # [15] - name: opt # [16] - namespace: default # [17] - certResolver: foo # [18] - domains: # [19] - - main: example.net # [20] - sans: # [21] + - name: middleware1 # [5] + namespace: default # [6] + services: # [7] + - name: foo # [8] + port: 8080 # [9] + weight: 10 # [10] + terminationDelay: 400 # [11] + proxyProtocol: # [12] + version: 1 # [13] + tls: # [14] + secretName: supersecret # [15] + options: # [16] + name: opt # [17] + namespace: default # [18] + certResolver: foo # [19] + domains: # [20] + - main: example.net # [21] + sans: # [22] - a.example.net - b.example.net - passthrough: false # [22] + passthrough: false # [23] ``` | Ref | Attribute | Purpose | |------|--------------------------------|--------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------| | [1] | `entryPoints` | List of [entrypoints](../routers/index.md#entrypoints_1) names | | [2] | `routes` | List of routes | -| [3] | `routes[n].match` | Defines the [rule](../routers/index.md#rule_1) corresponding to an underlying router | -| [4] | `middlewares[n].name` | Defines the [MiddlewareTCP](#kind-middlewaretcp) name | -| [5] | `middlewares[n].namespace` | Defines the [MiddlewareTCP](#kind-middlewaretcp) namespace | -| [6] | `routes[n].services` | List of [Kubernetes service](https://kubernetes.io/docs/concepts/services-networking/service/) definitions (See below for `ExternalName Service` setup) | -| [7] | `services[n].name` | Defines the name of a [Kubernetes service](https://kubernetes.io/docs/concepts/services-networking/service/) | -| [8] | `services[n].port` | Defines the port of a [Kubernetes service](https://kubernetes.io/docs/concepts/services-networking/service/). This can be a reference to a named port. | -| [9] | `services[n].weight` | Defines the weight to apply to the server load balancing | -| [10] | `services[n].terminationDelay` | corresponds to the deadline that the proxy sets, after one of its connected peers indicates it has closed the writing capability of its connection, to close the reading capability as well, hence fully terminating the connection. It is a duration in milliseconds, defaulting to 100. A negative value means an infinite deadline (i.e. the reading capability is never closed). | -| [11] | `proxyProtocol` | Defines the [PROXY protocol](../services/index.md#proxy-protocol) configuration | -| [12] | `version` | Defines the [PROXY protocol](../services/index.md#proxy-protocol) version | -| [13] | `tls` | Defines [TLS](../routers/index.md#tls_1) certificate configuration | -| [14] | `tls.secretName` | Defines the [secret](https://kubernetes.io/docs/concepts/configuration/secret/) name used to store the certificate (in the `IngressRoute` namespace) | -| [15] | `tls.options` | Defines the reference to a [TLSOption](#kind-tlsoption) | -| [16] | `options.name` | Defines the [TLSOption](#kind-tlsoption) name | -| [17] | `options.namespace` | Defines the [TLSOption](#kind-tlsoption) namespace | -| [18] | `tls.certResolver` | Defines the reference to a [CertResolver](../routers/index.md#certresolver_1) | -| [19] | `tls.domains` | List of [domains](../routers/index.md#domains_1) | -| [20] | `domains[n].main` | Defines the main domain name | -| [21] | `domains[n].sans` | List of SANs (alternative domains) | -| [22] | `tls.passthrough` | If `true`, delegates the TLS termination to the backend | +| [3] | `routes[n].match` | Defines the [rule](../routers/index.md#rule_1) of the underlying router | +| [4] | `routes[n].priority` | Defines the [priority](../routers/index.md#priority_1) to disambiguate rules of the same length, for route matching | +| [5] | `middlewares[n].name` | Defines the [MiddlewareTCP](#kind-middlewaretcp) name | +| [6] | `middlewares[n].namespace` | Defines the [MiddlewareTCP](#kind-middlewaretcp) namespace | +| [7] | `routes[n].services` | List of [Kubernetes service](https://kubernetes.io/docs/concepts/services-networking/service/) definitions (See below for `ExternalName Service` setup) | +| [8] | `services[n].name` | Defines the name of a [Kubernetes service](https://kubernetes.io/docs/concepts/services-networking/service/) | +| [9] | `services[n].port` | Defines the port of a [Kubernetes service](https://kubernetes.io/docs/concepts/services-networking/service/). This can be a reference to a named port. | +| [10] | `services[n].weight` | Defines the weight to apply to the server load balancing | +| [11] | `services[n].terminationDelay` | corresponds to the deadline that the proxy sets, after one of its connected peers indicates it has closed the writing capability of its connection, to close the reading capability as well, hence fully terminating the connection. It is a duration in milliseconds, defaulting to 100. A negative value means an infinite deadline (i.e. the reading capability is never closed). | +| [12] | `proxyProtocol` | Defines the [PROXY protocol](../services/index.md#proxy-protocol) configuration | +| [13] | `version` | Defines the [PROXY protocol](../services/index.md#proxy-protocol) version | +| [14] | `tls` | Defines [TLS](../routers/index.md#tls_1) certificate configuration | +| [15] | `tls.secretName` | Defines the [secret](https://kubernetes.io/docs/concepts/configuration/secret/) name used to store the certificate (in the `IngressRoute` namespace) | +| [16] | `tls.options` | Defines the reference to a [TLSOption](#kind-tlsoption) | +| [17] | `options.name` | Defines the [TLSOption](#kind-tlsoption) name | +| [18] | `options.namespace` | Defines the [TLSOption](#kind-tlsoption) namespace | +| [19] | `tls.certResolver` | Defines the reference to a [CertResolver](../routers/index.md#certresolver_1) | +| [20] | `tls.domains` | List of [domains](../routers/index.md#domains_1) | +| [21] | `domains[n].main` | Defines the main domain name | +| [22] | `domains[n].sans` | List of SANs (alternative domains) | +| [23] | `tls.passthrough` | If `true`, delegates the TLS termination to the backend | ??? example "Declaring an IngressRouteTCP" @@ -1151,6 +1153,7 @@ Register the `IngressRouteTCP` [kind](../../reference/dynamic-configuration/kube routes: # Match is the rule corresponding to an underlying router. - match: HostSNI(`*`) + priority: 10 services: - name: foo port: 8080 diff --git a/docs/content/routing/providers/kv.md b/docs/content/routing/providers/kv.md index 34ca6d38f..5fb1c0c6c 100644 --- a/docs/content/routing/providers/kv.md +++ b/docs/content/routing/providers/kv.md @@ -366,7 +366,6 @@ You can declare TCP Routers and/or Services using KV. | Key (Path) | Value | |-----------------------------------------------|----------| | `traefik/tcp/routers/mytcprouter/tls/options` | `foobar` | - ??? info "`traefik/tcp/routers//tls/passthrough`" @@ -376,6 +375,14 @@ You can declare TCP Routers and/or Services using KV. |---------------------------------------------------|--------| | `traefik/tcp/routers/mytcprouter/tls/passthrough` | `true` | +??? info "`traefik/tcp/routers//priority`" + + See [priority](../routers/index.md#priority_1) for more information. + + | Key (Path) | Value | + |------------------------------------------|-------| + | `traefik/tcp/routers/myrouter/priority` | `42` | + #### TCP Services ??? info "`traefik/tcp/services//loadbalancer/servers//url`" diff --git a/docs/content/routing/providers/marathon.md b/docs/content/routing/providers/marathon.md index 5591981fb..7d5d17371 100644 --- a/docs/content/routing/providers/marathon.md +++ b/docs/content/routing/providers/marathon.md @@ -412,6 +412,14 @@ You can declare TCP Routers and/or Services using labels. "traefik.tcp.routers.mytcprouter.tls.passthrough": "true" ``` +??? info "`traefik.tcp.routers..priority`" + + See [priority](../routers/index.md#priority_1) for more information. + + ```json + "traefik.tcp.routers.myrouter.priority": "42" + ``` + #### TCP Services ??? info "`traefik.tcp.services..loadbalancer.server.port`" diff --git a/docs/content/routing/providers/rancher.md b/docs/content/routing/providers/rancher.md index 1a5aa0444..ed751b6ea 100644 --- a/docs/content/routing/providers/rancher.md +++ b/docs/content/routing/providers/rancher.md @@ -415,6 +415,14 @@ You can declare TCP Routers and/or Services using labels. - "traefik.tcp.routers.mytcprouter.tls.passthrough=true" ``` +??? info "`traefik.tcp.routers..priority`" + + See [priority](../routers/index.md#priority_1) for more information. + + ```yaml + - "traefik.tcp.routers.myrouter.priority=42" + ``` + #### TCP Services ??? info "`traefik.tcp.services..loadbalancer.server.port`" diff --git a/docs/content/routing/routers/index.md b/docs/content/routing/routers/index.md index 4abc03a06..7b52fbc97 100644 --- a/docs/content/routing/routers/index.md +++ b/docs/content/routing/routers/index.md @@ -212,7 +212,7 @@ If the rule is verified, the router becomes active, calls middlewares, and then ??? tip "Backticks or Quotes?" To set the value of a rule, use [backticks](https://en.wiktionary.org/wiki/backtick) ``` ` ``` or escaped double-quotes `\"`. - Single quotes `'` are not accepted as values are [Golang's String Literals](https://golang.org/ref/spec#String_literals). + Single quotes `'` are not accepted since the values are [Golang's String Literals](https://golang.org/ref/spec#String_literals). !!! example "Host is example.com" @@ -257,11 +257,12 @@ The table below lists all the available matchers: !!! info "Combining Matchers Using Operators and Parenthesis" - You can combine multiple matchers using the AND (`&&`) and OR (`||`) operators. You can also use parenthesis. + The usual AND (`&&`) and OR (`||`) logical operators can be used, with the expected precedence rules, + as well as parentheses. -!!! info "Invert a matcher" +!!! info "Inverting a matcher" - You can invert a matcher by using the `!` operator. + One can invert a matcher by using the `!` operator. !!! important "Rule, Middleware, and Services" @@ -795,9 +796,33 @@ If you want to limit the router scope to a set of entry points, set the entry po ### Rule -| Rule | Description | -|--------------------------------|-------------------------------------------------------------------------| -| ```HostSNI(`domain-1`, ...)``` | Check if the Server Name Indication corresponds to the given `domains`. | +Rules are a set of matchers configured with values, that determine if a particular request matches specific criteria. +If the rule is verified, the router becomes active, calls middlewares, and then forwards the request to the service. + +??? tip "Backticks or Quotes?" + + To set the value of a rule, use [backticks](https://en.wiktionary.org/wiki/backtick) ``` ` ``` or escaped double-quotes `\"`. + + Single quotes `'` are not accepted since the values are [Golang's String Literals](https://golang.org/ref/spec#String_literals). + +!!! example "HostSNI is example.com" + + ```toml + rule = "HostSNI(`example.com`)" + ``` + +!!! example "HostSNI is example.com OR HostSNI is example.org AND ClientIP is 0.0.0.0" + + ```toml + rule = "HostSNI(`example.com`) || (HostSNI(`example.org`) && ClientIP(`0.0.0.0`))" + ``` + +The table below lists all the available matchers: + +| Rule | Description | +|---------------------------------------------|-----------------------------------------------------------------------------------------------------------| +| ```HostSNI(`domain-1`, ...)``` | Check if the Server Name Indication corresponds to the given `domains`. | +| ```ClientIP(`10.0.0.0/16`, `::1`)``` | Check if the request client IP is one of the given IP/CIDR. It accepts IPv4, IPv6 and CIDR formats. | !!! important "Non-ASCII Domain Names" @@ -808,7 +833,101 @@ If you want to limit the router scope to a set of entry points, set the entry po It is important to note that the Server Name Indication is an extension of the TLS protocol. Hence, only TLS routers will be able to specify a domain name with that rule. - However, non-TLS routers will have to explicitly use that rule with `*` (every domain) to state that every non-TLS request will be handled by the router. + However, there is one special use case for HostSNI with non-TLS routers: + when one wants a non-TLS router that matches all (non-TLS) requests, + one should use the specific `HostSNI(*)` syntax. + +!!! info "Combining Matchers Using Operators and Parenthesis" + + The usual AND (`&&`) and OR (`||`) logical operators can be used, with the expected precedence rules, + as well as parentheses. + +!!! info "Inverting a matcher" + + One can invert a matcher by using the `!` operator. + +!!! important "Rule, Middleware, and Services" + + The rule is evaluated "before" any middleware has the opportunity to work, and "before" the request is forwarded to the service. + +### Priority + +To avoid path overlap, routes are sorted, by default, in descending order using rules length. +The priority is directly equal to the length of the rule, and so the longest length has the highest priority. + +A value of `0` for the priority is ignored: `priority = 0` means that the default rules length sorting is used. + +??? info "How default priorities are computed" + + ```yaml tab="File (YAML)" + ## Dynamic configuration + tcp: + routers: + Router-1: + rule: "ClientIP(`192.168.0.12`)" + # ... + Router-2: + rule: "ClientIP(`192.168.0.0/24`)" + # ... + ``` + + ```toml tab="File (TOML)" + ## Dynamic configuration + [tcp.routers] + [tcp.routers.Router-1] + rule = "ClientIP(`192.168.0.12`)" + # ... + [tcp.routers.Router-2] + rule = "ClientIP(`192.168.0.0/24`)" + # ... + ``` + + The table below shows that `Router-2` has a higher computed priority than `Router-1`. + + | Name | Rule | Priority | + |----------|-------------------------------------------------------------|----------| + | Router-1 | ```ClientIP(`192.168.0.12`)``` | 24 | + | Router-2 | ```ClientIP(`192.168.0.0/24`)``` | 26 | + + Which means that requests from `192.168.0.12` would go to Router-2 even though Router-1 is intended to specifically handle them. + To achieve this intention, a priority (higher than 26) should be set on Router-1. + +??? example "Setting priorities -- using the [File Provider](../../providers/file.md)" + + ```yaml tab="File (YAML)" + ## Dynamic configuration + tcp: + routers: + Router-1: + rule: "ClientIP(`192.168.0.12`)" + entryPoints: + - "web" + service: service-1 + priority: 2 + Router-2: + rule: "ClientIP(`192.168.0.0/24`)" + entryPoints: + - "web" + priority: 1 + service: service-2 + ``` + + ```toml tab="File (TOML)" + ## Dynamic configuration + [tcp.routers] + [tcp.routers.Router-1] + rule = "ClientIP(`192.168.0.12`)" + entryPoints = ["web"] + service = "service-1" + priority = 2 + [tcp.routers.Router-2] + rule = "ClientIP(`192.168.0.0/24`)" + entryPoints = ["web"] + priority = 1 + service = "service-2" + ``` + + In this configuration, the priority is configured so that `Router-1` will handle requests from `192.168.0.12`. ### Middlewares diff --git a/integration/fixtures/k8s/01-traefik-crd.yml b/integration/fixtures/k8s/01-traefik-crd.yml index 0daee3add..51c471a3d 100644 --- a/integration/fixtures/k8s/01-traefik-crd.yml +++ b/integration/fixtures/k8s/01-traefik-crd.yml @@ -260,6 +260,8 @@ spec: - name type: object type: array + priority: + type: integer services: items: description: ServiceTCP defines an upstream to proxy traffic. diff --git a/integration/simple_test.go b/integration/simple_test.go index 5f36eabf6..f16efb677 100644 --- a/integration/simple_test.go +++ b/integration/simple_test.go @@ -663,7 +663,7 @@ func (s *SimpleSuite) TestTCPRouterConfigErrors(c *check.C) { c.Assert(err, checker.IsNil) // router4 has an unsupported Rule - err = try.GetRequest("http://127.0.0.1:8080/api/tcp/routers/router4@file", 1000*time.Millisecond, try.BodyContains("unknown rule Host(`mydomain.com`)")) + err = try.GetRequest("http://127.0.0.1:8080/api/tcp/routers/router4@file", 1000*time.Millisecond, try.BodyContains("invalid rule: \\\"Host(`mydomain.com`)\\\"")) c.Assert(err, checker.IsNil) } diff --git a/pkg/config/dynamic/tcp_config.go b/pkg/config/dynamic/tcp_config.go index 270f490d1..71042dc0b 100644 --- a/pkg/config/dynamic/tcp_config.go +++ b/pkg/config/dynamic/tcp_config.go @@ -52,6 +52,7 @@ type TCPRouter struct { Middlewares []string `json:"middlewares,omitempty" toml:"middlewares,omitempty" yaml:"middlewares,omitempty" export:"true"` Service string `json:"service,omitempty" toml:"service,omitempty" yaml:"service,omitempty" export:"true"` Rule string `json:"rule,omitempty" toml:"rule,omitempty" yaml:"rule,omitempty"` + Priority int `json:"priority,omitempty" toml:"priority,omitempty,omitzero" yaml:"priority,omitempty" export:"true"` TLS *RouterTCPTLSConfig `json:"tls,omitempty" toml:"tls,omitempty" yaml:"tls,omitempty" label:"allowEmpty" file:"allowEmpty" export:"true"` } diff --git a/pkg/config/label/label_test.go b/pkg/config/label/label_test.go index 4a2126f85..612120fb0 100644 --- a/pkg/config/label/label_test.go +++ b/pkg/config/label/label_test.go @@ -176,11 +176,13 @@ func TestDecodeConfiguration(t *testing.T) { "traefik.tcp.middlewares.Middleware0.ipwhitelist.sourcerange": "foobar, fiibar", "traefik.tcp.middlewares.Middleware2.inflightconn.amount": "42", "traefik.tcp.routers.Router0.rule": "foobar", + "traefik.tcp.routers.Router0.priority": "42", "traefik.tcp.routers.Router0.entrypoints": "foobar, fiibar", "traefik.tcp.routers.Router0.service": "foobar", "traefik.tcp.routers.Router0.tls.passthrough": "false", "traefik.tcp.routers.Router0.tls.options": "foo", "traefik.tcp.routers.Router1.rule": "foobar", + "traefik.tcp.routers.Router1.priority": "42", "traefik.tcp.routers.Router1.entrypoints": "foobar, fiibar", "traefik.tcp.routers.Router1.service": "foobar", "traefik.tcp.routers.Router1.tls.options": "foo", @@ -211,8 +213,9 @@ func TestDecodeConfiguration(t *testing.T) { "foobar", "fiibar", }, - Service: "foobar", - Rule: "foobar", + Service: "foobar", + Rule: "foobar", + Priority: 42, TLS: &dynamic.RouterTCPTLSConfig{ Passthrough: false, Options: "foo", @@ -223,8 +226,9 @@ func TestDecodeConfiguration(t *testing.T) { "foobar", "fiibar", }, - Service: "foobar", - Rule: "foobar", + Service: "foobar", + Rule: "foobar", + Priority: 42, TLS: &dynamic.RouterTCPTLSConfig{ Passthrough: false, Options: "foo", @@ -699,8 +703,9 @@ func TestEncodeConfiguration(t *testing.T) { "foobar", "fiibar", }, - Service: "foobar", - Rule: "foobar", + Service: "foobar", + Rule: "foobar", + Priority: 42, TLS: &dynamic.RouterTCPTLSConfig{ Passthrough: false, Options: "foo", @@ -711,8 +716,9 @@ func TestEncodeConfiguration(t *testing.T) { "foobar", "fiibar", }, - Service: "foobar", - Rule: "foobar", + Service: "foobar", + Rule: "foobar", + Priority: 42, TLS: &dynamic.RouterTCPTLSConfig{ Passthrough: false, Options: "foo", @@ -1333,11 +1339,13 @@ func TestEncodeConfiguration(t *testing.T) { "traefik.TCP.Middlewares.Middleware0.IPWhiteList.SourceRange": "foobar, fiibar", "traefik.TCP.Middlewares.Middleware2.InFlightConn.Amount": "42", "traefik.TCP.Routers.Router0.Rule": "foobar", + "traefik.TCP.Routers.Router0.Priority": "42", "traefik.TCP.Routers.Router0.EntryPoints": "foobar, fiibar", "traefik.TCP.Routers.Router0.Service": "foobar", "traefik.TCP.Routers.Router0.TLS.Passthrough": "false", "traefik.TCP.Routers.Router0.TLS.Options": "foo", "traefik.TCP.Routers.Router1.Rule": "foobar", + "traefik.TCP.Routers.Router1.Priority": "42", "traefik.TCP.Routers.Router1.EntryPoints": "foobar, fiibar", "traefik.TCP.Routers.Router1.Service": "foobar", "traefik.TCP.Routers.Router1.TLS.Passthrough": "false", diff --git a/pkg/rules/rules.go b/pkg/muxer/http/mux.go similarity index 71% rename from pkg/rules/rules.go rename to pkg/muxer/http/mux.go index 3f709eae7..6c57d8db8 100644 --- a/pkg/rules/rules.go +++ b/pkg/muxer/http/mux.go @@ -1,4 +1,4 @@ -package rules +package http import ( "fmt" @@ -10,11 +10,14 @@ import ( "github.com/traefik/traefik/v2/pkg/ip" "github.com/traefik/traefik/v2/pkg/log" "github.com/traefik/traefik/v2/pkg/middlewares/requestdecorator" + "github.com/traefik/traefik/v2/pkg/rules" "github.com/vulcand/predicate" ) -var funcs = map[string]func(*mux.Route, ...string) error{ - "Host": host, +const hostMatcher = "Host" + +var httpFuncs = map[string]func(*mux.Route, ...string) error{ + hostMatcher: host, "HostHeader": host, "HostRegexp": hostRegexp, "ClientIP": clientIP, @@ -26,33 +29,38 @@ var funcs = map[string]func(*mux.Route, ...string) error{ "Query": query, } -// Router handle routing with rules. -type Router struct { +// Muxer handles routing with rules. +type Muxer struct { *mux.Router parser predicate.Parser } -// NewRouter returns a new router instance. -func NewRouter() (*Router, error) { - parser, err := newParser() +// NewMuxer returns a new muxer instance. +func NewMuxer() (*Muxer, error) { + var matchers []string + for matcher := range httpFuncs { + matchers = append(matchers, matcher) + } + + parser, err := rules.NewParser(matchers) if err != nil { return nil, err } - return &Router{ + return &Muxer{ Router: mux.NewRouter().SkipClean(true), parser: parser, }, nil } // AddRoute add a new route to the router. -func (r *Router) AddRoute(rule string, priority int, handler http.Handler) error { +func (r *Muxer) AddRoute(rule string, priority int, handler http.Handler) error { parse, err := r.parser.Parse(rule) if err != nil { return fmt.Errorf("error while parsing rule %s: %w", rule, err) } - buildTree, ok := parse.(treeBuilder) + buildTree, ok := parse.(rules.TreeBuilder) if !ok { return fmt.Errorf("error while parsing rule %s", rule) } @@ -72,23 +80,40 @@ func (r *Router) AddRoute(rule string, priority int, handler http.Handler) error return nil } -type tree struct { - matcher string - not bool - value []string - ruleLeft *tree - ruleRight *tree +// ParseDomains extract domains from rule. +func ParseDomains(rule string) ([]string, error) { + var matchers []string + for matcher := range httpFuncs { + matchers = append(matchers, matcher) + } + + parser, err := rules.NewParser(matchers) + if err != nil { + return nil, err + } + + parse, err := parser.Parse(rule) + if err != nil { + return nil, err + } + + buildTree, ok := parse.(rules.TreeBuilder) + if !ok { + return nil, fmt.Errorf("error while parsing rule %s", rule) + } + + return buildTree().ParseMatchers([]string{hostMatcher}), nil } func path(route *mux.Route, paths ...string) error { rt := route.Subrouter() for _, path := range paths { - tmpRt := rt.Path(path) - if tmpRt.GetError() != nil { - return tmpRt.GetError() + if err := rt.Path(path).GetError(); err != nil { + return err } } + return nil } @@ -96,11 +121,11 @@ func pathPrefix(route *mux.Route, paths ...string) error { rt := route.Subrouter() for _, path := range paths { - tmpRt := rt.PathPrefix(path) - if tmpRt.GetError() != nil { - return tmpRt.GetError() + if err := rt.PathPrefix(path).GetError(); err != nil { + return err } } + return nil } @@ -220,33 +245,34 @@ func query(route *mux.Route, query ...string) error { return route.GetError() } -func addRuleOnRouter(router *mux.Router, rule *tree) error { - switch rule.matcher { +func addRuleOnRouter(router *mux.Router, rule *rules.Tree) error { + switch rule.Matcher { case "and": route := router.NewRoute() - err := addRuleOnRoute(route, rule.ruleLeft) + err := addRuleOnRoute(route, rule.RuleLeft) if err != nil { return err } - return addRuleOnRoute(route, rule.ruleRight) + return addRuleOnRoute(route, rule.RuleRight) case "or": - err := addRuleOnRouter(router, rule.ruleLeft) + err := addRuleOnRouter(router, rule.RuleLeft) if err != nil { return err } - return addRuleOnRouter(router, rule.ruleRight) + return addRuleOnRouter(router, rule.RuleRight) default: - err := checkRule(rule) + err := rules.CheckRule(rule) if err != nil { return err } - if rule.not { - return not(funcs[rule.matcher])(router.NewRoute(), rule.value...) + if rule.Not { + return not(httpFuncs[rule.Matcher])(router.NewRoute(), rule.Value...) } - return funcs[rule.matcher](router.NewRoute(), rule.value...) + + return httpFuncs[rule.Matcher](router.NewRoute(), rule.Value...) } } @@ -264,48 +290,36 @@ func not(m func(*mux.Route, ...string) error) func(*mux.Route, ...string) error } } -func addRuleOnRoute(route *mux.Route, rule *tree) error { - switch rule.matcher { +func addRuleOnRoute(route *mux.Route, rule *rules.Tree) error { + switch rule.Matcher { case "and": - err := addRuleOnRoute(route, rule.ruleLeft) + err := addRuleOnRoute(route, rule.RuleLeft) if err != nil { return err } - return addRuleOnRoute(route, rule.ruleRight) + return addRuleOnRoute(route, rule.RuleRight) case "or": subRouter := route.Subrouter() - err := addRuleOnRouter(subRouter, rule.ruleLeft) + err := addRuleOnRouter(subRouter, rule.RuleLeft) if err != nil { return err } - return addRuleOnRouter(subRouter, rule.ruleRight) + return addRuleOnRouter(subRouter, rule.RuleRight) default: - err := checkRule(rule) + err := rules.CheckRule(rule) if err != nil { return err } - if rule.not { - return not(funcs[rule.matcher])(route, rule.value...) + if rule.Not { + return not(httpFuncs[rule.Matcher])(route, rule.Value...) } - return funcs[rule.matcher](route, rule.value...) - } -} -func checkRule(rule *tree) error { - if len(rule.value) == 0 { - return fmt.Errorf("no args for matcher %s", rule.matcher) + return httpFuncs[rule.Matcher](route, rule.Value...) } - - for _, v := range rule.value { - if len(v) == 0 { - return fmt.Errorf("empty args for matcher %s, %v", rule.matcher, rule.value) - } - } - return nil } // IsASCII checks if the given string contains only ASCII characters. diff --git a/pkg/rules/rules_test.go b/pkg/muxer/http/mux_test.go similarity index 94% rename from pkg/rules/rules_test.go rename to pkg/muxer/http/mux_test.go index ea468b5d1..930364ba6 100644 --- a/pkg/rules/rules_test.go +++ b/pkg/muxer/http/mux_test.go @@ -1,4 +1,4 @@ -package rules +package http import ( "net/http" @@ -635,10 +635,10 @@ func Test_addRoute(t *testing.T) { t.Parallel() handler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {}) - router, err := NewRouter() + muxer, err := NewMuxer() require.NoError(t, err) - err = router.AddRoute(test.rule, 0, handler) + err = muxer.AddRoute(test.rule, 0, handler) if test.expectedError { require.Error(t, err) } else { @@ -659,7 +659,7 @@ func Test_addRoute(t *testing.T) { for key, value := range test.headers { req.Header.Set(key, value) } - reqHost.ServeHTTP(w, req, router.ServeHTTP) + reqHost.ServeHTTP(w, req, muxer.ServeHTTP) results[calledURL] = w.Code } assert.Equal(t, test.expected, results) @@ -787,7 +787,7 @@ func Test_addRoutePriority(t *testing.T) { test := test t.Run(test.desc, func(t *testing.T) { t.Parallel() - router, err := NewRouter() + muxer, err := NewMuxer() require.NoError(t, err) for _, route := range test.cases { @@ -796,16 +796,16 @@ func Test_addRoutePriority(t *testing.T) { w.Header().Set("X-From", route.xFrom) }) - err := router.AddRoute(route.rule, route.priority, handler) + err := muxer.AddRoute(route.rule, route.priority, handler) require.NoError(t, err, route.rule) } - router.SortRoutes() + muxer.SortRoutes() w := httptest.NewRecorder() req := testhelpers.MustNewRequest(http.MethodGet, test.path, nil) - router.ServeHTTP(w, req) + muxer.ServeHTTP(w, req) assert.Equal(t, test.expected, w.Header().Get("X-From")) }) @@ -900,44 +900,42 @@ func TestParseDomains(t *testing.T) { errorExpected bool }{ { - description: "Many host rules", - expression: "Host(`foo.bar`,`test.bar`)", - domain: []string{"foo.bar", "test.bar"}, - errorExpected: false, + description: "Unknown rule", + expression: "Foobar(`foo.bar`,`test.bar`)", + errorExpected: true, }, { - description: "Many host rules upper", - expression: "HOST(`foo.bar`,`test.bar`)", - domain: []string{"foo.bar", "test.bar"}, - errorExpected: false, + description: "Several host rules", + expression: "Host(`foo.bar`,`test.bar`)", + domain: []string{"foo.bar", "test.bar"}, }, { - description: "Many host rules lower", - expression: "host(`foo.bar`,`test.bar`)", - domain: []string{"foo.bar", "test.bar"}, - errorExpected: false, + description: "Several host rules upper", + expression: "HOST(`foo.bar`,`test.bar`)", + domain: []string{"foo.bar", "test.bar"}, }, { - description: "No host rule", - expression: "Path(`/test`)", - errorExpected: false, + description: "Several host rules lower", + expression: "host(`foo.bar`,`test.bar`)", + domain: []string{"foo.bar", "test.bar"}, }, { - description: "Host rule and another rule", - expression: "Host(`foo.bar`) && Path(`/test`)", - domain: []string{"foo.bar"}, - errorExpected: false, + description: "No host rule", + expression: "Path(`/test`)", }, { - description: "Host rule to trim and another rule", - expression: "Host(`Foo.Bar`) && Path(`/test`)", - domain: []string{"foo.bar"}, - errorExpected: false, + description: "Host rule and another rule", + expression: "Host(`foo.bar`) && Path(`/test`)", + domain: []string{"foo.bar"}, }, { - description: "Host rule with no domain", - expression: "Host() && Path(`/test`)", - errorExpected: false, + description: "Host rule to trim and another rule", + expression: "Host(`Foo.Bar`) && Path(`/test`)", + domain: []string{"foo.bar"}, + }, + { + description: "Host rule with no domain", + expression: "Host() && Path(`/test`)", }, } diff --git a/pkg/muxer/tcp/mux.go b/pkg/muxer/tcp/mux.go new file mode 100644 index 000000000..260761a74 --- /dev/null +++ b/pkg/muxer/tcp/mux.go @@ -0,0 +1,328 @@ +package tcp + +import ( + "errors" + "fmt" + "net" + "regexp" + "sort" + "strings" + + "github.com/traefik/traefik/v2/pkg/ip" + "github.com/traefik/traefik/v2/pkg/log" + "github.com/traefik/traefik/v2/pkg/rules" + "github.com/traefik/traefik/v2/pkg/tcp" + "github.com/traefik/traefik/v2/pkg/types" + "github.com/vulcand/predicate" +) + +var tcpFuncs = map[string]func(*matchersTree, ...string) error{ + "HostSNI": hostSNI, + "ClientIP": clientIP, +} + +// ParseHostSNI extracts the HostSNIs declared in a rule. +// This is a first naive implementation used in TCP routing. +func ParseHostSNI(rule string) ([]string, error) { + var matchers []string + for matcher := range tcpFuncs { + matchers = append(matchers, matcher) + } + + parser, err := rules.NewParser(matchers) + if err != nil { + return nil, err + } + + parse, err := parser.Parse(rule) + if err != nil { + return nil, err + } + + buildTree, ok := parse.(rules.TreeBuilder) + if !ok { + return nil, fmt.Errorf("error while parsing rule %s", rule) + } + + return buildTree().ParseMatchers([]string{"HostSNI"}), nil +} + +// ConnData contains TCP connection metadata. +type ConnData struct { + serverName string + remoteIP string +} + +// NewConnData builds a connData struct from the given parameters. +func NewConnData(serverName string, conn tcp.WriteCloser) (ConnData, error) { + remoteIP, _, err := net.SplitHostPort(conn.RemoteAddr().String()) + if err != nil { + return ConnData{}, fmt.Errorf("error while parsing remote address %q: %w", conn.RemoteAddr().String(), err) + } + + // as per https://datatracker.ietf.org/doc/html/rfc6066: + // > The hostname is represented as a byte string using ASCII encoding without a trailing dot. + // so there is no need to trim a potential trailing dot + serverName = types.CanonicalDomain(serverName) + + return ConnData{ + serverName: types.CanonicalDomain(serverName), + remoteIP: remoteIP, + }, nil +} + +// Muxer defines a muxer that handles TCP routing with rules. +type Muxer struct { + routes []*route + parser predicate.Parser +} + +// NewMuxer returns a TCP muxer. +func NewMuxer() (*Muxer, error) { + var matcherNames []string + for matcherName := range tcpFuncs { + matcherNames = append(matcherNames, matcherName) + } + + parser, err := rules.NewParser(matcherNames) + if err != nil { + return nil, fmt.Errorf("error while creating rules parser: %w", err) + } + + return &Muxer{parser: parser}, nil +} + +// Match returns the handler of the first route matching the connection metadata. +func (m Muxer) Match(meta ConnData) tcp.Handler { + for _, route := range m.routes { + if route.matchers.match(meta) { + return route.handler + } + } + + return nil +} + +// AddRoute adds a new route, associated to the given handler, at the given +// priority, to the muxer. +func (m *Muxer) AddRoute(rule string, priority int, handler tcp.Handler) error { + // Special case for when the catchAll fallback is present. + // When no user-defined priority is found, the lowest computable priority minus one is used, + // in order to make the fallback the last to be evaluated. + if priority == 0 && rule == "HostSNI(`*`)" { + priority = -1 + } + + // Default value, which means the user has not set it, so we'll compute it. + if priority == 0 { + priority = len(rule) + } + + parse, err := m.parser.Parse(rule) + if err != nil { + return fmt.Errorf("error while parsing rule %s: %w", rule, err) + } + + buildTree, ok := parse.(rules.TreeBuilder) + if !ok { + return fmt.Errorf("error while parsing rule %s", rule) + } + + var matchers matchersTree + err = addRule(&matchers, buildTree()) + if err != nil { + return err + } + + newRoute := &route{ + handler: handler, + priority: priority, + matchers: matchers, + } + m.routes = append(m.routes, newRoute) + + sort.Sort(routes(m.routes)) + + return nil +} + +func addRule(tree *matchersTree, rule *rules.Tree) error { + switch rule.Matcher { + case "and", "or": + tree.operator = rule.Matcher + tree.left = &matchersTree{} + err := addRule(tree.left, rule.RuleLeft) + if err != nil { + return err + } + + tree.right = &matchersTree{} + return addRule(tree.right, rule.RuleRight) + default: + err := rules.CheckRule(rule) + if err != nil { + return err + } + + err = tcpFuncs[rule.Matcher](tree, rule.Value...) + if err != nil { + return err + } + + if rule.Not { + matcherFunc := tree.matcher + tree.matcher = func(meta ConnData) bool { + return !matcherFunc(meta) + } + } + } + + return nil +} + +// HasRoutes returns whether the muxer has routes. +func (m *Muxer) HasRoutes() bool { + return len(m.routes) > 0 +} + +// routes implements sort.Interface. +type routes []*route + +// Len implements sort.Interface. +func (r routes) Len() int { return len(r) } + +// Swap implements sort.Interface. +func (r routes) Swap(i, j int) { r[i], r[j] = r[j], r[i] } + +// Less implements sort.Interface. +func (r routes) Less(i, j int) bool { return r[i].priority > r[j].priority } + +// route holds the matchers to match TCP route, +// and the handler that will serve the connection. +type route struct { + // matchers tree structure reflecting the rule. + matchers matchersTree + // handler responsible for handling the route. + handler tcp.Handler + + // Used to disambiguate between two (or more) rules that would both match for a + // given request. + // Computed from the matching rule length, if not user-set. + priority int +} + +// matcher is a matcher func used to match connection properties. +type matcher func(meta ConnData) bool + +// matchersTree represents the matchers tree structure. +type matchersTree struct { + // If matcher is not nil, it means that this matcherTree is a leaf of the tree. + // It is therefore mutually exclusive with left and right. + matcher matcher + // operator to combine the evaluation of left and right leaves. + operator string + // Mutually exclusive with matcher. + left *matchersTree + right *matchersTree +} + +func (m *matchersTree) match(meta ConnData) bool { + if m == nil { + // This should never happen as it should have been detected during parsing. + log.WithoutContext().Warnf("Rule matcher is nil") + return false + } + + if m.matcher != nil { + return m.matcher(meta) + } + + switch m.operator { + case "or": + return m.left.match(meta) || m.right.match(meta) + case "and": + return m.left.match(meta) && m.right.match(meta) + default: + // This should never happen as it should have been detected during parsing. + log.WithoutContext().Warnf("Invalid rule operator %s", m.operator) + return false + } +} + +func clientIP(tree *matchersTree, clientIPs ...string) error { + checker, err := ip.NewChecker(clientIPs) + if err != nil { + return fmt.Errorf("could not initialize IP Checker for \"ClientIP\" matcher: %w", err) + } + + tree.matcher = func(meta ConnData) bool { + if meta.remoteIP == "" { + return false + } + + ok, err := checker.Contains(meta.remoteIP) + if err != nil { + log.WithoutContext().Warnf("\"ClientIP\" matcher: could not match remote address: %v", err) + return false + } + return ok + } + + return nil +} + +var almostFQDN = regexp.MustCompile(`^[[:alnum:]\.-]+$`) + +// hostSNI checks if the SNI Host of the connection match the matcher host. +func hostSNI(tree *matchersTree, hosts ...string) error { + if len(hosts) == 0 { + return errors.New("empty value for \"HostSNI\" matcher is not allowed") + } + + for i, host := range hosts { + // Special case to allow global wildcard + if host == "*" { + continue + } + + if !almostFQDN.MatchString(host) { + return fmt.Errorf("invalid value for \"HostSNI\" matcher, %q is not a valid hostname", host) + } + + hosts[i] = strings.ToLower(host) + } + + tree.matcher = func(meta ConnData) bool { + // Since a HostSNI(`*`) rule has been provided as catchAll for non-TLS TCP, + // it allows matching with an empty serverName. + // Which is why we make sure to take that case into account before + // checking meta.serverName. + if hosts[0] == "*" { + return true + } + + if meta.serverName == "" { + return false + } + + for _, host := range hosts { + if host == "*" { + return true + } + + if host == meta.serverName { + return true + } + + // trim trailing period in case of FQDN + host = strings.TrimSuffix(host, ".") + if host == meta.serverName { + return true + } + } + + return false + } + + return nil +} diff --git a/pkg/muxer/tcp/mux_test.go b/pkg/muxer/tcp/mux_test.go new file mode 100644 index 000000000..44be7c792 --- /dev/null +++ b/pkg/muxer/tcp/mux_test.go @@ -0,0 +1,776 @@ +package tcp + +import ( + "net" + "testing" + "time" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + "github.com/traefik/traefik/v2/pkg/tcp" +) + +type fakeConn struct { + call map[string]int + remoteAddr net.Addr +} + +func (f *fakeConn) Read(b []byte) (n int, err error) { + panic("implement me") +} + +func (f *fakeConn) Write(b []byte) (n int, err error) { + f.call[string(b)]++ + return len(b), nil +} + +func (f *fakeConn) Close() error { + panic("implement me") +} + +func (f *fakeConn) LocalAddr() net.Addr { + panic("implement me") +} + +func (f *fakeConn) RemoteAddr() net.Addr { + return f.remoteAddr +} + +func (f *fakeConn) SetDeadline(t time.Time) error { + panic("implement me") +} + +func (f *fakeConn) SetReadDeadline(t time.Time) error { + panic("implement me") +} + +func (f *fakeConn) SetWriteDeadline(t time.Time) error { + panic("implement me") +} + +func (f *fakeConn) CloseWrite() error { + panic("implement me") +} + +func Test_addTCPRoute(t *testing.T) { + testCases := []struct { + desc string + rule string + serverName string + remoteAddr string + routeErr bool + matchErr bool + }{ + { + desc: "no tree", + routeErr: true, + }, + { + desc: "Rule with no matcher", + rule: "rulewithnotmatcher", + routeErr: true, + }, + { + desc: "Empty HostSNI rule", + rule: "HostSNI()", + serverName: "foobar", + routeErr: true, + }, + { + desc: "Empty HostSNI rule", + rule: "HostSNI(``)", + serverName: "foobar", + routeErr: true, + }, + { + desc: "Valid HostSNI rule matching", + rule: "HostSNI(`foobar`)", + serverName: "foobar", + }, + { + desc: "Valid negative HostSNI rule matching", + rule: "!HostSNI(`bar`)", + serverName: "foobar", + }, + { + desc: "Valid HostSNI rule matching with alternative case", + rule: "hostsni(`foobar`)", + serverName: "foobar", + }, + { + desc: "Valid HostSNI rule matching with alternative case", + rule: "HOSTSNI(`foobar`)", + serverName: "foobar", + }, + { + desc: "Valid HostSNI rule not matching", + rule: "HostSNI(`foobar`)", + serverName: "bar", + matchErr: true, + }, + { + desc: "Valid negative HostSNI rule not matching", + rule: "!HostSNI(`bar`)", + serverName: "bar", + matchErr: true, + }, + { + desc: "Empty ClientIP rule", + rule: "ClientIP()", + routeErr: true, + }, + { + desc: "Empty ClientIP rule", + rule: "ClientIP(``)", + routeErr: true, + }, + { + desc: "Invalid ClientIP", + rule: "ClientIP(`invalid`)", + routeErr: true, + }, + { + desc: "Invalid remoteAddr", + rule: "ClientIP(`10.0.0.1`)", + remoteAddr: "not.an.IP:80", + matchErr: true, + }, + { + desc: "Valid ClientIP rule matching", + rule: "ClientIP(`10.0.0.1`)", + remoteAddr: "10.0.0.1:80", + }, + { + desc: "Valid negative ClientIP rule matching", + rule: "!ClientIP(`20.0.0.1`)", + remoteAddr: "10.0.0.1:80", + }, + { + desc: "Valid ClientIP rule matching with alternative case", + rule: "clientip(`10.0.0.1`)", + remoteAddr: "10.0.0.1:80", + }, + { + desc: "Valid ClientIP rule matching with alternative case", + rule: "CLIENTIP(`10.0.0.1`)", + remoteAddr: "10.0.0.1:80", + }, + { + desc: "Valid ClientIP rule not matching", + rule: "ClientIP(`10.0.0.1`)", + remoteAddr: "10.0.0.2:80", + matchErr: true, + }, + { + desc: "Valid negative ClientIP rule not matching", + rule: "!ClientIP(`10.0.0.2`)", + remoteAddr: "10.0.0.2:80", + matchErr: true, + }, + { + desc: "Valid ClientIP rule matching IPv6", + rule: "ClientIP(`10::10`)", + remoteAddr: "[10::10]:80", + }, + { + desc: "Valid negative ClientIP rule matching IPv6", + rule: "!ClientIP(`10::10`)", + remoteAddr: "[::1]:80", + }, + { + desc: "Valid ClientIP rule not matching IPv6", + rule: "ClientIP(`10::10`)", + remoteAddr: "[::1]:80", + matchErr: true, + }, + { + desc: "Valid ClientIP rule matching multiple IPs", + rule: "ClientIP(`10.0.0.1`, `10.0.0.0`)", + remoteAddr: "10.0.0.0:80", + }, + { + desc: "Valid ClientIP rule matching CIDR", + rule: "ClientIP(`11.0.0.0/24`)", + remoteAddr: "11.0.0.0:80", + }, + { + desc: "Valid ClientIP rule not matching CIDR", + rule: "ClientIP(`11.0.0.0/24`)", + remoteAddr: "10.0.0.0:80", + matchErr: true, + }, + { + desc: "Valid ClientIP rule matching CIDR IPv6", + rule: "ClientIP(`11::/16`)", + remoteAddr: "[11::]:80", + }, + { + desc: "Valid ClientIP rule not matching CIDR IPv6", + rule: "ClientIP(`11::/16`)", + remoteAddr: "[10::]:80", + matchErr: true, + }, + { + desc: "Valid ClientIP rule matching multiple CIDR", + rule: "ClientIP(`11.0.0.0/16`, `10.0.0.0/16`)", + remoteAddr: "10.0.0.0:80", + }, + { + desc: "Valid ClientIP rule not matching CIDR and matching IP", + rule: "ClientIP(`11.0.0.0/16`, `10.0.0.0`)", + remoteAddr: "10.0.0.0:80", + }, + { + desc: "Valid ClientIP rule matching CIDR and not matching IP", + rule: "ClientIP(`11.0.0.0`, `10.0.0.0/16`)", + remoteAddr: "10.0.0.0:80", + }, + { + desc: "Valid HostSNI and ClientIP rule matching", + rule: "HostSNI(`foobar`) && ClientIP(`10.0.0.1`)", + serverName: "foobar", + remoteAddr: "10.0.0.1:80", + }, + { + desc: "Valid negative HostSNI and ClientIP rule matching", + rule: "!HostSNI(`bar`) && ClientIP(`10.0.0.1`)", + serverName: "foobar", + remoteAddr: "10.0.0.1:80", + }, + { + desc: "Valid HostSNI and negative ClientIP rule matching", + rule: "HostSNI(`foobar`) && !ClientIP(`10.0.0.2`)", + serverName: "foobar", + remoteAddr: "10.0.0.1:80", + }, + { + desc: "Valid negative HostSNI and negative ClientIP rule matching", + rule: "!HostSNI(`bar`) && !ClientIP(`10.0.0.2`)", + serverName: "foobar", + remoteAddr: "10.0.0.1:80", + }, + { + desc: "Valid negative HostSNI or negative ClientIP rule matching", + rule: "!(HostSNI(`bar`) || ClientIP(`10.0.0.2`))", + serverName: "foobar", + remoteAddr: "10.0.0.1:80", + }, + { + desc: "Valid negative HostSNI and negative ClientIP rule matching", + rule: "!(HostSNI(`bar`) && ClientIP(`10.0.0.2`))", + serverName: "foobar", + remoteAddr: "10.0.0.2:80", + }, + { + desc: "Valid negative HostSNI and negative ClientIP rule matching", + rule: "!(HostSNI(`bar`) && ClientIP(`10.0.0.2`))", + serverName: "bar", + remoteAddr: "10.0.0.1:80", + }, + { + desc: "Valid negative HostSNI and negative ClientIP rule matching", + rule: "!(HostSNI(`bar`) && ClientIP(`10.0.0.2`))", + serverName: "bar", + remoteAddr: "10.0.0.2:80", + matchErr: true, + }, + { + desc: "Valid negative HostSNI and negative ClientIP rule matching", + rule: "!(HostSNI(`bar`) && ClientIP(`10.0.0.2`))", + serverName: "foobar", + remoteAddr: "10.0.0.1:80", + }, + { + desc: "Valid HostSNI and ClientIP rule not matching", + rule: "HostSNI(`foobar`) && ClientIP(`10.0.0.1`)", + serverName: "bar", + remoteAddr: "10.0.0.1:80", + matchErr: true, + }, + { + desc: "Valid HostSNI and ClientIP rule not matching", + rule: "HostSNI(`foobar`) && ClientIP(`10.0.0.1`)", + serverName: "foobar", + remoteAddr: "10.0.0.2:80", + matchErr: true, + }, + { + desc: "Valid HostSNI or ClientIP rule matching", + rule: "HostSNI(`foobar`) || ClientIP(`10.0.0.1`)", + serverName: "foobar", + remoteAddr: "10.0.0.1:80", + }, + { + desc: "Valid HostSNI or ClientIP rule matching", + rule: "HostSNI(`foobar`) || ClientIP(`10.0.0.1`)", + serverName: "bar", + remoteAddr: "10.0.0.1:80", + }, + { + desc: "Valid HostSNI or ClientIP rule matching", + rule: "HostSNI(`foobar`) || ClientIP(`10.0.0.1`)", + serverName: "foobar", + remoteAddr: "10.0.0.2:80", + }, + { + desc: "Valid HostSNI or ClientIP rule not matching", + rule: "HostSNI(`foobar`) || ClientIP(`10.0.0.1`)", + serverName: "bar", + remoteAddr: "10.0.0.2:80", + matchErr: true, + }, + { + desc: "Valid HostSNI x 3 OR rule matching", + rule: "HostSNI(`foobar`) || HostSNI(`foo`) || HostSNI(`bar`)", + serverName: "foobar", + }, + { + desc: "Valid HostSNI x 3 OR rule not matching", + rule: "HostSNI(`foobar`) || HostSNI(`foo`) || HostSNI(`bar`)", + serverName: "baz", + matchErr: true, + }, + { + desc: "Valid HostSNI and ClientIP Combined rule matching", + rule: "HostSNI(`foobar`) || HostSNI(`bar`) && ClientIP(`10.0.0.1`)", + serverName: "foobar", + remoteAddr: "10.0.0.2:80", + }, + { + desc: "Valid HostSNI and ClientIP Combined rule matching", + rule: "HostSNI(`foobar`) || HostSNI(`bar`) && ClientIP(`10.0.0.1`)", + serverName: "bar", + remoteAddr: "10.0.0.1:80", + }, + { + desc: "Valid HostSNI and ClientIP Combined rule not matching", + rule: "HostSNI(`foobar`) || HostSNI(`bar`) && ClientIP(`10.0.0.1`)", + serverName: "bar", + remoteAddr: "10.0.0.2:80", + matchErr: true, + }, + { + desc: "Valid HostSNI and ClientIP Combined rule not matching", + rule: "HostSNI(`foobar`) || HostSNI(`bar`) && ClientIP(`10.0.0.1`)", + serverName: "baz", + remoteAddr: "10.0.0.1:80", + matchErr: true, + }, + { + desc: "Valid HostSNI and ClientIP complex combined rule matching", + rule: "(HostSNI(`foobar`) || HostSNI(`bar`)) && (ClientIP(`10.0.0.1`) || ClientIP(`10.0.0.2`))", + serverName: "bar", + remoteAddr: "10.0.0.1:80", + }, + { + desc: "Valid HostSNI and ClientIP complex combined rule not matching", + rule: "(HostSNI(`foobar`) || HostSNI(`bar`)) && (ClientIP(`10.0.0.1`) || ClientIP(`10.0.0.2`))", + serverName: "baz", + remoteAddr: "10.0.0.1:80", + matchErr: true, + }, + { + desc: "Valid HostSNI and ClientIP complex combined rule not matching", + rule: "(HostSNI(`foobar`) || HostSNI(`bar`)) && (ClientIP(`10.0.0.1`) || ClientIP(`10.0.0.2`))", + serverName: "bar", + remoteAddr: "10.0.0.3:80", + matchErr: true, + }, + { + desc: "Valid HostSNI and ClientIP more complex (but absurd) combined rule matching", + rule: "(HostSNI(`foobar`) || (HostSNI(`bar`) && !HostSNI(`foobar`))) && ((ClientIP(`10.0.0.1`) && !ClientIP(`10.0.0.2`)) || ClientIP(`10.0.0.2`)) ", + serverName: "bar", + remoteAddr: "10.0.0.1:80", + }, + } + + for _, test := range testCases { + test := test + + t.Run(test.desc, func(t *testing.T) { + t.Parallel() + + msg := "BYTES" + handler := tcp.HandlerFunc(func(conn tcp.WriteCloser) { + _, err := conn.Write([]byte(msg)) + require.NoError(t, err) + }) + + router, err := NewMuxer() + require.NoError(t, err) + + err = router.AddRoute(test.rule, 0, handler) + if test.routeErr { + require.Error(t, err) + return + } + + require.NoError(t, err) + + addr := "0.0.0.0:0" + if test.remoteAddr != "" { + addr = test.remoteAddr + } + + conn := &fakeConn{ + call: map[string]int{}, + remoteAddr: fakeAddr{addr: addr}, + } + + connData, err := NewConnData(test.serverName, conn) + require.NoError(t, err) + + matchingHandler := router.Match(connData) + if test.matchErr { + require.Nil(t, matchingHandler) + return + } + + require.NotNil(t, matchingHandler) + + matchingHandler.ServeTCP(conn) + + n, ok := conn.call[msg] + assert.Equal(t, n, 1) + assert.True(t, ok) + }) + } +} + +type fakeAddr struct { + addr string +} + +func (f fakeAddr) String() string { + return f.addr +} + +func (f fakeAddr) Network() string { + panic("Implement me") +} + +func TestParseHostSNI(t *testing.T) { + testCases := []struct { + description string + expression string + domain []string + errorExpected bool + }{ + { + description: "Unknown rule", + expression: "Foobar(`foo.bar`,`test.bar`)", + errorExpected: true, + }, + { + description: "Many hostSNI rules", + expression: "HostSNI(`foo.bar`,`test.bar`)", + domain: []string{"foo.bar", "test.bar"}, + }, + { + description: "Many hostSNI rules upper", + expression: "HOSTSNI(`foo.bar`,`test.bar`)", + domain: []string{"foo.bar", "test.bar"}, + }, + { + description: "Many hostSNI rules lower", + expression: "hostsni(`foo.bar`,`test.bar`)", + domain: []string{"foo.bar", "test.bar"}, + }, + { + description: "No hostSNI rule", + expression: "ClientIP(`10.1`)", + }, + { + description: "HostSNI rule and another rule", + expression: "HostSNI(`foo.bar`) && ClientIP(`10.1`)", + domain: []string{"foo.bar"}, + }, + { + description: "HostSNI rule to lower and another rule", + expression: "HostSNI(`Foo.Bar`) && ClientIP(`10.1`)", + domain: []string{"foo.bar"}, + }, + { + description: "HostSNI rule with no domain", + expression: "HostSNI() && ClientIP(`10.1`)", + }, + } + + for _, test := range testCases { + test := test + t.Run(test.expression, func(t *testing.T) { + t.Parallel() + + domains, err := ParseHostSNI(test.expression) + + if test.errorExpected { + require.Errorf(t, err, "unable to parse correctly the domains in the HostSNI rule from %q", test.expression) + } else { + require.NoError(t, err, "%s: Error while parsing domain.", test.expression) + } + + assert.EqualValues(t, test.domain, domains, "%s: Error parsing domains from expression.", test.expression) + }) + } +} + +func Test_HostSNI(t *testing.T) { + testCases := []struct { + desc string + ruleHosts []string + serverName string + buildErr bool + matchErr bool + }{ + { + desc: "Empty", + buildErr: true, + }, + { + desc: "Non ASCII host", + ruleHosts: []string{"héhé"}, + buildErr: true, + }, + { + desc: "Not Matching hosts", + ruleHosts: []string{"foobar"}, + serverName: "bar", + matchErr: true, + }, + { + desc: "Matching globing host `*`", + ruleHosts: []string{"*"}, + serverName: "foobar", + }, + { + desc: "Matching globing host `*` and empty serverName", + ruleHosts: []string{"*"}, + serverName: "", + }, + { + desc: "Matching globing host `*` and another non matching host", + ruleHosts: []string{"foo", "*"}, + serverName: "bar", + }, + { + desc: "Matching globing host `*` and another non matching host, and empty servername", + ruleHosts: []string{"foo", "*"}, + serverName: "", + matchErr: true, + }, + { + desc: "Not Matching globing host with subdomain", + ruleHosts: []string{"*.bar"}, + buildErr: true, + }, + { + desc: "Not Matching host with trailing dot with ", + ruleHosts: []string{"foobar."}, + serverName: "foobar.", + }, + { + desc: "Matching host with trailing dot", + ruleHosts: []string{"foobar."}, + serverName: "foobar", + }, + { + desc: "Matching hosts", + ruleHosts: []string{"foobar"}, + serverName: "foobar", + }, + { + desc: "Matching hosts with subdomains", + ruleHosts: []string{"foo.bar"}, + serverName: "foo.bar", + }, + } + + for _, test := range testCases { + test := test + + t.Run(test.desc, func(t *testing.T) { + t.Parallel() + + matcherTree := &matchersTree{} + err := hostSNI(matcherTree, test.ruleHosts...) + if test.buildErr { + require.Error(t, err) + return + } + require.NoError(t, err) + + meta := ConnData{ + serverName: test.serverName, + } + + assert.Equal(t, test.matchErr, !matcherTree.match(meta)) + }) + } +} + +func Test_ClientIP(t *testing.T) { + testCases := []struct { + desc string + ruleCIDRs []string + remoteIP string + buildErr bool + matchErr bool + }{ + { + desc: "Empty", + buildErr: true, + }, + { + desc: "Malformed CIDR", + ruleCIDRs: []string{"héhé"}, + buildErr: true, + }, + { + desc: "Not matching empty remote IP", + ruleCIDRs: []string{"20.20.20.20"}, + matchErr: true, + }, + { + desc: "Not matching IP", + ruleCIDRs: []string{"20.20.20.20"}, + remoteIP: "10.10.10.10", + matchErr: true, + }, + { + desc: "Matching IP", + ruleCIDRs: []string{"10.10.10.10"}, + remoteIP: "10.10.10.10", + }, + { + desc: "Not matching multiple IPs", + ruleCIDRs: []string{"20.20.20.20", "30.30.30.30"}, + remoteIP: "10.10.10.10", + matchErr: true, + }, + { + desc: "Matching multiple IPs", + ruleCIDRs: []string{"10.10.10.10", "20.20.20.20", "30.30.30.30"}, + remoteIP: "20.20.20.20", + }, + { + desc: "Not matching CIDR", + ruleCIDRs: []string{"20.0.0.0/24"}, + remoteIP: "10.10.10.10", + matchErr: true, + }, + { + desc: "Matching CIDR", + ruleCIDRs: []string{"20.0.0.0/8"}, + remoteIP: "20.10.10.10", + }, + { + desc: "Not matching multiple CIDRs", + ruleCIDRs: []string{"10.0.0.0/24", "20.0.0.0/24"}, + remoteIP: "10.10.10.10", + matchErr: true, + }, + { + desc: "Matching multiple CIDRs", + ruleCIDRs: []string{"10.0.0.0/8", "20.0.0.0/8"}, + remoteIP: "20.10.10.10", + }, + } + + for _, test := range testCases { + test := test + + t.Run(test.desc, func(t *testing.T) { + t.Parallel() + + matchersTree := &matchersTree{} + err := clientIP(matchersTree, test.ruleCIDRs...) + if test.buildErr { + require.Error(t, err) + return + } + require.NoError(t, err) + + meta := ConnData{ + remoteIP: test.remoteIP, + } + + assert.Equal(t, test.matchErr, !matchersTree.match(meta)) + }) + } +} + +func Test_Priority(t *testing.T) { + testCases := []struct { + desc string + rules map[string]int + serverName string + remoteIP string + expectedRule string + }{ + { + desc: "One matching rule, calculated priority", + rules: map[string]int{ + "HostSNI(`bar`)": 0, + "HostSNI(`foobar`)": 0, + }, + expectedRule: "HostSNI(`bar`)", + serverName: "bar", + }, + { + desc: "One matching rule, custom priority", + rules: map[string]int{ + "HostSNI(`foobar`)": 0, + "HostSNI(`bar`)": 10000, + }, + expectedRule: "HostSNI(`foobar`)", + serverName: "foobar", + }, + { + desc: "Two matching rules, calculated priority", + rules: map[string]int{ + "HostSNI(`foobar`)": 0, + "HostSNI(`foobar`, `bar`)": 0, + }, + expectedRule: "HostSNI(`foobar`, `bar`)", + serverName: "foobar", + }, + { + desc: "Two matching rules, custom priority", + rules: map[string]int{ + "HostSNI(`foobar`)": 10000, + "HostSNI(`foobar`, `bar`)": 0, + }, + expectedRule: "HostSNI(`foobar`)", + serverName: "foobar", + }, + } + + for _, test := range testCases { + test := test + + t.Run(test.desc, func(t *testing.T) { + t.Parallel() + + muxer, err := NewMuxer() + require.NoError(t, err) + + matchedRule := "" + for rule, priority := range test.rules { + rule := rule + err := muxer.AddRoute(rule, priority, tcp.HandlerFunc(func(conn tcp.WriteCloser) { + matchedRule = rule + })) + require.NoError(t, err) + } + + handler := muxer.Match(ConnData{ + serverName: test.serverName, + remoteIP: test.remoteIP, + }) + require.NotNil(t, handler) + + handler.ServeTCP(nil) + assert.Equal(t, test.expectedRule, matchedRule) + }) + } +} diff --git a/pkg/provider/acme/provider.go b/pkg/provider/acme/provider.go index 6a8eda985..df02a5a32 100644 --- a/pkg/provider/acme/provider.go +++ b/pkg/provider/acme/provider.go @@ -21,7 +21,8 @@ import ( ptypes "github.com/traefik/paerser/types" "github.com/traefik/traefik/v2/pkg/config/dynamic" "github.com/traefik/traefik/v2/pkg/log" - "github.com/traefik/traefik/v2/pkg/rules" + httpmuxer "github.com/traefik/traefik/v2/pkg/muxer/http" + tcpmuxer "github.com/traefik/traefik/v2/pkg/muxer/tcp" "github.com/traefik/traefik/v2/pkg/safe" traefiktls "github.com/traefik/traefik/v2/pkg/tls" "github.com/traefik/traefik/v2/pkg/types" @@ -423,7 +424,7 @@ func (p *Provider) watchNewDomains(ctx context.Context) { }) } } else { - domains, err := rules.ParseHostSNI(route.Rule) + domains, err := tcpmuxer.ParseHostSNI(route.Rule) if err != nil { logger.Errorf("Error parsing domains in provider ACME: %v", err) continue @@ -452,7 +453,7 @@ func (p *Provider) watchNewDomains(ctx context.Context) { }) } } else { - domains, err := rules.ParseDomains(route.Rule) + domains, err := httpmuxer.ParseDomains(route.Rule) if err != nil { log.FromContext(ctxRouter).Errorf("Error parsing domains in provider ACME: %v", err) continue diff --git a/pkg/provider/kubernetes/crd/kubernetes_tcp.go b/pkg/provider/kubernetes/crd/kubernetes_tcp.go index 7ab950427..a9c134de5 100644 --- a/pkg/provider/kubernetes/crd/kubernetes_tcp.go +++ b/pkg/provider/kubernetes/crd/kubernetes_tcp.go @@ -98,6 +98,7 @@ func (p *Provider) loadIngressRouteTCPConfiguration(ctx context.Context, client EntryPoints: ingressRouteTCP.Spec.EntryPoints, Middlewares: mds, Rule: route.Match, + Priority: route.Priority, Service: serviceName, } diff --git a/pkg/provider/kubernetes/crd/traefik/v1alpha1/ingressroutetcp.go b/pkg/provider/kubernetes/crd/traefik/v1alpha1/ingressroutetcp.go index ef3af41cb..ef62542e8 100644 --- a/pkg/provider/kubernetes/crd/traefik/v1alpha1/ingressroutetcp.go +++ b/pkg/provider/kubernetes/crd/traefik/v1alpha1/ingressroutetcp.go @@ -17,6 +17,7 @@ type IngressRouteTCPSpec struct { // RouteTCP contains the set of routes. type RouteTCP struct { Match string `json:"match"` + Priority int `json:"priority,omitempty"` Services []ServiceTCP `json:"services,omitempty"` // Middlewares contains references to MiddlewareTCP resources. Middlewares []ObjectReference `json:"middlewares,omitempty"` diff --git a/pkg/rules/parser.go b/pkg/rules/parser.go index f1224aaad..b984ddb5a 100644 --- a/pkg/rules/parser.go +++ b/pkg/rules/parser.go @@ -1,7 +1,7 @@ package rules import ( - "errors" + "fmt" "strings" "github.com/vulcand/predicate" @@ -12,121 +12,29 @@ const ( or = "or" ) -type treeBuilder func() *tree +// TreeBuilder defines the type for a Tree builder. +type TreeBuilder func() *Tree -// ParseDomains extract domains from rule. -func ParseDomains(rule string) ([]string, error) { - parser, err := newParser() - if err != nil { - return nil, err - } - - parse, err := parser.Parse(rule) - if err != nil { - return nil, err - } - - buildTree, ok := parse.(treeBuilder) - if !ok { - return nil, errors.New("cannot parse") - } - - return lower(parseDomain(buildTree())), nil +// Tree represents the rules' tree structure. +type Tree struct { + Matcher string + Not bool + Value []string + RuleLeft *Tree + RuleRight *Tree } -// ParseHostSNI extracts the HostSNIs declared in a rule. -// This is a first naive implementation used in TCP routing. -func ParseHostSNI(rule string) ([]string, error) { - parser, err := newTCPParser() - if err != nil { - return nil, err - } - - parse, err := parser.Parse(rule) - if err != nil { - return nil, err - } - - buildTree, ok := parse.(treeBuilder) - if !ok { - return nil, errors.New("cannot parse") - } - - return lower(parseDomain(buildTree())), nil -} - -func lower(slice []string) []string { - var lowerStrings []string - for _, value := range slice { - lowerStrings = append(lowerStrings, strings.ToLower(value)) - } - return lowerStrings -} - -func parseDomain(tree *tree) []string { - switch tree.matcher { - case and, or: - return append(parseDomain(tree.ruleLeft), parseDomain(tree.ruleRight)...) - case "Host", "HostSNI": - return tree.value - default: - return nil - } -} - -func andFunc(left, right treeBuilder) treeBuilder { - return func() *tree { - return &tree{ - matcher: and, - ruleLeft: left(), - ruleRight: right(), - } - } -} - -func orFunc(left, right treeBuilder) treeBuilder { - return func() *tree { - return &tree{ - matcher: or, - ruleLeft: left(), - ruleRight: right(), - } - } -} - -func invert(t *tree) *tree { - switch t.matcher { - case or: - t.matcher = and - t.ruleLeft = invert(t.ruleLeft) - t.ruleRight = invert(t.ruleRight) - case and: - t.matcher = or - t.ruleLeft = invert(t.ruleLeft) - t.ruleRight = invert(t.ruleRight) - default: - t.not = !t.not - } - - return t -} - -func notFunc(elem treeBuilder) treeBuilder { - return func() *tree { - return invert(elem()) - } -} - -func newParser() (predicate.Parser, error) { +// NewParser constructs a parser for the given matchers. +func NewParser(matchers []string) (predicate.Parser, error) { parserFuncs := make(map[string]interface{}) - for matcherName := range funcs { + for _, matcherName := range matchers { matcherName := matcherName - fn := func(value ...string) treeBuilder { - return func() *tree { - return &tree{ - matcher: matcherName, - value: value, + fn := func(value ...string) TreeBuilder { + return func() *Tree { + return &Tree{ + Matcher: matcherName, + Value: value, } } } @@ -146,28 +54,85 @@ func newParser() (predicate.Parser, error) { }) } -func newTCPParser() (predicate.Parser, error) { - parserFuncs := make(map[string]interface{}) - - // FIXME quircky way of waiting for new rules - matcherName := "HostSNI" - fn := func(value ...string) treeBuilder { - return func() *tree { - return &tree{ - matcher: matcherName, - value: value, - } +func andFunc(left, right TreeBuilder) TreeBuilder { + return func() *Tree { + return &Tree{ + Matcher: and, + RuleLeft: left(), + RuleRight: right(), } } - parserFuncs[matcherName] = fn - parserFuncs[strings.ToLower(matcherName)] = fn - parserFuncs[strings.ToUpper(matcherName)] = fn - parserFuncs[strings.Title(strings.ToLower(matcherName))] = fn - - return predicate.NewParser(predicate.Def{ - Operators: predicate.Operators{ - OR: orFunc, - }, - Functions: parserFuncs, - }) +} + +func orFunc(left, right TreeBuilder) TreeBuilder { + return func() *Tree { + return &Tree{ + Matcher: or, + RuleLeft: left(), + RuleRight: right(), + } + } +} + +func invert(t *Tree) *Tree { + switch t.Matcher { + case or: + t.Matcher = and + t.RuleLeft = invert(t.RuleLeft) + t.RuleRight = invert(t.RuleRight) + case and: + t.Matcher = or + t.RuleLeft = invert(t.RuleLeft) + t.RuleRight = invert(t.RuleRight) + default: + t.Not = !t.Not + } + + return t +} + +func notFunc(elem TreeBuilder) TreeBuilder { + return func() *Tree { + return invert(elem()) + } +} + +// ParseMatchers returns the subset of matchers in the Tree matching the given matchers. +func (tree *Tree) ParseMatchers(matchers []string) []string { + switch tree.Matcher { + case and, or: + return append(tree.RuleLeft.ParseMatchers(matchers), tree.RuleRight.ParseMatchers(matchers)...) + default: + for _, matcher := range matchers { + if tree.Matcher == matcher { + return lower(tree.Value) + } + } + + return nil + } +} + +// CheckRule validates the given rule. +func CheckRule(rule *Tree) error { + if len(rule.Value) == 0 { + return fmt.Errorf("no args for matcher %s", rule.Matcher) + } + + for _, v := range rule.Value { + if len(v) == 0 { + return fmt.Errorf("empty args for matcher %s, %v", rule.Matcher, rule.Value) + } + } + + return nil +} + +func lower(slice []string) []string { + var lowerStrings []string + for _, value := range slice { + lowerStrings = append(lowerStrings, strings.ToLower(value)) + } + + return lowerStrings } diff --git a/pkg/rules/parser_test.go b/pkg/rules/parser_test.go new file mode 100644 index 000000000..503ab4298 --- /dev/null +++ b/pkg/rules/parser_test.go @@ -0,0 +1,301 @@ +package rules + +import ( + "testing" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +// testTree = Tree + CheckErr +type testTree struct { + Matcher string + Not bool + Value []string + RuleLeft *testTree + RuleRight *testTree + + // CheckErr allow knowing if a Tree has a rule error. + CheckErr bool +} + +func TestRuleMatch(t *testing.T) { + matchers := []string{"m"} + testCases := []struct { + desc string + rule string + tree testTree + matchers []string + values []string + expectParseErr bool + }{ + { + desc: "No rule", + rule: "", + expectParseErr: true, + }, + { + desc: "No matcher in rule", + rule: "m", + expectParseErr: true, + }, + { + desc: "No value in rule", + rule: "m()", + tree: testTree{ + Matcher: "m", + Value: []string{}, + CheckErr: true, + }, + }, + { + desc: "Empty value in rule", + rule: "m(``)", + tree: testTree{ + Matcher: "m", + Value: []string{""}, + CheckErr: true, + }, + matchers: []string{"m"}, + values: []string{""}, + }, + { + desc: "One value in rule with and", + rule: "m(`1`) &&", + expectParseErr: true, + }, + { + desc: "One value in rule with or", + rule: "m(`1`) ||", + expectParseErr: true, + }, + { + desc: "One value in rule with missing back tick", + rule: "m(`1)", + expectParseErr: true, + }, + { + desc: "One value in rule with missing opening parenthesis", + rule: "m(`1`))", + expectParseErr: true, + }, + { + desc: "One value in rule with missing closing parenthesis", + rule: "(m(`1`)", + expectParseErr: true, + }, + { + desc: "One value in rule", + rule: "m(`1`)", + tree: testTree{ + Matcher: "m", + Value: []string{"1"}, + }, + matchers: []string{"m"}, + values: []string{"1"}, + }, + { + desc: "One value in rule with superfluous parenthesis", + rule: "(m(`1`))", + tree: testTree{ + Matcher: "m", + Value: []string{"1"}, + }, + matchers: []string{"m"}, + values: []string{"1"}, + }, + { + desc: "Rule with CAPS matcher", + rule: "M(`1`)", + tree: testTree{ + Matcher: "m", + Value: []string{"1"}, + }, + matchers: []string{"m"}, + values: []string{"1"}, + }, + { + desc: "Invalid matcher in rule", + rule: "w(`1`)", + expectParseErr: true, + }, + { + desc: "Invalid matchers", + rule: "m(`1`)", + tree: testTree{ + Matcher: "m", + Value: []string{"1"}, + }, + matchers: []string{"not-m"}, + }, + { + desc: "Two value in rule", + rule: "m(`1`, `2`)", + tree: testTree{ + Matcher: "m", + Value: []string{"1", "2"}, + }, + matchers: []string{"m"}, + values: []string{"1", "2"}, + }, + { + desc: "Not one value in rule", + rule: "!m(`1`)", + tree: testTree{ + Matcher: "m", + Not: true, + Value: []string{"1"}, + }, + matchers: []string{"m"}, + values: []string{"1"}, + }, + { + desc: "Two value in rule with and", + rule: "m(`1`) && m(`2`)", + tree: testTree{ + Matcher: "and", + CheckErr: true, + RuleLeft: &testTree{ + Matcher: "m", + Value: []string{"1"}, + }, + RuleRight: &testTree{ + Matcher: "m", + Value: []string{"2"}, + }, + }, + matchers: []string{"m"}, + values: []string{"1", "2"}, + }, + { + desc: "Two value in rule with or", + rule: "m(`1`) || m(`2`)", + tree: testTree{ + Matcher: "or", + CheckErr: true, + RuleLeft: &testTree{ + Matcher: "m", + Value: []string{"1"}, + }, + RuleRight: &testTree{ + Matcher: "m", + Value: []string{"2"}, + }, + }, + matchers: []string{"m"}, + values: []string{"1", "2"}, + }, + { + desc: "Two value in rule with and negated", + rule: "!(m(`1`) && m(`2`))", + tree: testTree{ + Matcher: "or", + CheckErr: true, + RuleLeft: &testTree{ + Matcher: "m", + Not: true, + Value: []string{"1"}, + }, + RuleRight: &testTree{ + Matcher: "m", + Not: true, + Value: []string{"2"}, + }, + }, + matchers: []string{"m"}, + values: []string{"1", "2"}, + }, + { + desc: "Two value in rule with or negated", + rule: "!(m(`1`) || m(`2`))", + tree: testTree{ + Matcher: "and", + CheckErr: true, + RuleLeft: &testTree{ + Matcher: "m", + Not: true, + Value: []string{"1"}, + }, + RuleRight: &testTree{ + Matcher: "m", + Not: true, + Value: []string{"2"}, + }, + }, + matchers: []string{"m"}, + values: []string{"1", "2"}, + }, + { + desc: "No value in rule", + rule: "m(`1`) && m()", + tree: testTree{ + Matcher: "and", + CheckErr: true, + RuleLeft: &testTree{ + Matcher: "m", + Value: []string{"1"}, + }, + RuleRight: &testTree{ + Matcher: "m", + Value: []string{}, + CheckErr: true, + }, + }, + matchers: []string{"m"}, + values: []string{"1"}, + }, + } + + parser, err := NewParser(matchers) + require.NoError(t, err) + + for _, test := range testCases { + test := test + + t.Run(test.desc, func(t *testing.T) { + t.Parallel() + + parse, err := parser.Parse(test.rule) + if test.expectParseErr { + require.Error(t, err) + return + } + require.NoError(t, err) + + treeBuilder, ok := parse.(TreeBuilder) + require.True(t, ok) + + tree := treeBuilder() + checkEquivalence(t, &test.tree, tree) + + assert.Equal(t, test.values, tree.ParseMatchers(test.matchers)) + }) + } +} + +func checkEquivalence(t *testing.T, expected *testTree, actual *Tree) { + t.Helper() + + if actual == nil { + return + } + + if actual.RuleLeft != nil { + checkEquivalence(t, expected.RuleLeft, actual.RuleLeft) + } + + if actual.RuleRight != nil { + checkEquivalence(t, expected.RuleRight, actual.RuleRight) + } + + assert.Equal(t, expected.Matcher, actual.Matcher) + assert.Equal(t, expected.Not, actual.Not) + assert.Equal(t, expected.Value, actual.Value) + + t.Logf("%+v -> %v", actual, CheckRule(actual)) + if expected.CheckErr { + assert.Error(t, CheckRule(actual)) + } else { + assert.NoError(t, CheckRule(actual)) + } +} diff --git a/pkg/server/router/router.go b/pkg/server/router/router.go index 7f9c40d0f..6f1819dbe 100644 --- a/pkg/server/router/router.go +++ b/pkg/server/router/router.go @@ -13,7 +13,7 @@ import ( metricsMiddle "github.com/traefik/traefik/v2/pkg/middlewares/metrics" "github.com/traefik/traefik/v2/pkg/middlewares/recovery" "github.com/traefik/traefik/v2/pkg/middlewares/tracing" - "github.com/traefik/traefik/v2/pkg/rules" + httpmuxer "github.com/traefik/traefik/v2/pkg/muxer/http" "github.com/traefik/traefik/v2/pkg/server/middleware" "github.com/traefik/traefik/v2/pkg/server/provider" ) @@ -102,7 +102,7 @@ func (m *Manager) BuildHandlers(rootCtx context.Context, entryPoints []string, t } func (m *Manager) buildEntryPointHandler(ctx context.Context, configs map[string]*runtime.RouterInfo) (http.Handler, error) { - router, err := rules.NewRouter() + muxer, err := httpmuxer.NewMuxer() if err != nil { return nil, err } @@ -118,7 +118,7 @@ func (m *Manager) buildEntryPointHandler(ctx context.Context, configs map[string continue } - err = router.AddRoute(routerConfig.Rule, routerConfig.Priority, handler) + err = muxer.AddRoute(routerConfig.Rule, routerConfig.Priority, handler) if err != nil { routerConfig.AddError(err, true) logger.Error(err) @@ -126,14 +126,14 @@ func (m *Manager) buildEntryPointHandler(ctx context.Context, configs map[string } } - router.SortRoutes() + muxer.SortRoutes() chain := alice.New() chain = chain.Append(func(next http.Handler) (http.Handler, error) { return recovery.New(ctx, next) }) - return chain.Then(router) + return chain.Then(muxer) } func (m *Manager) buildRouterHandler(ctx context.Context, routerName string, routerConfig *runtime.RouterInfo) (http.Handler, error) { diff --git a/pkg/server/router/tcp/manager.go b/pkg/server/router/tcp/manager.go new file mode 100644 index 000000000..ff92f3112 --- /dev/null +++ b/pkg/server/router/tcp/manager.go @@ -0,0 +1,365 @@ +package tcp + +import ( + "context" + "crypto/tls" + "errors" + "fmt" + "net/http" + + "github.com/traefik/traefik/v2/pkg/config/runtime" + "github.com/traefik/traefik/v2/pkg/log" + "github.com/traefik/traefik/v2/pkg/middlewares/snicheck" + httpmuxer "github.com/traefik/traefik/v2/pkg/muxer/http" + tcpmuxer "github.com/traefik/traefik/v2/pkg/muxer/tcp" + "github.com/traefik/traefik/v2/pkg/server/provider" + tcpservice "github.com/traefik/traefik/v2/pkg/server/service/tcp" + "github.com/traefik/traefik/v2/pkg/tcp" + traefiktls "github.com/traefik/traefik/v2/pkg/tls" +) + +type middlewareBuilder interface { + BuildChain(ctx context.Context, names []string) *tcp.Chain +} + +// NewManager Creates a new Manager. +func NewManager(conf *runtime.Configuration, + serviceManager *tcpservice.Manager, + middlewaresBuilder middlewareBuilder, + httpHandlers map[string]http.Handler, + httpsHandlers map[string]http.Handler, + tlsManager *traefiktls.Manager, +) *Manager { + return &Manager{ + serviceManager: serviceManager, + middlewaresBuilder: middlewaresBuilder, + httpHandlers: httpHandlers, + httpsHandlers: httpsHandlers, + tlsManager: tlsManager, + conf: conf, + } +} + +// Manager is a route/router manager. +type Manager struct { + serviceManager *tcpservice.Manager + middlewaresBuilder middlewareBuilder + httpHandlers map[string]http.Handler + httpsHandlers map[string]http.Handler + tlsManager *traefiktls.Manager + conf *runtime.Configuration +} + +func (m *Manager) getTCPRouters(ctx context.Context, entryPoints []string) map[string]map[string]*runtime.TCPRouterInfo { + if m.conf != nil { + return m.conf.GetTCPRoutersByEntryPoints(ctx, entryPoints) + } + + return make(map[string]map[string]*runtime.TCPRouterInfo) +} + +func (m *Manager) getHTTPRouters(ctx context.Context, entryPoints []string, tls bool) map[string]map[string]*runtime.RouterInfo { + if m.conf != nil { + return m.conf.GetRoutersByEntryPoints(ctx, entryPoints, tls) + } + + return make(map[string]map[string]*runtime.RouterInfo) +} + +// BuildHandlers builds the handlers for the given entrypoints. +func (m *Manager) BuildHandlers(rootCtx context.Context, entryPoints []string) map[string]*Router { + entryPointsRouters := m.getTCPRouters(rootCtx, entryPoints) + entryPointsRoutersHTTP := m.getHTTPRouters(rootCtx, entryPoints, true) + + entryPointHandlers := make(map[string]*Router) + for _, entryPointName := range entryPoints { + entryPointName := entryPointName + + routers := entryPointsRouters[entryPointName] + + ctx := log.With(rootCtx, log.Str(log.EntryPointName, entryPointName)) + + handler, err := m.buildEntryPointHandler(ctx, routers, entryPointsRoutersHTTP[entryPointName], m.httpHandlers[entryPointName], m.httpsHandlers[entryPointName]) + if err != nil { + log.FromContext(ctx).Error(err) + continue + } + entryPointHandlers[entryPointName] = handler + } + return entryPointHandlers +} + +type nameAndConfig struct { + routerName string // just so we have it as additional information when logging + TLSConfig *tls.Config +} + +func (m *Manager) buildEntryPointHandler(ctx context.Context, configs map[string]*runtime.TCPRouterInfo, configsHTTP map[string]*runtime.RouterInfo, handlerHTTP, handlerHTTPS http.Handler) (*Router, error) { + // Build a new Router. + router, err := NewRouter() + if err != nil { + return nil, err + } + + router.SetHTTPHandler(handlerHTTP) + + defaultTLSConf, err := m.tlsManager.Get(traefiktls.DefaultTLSStoreName, traefiktls.DefaultTLSConfigName) + if err != nil { + log.FromContext(ctx).Errorf("Error during the build of the default TLS configuration: %v", err) + } + + // Keyed by domain. The source of truth for doing SNI checking, and for what TLS + // options will actually be used for the connection. + // As soon as there's (at least) two different tlsOptions found for the same domain, + // we set the value to the default TLS conf. + tlsOptionsForHost := map[string]string{} + + // Keyed by domain, then by options reference. + // As opposed to tlsOptionsForHost, it keeps track of all the (different) TLS + // options that occur for a given host name, so that later on we can set relevant + // errors and logging for all the routers concerned (i.e. wrongly configured). + tlsOptionsForHostSNI := map[string]map[string]nameAndConfig{} + + for routerHTTPName, routerHTTPConfig := range configsHTTP { + if routerHTTPConfig.TLS == nil { + continue + } + + ctxRouter := log.With(provider.AddInContext(ctx, routerHTTPName), log.Str(log.RouterName, routerHTTPName)) + logger := log.FromContext(ctxRouter) + + tlsOptionsName := traefiktls.DefaultTLSConfigName + if len(routerHTTPConfig.TLS.Options) > 0 && routerHTTPConfig.TLS.Options != traefiktls.DefaultTLSConfigName { + tlsOptionsName = provider.GetQualifiedName(ctxRouter, routerHTTPConfig.TLS.Options) + } + + domains, err := httpmuxer.ParseDomains(routerHTTPConfig.Rule) + if err != nil { + routerErr := fmt.Errorf("invalid rule %s, error: %w", routerHTTPConfig.Rule, err) + routerHTTPConfig.AddError(routerErr, true) + logger.Error(routerErr) + continue + } + + if len(domains) == 0 { + // Extra Host(*) rule, for HTTPS routers with no Host rule, and for requests for + // which the SNI does not match _any_ of the other existing routers Host. This is + // only about choosing the TLS configuration. The actual routing will be done + // further on by the HTTPS handler. See examples below. + router.AddHTTPTLSConfig("*", defaultTLSConf) + + // The server name (from a Host(SNI) rule) is the only parameter (available in + // HTTP routing rules) on which we can map a TLS config, because it is the only one + // accessible before decryption (we obtain it during the ClientHello). Therefore, + // when a router has no Host rule, it does not make any sense to specify some TLS + // options. Consequently, when it comes to deciding what TLS config will be used, + // for a request that will match an HTTPS router with no Host rule, the result will + // depend on the _others_ existing routers (their Host rule, to be precise), and + // the TLS options associated with them, even though they don't match the incoming + // request. Consider the following examples: + + // # conf1 + // httpRouter1: + // rule: PathPrefix("/foo") + // # Wherever the request comes from, the TLS config used will be the default one, because of the Host(*) fallback. + + // # conf2 + // httpRouter1: + // rule: PathPrefix("/foo") + // + // httpRouter2: + // rule: Host("foo.com") && PathPrefix("/bar") + // tlsoptions: myTLSOptions + // # When a request for "/foo" comes, even though it won't be routed by + // httpRouter2, if its SNI is set to foo.com, myTLSOptions will be used for the TLS + // connection. Otherwise, it will fallback to the default TLS config. + logger.Warnf("No domain found in rule %v, the TLS options applied for this router will depend on the SNI of each request", routerHTTPConfig.Rule) + } + + tlsConf, err := m.tlsManager.Get(traefiktls.DefaultTLSStoreName, tlsOptionsName) + if err != nil { + routerHTTPConfig.AddError(err, true) + logger.Error(err) + continue + } + + for _, domain := range domains { + // domain is already in lower case thanks to the domain parsing + if tlsOptionsForHostSNI[domain] == nil { + tlsOptionsForHostSNI[domain] = make(map[string]nameAndConfig) + } + tlsOptionsForHostSNI[domain][tlsOptionsName] = nameAndConfig{ + routerName: routerHTTPName, + TLSConfig: tlsConf, + } + + if name, ok := tlsOptionsForHost[domain]; ok && name != tlsOptionsName { + // Different tlsOptions on the same domain, so fallback to default + tlsOptionsForHost[domain] = traefiktls.DefaultTLSConfigName + } else { + tlsOptionsForHost[domain] = tlsOptionsName + } + } + } + + sniCheck := snicheck.New(tlsOptionsForHost, handlerHTTPS) + + router.SetHTTPSHandler(sniCheck, defaultTLSConf) + + logger := log.FromContext(ctx) + for hostSNI, tlsConfigs := range tlsOptionsForHostSNI { + if len(tlsConfigs) == 1 { + var optionsName string + var config *tls.Config + for k, v := range tlsConfigs { + optionsName = k + config = v.TLSConfig + break + } + + logger.Debugf("Adding route for %s with TLS options %s", hostSNI, optionsName) + + router.AddHTTPTLSConfig(hostSNI, config) + } else { + routers := make([]string, 0, len(tlsConfigs)) + for _, v := range tlsConfigs { + configsHTTP[v.routerName].AddError(fmt.Errorf("found different TLS options for routers on the same host %v, so using the default TLS options instead", hostSNI), false) + routers = append(routers, v.routerName) + } + + logger.Warnf("Found different TLS options for routers on the same host %v, so using the default TLS options instead for these routers: %#v", hostSNI, routers) + + router.AddHTTPTLSConfig(hostSNI, defaultTLSConf) + } + } + + for routerName, routerConfig := range configs { + ctxRouter := log.With(provider.AddInContext(ctx, routerName), log.Str(log.RouterName, routerName)) + logger := log.FromContext(ctxRouter) + + if routerConfig.Service == "" { + err := errors.New("the service is missing on the router") + routerConfig.AddError(err, true) + logger.Error(err) + continue + } + + if routerConfig.Rule == "" { + err := errors.New("router has no rule") + routerConfig.AddError(err, true) + logger.Error(err) + continue + } + + handler, err := m.buildTCPHandler(ctxRouter, routerConfig) + if err != nil { + routerConfig.AddError(err, true) + logger.Error(err) + continue + } + + domains, err := tcpmuxer.ParseHostSNI(routerConfig.Rule) + if err != nil { + routerErr := fmt.Errorf("invalid rule: %q , %w", routerConfig.Rule, err) + routerConfig.AddError(routerErr, true) + logger.Error(routerErr) + continue + } + + // HostSNI Rule, but TLS not set on the router, which is an error. + // However, we allow the HostSNI(*) exception. + if len(domains) > 0 && routerConfig.TLS == nil && domains[0] != "*" { + routerErr := fmt.Errorf("invalid rule: %q , has HostSNI matcher, but no TLS on router", routerConfig.Rule) + routerConfig.AddError(routerErr, true) + logger.Error(routerErr) + } + + if routerConfig.TLS == nil { + logger.Debugf("Adding route for %q", routerConfig.Rule) + if err := router.AddRoute(routerConfig.Rule, routerConfig.Priority, handler); err != nil { + routerConfig.AddError(err, true) + logger.Error(err) + } + continue + } + + if routerConfig.TLS.Passthrough { + logger.Debugf("Adding Passthrough route for %q", routerConfig.Rule) + if err := router.AddRouteTLS(routerConfig.Rule, routerConfig.Priority, handler, nil); err != nil { + routerConfig.AddError(err, true) + logger.Error(err) + } + continue + } + + for _, domain := range domains { + if httpmuxer.IsASCII(domain) { + continue + } + + asciiError := fmt.Errorf("invalid domain name value %q, non-ASCII characters are not allowed", domain) + routerConfig.AddError(asciiError, true) + logger.Error(asciiError) + } + + tlsOptionsName := routerConfig.TLS.Options + + if len(tlsOptionsName) == 0 { + tlsOptionsName = traefiktls.DefaultTLSConfigName + } + + if tlsOptionsName != traefiktls.DefaultTLSConfigName { + tlsOptionsName = provider.GetQualifiedName(ctxRouter, tlsOptionsName) + } + + tlsConf, err := m.tlsManager.Get(traefiktls.DefaultTLSStoreName, tlsOptionsName) + if err != nil { + routerConfig.AddError(err, true) + logger.Error(err) + continue + } + + // Now that the Rule is not just about the Host, we could theoretically have a config like: + // router1: + // rule: HostSNI(foo.com) && ClientIP(IP1) + // tlsOption: tlsOne + // router2: + // rule: HostSNI(foo.com) && ClientIP(IP2) + // tlsOption: tlsTwo + // i.e. same HostSNI but different tlsOptions + // This is only applicable if the muxer can decide about the routing _before_ + // telling the client about the tlsConf (i.e. before the TLS HandShake). This seems + // to be the case so far with the existing matchers (HostSNI, and ClientIP), so + // it's all good. Otherwise, we would have to do as for HTTPS, i.e. disallow + // different TLS configs for the same HostSNIs. + + logger.Debugf("Adding TLS route for %q", routerConfig.Rule) + if err := router.AddRouteTLS(routerConfig.Rule, routerConfig.Priority, handler, tlsConf); err != nil { + routerConfig.AddError(err, true) + logger.Error(err) + } + } + + return router, nil +} + +func (m *Manager) buildTCPHandler(ctx context.Context, router *runtime.TCPRouterInfo) (tcp.Handler, error) { + var qualifiedNames []string + for _, name := range router.Middlewares { + qualifiedNames = append(qualifiedNames, provider.GetQualifiedName(ctx, name)) + } + router.Middlewares = qualifiedNames + + if router.Service == "" { + return nil, errors.New("the service is missing on the router") + } + + sHandler, err := m.serviceManager.BuildTCP(ctx, router.Service) + if err != nil { + return nil, err + } + + mHandler := m.middlewaresBuilder.BuildChain(ctx, router.Middlewares) + + return tcp.NewChain().Extend(*mHandler).Then(sHandler) +} diff --git a/pkg/server/router/tcp/router_test.go b/pkg/server/router/tcp/manager_test.go similarity index 95% rename from pkg/server/router/tcp/router_test.go rename to pkg/server/router/tcp/manager_test.go index cdf7c3bbc..dd31ce3b3 100644 --- a/pkg/server/router/tcp/router_test.go +++ b/pkg/server/router/tcp/manager_test.go @@ -175,6 +175,7 @@ func TestRuntimeConfiguration(t *testing.T) { EntryPoints: []string{"web"}, Service: "foo-service", Rule: "HostSNI(`foo.bar`)", + TLS: &dynamic.RouterTCPTLSConfig{}, }, }, }, @@ -234,6 +235,7 @@ func TestRuntimeConfiguration(t *testing.T) { EntryPoints: []string{"web"}, Service: "wrong-service", Rule: "HostSNI(`bar.foo`)", + TLS: &dynamic.RouterTCPTLSConfig{}, }, }, "bar": { @@ -241,6 +243,7 @@ func TestRuntimeConfiguration(t *testing.T) { EntryPoints: []string{"web"}, Service: "foo-service", Rule: "HostSNI(`foo.bar`)", + TLS: &dynamic.RouterTCPTLSConfig{}, }, }, }, @@ -266,6 +269,32 @@ func TestRuntimeConfiguration(t *testing.T) { }, expectedError: 2, }, + { + desc: "Router with HostSNI but no TLS", + tcpServiceConfig: map[string]*runtime.TCPServiceInfo{ + "foo-service": { + TCPService: &dynamic.TCPService{ + LoadBalancer: &dynamic.TCPServersLoadBalancer{ + Servers: []dynamic.TCPServer{ + { + Address: "127.0.0.1:80", + }, + }, + }, + }, + }, + }, + tcpRouterConfig: map[string]*runtime.TCPRouterInfo{ + "foo": { + TCPRouter: &dynamic.TCPRouter{ + EntryPoints: []string{"web"}, + Service: "foo-service", + Rule: "HostSNI(`bar.foo`)", + }, + }, + }, + expectedError: 1, + }, } for _, test := range testCases { diff --git a/pkg/server/router/tcp/router.go b/pkg/server/router/tcp/router.go index 0a97a10b3..d9b5327c0 100644 --- a/pkg/server/router/tcp/router.go +++ b/pkg/server/router/tcp/router.go @@ -1,291 +1,364 @@ package tcp import ( - "context" + "bufio" + "bytes" "crypto/tls" "errors" - "fmt" + "io" + "net" "net/http" + "time" - "github.com/traefik/traefik/v2/pkg/config/runtime" "github.com/traefik/traefik/v2/pkg/log" - "github.com/traefik/traefik/v2/pkg/middlewares/snicheck" - "github.com/traefik/traefik/v2/pkg/rules" - "github.com/traefik/traefik/v2/pkg/server/provider" - tcpservice "github.com/traefik/traefik/v2/pkg/server/service/tcp" + tcpmuxer "github.com/traefik/traefik/v2/pkg/muxer/tcp" "github.com/traefik/traefik/v2/pkg/tcp" - traefiktls "github.com/traefik/traefik/v2/pkg/tls" ) -type middlewareBuilder interface { - BuildChain(ctx context.Context, names []string) *tcp.Chain +const defaultBufSize = 4096 + +// Router is a TCP router. +type Router struct { + // Contains TCP routes. + muxerTCP tcpmuxer.Muxer + // Contains TCP TLS routes. + muxerTCPTLS tcpmuxer.Muxer + // Contains HTTPS routes. + muxerHTTPS tcpmuxer.Muxer + + // Forwarder handlers. + // Handles all HTTP requests. + httpForwarder tcp.Handler + // Handles (indirectly through muxerHTTPS, or directly) all HTTPS requests. + httpsForwarder tcp.Handler + + // Neither is used directly, but they are held here, and recreated on config + // reload, so that they can be passed to the Switcher at the end of the config + // reload phase. + httpHandler http.Handler + httpsHandler http.Handler + + // TLS configs. + httpsTLSConfig *tls.Config // default TLS config + hostHTTPTLSConfig map[string]*tls.Config // TLS configs keyed by SNI } -// NewManager Creates a new Manager. -func NewManager(conf *runtime.Configuration, - serviceManager *tcpservice.Manager, - middlewaresBuilder middlewareBuilder, - httpHandlers map[string]http.Handler, - httpsHandlers map[string]http.Handler, - tlsManager *traefiktls.Manager, -) *Manager { - return &Manager{ - serviceManager: serviceManager, - middlewaresBuilder: middlewaresBuilder, - httpHandlers: httpHandlers, - httpsHandlers: httpsHandlers, - tlsManager: tlsManager, - conf: conf, - } -} - -// Manager is a route/router manager. -type Manager struct { - serviceManager *tcpservice.Manager - middlewaresBuilder middlewareBuilder - httpHandlers map[string]http.Handler - httpsHandlers map[string]http.Handler - tlsManager *traefiktls.Manager - conf *runtime.Configuration -} - -func (m *Manager) getTCPRouters(ctx context.Context, entryPoints []string) map[string]map[string]*runtime.TCPRouterInfo { - if m.conf != nil { - return m.conf.GetTCPRoutersByEntryPoints(ctx, entryPoints) - } - - return make(map[string]map[string]*runtime.TCPRouterInfo) -} - -func (m *Manager) getHTTPRouters(ctx context.Context, entryPoints []string, tls bool) map[string]map[string]*runtime.RouterInfo { - if m.conf != nil { - return m.conf.GetRoutersByEntryPoints(ctx, entryPoints, tls) - } - - return make(map[string]map[string]*runtime.RouterInfo) -} - -// BuildHandlers builds the handlers for the given entrypoints. -func (m *Manager) BuildHandlers(rootCtx context.Context, entryPoints []string) map[string]*tcp.Router { - entryPointsRouters := m.getTCPRouters(rootCtx, entryPoints) - entryPointsRoutersHTTP := m.getHTTPRouters(rootCtx, entryPoints, true) - - entryPointHandlers := make(map[string]*tcp.Router) - for _, entryPointName := range entryPoints { - entryPointName := entryPointName - - routers := entryPointsRouters[entryPointName] - - ctx := log.With(rootCtx, log.Str(log.EntryPointName, entryPointName)) - - handler, err := m.buildEntryPointHandler(ctx, routers, entryPointsRoutersHTTP[entryPointName], m.httpHandlers[entryPointName], m.httpsHandlers[entryPointName]) - if err != nil { - log.FromContext(ctx).Error(err) - continue - } - entryPointHandlers[entryPointName] = handler - } - return entryPointHandlers -} - -type nameAndConfig struct { - routerName string // just so we have it as additional information when logging - TLSConfig *tls.Config -} - -func (m *Manager) buildEntryPointHandler(ctx context.Context, configs map[string]*runtime.TCPRouterInfo, configsHTTP map[string]*runtime.RouterInfo, handlerHTTP, handlerHTTPS http.Handler) (*tcp.Router, error) { - router := &tcp.Router{} - router.HTTPHandler(handlerHTTP) - - defaultTLSConf, err := m.tlsManager.Get(traefiktls.DefaultTLSStoreName, traefiktls.DefaultTLSConfigName) - if err != nil { - log.FromContext(ctx).Errorf("Error during the build of the default TLS configuration: %v", err) - } - - if len(configsHTTP) > 0 { - router.AddRouteHTTPTLS("*", defaultTLSConf) - } - - // Keyed by domain, then by options reference. - tlsOptionsForHostSNI := map[string]map[string]nameAndConfig{} - tlsOptionsForHost := map[string]string{} - for routerHTTPName, routerHTTPConfig := range configsHTTP { - if routerHTTPConfig.TLS == nil { - continue - } - - ctxRouter := log.With(provider.AddInContext(ctx, routerHTTPName), log.Str(log.RouterName, routerHTTPName)) - logger := log.FromContext(ctxRouter) - - tlsOptionsName := traefiktls.DefaultTLSConfigName - if len(routerHTTPConfig.TLS.Options) > 0 && routerHTTPConfig.TLS.Options != traefiktls.DefaultTLSConfigName { - tlsOptionsName = provider.GetQualifiedName(ctxRouter, routerHTTPConfig.TLS.Options) - } - - domains, err := rules.ParseDomains(routerHTTPConfig.Rule) - if err != nil { - routerErr := fmt.Errorf("invalid rule %s, error: %w", routerHTTPConfig.Rule, err) - routerHTTPConfig.AddError(routerErr, true) - logger.Debug(routerErr) - continue - } - - if len(domains) == 0 { - logger.Warnf("No domain found in rule %v, the TLS options applied for this router will depend on the hostSNI of each request", routerHTTPConfig.Rule) - } - - for _, domain := range domains { - tlsConf, err := m.tlsManager.Get(traefiktls.DefaultTLSStoreName, tlsOptionsName) - if err != nil { - routerHTTPConfig.AddError(err, true) - logger.Debug(err) - continue - } - - // domain is already in lower case thanks to the domain parsing - if tlsOptionsForHostSNI[domain] == nil { - tlsOptionsForHostSNI[domain] = make(map[string]nameAndConfig) - } - tlsOptionsForHostSNI[domain][tlsOptionsName] = nameAndConfig{ - routerName: routerHTTPName, - TLSConfig: tlsConf, - } - - if name, ok := tlsOptionsForHost[domain]; ok && name != tlsOptionsName { - // Different tlsOptions on the same domain fallback to default - tlsOptionsForHost[domain] = traefiktls.DefaultTLSConfigName - } else { - tlsOptionsForHost[domain] = tlsOptionsName - } - } - } - - sniCheck := snicheck.New(tlsOptionsForHost, handlerHTTPS) - - router.HTTPSHandler(sniCheck, defaultTLSConf) - - logger := log.FromContext(ctx) - for hostSNI, tlsConfigs := range tlsOptionsForHostSNI { - if len(tlsConfigs) == 1 { - var optionsName string - var config *tls.Config - for k, v := range tlsConfigs { - optionsName = k - config = v.TLSConfig - break - } - - logger.Debugf("Adding route for %s with TLS options %s", hostSNI, optionsName) - - router.AddRouteHTTPTLS(hostSNI, config) - } else { - routers := make([]string, 0, len(tlsConfigs)) - for _, v := range tlsConfigs { - configsHTTP[v.routerName].AddError(fmt.Errorf("found different TLS options for routers on the same host %v, so using the default TLS options instead", hostSNI), false) - routers = append(routers, v.routerName) - } - - logger.Warnf("Found different TLS options for routers on the same host %v, so using the default TLS options instead for these routers: %#v", hostSNI, routers) - - router.AddRouteHTTPTLS(hostSNI, defaultTLSConf) - } - } - - for routerName, routerConfig := range configs { - ctxRouter := log.With(provider.AddInContext(ctx, routerName), log.Str(log.RouterName, routerName)) - logger := log.FromContext(ctxRouter) - - if routerConfig.Service == "" { - err := errors.New("the service is missing on the router") - routerConfig.AddError(err, true) - logger.Error(err) - continue - } - - if routerConfig.Rule == "" { - err := errors.New("router has no rule") - routerConfig.AddError(err, true) - logger.Error(err) - continue - } - - handler, err := m.buildTCPHandler(ctxRouter, routerConfig) - if err != nil { - routerConfig.AddError(err, true) - logger.Error(err) - continue - } - - domains, err := rules.ParseHostSNI(routerConfig.Rule) - if err != nil { - routerErr := fmt.Errorf("unknown rule %s", routerConfig.Rule) - routerConfig.AddError(routerErr, true) - logger.Error(routerErr) - continue - } - - for _, domain := range domains { - logger.Debugf("Adding route %s on TCP", domain) - switch { - case routerConfig.TLS != nil: - if !rules.IsASCII(domain) { - asciiError := fmt.Errorf("invalid domain name value %q, non-ASCII characters are not allowed", domain) - routerConfig.AddError(asciiError, true) - logger.Debug(asciiError) - continue - } - - if routerConfig.TLS.Passthrough { - router.AddRoute(domain, handler) - continue - } - - tlsOptionsName := routerConfig.TLS.Options - - if len(tlsOptionsName) == 0 { - tlsOptionsName = traefiktls.DefaultTLSConfigName - } - - if tlsOptionsName != traefiktls.DefaultTLSConfigName { - tlsOptionsName = provider.GetQualifiedName(ctxRouter, tlsOptionsName) - } - - tlsConf, err := m.tlsManager.Get(traefiktls.DefaultTLSStoreName, tlsOptionsName) - if err != nil { - routerConfig.AddError(err, true) - logger.Debug(err) - continue - } - - router.AddRouteTLS(domain, handler, tlsConf) - case domain == "*": - router.AddCatchAllNoTLS(handler) - default: - logger.Warn("TCP Router ignored, cannot specify a Host rule without TLS") - } - } - } - - return router, nil -} - -func (m *Manager) buildTCPHandler(ctx context.Context, router *runtime.TCPRouterInfo) (tcp.Handler, error) { - var qualifiedNames []string - for _, name := range router.Middlewares { - qualifiedNames = append(qualifiedNames, provider.GetQualifiedName(ctx, name)) - } - router.Middlewares = qualifiedNames - - if router.Service == "" { - return nil, errors.New("the service is missing on the router") - } - - sHandler, err := m.serviceManager.BuildTCP(ctx, router.Service) +// NewRouter returns a new TCP router. +func NewRouter() (*Router, error) { + muxTCP, err := tcpmuxer.NewMuxer() if err != nil { return nil, err } - mHandler := m.middlewaresBuilder.BuildChain(ctx, router.Middlewares) + muxTCPTLS, err := tcpmuxer.NewMuxer() + if err != nil { + return nil, err + } - return tcp.NewChain().Extend(*mHandler).Then(sHandler) + muxHTTPS, err := tcpmuxer.NewMuxer() + if err != nil { + return nil, err + } + + return &Router{ + muxerTCP: *muxTCP, + muxerTCPTLS: *muxTCPTLS, + muxerHTTPS: *muxHTTPS, + }, nil } + +// GetTLSGetClientInfo is called after a ClientHello is received from a client. +func (r *Router) GetTLSGetClientInfo() func(info *tls.ClientHelloInfo) (*tls.Config, error) { + return func(info *tls.ClientHelloInfo) (*tls.Config, error) { + if tlsConfig, ok := r.hostHTTPTLSConfig[info.ServerName]; ok { + return tlsConfig, nil + } + + return r.httpsTLSConfig, nil + } +} + +// ServeTCP forwards the connection to the right TCP/HTTP handler. +func (r *Router) ServeTCP(conn tcp.WriteCloser) { + // Handling Non-TLS TCP connection early if there is neither HTTP(S) nor TLS + // routers on the entryPoint, and if there is at least one non-TLS TCP router. + // In the case of a non-TLS TCP client (that does not "send" first), we would + // block forever on clientHelloServerName, which is why we want to detect and + // handle that case first and foremost. + if r.muxerTCP.HasRoutes() && !r.muxerTCPTLS.HasRoutes() && !r.muxerHTTPS.HasRoutes() { + connData, err := tcpmuxer.NewConnData("", conn) + if err != nil { + log.WithoutContext().Errorf("Error while reading TCP connection data: %v", err) + conn.Close() + return + } + + handler := r.muxerTCP.Match(connData) + // If there is a handler matching the connection metadata, + // we let it handle the connection. + if handler != nil { + handler.ServeTCP(conn) + return + } + // Otherwise, we keep going because: + // 1) we could be in the case where we have HTTP routers. + // 2) if it is an HTTPS request, even though we do not have any TLS routers, + // we still need to reply with a 404. + } + + // FIXME -- Check if ProxyProtocol changes the first bytes of the request + br := bufio.NewReader(conn) + serverName, tls, peeked, err := clientHelloServerName(br) + if err != nil { + conn.Close() + return + } + + // Remove read/write deadline and delegate this to underlying tcp server (for now only handled by HTTP Server) + err = conn.SetReadDeadline(time.Time{}) + if err != nil { + log.WithoutContext().Errorf("Error while setting read deadline: %v", err) + } + + err = conn.SetWriteDeadline(time.Time{}) + if err != nil { + log.WithoutContext().Errorf("Error while setting write deadline: %v", err) + } + + connData, err := tcpmuxer.NewConnData(serverName, conn) + if err != nil { + log.WithoutContext().Errorf("Error while reading TCP connection data: %v", err) + conn.Close() + return + } + + if !tls { + handler := r.muxerTCP.Match(connData) + switch { + case handler != nil: + handler.ServeTCP(r.GetConn(conn, peeked)) + case r.httpForwarder != nil: + r.httpForwarder.ServeTCP(r.GetConn(conn, peeked)) + default: + conn.Close() + } + return + } + + handler := r.muxerTCPTLS.Match(connData) + if handler != nil { + handler.ServeTCP(r.GetConn(conn, peeked)) + return + } + + // for real, the handler returned here is (almost) always the same: + // it is the httpsForwarder that is used for all HTTPS connections that match + // (which is also incidentally the same used in the last block below for 404s). + // The added value from doing Match, is to find and use the specific TLS config + // requested for the given HostSNI. + handler = r.muxerHTTPS.Match(connData) + if handler != nil { + handler.ServeTCP(r.GetConn(conn, peeked)) + return + } + + // needed to handle 404s for HTTPS, as well as all non-Host (e.g. PathPrefix) matches. + if r.httpsForwarder != nil { + r.httpsForwarder.ServeTCP(r.GetConn(conn, peeked)) + return + } + + conn.Close() +} + +// AddRoute defines a handler for the given rule. +func (r *Router) AddRoute(rule string, priority int, target tcp.Handler) error { + return r.muxerTCP.AddRoute(rule, priority, target) +} + +// AddRouteTLS defines a handler for a given rule and sets the matching tlsConfig. +func (r *Router) AddRouteTLS(rule string, priority int, target tcp.Handler, config *tls.Config) error { + // TLS PassThrough + if config == nil { + return r.muxerTCPTLS.AddRoute(rule, priority, target) + } + + return r.muxerTCPTLS.AddRoute(rule, priority, &tcp.TLSHandler{ + Next: target, + Config: config, + }) +} + +// AddHTTPTLSConfig defines a handler for a given sniHost and sets the matching tlsConfig. +func (r *Router) AddHTTPTLSConfig(sniHost string, config *tls.Config) { + if r.hostHTTPTLSConfig == nil { + r.hostHTTPTLSConfig = map[string]*tls.Config{} + } + + r.hostHTTPTLSConfig[sniHost] = config +} + +// GetConn creates a connection proxy with a peeked string. +func (r *Router) GetConn(conn tcp.WriteCloser, peeked string) tcp.WriteCloser { + // FIXME should it really be on Router ? + conn = &Conn{ + Peeked: []byte(peeked), + WriteCloser: conn, + } + + return conn +} + +// GetHTTPHandler gets the attached http handler. +func (r *Router) GetHTTPHandler() http.Handler { + return r.httpHandler +} + +// GetHTTPSHandler gets the attached https handler. +func (r *Router) GetHTTPSHandler() http.Handler { + return r.httpsHandler +} + +// SetHTTPForwarder sets the tcp handler that will forward the connections to an http handler. +func (r *Router) SetHTTPForwarder(handler tcp.Handler) { + r.httpForwarder = handler +} + +// SetHTTPSForwarder sets the tcp handler that will forward the TLS connections to an http handler. +func (r *Router) SetHTTPSForwarder(handler tcp.Handler) { + for sniHost, tlsConf := range r.hostHTTPTLSConfig { + // muxerHTTPS only contains single HostSNI rules (and no other kind of rules), + // so there's no need for specifying a priority for them. + err := r.muxerHTTPS.AddRoute("HostSNI(`"+sniHost+"`)", 0, &tcp.TLSHandler{ + Next: handler, + Config: tlsConf, + }) + if err != nil { + log.WithoutContext().Errorf("Error while adding route for host: %w", err) + } + } + + r.httpsForwarder = &tcp.TLSHandler{ + Next: handler, + Config: r.httpsTLSConfig, + } +} + +// SetHTTPHandler attaches http handlers on the router. +func (r *Router) SetHTTPHandler(handler http.Handler) { + r.httpHandler = handler +} + +// SetHTTPSHandler attaches https handlers on the router. +func (r *Router) SetHTTPSHandler(handler http.Handler, config *tls.Config) { + r.httpsHandler = handler + r.httpsTLSConfig = config +} + +// Conn is a connection proxy that handles Peeked bytes. +type Conn struct { + // Peeked are the bytes that have been read from Conn for the + // purposes of route matching, but have not yet been consumed + // by Read calls. It set to nil by Read when fully consumed. + Peeked []byte + + // Conn is the underlying connection. + // It can be type asserted against *net.TCPConn or other types + // as needed. It should not be read from directly unless + // Peeked is nil. + tcp.WriteCloser +} + +// Read reads bytes from the connection (using the buffer prior to actually reading). +func (c *Conn) Read(p []byte) (n int, err error) { + if len(c.Peeked) > 0 { + n = copy(p, c.Peeked) + c.Peeked = c.Peeked[n:] + if len(c.Peeked) == 0 { + c.Peeked = nil + } + return n, nil + } + return c.WriteCloser.Read(p) +} + +// clientHelloServerName returns the SNI server name inside the TLS ClientHello, +// without consuming any bytes from br. +// On any error, the empty string is returned. +func clientHelloServerName(br *bufio.Reader) (string, bool, string, error) { + hdr, err := br.Peek(1) + if err != nil { + var opErr *net.OpError + if !errors.Is(err, io.EOF) && (!errors.As(err, &opErr) || opErr.Timeout()) { + log.WithoutContext().Errorf("Error while Peeking first byte: %s", err) + } + + return "", false, "", err + } + + // No valid TLS record has a type of 0x80, however SSLv2 handshakes + // start with a uint16 length where the MSB is set and the first record + // is always < 256 bytes long. Therefore typ == 0x80 strongly suggests + // an SSLv2 client. + const recordTypeSSLv2 = 0x80 + const recordTypeHandshake = 0x16 + if hdr[0] != recordTypeHandshake { + if hdr[0] == recordTypeSSLv2 { + // we consider SSLv2 as TLS and it will be refused by real TLS handshake. + return "", true, getPeeked(br), nil + } + return "", false, getPeeked(br), nil // Not TLS. + } + + const recordHeaderLen = 5 + hdr, err = br.Peek(recordHeaderLen) + if err != nil { + log.Errorf("Error while Peeking hello: %s", err) + return "", false, getPeeked(br), nil + } + + recLen := int(hdr[3])<<8 | int(hdr[4]) // ignoring version in hdr[1:3] + + if recordHeaderLen+recLen > defaultBufSize { + br = bufio.NewReaderSize(br, recordHeaderLen+recLen) + } + + helloBytes, err := br.Peek(recordHeaderLen + recLen) + if err != nil { + log.Errorf("Error while Hello: %s", err) + return "", true, getPeeked(br), nil + } + + sni := "" + server := tls.Server(sniSniffConn{r: bytes.NewReader(helloBytes)}, &tls.Config{ + GetConfigForClient: func(hello *tls.ClientHelloInfo) (*tls.Config, error) { + sni = hello.ServerName + return nil, nil + }, + }) + _ = server.Handshake() + + return sni, true, getPeeked(br), nil +} + +func getPeeked(br *bufio.Reader) string { + peeked, err := br.Peek(br.Buffered()) + if err != nil { + log.Errorf("Could not get anything: %s", err) + return "" + } + return string(peeked) +} + +// sniSniffConn is a net.Conn that reads from r, fails on Writes, +// and crashes otherwise. +type sniSniffConn struct { + r io.Reader + net.Conn // nil; crash on any unexpected use +} + +// Read reads from the underlying reader. +func (c sniSniffConn) Read(p []byte) (int, error) { return c.r.Read(p) } + +// Write crashes all the time. +func (sniSniffConn) Write(p []byte) (int, error) { return 0, io.EOF } diff --git a/pkg/server/routerfactory.go b/pkg/server/routerfactory.go index 58f4b0244..2226b3db3 100644 --- a/pkg/server/routerfactory.go +++ b/pkg/server/routerfactory.go @@ -8,16 +8,15 @@ import ( "github.com/traefik/traefik/v2/pkg/log" "github.com/traefik/traefik/v2/pkg/metrics" "github.com/traefik/traefik/v2/pkg/server/middleware" - middlewaretcp "github.com/traefik/traefik/v2/pkg/server/middleware/tcp" + tcpmiddleware "github.com/traefik/traefik/v2/pkg/server/middleware/tcp" "github.com/traefik/traefik/v2/pkg/server/router" - routertcp "github.com/traefik/traefik/v2/pkg/server/router/tcp" - routerudp "github.com/traefik/traefik/v2/pkg/server/router/udp" + tcprouter "github.com/traefik/traefik/v2/pkg/server/router/tcp" + udprouter "github.com/traefik/traefik/v2/pkg/server/router/udp" "github.com/traefik/traefik/v2/pkg/server/service" "github.com/traefik/traefik/v2/pkg/server/service/tcp" "github.com/traefik/traefik/v2/pkg/server/service/udp" - tcpCore "github.com/traefik/traefik/v2/pkg/tcp" "github.com/traefik/traefik/v2/pkg/tls" - udpCore "github.com/traefik/traefik/v2/pkg/udp" + udptypes "github.com/traefik/traefik/v2/pkg/udp" ) // RouterFactory the factory of TCP/UDP routers. @@ -64,7 +63,7 @@ func NewRouterFactory(staticConfiguration static.Configuration, managerFactory * } // CreateRouters creates new TCPRouters and UDPRouters. -func (f *RouterFactory) CreateRouters(rtConf *runtime.Configuration) (map[string]*tcpCore.Router, map[string]udpCore.Handler) { +func (f *RouterFactory) CreateRouters(rtConf *runtime.Configuration) (map[string]*tcprouter.Router, map[string]udptypes.Handler) { ctx := context.Background() // HTTP @@ -82,14 +81,14 @@ func (f *RouterFactory) CreateRouters(rtConf *runtime.Configuration) (map[string // TCP svcTCPManager := tcp.NewManager(rtConf) - middlewaresTCPBuilder := middlewaretcp.NewBuilder(rtConf.TCPMiddlewares) + middlewaresTCPBuilder := tcpmiddleware.NewBuilder(rtConf.TCPMiddlewares) - rtTCPManager := routertcp.NewManager(rtConf, svcTCPManager, middlewaresTCPBuilder, handlersNonTLS, handlersTLS, f.tlsManager) + rtTCPManager := tcprouter.NewManager(rtConf, svcTCPManager, middlewaresTCPBuilder, handlersNonTLS, handlersTLS, f.tlsManager) routersTCP := rtTCPManager.BuildHandlers(ctx, f.entryPointsTCP) // UDP svcUDPManager := udp.NewManager(rtConf) - rtUDPManager := routerudp.NewManager(rtConf, svcUDPManager) + rtUDPManager := udprouter.NewManager(rtConf, svcUDPManager) routersUDP := rtUDPManager.BuildHandlers(ctx, f.entryPointsUDP) rtConf.PopulateUsedBy() diff --git a/pkg/server/server_entrypoint_tcp.go b/pkg/server/server_entrypoint_tcp.go index 42fb868cf..41bd85fc1 100644 --- a/pkg/server/server_entrypoint_tcp.go +++ b/pkg/server/server_entrypoint_tcp.go @@ -22,6 +22,7 @@ import ( "github.com/traefik/traefik/v2/pkg/middlewares/requestdecorator" "github.com/traefik/traefik/v2/pkg/safe" "github.com/traefik/traefik/v2/pkg/server/router" + tcprouter "github.com/traefik/traefik/v2/pkg/server/router/tcp" "github.com/traefik/traefik/v2/pkg/tcp" "github.com/traefik/traefik/v2/pkg/types" "golang.org/x/net/http2" @@ -114,7 +115,7 @@ func (eps TCPEntryPoints) Stop() { } // Switch the TCP routers. -func (eps TCPEntryPoints) Switch(routersTCP map[string]*tcp.Router) { +func (eps TCPEntryPoints) Switch(routersTCP map[string]*tcprouter.Router) { for entryPointName, rt := range routersTCP { eps[entryPointName].SwitchRouter(rt) } @@ -141,7 +142,7 @@ func NewTCPEntryPoint(ctx context.Context, configuration *static.EntryPoint, hos return nil, fmt.Errorf("error preparing server: %w", err) } - rt := &tcp.Router{} + rt := &tcprouter.Router{} reqDecorator := requestdecorator.New(hostResolverConfig) @@ -150,7 +151,7 @@ func NewTCPEntryPoint(ctx context.Context, configuration *static.EntryPoint, hos return nil, fmt.Errorf("error preparing http server: %w", err) } - rt.HTTPForwarder(httpServer.Forwarder) + rt.SetHTTPForwarder(httpServer.Forwarder) httpsServer, err := createHTTPServer(ctx, listener, configuration, false, reqDecorator) if err != nil { @@ -162,7 +163,7 @@ func NewTCPEntryPoint(ctx context.Context, configuration *static.EntryPoint, hos return nil, fmt.Errorf("error preparing http3 server: %w", err) } - rt.HTTPSForwarder(httpsServer.Forwarder) + rt.SetHTTPSForwarder(httpsServer.Forwarder) tcpSwitcher := &tcp.HandlerSwitcher{} tcpSwitcher.Switch(rt) @@ -181,7 +182,7 @@ func NewTCPEntryPoint(ctx context.Context, configuration *static.EntryPoint, hos // Start starts the TCP server. func (e *TCPEntryPoint) Start(ctx context.Context) { logger := log.FromContext(ctx) - logger.Debugf("Start TCP Server") + logger.Debug("Starting TCP Server") if e.http3Server != nil { go func() { _ = e.http3Server.Start() }() @@ -301,8 +302,8 @@ func (e *TCPEntryPoint) Shutdown(ctx context.Context) { } // SwitchRouter switches the TCP router handler. -func (e *TCPEntryPoint) SwitchRouter(rt *tcp.Router) { - rt.HTTPForwarder(e.httpServer.Forwarder) +func (e *TCPEntryPoint) SwitchRouter(rt *tcprouter.Router) { + rt.SetHTTPForwarder(e.httpServer.Forwarder) httpHandler := rt.GetHTTPHandler() if httpHandler == nil { @@ -311,7 +312,7 @@ func (e *TCPEntryPoint) SwitchRouter(rt *tcp.Router) { e.httpServer.Switcher.UpdateHandler(httpHandler) - rt.HTTPSForwarder(e.httpsServer.Forwarder) + rt.SetHTTPSForwarder(e.httpsServer.Forwarder) httpsHandler := rt.GetHTTPSHandler() if httpsHandler == nil { diff --git a/pkg/server/server_entrypoint_tcp_http3.go b/pkg/server/server_entrypoint_tcp_http3.go index 81146ae20..2c40a0f2f 100644 --- a/pkg/server/server_entrypoint_tcp_http3.go +++ b/pkg/server/server_entrypoint_tcp_http3.go @@ -13,7 +13,7 @@ import ( "github.com/lucas-clemente/quic-go/http3" "github.com/traefik/traefik/v2/pkg/config/static" "github.com/traefik/traefik/v2/pkg/log" - "github.com/traefik/traefik/v2/pkg/tcp" + tcprouter "github.com/traefik/traefik/v2/pkg/server/router/tcp" ) type http3server struct { @@ -77,7 +77,7 @@ func (e *http3server) Start() error { return e.Serve(e.http3conn) } -func (e *http3server) Switch(rt *tcp.Router) { +func (e *http3server) Switch(rt *tcprouter.Router) { e.lock.Lock() defer e.lock.Unlock() diff --git a/pkg/server/server_entrypoint_tcp_http3_test.go b/pkg/server/server_entrypoint_tcp_http3_test.go index b0759ef48..bf863653a 100644 --- a/pkg/server/server_entrypoint_tcp_http3_test.go +++ b/pkg/server/server_entrypoint_tcp_http3_test.go @@ -10,7 +10,7 @@ import ( "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" "github.com/traefik/traefik/v2/pkg/config/static" - "github.com/traefik/traefik/v2/pkg/tcp" + tcprouter "github.com/traefik/traefik/v2/pkg/server/router/tcp" traefiktls "github.com/traefik/traefik/v2/pkg/tls" ) @@ -94,11 +94,13 @@ func TestHTTP3AdvertisedPort(t *testing.T) { }, nil) require.NoError(t, err) - router := &tcp.Router{} - router.AddRouteHTTPTLS("*", &tls.Config{ + router, err := tcprouter.NewRouter() + require.NoError(t, err) + + router.AddHTTPTLSConfig("*", &tls.Config{ Certificates: []tls.Certificate{tlsCert}, }) - router.HTTPSHandler(http.HandlerFunc(func(rw http.ResponseWriter, req *http.Request) { + router.SetHTTPSHandler(http.HandlerFunc(func(rw http.ResponseWriter, req *http.Request) { rw.WriteHeader(http.StatusOK) }), nil) diff --git a/pkg/server/server_entrypoint_tcp_test.go b/pkg/server/server_entrypoint_tcp_test.go index ab49cd5a9..1477629a7 100644 --- a/pkg/server/server_entrypoint_tcp_test.go +++ b/pkg/server/server_entrypoint_tcp_test.go @@ -15,12 +15,13 @@ import ( "github.com/stretchr/testify/require" ptypes "github.com/traefik/paerser/types" "github.com/traefik/traefik/v2/pkg/config/static" + tcprouter "github.com/traefik/traefik/v2/pkg/server/router/tcp" "github.com/traefik/traefik/v2/pkg/tcp" ) func TestShutdownHijacked(t *testing.T) { - router := &tcp.Router{} - router.HTTPHandler(http.HandlerFunc(func(rw http.ResponseWriter, req *http.Request) { + router := &tcprouter.Router{} + router.SetHTTPHandler(http.HandlerFunc(func(rw http.ResponseWriter, req *http.Request) { conn, _, err := rw.(http.Hijacker).Hijack() require.NoError(t, err) @@ -33,8 +34,8 @@ func TestShutdownHijacked(t *testing.T) { } func TestShutdownHTTP(t *testing.T) { - router := &tcp.Router{} - router.HTTPHandler(http.HandlerFunc(func(rw http.ResponseWriter, req *http.Request) { + router := &tcprouter.Router{} + router.SetHTTPHandler(http.HandlerFunc(func(rw http.ResponseWriter, req *http.Request) { rw.WriteHeader(http.StatusOK) time.Sleep(time.Second) })) @@ -43,8 +44,10 @@ func TestShutdownHTTP(t *testing.T) { } func TestShutdownTCP(t *testing.T) { - router := &tcp.Router{} - router.AddCatchAllNoTLS(tcp.HandlerFunc(func(conn tcp.WriteCloser) { + router, err := tcprouter.NewRouter() + require.NoError(t, err) + + err = router.AddRoute("HostSNI(`*`)", 0, tcp.HandlerFunc(func(conn tcp.WriteCloser) { for { _, err := http.ReadRequest(bufio.NewReader(conn)) @@ -58,11 +61,12 @@ func TestShutdownTCP(t *testing.T) { require.NoError(t, err) } })) + require.NoError(t, err) testShutdown(t, router) } -func testShutdown(t *testing.T, router *tcp.Router) { +func testShutdown(t *testing.T, router *tcprouter.Router) { t.Helper() epConfig := &static.EntryPointsTransport{} @@ -135,7 +139,7 @@ func testShutdown(t *testing.T, router *tcp.Router) { assert.Equal(t, http.StatusOK, resp.StatusCode) } -func startEntrypoint(entryPoint *TCPEntryPoint, router *tcp.Router) (net.Conn, error) { +func startEntrypoint(entryPoint *TCPEntryPoint, router *tcprouter.Router) (net.Conn, error) { go entryPoint.Start(context.Background()) entryPoint.SwitchRouter(router) @@ -165,8 +169,8 @@ func TestReadTimeoutWithoutFirstByte(t *testing.T) { }, nil) require.NoError(t, err) - router := &tcp.Router{} - router.HTTPHandler(http.HandlerFunc(func(rw http.ResponseWriter, req *http.Request) { + router := &tcprouter.Router{} + router.SetHTTPHandler(http.HandlerFunc(func(rw http.ResponseWriter, req *http.Request) { rw.WriteHeader(http.StatusOK) })) @@ -201,8 +205,8 @@ func TestReadTimeoutWithFirstByte(t *testing.T) { }, nil) require.NoError(t, err) - router := &tcp.Router{} - router.HTTPHandler(http.HandlerFunc(func(rw http.ResponseWriter, req *http.Request) { + router := &tcprouter.Router{} + router.SetHTTPHandler(http.HandlerFunc(func(rw http.ResponseWriter, req *http.Request) { rw.WriteHeader(http.StatusOK) })) diff --git a/pkg/server/service/loadbalancer/mirror/mirror.go b/pkg/server/service/loadbalancer/mirror/mirror.go index bd7142394..f0d12a952 100644 --- a/pkg/server/service/loadbalancer/mirror/mirror.go +++ b/pkg/server/service/loadbalancer/mirror/mirror.go @@ -93,7 +93,7 @@ func (m *Mirroring) ServeHTTP(rw http.ResponseWriter, req *http.Request) { if errors.Is(err, errBodyTooLarge) { req.Body = io.NopCloser(io.MultiReader(bytes.NewReader(bytesRead), req.Body)) m.handler.ServeHTTP(rw, req) - logger.Debugf("no mirroring, request body larger than allowed size") + logger.Debug("no mirroring, request body larger than allowed size") return } diff --git a/pkg/tcp/router.go b/pkg/tcp/router.go deleted file mode 100644 index 6be63ba44..000000000 --- a/pkg/tcp/router.go +++ /dev/null @@ -1,286 +0,0 @@ -package tcp - -import ( - "bufio" - "bytes" - "crypto/tls" - "errors" - "io" - "net" - "net/http" - "strings" - "time" - - "github.com/traefik/traefik/v2/pkg/log" - "github.com/traefik/traefik/v2/pkg/types" -) - -const defaultBufSize = 4096 - -// Router is a TCP router. -type Router struct { - routingTable map[string]Handler - httpForwarder Handler - httpsForwarder Handler - httpHandler http.Handler - httpsHandler http.Handler - httpsTLSConfig *tls.Config // default TLS config - catchAllNoTLS Handler - hostHTTPTLSConfig map[string]*tls.Config // TLS configs keyed by SNI -} - -// GetTLSGetClientInfo is called after a ClientHello is received from a client. -func (r *Router) GetTLSGetClientInfo() func(info *tls.ClientHelloInfo) (*tls.Config, error) { - return func(info *tls.ClientHelloInfo) (*tls.Config, error) { - if tlsConfig, ok := r.hostHTTPTLSConfig[info.ServerName]; ok { - return tlsConfig, nil - } - return r.httpsTLSConfig, nil - } -} - -// ServeTCP forwards the connection to the right TCP/HTTP handler. -func (r *Router) ServeTCP(conn WriteCloser) { - // FIXME -- Check if ProxyProtocol changes the first bytes of the request - - if r.catchAllNoTLS != nil && len(r.routingTable) == 0 { - r.catchAllNoTLS.ServeTCP(conn) - return - } - - br := bufio.NewReader(conn) - serverName, tls, peeked, err := clientHelloServerName(br) - if err != nil { - conn.Close() - return - } - - // Remove read/write deadline and delegate this to underlying tcp server (for now only handled by HTTP Server) - err = conn.SetReadDeadline(time.Time{}) - if err != nil { - log.WithoutContext().Errorf("Error while setting read deadline: %v", err) - } - - err = conn.SetWriteDeadline(time.Time{}) - if err != nil { - log.WithoutContext().Errorf("Error while setting write deadline: %v", err) - } - - if !tls { - switch { - case r.catchAllNoTLS != nil: - r.catchAllNoTLS.ServeTCP(r.GetConn(conn, peeked)) - case r.httpForwarder != nil: - r.httpForwarder.ServeTCP(r.GetConn(conn, peeked)) - default: - conn.Close() - } - return - } - - // FIXME Optimize and test the routing table before helloServerName - serverName = types.CanonicalDomain(serverName) - if r.routingTable != nil && serverName != "" { - if target, ok := r.routingTable[serverName]; ok { - target.ServeTCP(r.GetConn(conn, peeked)) - return - } - } - - // FIXME Needs tests - if target, ok := r.routingTable["*"]; ok { - target.ServeTCP(r.GetConn(conn, peeked)) - return - } - - if r.httpsForwarder != nil { - r.httpsForwarder.ServeTCP(r.GetConn(conn, peeked)) - } else { - conn.Close() - } -} - -// AddRoute defines a handler for a given sniHost (* is the only valid option). -func (r *Router) AddRoute(sniHost string, target Handler) { - if r.routingTable == nil { - r.routingTable = map[string]Handler{} - } - r.routingTable[strings.ToLower(sniHost)] = target -} - -// AddRouteTLS defines a handler for a given sniHost and sets the matching tlsConfig. -func (r *Router) AddRouteTLS(sniHost string, target Handler, config *tls.Config) { - r.AddRoute(sniHost, &TLSHandler{ - Next: target, - Config: config, - }) -} - -// AddRouteHTTPTLS defines a handler for a given sniHost and sets the matching tlsConfig. -func (r *Router) AddRouteHTTPTLS(sniHost string, config *tls.Config) { - if r.hostHTTPTLSConfig == nil { - r.hostHTTPTLSConfig = map[string]*tls.Config{} - } - r.hostHTTPTLSConfig[sniHost] = config -} - -// AddCatchAllNoTLS defines the fallback tcp handler. -func (r *Router) AddCatchAllNoTLS(handler Handler) { - r.catchAllNoTLS = handler -} - -// GetConn creates a connection proxy with a peeked string. -func (r *Router) GetConn(conn WriteCloser, peeked string) WriteCloser { - // FIXME should it really be on Router ? - conn = &Conn{ - Peeked: []byte(peeked), - WriteCloser: conn, - } - return conn -} - -// GetHTTPHandler gets the attached http handler. -func (r *Router) GetHTTPHandler() http.Handler { - return r.httpHandler -} - -// GetHTTPSHandler gets the attached https handler. -func (r *Router) GetHTTPSHandler() http.Handler { - return r.httpsHandler -} - -// HTTPForwarder sets the tcp handler that will forward the connections to an http handler. -func (r *Router) HTTPForwarder(handler Handler) { - r.httpForwarder = handler -} - -// HTTPSForwarder sets the tcp handler that will forward the TLS connections to an http handler. -func (r *Router) HTTPSForwarder(handler Handler) { - for sniHost, tlsConf := range r.hostHTTPTLSConfig { - r.AddRouteTLS(sniHost, handler, tlsConf) - } - - r.httpsForwarder = &TLSHandler{ - Next: handler, - Config: r.httpsTLSConfig, - } -} - -// HTTPHandler attaches http handlers on the router. -func (r *Router) HTTPHandler(handler http.Handler) { - r.httpHandler = handler -} - -// HTTPSHandler attaches https handlers on the router. -func (r *Router) HTTPSHandler(handler http.Handler, config *tls.Config) { - r.httpsHandler = handler - r.httpsTLSConfig = config -} - -// Conn is a connection proxy that handles Peeked bytes. -type Conn struct { - // Peeked are the bytes that have been read from Conn for the - // purposes of route matching, but have not yet been consumed - // by Read calls. It set to nil by Read when fully consumed. - Peeked []byte - - // Conn is the underlying connection. - // It can be type asserted against *net.TCPConn or other types - // as needed. It should not be read from directly unless - // Peeked is nil. - WriteCloser -} - -// Read reads bytes from the connection (using the buffer prior to actually reading). -func (c *Conn) Read(p []byte) (n int, err error) { - if len(c.Peeked) > 0 { - n = copy(p, c.Peeked) - c.Peeked = c.Peeked[n:] - if len(c.Peeked) == 0 { - c.Peeked = nil - } - return n, nil - } - return c.WriteCloser.Read(p) -} - -// clientHelloServerName returns the SNI server name inside the TLS ClientHello, -// without consuming any bytes from br. -// On any error, the empty string is returned. -func clientHelloServerName(br *bufio.Reader) (string, bool, string, error) { - hdr, err := br.Peek(1) - if err != nil { - var opErr *net.OpError - if !errors.Is(err, io.EOF) && (!errors.As(err, &opErr) || opErr.Timeout()) { - log.WithoutContext().Debugf("Error while Peeking first byte: %s", err) - } - - return "", false, "", err - } - - // No valid TLS record has a type of 0x80, however SSLv2 handshakes - // start with a uint16 length where the MSB is set and the first record - // is always < 256 bytes long. Therefore typ == 0x80 strongly suggests - // an SSLv2 client. - const recordTypeSSLv2 = 0x80 - const recordTypeHandshake = 0x16 - if hdr[0] != recordTypeHandshake { - if hdr[0] == recordTypeSSLv2 { - // we consider SSLv2 as TLS and it will be refuse by real TLS handshake. - return "", true, getPeeked(br), nil - } - return "", false, getPeeked(br), nil // Not TLS. - } - - const recordHeaderLen = 5 - hdr, err = br.Peek(recordHeaderLen) - if err != nil { - log.Errorf("Error while Peeking hello: %s", err) - return "", false, getPeeked(br), nil - } - - recLen := int(hdr[3])<<8 | int(hdr[4]) // ignoring version in hdr[1:3] - - if recordHeaderLen+recLen > defaultBufSize { - br = bufio.NewReaderSize(br, recordHeaderLen+recLen) - } - - helloBytes, err := br.Peek(recordHeaderLen + recLen) - if err != nil { - log.Errorf("Error while Hello: %s", err) - return "", true, getPeeked(br), nil - } - - sni := "" - server := tls.Server(sniSniffConn{r: bytes.NewReader(helloBytes)}, &tls.Config{ - GetConfigForClient: func(hello *tls.ClientHelloInfo) (*tls.Config, error) { - sni = hello.ServerName - return nil, nil - }, - }) - _ = server.Handshake() - - return sni, true, getPeeked(br), nil -} - -func getPeeked(br *bufio.Reader) string { - peeked, err := br.Peek(br.Buffered()) - if err != nil { - log.Errorf("Could not get anything: %s", err) - return "" - } - return string(peeked) -} - -// sniSniffConn is a net.Conn that reads from r, fails on Writes, -// and crashes otherwise. -type sniSniffConn struct { - r io.Reader - net.Conn // nil; crash on any unexpected use -} - -// Read reads from the underlying reader. -func (c sniSniffConn) Read(p []byte) (int, error) { return c.r.Read(p) } - -// Write crashes all the time. -func (sniSniffConn) Write(p []byte) (int, error) { return 0, io.EOF }