Merge branch 'v1.7' into master
This commit is contained in:
commit
d8f69700e6
10 changed files with 407 additions and 38 deletions
2
Gopkg.lock
generated
2
Gopkg.lock
generated
|
@ -1266,7 +1266,7 @@
|
||||||
"roundrobin",
|
"roundrobin",
|
||||||
"utils"
|
"utils"
|
||||||
]
|
]
|
||||||
revision = "f0cbb9d6b797d92d168b95b5c443a31dfa67ccd0"
|
revision = "a3ed5f65204f4ffccbb56d58cec466cdb7ab730b"
|
||||||
|
|
||||||
[[projects]]
|
[[projects]]
|
||||||
name = "github.com/vulcand/predicate"
|
name = "github.com/vulcand/predicate"
|
||||||
|
|
|
@ -33,7 +33,7 @@ func (p *Provider) buildConfigurationV2(containersInspected []dockerData) *types
|
||||||
"getDomain": label.GetFuncString(label.TraefikDomain, p.Domain),
|
"getDomain": label.GetFuncString(label.TraefikDomain, p.Domain),
|
||||||
|
|
||||||
// Backend functions
|
// Backend functions
|
||||||
"getIPAddress": p.getIPAddress,
|
"getIPAddress": p.getDeprecatedIPAddress, // TODO: Should we expose getIPPort instead?
|
||||||
"getServers": p.getServers,
|
"getServers": p.getServers,
|
||||||
"getMaxConn": label.GetMaxConn,
|
"getMaxConn": label.GetMaxConn,
|
||||||
"getHealthCheck": label.GetHealthCheck,
|
"getHealthCheck": label.GetHealthCheck,
|
||||||
|
@ -235,17 +235,6 @@ func (p Provider) getIPAddress(container dockerData) string {
|
||||||
return p.getIPAddress(parseContainer(containerInspected))
|
return p.getIPAddress(parseContainer(containerInspected))
|
||||||
}
|
}
|
||||||
|
|
||||||
if p.UseBindPortIP {
|
|
||||||
port := getPortV1(container)
|
|
||||||
for netPort, portBindings := range container.NetworkSettings.Ports {
|
|
||||||
if string(netPort) == port+"/TCP" || string(netPort) == port+"/UDP" {
|
|
||||||
for _, p := range portBindings {
|
|
||||||
return p.HostIP
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
for _, network := range container.NetworkSettings.Networks {
|
for _, network := range container.NetworkSettings.Networks {
|
||||||
return network.Addr
|
return network.Addr
|
||||||
}
|
}
|
||||||
|
@ -254,6 +243,16 @@ func (p Provider) getIPAddress(container dockerData) string {
|
||||||
return ""
|
return ""
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Deprecated: Please use getIPPort instead
|
||||||
|
func (p *Provider) getDeprecatedIPAddress(container dockerData) string {
|
||||||
|
ip, _, err := p.getIPPort(container)
|
||||||
|
if err != nil {
|
||||||
|
log.Warn(err)
|
||||||
|
return ""
|
||||||
|
}
|
||||||
|
return ip
|
||||||
|
}
|
||||||
|
|
||||||
// Escape beginning slash "/", convert all others to dash "-", and convert underscores "_" to dash "-"
|
// Escape beginning slash "/", convert all others to dash "-", and convert underscores "_" to dash "-"
|
||||||
func getSubDomain(name string) string {
|
func getSubDomain(name string) string {
|
||||||
return strings.Replace(strings.Replace(strings.TrimPrefix(name, "/"), "/", "-", -1), "_", "-", -1)
|
return strings.Replace(strings.Replace(strings.TrimPrefix(name, "/"), "/", "-", -1), "_", "-", -1)
|
||||||
|
@ -322,13 +321,53 @@ func getPort(container dockerData) string {
|
||||||
return ""
|
return ""
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (p *Provider) getPortBinding(container dockerData) (*nat.PortBinding, error) {
|
||||||
|
port := getPort(container)
|
||||||
|
for netPort, portBindings := range container.NetworkSettings.Ports {
|
||||||
|
if strings.EqualFold(string(netPort), port+"/TCP") || strings.EqualFold(string(netPort), port+"/UDP") {
|
||||||
|
for _, p := range portBindings {
|
||||||
|
return &p, nil
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil, fmt.Errorf("unable to find the external IP:Port for the container %q", container.Name)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (p *Provider) getIPPort(container dockerData) (string, string, error) {
|
||||||
|
var ip, port string
|
||||||
|
|
||||||
|
if p.UseBindPortIP {
|
||||||
|
portBinding, err := p.getPortBinding(container)
|
||||||
|
if err != nil {
|
||||||
|
return "", "", fmt.Errorf("unable to find a binding for the container %q: ignoring server", container.Name)
|
||||||
|
}
|
||||||
|
|
||||||
|
if portBinding.HostIP == "0.0.0.0" {
|
||||||
|
return "", "", fmt.Errorf("cannot determine the IP address (got 0.0.0.0) for the container %q: ignoring server", container.Name)
|
||||||
|
}
|
||||||
|
|
||||||
|
ip = portBinding.HostIP
|
||||||
|
port = portBinding.HostPort
|
||||||
|
|
||||||
|
} else {
|
||||||
|
ip = p.getIPAddress(container)
|
||||||
|
port = getPort(container)
|
||||||
|
}
|
||||||
|
|
||||||
|
if len(ip) == 0 {
|
||||||
|
return "", "", fmt.Errorf("unable to find the IP address for the container %q: the server is ignored", container.Name)
|
||||||
|
}
|
||||||
|
return ip, port, nil
|
||||||
|
}
|
||||||
|
|
||||||
func (p *Provider) getServers(containers []dockerData) map[string]types.Server {
|
func (p *Provider) getServers(containers []dockerData) map[string]types.Server {
|
||||||
var servers map[string]types.Server
|
var servers map[string]types.Server
|
||||||
|
|
||||||
for _, container := range containers {
|
for _, container := range containers {
|
||||||
ip := p.getIPAddress(container)
|
ip, port, err := p.getIPPort(container)
|
||||||
if len(ip) == 0 {
|
if err != nil {
|
||||||
log.Warnf("Unable to find the IP address for the container %q: the server is ignored.", container.Name)
|
log.Warn(err)
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -337,7 +376,6 @@ func (p *Provider) getServers(containers []dockerData) map[string]types.Server {
|
||||||
}
|
}
|
||||||
|
|
||||||
protocol := label.GetStringValue(container.SegmentLabels, label.TraefikProtocol, label.DefaultProtocol)
|
protocol := label.GetStringValue(container.SegmentLabels, label.TraefikProtocol, label.DefaultProtocol)
|
||||||
port := getPort(container)
|
|
||||||
|
|
||||||
serverURL := fmt.Sprintf("%s://%s", protocol, net.JoinHostPort(ip, port))
|
serverURL := fmt.Sprintf("%s://%s", protocol, net.JoinHostPort(ip, port))
|
||||||
|
|
||||||
|
|
|
@ -1287,12 +1287,173 @@ func TestDockerGetIPAddress(t *testing.T) {
|
||||||
Network: "webnet",
|
Network: "webnet",
|
||||||
}
|
}
|
||||||
|
|
||||||
actual := provider.getIPAddress(dData)
|
actual := provider.getDeprecatedIPAddress(dData)
|
||||||
assert.Equal(t, test.expected, actual)
|
assert.Equal(t, test.expected, actual)
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestDockerGetIPPort(t *testing.T) {
|
||||||
|
testCases := []struct {
|
||||||
|
desc string
|
||||||
|
container docker.ContainerJSON
|
||||||
|
ip, port string
|
||||||
|
expectsError bool
|
||||||
|
}{
|
||||||
|
{
|
||||||
|
desc: "label traefik.port not set, binding with ip:port should create a route to the bound ip:port",
|
||||||
|
container: containerJSON(
|
||||||
|
ports(nat.PortMap{
|
||||||
|
"80/tcp": []nat.PortBinding{
|
||||||
|
{
|
||||||
|
HostIP: "1.2.3.4",
|
||||||
|
HostPort: "8081",
|
||||||
|
},
|
||||||
|
},
|
||||||
|
}),
|
||||||
|
withNetwork("testnet", ipv4("10.11.12.13"))),
|
||||||
|
ip: "1.2.3.4",
|
||||||
|
port: "8081",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
desc: "label traefik.port set, multiple bindings on different ports, uses the label to select the correct (first) binding",
|
||||||
|
container: containerJSON(
|
||||||
|
labels(map[string]string{
|
||||||
|
label.TraefikPort: "80",
|
||||||
|
}),
|
||||||
|
ports(nat.PortMap{
|
||||||
|
"80/tcp": []nat.PortBinding{
|
||||||
|
{
|
||||||
|
HostIP: "1.2.3.4",
|
||||||
|
HostPort: "8081",
|
||||||
|
},
|
||||||
|
},
|
||||||
|
"443/tcp": []nat.PortBinding{
|
||||||
|
{
|
||||||
|
HostIP: "5.6.7.8",
|
||||||
|
HostPort: "8082",
|
||||||
|
},
|
||||||
|
},
|
||||||
|
}),
|
||||||
|
withNetwork("testnet", ipv4("10.11.12.13"))),
|
||||||
|
ip: "1.2.3.4",
|
||||||
|
port: "8081",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
desc: "label traefik.port set, multiple bindings on different ports, uses the label to select the correct (second) binding",
|
||||||
|
container: containerJSON(
|
||||||
|
labels(map[string]string{
|
||||||
|
label.TraefikPort: "443",
|
||||||
|
}),
|
||||||
|
ports(nat.PortMap{
|
||||||
|
"80/tcp": []nat.PortBinding{
|
||||||
|
{
|
||||||
|
HostIP: "1.2.3.4",
|
||||||
|
HostPort: "8081",
|
||||||
|
},
|
||||||
|
},
|
||||||
|
"443/tcp": []nat.PortBinding{
|
||||||
|
{
|
||||||
|
HostIP: "5.6.7.8",
|
||||||
|
HostPort: "8082",
|
||||||
|
},
|
||||||
|
},
|
||||||
|
}),
|
||||||
|
withNetwork("testnet", ipv4("10.11.12.13"))),
|
||||||
|
ip: "5.6.7.8",
|
||||||
|
port: "8082",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
desc: "label traefik.port set, single binding with ip:port for the label, creates the route",
|
||||||
|
container: containerJSON(
|
||||||
|
labels(map[string]string{
|
||||||
|
label.TraefikPort: "443",
|
||||||
|
}),
|
||||||
|
ports(nat.PortMap{
|
||||||
|
"443/tcp": []nat.PortBinding{
|
||||||
|
{
|
||||||
|
HostIP: "5.6.7.8",
|
||||||
|
HostPort: "8082",
|
||||||
|
},
|
||||||
|
},
|
||||||
|
}),
|
||||||
|
withNetwork("testnet", ipv4("10.11.12.13"))),
|
||||||
|
ip: "5.6.7.8",
|
||||||
|
port: "8082",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
desc: "label traefik.port not set, single binding with port only, server ignored",
|
||||||
|
container: containerJSON(
|
||||||
|
ports(nat.PortMap{
|
||||||
|
"80/tcp": []nat.PortBinding{
|
||||||
|
{
|
||||||
|
HostPort: "8082",
|
||||||
|
},
|
||||||
|
},
|
||||||
|
}),
|
||||||
|
withNetwork("testnet", ipv4("10.11.12.13"))),
|
||||||
|
expectsError: true,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
desc: "label traefik.port not set, no binding, server ignored",
|
||||||
|
container: containerJSON(
|
||||||
|
withNetwork("testnet", ipv4("10.11.12.13"))),
|
||||||
|
expectsError: true,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
desc: "label traefik.port set, no binding on the corresponding port, server ignored",
|
||||||
|
container: containerJSON(
|
||||||
|
labels(map[string]string{
|
||||||
|
label.TraefikPort: "80",
|
||||||
|
}),
|
||||||
|
ports(nat.PortMap{
|
||||||
|
"443/tcp": []nat.PortBinding{
|
||||||
|
{
|
||||||
|
HostIP: "5.6.7.8",
|
||||||
|
HostPort: "8082",
|
||||||
|
},
|
||||||
|
},
|
||||||
|
}),
|
||||||
|
withNetwork("testnet", ipv4("10.11.12.13"))),
|
||||||
|
expectsError: true,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
desc: "label traefik.port set, no binding, server ignored",
|
||||||
|
container: containerJSON(
|
||||||
|
labels(map[string]string{
|
||||||
|
label.TraefikPort: "80",
|
||||||
|
}),
|
||||||
|
withNetwork("testnet", ipv4("10.11.12.13"))),
|
||||||
|
expectsError: true,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, test := range testCases {
|
||||||
|
test := test
|
||||||
|
t.Run(test.desc, func(t *testing.T) {
|
||||||
|
t.Parallel()
|
||||||
|
|
||||||
|
dData := parseContainer(test.container)
|
||||||
|
segmentProperties := label.ExtractTraefikLabels(dData.Labels)
|
||||||
|
dData.SegmentLabels = segmentProperties[""]
|
||||||
|
|
||||||
|
provider := &Provider{
|
||||||
|
Network: "webnet",
|
||||||
|
UseBindPortIP: true,
|
||||||
|
}
|
||||||
|
|
||||||
|
actualIP, actualPort, actualError := provider.getIPPort(dData)
|
||||||
|
if test.expectsError {
|
||||||
|
require.Error(t, actualError)
|
||||||
|
} else {
|
||||||
|
require.NoError(t, actualError)
|
||||||
|
}
|
||||||
|
assert.Equal(t, test.ip, actualIP)
|
||||||
|
assert.Equal(t, test.port, actualPort)
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
func TestDockerGetPort(t *testing.T) {
|
func TestDockerGetPort(t *testing.T) {
|
||||||
testCases := []struct {
|
testCases := []struct {
|
||||||
container docker.ContainerJSON
|
container docker.ContainerJSON
|
||||||
|
|
|
@ -933,7 +933,7 @@ func TestSwarmGetIPAddress(t *testing.T) {
|
||||||
segmentProperties := label.ExtractTraefikLabels(dData.Labels)
|
segmentProperties := label.ExtractTraefikLabels(dData.Labels)
|
||||||
dData.SegmentLabels = segmentProperties[""]
|
dData.SegmentLabels = segmentProperties[""]
|
||||||
|
|
||||||
actual := provider.getIPAddress(dData)
|
actual := provider.getDeprecatedIPAddress(dData)
|
||||||
assert.Equal(t, test.expected, actual)
|
assert.Equal(t, test.expected, actual)
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
|
@ -1,8 +1,10 @@
|
||||||
package docker
|
package docker
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"context"
|
||||||
"math"
|
"math"
|
||||||
"strconv"
|
"strconv"
|
||||||
|
"strings"
|
||||||
"text/template"
|
"text/template"
|
||||||
|
|
||||||
"github.com/BurntSushi/ty/fun"
|
"github.com/BurntSushi/ty/fun"
|
||||||
|
@ -19,7 +21,7 @@ func (p *Provider) buildConfigurationV1(containersInspected []dockerData) *types
|
||||||
"isBackendLBSwarm": isBackendLBSwarm,
|
"isBackendLBSwarm": isBackendLBSwarm,
|
||||||
|
|
||||||
// Backend functions
|
// Backend functions
|
||||||
"getIPAddress": p.getIPAddress,
|
"getIPAddress": p.getIPAddressV1,
|
||||||
"getPort": getPortV1,
|
"getPort": getPortV1,
|
||||||
"getWeight": getFuncIntLabelV1(label.TraefikWeight, label.DefaultWeight),
|
"getWeight": getFuncIntLabelV1(label.TraefikWeight, label.DefaultWeight),
|
||||||
"getProtocol": getFuncStringLabelV1(label.TraefikProtocol, label.DefaultProtocol),
|
"getProtocol": getFuncStringLabelV1(label.TraefikProtocol, label.DefaultProtocol),
|
||||||
|
@ -202,3 +204,60 @@ func (p Provider) containerFilterV1(container dockerData) bool {
|
||||||
|
|
||||||
return true
|
return true
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (p Provider) getIPAddressV1(container dockerData) string {
|
||||||
|
if value := label.GetStringValue(container.Labels, labelDockerNetwork, p.Network); value != "" {
|
||||||
|
networkSettings := container.NetworkSettings
|
||||||
|
if networkSettings.Networks != nil {
|
||||||
|
network := networkSettings.Networks[value]
|
||||||
|
if network != nil {
|
||||||
|
return network.Addr
|
||||||
|
}
|
||||||
|
|
||||||
|
log.Warnf("Could not find network named '%s' for container '%s'! Maybe you're missing the project's prefix in the label? Defaulting to first available network.", value, container.Name)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if container.NetworkSettings.NetworkMode.IsHost() {
|
||||||
|
if container.Node != nil {
|
||||||
|
if container.Node.IPAddress != "" {
|
||||||
|
return container.Node.IPAddress
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return "127.0.0.1"
|
||||||
|
}
|
||||||
|
|
||||||
|
if container.NetworkSettings.NetworkMode.IsContainer() {
|
||||||
|
dockerClient, err := p.createClient()
|
||||||
|
if err != nil {
|
||||||
|
log.Warnf("Unable to get IP address for container %s, error: %s", container.Name, err)
|
||||||
|
return ""
|
||||||
|
}
|
||||||
|
|
||||||
|
connectedContainer := container.NetworkSettings.NetworkMode.ConnectedContainer()
|
||||||
|
containerInspected, err := dockerClient.ContainerInspect(context.Background(), connectedContainer)
|
||||||
|
if err != nil {
|
||||||
|
log.Warnf("Unable to get IP address for container %s : Failed to inspect container ID %s, error: %s", container.Name, connectedContainer, err)
|
||||||
|
return ""
|
||||||
|
}
|
||||||
|
return p.getIPAddress(parseContainer(containerInspected))
|
||||||
|
}
|
||||||
|
|
||||||
|
if p.UseBindPortIP {
|
||||||
|
port := getPortV1(container)
|
||||||
|
for netPort, portBindings := range container.NetworkSettings.Ports {
|
||||||
|
if strings.EqualFold(string(netPort), port+"/TCP") || strings.EqualFold(string(netPort), port+"/UDP") {
|
||||||
|
for _, p := range portBindings {
|
||||||
|
return p.HostIP
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, network := range container.NetworkSettings.Networks {
|
||||||
|
return network.Addr
|
||||||
|
}
|
||||||
|
|
||||||
|
log.Warnf("Unable to find the IP address for the container %q.", container.Name)
|
||||||
|
return ""
|
||||||
|
}
|
||||||
|
|
|
@ -898,7 +898,7 @@ func TestDockerGetIPAddressV1(t *testing.T) {
|
||||||
t.Parallel()
|
t.Parallel()
|
||||||
dData := parseContainer(test.container)
|
dData := parseContainer(test.container)
|
||||||
provider := &Provider{}
|
provider := &Provider{}
|
||||||
actual := provider.getIPAddress(dData)
|
actual := provider.getDeprecatedIPAddress(dData)
|
||||||
if actual != test.expected {
|
if actual != test.expected {
|
||||||
t.Errorf("expected %q, got %q", test.expected, actual)
|
t.Errorf("expected %q, got %q", test.expected, actual)
|
||||||
}
|
}
|
||||||
|
|
|
@ -667,7 +667,7 @@ func TestSwarmGetIPAddressV1(t *testing.T) {
|
||||||
SwarmMode: true,
|
SwarmMode: true,
|
||||||
}
|
}
|
||||||
|
|
||||||
actual := provider.getIPAddress(dData)
|
actual := provider.getDeprecatedIPAddress(dData)
|
||||||
if actual != test.expected {
|
if actual != test.expected {
|
||||||
t.Errorf("expected %q, got %q", test.expected, actual)
|
t.Errorf("expected %q, got %q", test.expected, actual)
|
||||||
}
|
}
|
||||||
|
|
100
server/server.go
100
server/server.go
|
@ -40,6 +40,59 @@ import (
|
||||||
|
|
||||||
var httpServerLogger = stdlog.New(log.WriterLevel(logrus.DebugLevel), "", 0)
|
var httpServerLogger = stdlog.New(log.WriterLevel(logrus.DebugLevel), "", 0)
|
||||||
|
|
||||||
|
func newHijackConnectionTracker() *hijackConnectionTracker {
|
||||||
|
return &hijackConnectionTracker{
|
||||||
|
conns: make(map[net.Conn]struct{}),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
type hijackConnectionTracker struct {
|
||||||
|
conns map[net.Conn]struct{}
|
||||||
|
lock sync.RWMutex
|
||||||
|
}
|
||||||
|
|
||||||
|
// AddHijackedConnection add a connection in the tracked connections list
|
||||||
|
func (h *hijackConnectionTracker) AddHijackedConnection(conn net.Conn) {
|
||||||
|
h.lock.Lock()
|
||||||
|
defer h.lock.Unlock()
|
||||||
|
h.conns[conn] = struct{}{}
|
||||||
|
}
|
||||||
|
|
||||||
|
// RemoveHijackedConnection remove a connection from the tracked connections list
|
||||||
|
func (h *hijackConnectionTracker) RemoveHijackedConnection(conn net.Conn) {
|
||||||
|
h.lock.Lock()
|
||||||
|
defer h.lock.Unlock()
|
||||||
|
delete(h.conns, conn)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Shutdown wait for the connection closing
|
||||||
|
func (h *hijackConnectionTracker) Shutdown(ctx context.Context) error {
|
||||||
|
ticker := time.NewTicker(500 * time.Millisecond)
|
||||||
|
defer ticker.Stop()
|
||||||
|
for {
|
||||||
|
h.lock.RLock()
|
||||||
|
if len(h.conns) == 0 {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
h.lock.RUnlock()
|
||||||
|
select {
|
||||||
|
case <-ctx.Done():
|
||||||
|
return ctx.Err()
|
||||||
|
case <-ticker.C:
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Close close all the connections in the tracked connections list
|
||||||
|
func (h *hijackConnectionTracker) Close() {
|
||||||
|
for conn := range h.conns {
|
||||||
|
if err := conn.Close(); err != nil {
|
||||||
|
log.Errorf("Error while closing Hijacked conn: %v", err)
|
||||||
|
}
|
||||||
|
delete(h.conns, conn)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
// Server is the reverse-proxy/load-balancer engine
|
// Server is the reverse-proxy/load-balancer engine
|
||||||
type Server struct {
|
type Server struct {
|
||||||
serverEntryPoints serverEntryPoints
|
serverEntryPoints serverEntryPoints
|
||||||
|
@ -80,6 +133,35 @@ type serverEntryPoint struct {
|
||||||
certs *traefiktls.CertificateStore
|
certs *traefiktls.CertificateStore
|
||||||
onDemandListener func(string) (*tls.Certificate, error)
|
onDemandListener func(string) (*tls.Certificate, error)
|
||||||
tlsALPNGetter func(string) (*tls.Certificate, error)
|
tlsALPNGetter func(string) (*tls.Certificate, error)
|
||||||
|
hijackConnectionTracker *hijackConnectionTracker
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s serverEntryPoint) Shutdown(ctx context.Context) {
|
||||||
|
var wg sync.WaitGroup
|
||||||
|
wg.Add(1)
|
||||||
|
go func() {
|
||||||
|
defer wg.Done()
|
||||||
|
if err := s.httpServer.Shutdown(ctx); err != nil {
|
||||||
|
if ctx.Err() == context.DeadlineExceeded {
|
||||||
|
log.Debugf("Wait server shutdown is over due to: %s", err)
|
||||||
|
err = s.httpServer.Close()
|
||||||
|
if err != nil {
|
||||||
|
log.Error(err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}()
|
||||||
|
wg.Add(1)
|
||||||
|
go func() {
|
||||||
|
defer wg.Done()
|
||||||
|
if err := s.hijackConnectionTracker.Shutdown(ctx); err != nil {
|
||||||
|
if ctx.Err() == context.DeadlineExceeded {
|
||||||
|
log.Debugf("Wait hijack connection is over due to: %s", err)
|
||||||
|
s.hijackConnectionTracker.Close()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}()
|
||||||
|
wg.Wait()
|
||||||
}
|
}
|
||||||
|
|
||||||
// NewServer returns an initialized Server.
|
// NewServer returns an initialized Server.
|
||||||
|
@ -187,13 +269,7 @@ func (s *Server) Stop() {
|
||||||
graceTimeOut := time.Duration(s.globalConfiguration.LifeCycle.GraceTimeOut)
|
graceTimeOut := time.Duration(s.globalConfiguration.LifeCycle.GraceTimeOut)
|
||||||
ctx, cancel := context.WithTimeout(context.Background(), graceTimeOut)
|
ctx, cancel := context.WithTimeout(context.Background(), graceTimeOut)
|
||||||
log.Debugf("Waiting %s seconds before killing connections on entrypoint %s...", graceTimeOut, serverEntryPointName)
|
log.Debugf("Waiting %s seconds before killing connections on entrypoint %s...", graceTimeOut, serverEntryPointName)
|
||||||
if err := serverEntryPoint.httpServer.Shutdown(ctx); err != nil {
|
serverEntryPoint.Shutdown(ctx)
|
||||||
log.Debugf("Wait is over due to: %s", err)
|
|
||||||
err = serverEntryPoint.httpServer.Close()
|
|
||||||
if err != nil {
|
|
||||||
log.Error(err)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
cancel()
|
cancel()
|
||||||
log.Debugf("Entrypoint %s closed", serverEntryPointName)
|
log.Debugf("Entrypoint %s closed", serverEntryPointName)
|
||||||
}(sepn, sep)
|
}(sepn, sep)
|
||||||
|
@ -447,6 +523,16 @@ func (s *Server) setupServerEntryPoint(newServerEntryPointName string, newServer
|
||||||
serverEntryPoint.httpServer = newSrv
|
serverEntryPoint.httpServer = newSrv
|
||||||
serverEntryPoint.listener = listener
|
serverEntryPoint.listener = listener
|
||||||
|
|
||||||
|
serverEntryPoint.hijackConnectionTracker = newHijackConnectionTracker()
|
||||||
|
serverEntryPoint.httpServer.ConnState = func(conn net.Conn, state http.ConnState) {
|
||||||
|
switch state {
|
||||||
|
case http.StateHijacked:
|
||||||
|
serverEntryPoint.hijackConnectionTracker.AddHijackedConnection(conn)
|
||||||
|
case http.StateClosed:
|
||||||
|
serverEntryPoint.hijackConnectionTracker.RemoveHijackedConnection(conn)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
return serverEntryPoint
|
return serverEntryPoint
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -4,6 +4,7 @@ import (
|
||||||
"crypto/tls"
|
"crypto/tls"
|
||||||
"encoding/json"
|
"encoding/json"
|
||||||
"fmt"
|
"fmt"
|
||||||
|
"net"
|
||||||
"net/http"
|
"net/http"
|
||||||
"reflect"
|
"reflect"
|
||||||
"sort"
|
"sort"
|
||||||
|
@ -245,6 +246,15 @@ func (s *Server) buildForwarder(entryPointName string, entryPoint *configuration
|
||||||
forward.Rewriter(rewriter),
|
forward.Rewriter(rewriter),
|
||||||
forward.ResponseModifier(responseModifier),
|
forward.ResponseModifier(responseModifier),
|
||||||
forward.BufferPool(s.bufferPool),
|
forward.BufferPool(s.bufferPool),
|
||||||
|
forward.WebsocketConnectionClosedHook(func(req *http.Request, conn net.Conn) {
|
||||||
|
server := req.Context().Value(http.ServerContextKey).(*http.Server)
|
||||||
|
if server != nil {
|
||||||
|
connState := server.ConnState
|
||||||
|
if connState != nil {
|
||||||
|
connState(conn, http.StateClosed)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}),
|
||||||
)
|
)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, fmt.Errorf("error creating forwarder for frontend %s: %v", frontendName, err)
|
return nil, fmt.Errorf("error creating forwarder for frontend %s: %v", frontendName, err)
|
||||||
|
|
19
vendor/github.com/vulcand/oxy/forward/fwd.go
generated
vendored
19
vendor/github.com/vulcand/oxy/forward/fwd.go
generated
vendored
|
@ -7,6 +7,7 @@ import (
|
||||||
"crypto/tls"
|
"crypto/tls"
|
||||||
"errors"
|
"errors"
|
||||||
"fmt"
|
"fmt"
|
||||||
|
"net"
|
||||||
"net/http"
|
"net/http"
|
||||||
"net/http/httptest"
|
"net/http/httptest"
|
||||||
"net/http/httputil"
|
"net/http/httputil"
|
||||||
|
@ -126,6 +127,14 @@ func StateListener(stateListener UrlForwardingStateListener) optSetter {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// WebsocketConnectionClosedHook defines a hook called when websocket connection is closed
|
||||||
|
func WebsocketConnectionClosedHook(hook func(req *http.Request, conn net.Conn)) optSetter {
|
||||||
|
return func(f *Forwarder) error {
|
||||||
|
f.httpForwarder.websocketConnectionClosedHook = hook
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
// ResponseModifier defines a response modifier for the HTTP forwarder
|
// ResponseModifier defines a response modifier for the HTTP forwarder
|
||||||
func ResponseModifier(responseModifier func(*http.Response) error) optSetter {
|
func ResponseModifier(responseModifier func(*http.Response) error) optSetter {
|
||||||
return func(f *Forwarder) error {
|
return func(f *Forwarder) error {
|
||||||
|
@ -189,6 +198,7 @@ type httpForwarder struct {
|
||||||
log OxyLogger
|
log OxyLogger
|
||||||
|
|
||||||
bufferPool httputil.BufferPool
|
bufferPool httputil.BufferPool
|
||||||
|
websocketConnectionClosedHook func(req *http.Request, conn net.Conn)
|
||||||
}
|
}
|
||||||
|
|
||||||
const defaultFlushInterval = time.Duration(100) * time.Millisecond
|
const defaultFlushInterval = time.Duration(100) * time.Millisecond
|
||||||
|
@ -374,8 +384,13 @@ func (f *httpForwarder) serveWebSocket(w http.ResponseWriter, req *http.Request,
|
||||||
log.Errorf("vulcand/oxy/forward/websocket: Error while upgrading connection : %v", err)
|
log.Errorf("vulcand/oxy/forward/websocket: Error while upgrading connection : %v", err)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
defer underlyingConn.Close()
|
defer func() {
|
||||||
defer targetConn.Close()
|
underlyingConn.Close()
|
||||||
|
targetConn.Close()
|
||||||
|
if f.websocketConnectionClosedHook != nil {
|
||||||
|
f.websocketConnectionClosedHook(req, underlyingConn.UnderlyingConn())
|
||||||
|
}
|
||||||
|
}()
|
||||||
|
|
||||||
errClient := make(chan error, 1)
|
errClient := make(chan error, 1)
|
||||||
errBackend := make(chan error, 1)
|
errBackend := make(chan error, 1)
|
||||||
|
|
Loading…
Reference in a new issue