Merge branch 'v1.7' into master

This commit is contained in:
Fernandez Ludovic 2018-07-19 17:33:28 +02:00
commit d8f69700e6
10 changed files with 407 additions and 38 deletions

2
Gopkg.lock generated
View file

@ -1266,7 +1266,7 @@
"roundrobin", "roundrobin",
"utils" "utils"
] ]
revision = "f0cbb9d6b797d92d168b95b5c443a31dfa67ccd0" revision = "a3ed5f65204f4ffccbb56d58cec466cdb7ab730b"
[[projects]] [[projects]]
name = "github.com/vulcand/predicate" name = "github.com/vulcand/predicate"

View file

@ -33,7 +33,7 @@ func (p *Provider) buildConfigurationV2(containersInspected []dockerData) *types
"getDomain": label.GetFuncString(label.TraefikDomain, p.Domain), "getDomain": label.GetFuncString(label.TraefikDomain, p.Domain),
// Backend functions // Backend functions
"getIPAddress": p.getIPAddress, "getIPAddress": p.getDeprecatedIPAddress, // TODO: Should we expose getIPPort instead?
"getServers": p.getServers, "getServers": p.getServers,
"getMaxConn": label.GetMaxConn, "getMaxConn": label.GetMaxConn,
"getHealthCheck": label.GetHealthCheck, "getHealthCheck": label.GetHealthCheck,
@ -235,17 +235,6 @@ func (p Provider) getIPAddress(container dockerData) string {
return p.getIPAddress(parseContainer(containerInspected)) return p.getIPAddress(parseContainer(containerInspected))
} }
if p.UseBindPortIP {
port := getPortV1(container)
for netPort, portBindings := range container.NetworkSettings.Ports {
if string(netPort) == port+"/TCP" || string(netPort) == port+"/UDP" {
for _, p := range portBindings {
return p.HostIP
}
}
}
}
for _, network := range container.NetworkSettings.Networks { for _, network := range container.NetworkSettings.Networks {
return network.Addr return network.Addr
} }
@ -254,6 +243,16 @@ func (p Provider) getIPAddress(container dockerData) string {
return "" return ""
} }
// Deprecated: Please use getIPPort instead
func (p *Provider) getDeprecatedIPAddress(container dockerData) string {
ip, _, err := p.getIPPort(container)
if err != nil {
log.Warn(err)
return ""
}
return ip
}
// Escape beginning slash "/", convert all others to dash "-", and convert underscores "_" to dash "-" // Escape beginning slash "/", convert all others to dash "-", and convert underscores "_" to dash "-"
func getSubDomain(name string) string { func getSubDomain(name string) string {
return strings.Replace(strings.Replace(strings.TrimPrefix(name, "/"), "/", "-", -1), "_", "-", -1) return strings.Replace(strings.Replace(strings.TrimPrefix(name, "/"), "/", "-", -1), "_", "-", -1)
@ -322,13 +321,53 @@ func getPort(container dockerData) string {
return "" return ""
} }
func (p *Provider) getPortBinding(container dockerData) (*nat.PortBinding, error) {
port := getPort(container)
for netPort, portBindings := range container.NetworkSettings.Ports {
if strings.EqualFold(string(netPort), port+"/TCP") || strings.EqualFold(string(netPort), port+"/UDP") {
for _, p := range portBindings {
return &p, nil
}
}
}
return nil, fmt.Errorf("unable to find the external IP:Port for the container %q", container.Name)
}
func (p *Provider) getIPPort(container dockerData) (string, string, error) {
var ip, port string
if p.UseBindPortIP {
portBinding, err := p.getPortBinding(container)
if err != nil {
return "", "", fmt.Errorf("unable to find a binding for the container %q: ignoring server", container.Name)
}
if portBinding.HostIP == "0.0.0.0" {
return "", "", fmt.Errorf("cannot determine the IP address (got 0.0.0.0) for the container %q: ignoring server", container.Name)
}
ip = portBinding.HostIP
port = portBinding.HostPort
} else {
ip = p.getIPAddress(container)
port = getPort(container)
}
if len(ip) == 0 {
return "", "", fmt.Errorf("unable to find the IP address for the container %q: the server is ignored", container.Name)
}
return ip, port, nil
}
func (p *Provider) getServers(containers []dockerData) map[string]types.Server { func (p *Provider) getServers(containers []dockerData) map[string]types.Server {
var servers map[string]types.Server var servers map[string]types.Server
for _, container := range containers { for _, container := range containers {
ip := p.getIPAddress(container) ip, port, err := p.getIPPort(container)
if len(ip) == 0 { if err != nil {
log.Warnf("Unable to find the IP address for the container %q: the server is ignored.", container.Name) log.Warn(err)
continue continue
} }
@ -337,7 +376,6 @@ func (p *Provider) getServers(containers []dockerData) map[string]types.Server {
} }
protocol := label.GetStringValue(container.SegmentLabels, label.TraefikProtocol, label.DefaultProtocol) protocol := label.GetStringValue(container.SegmentLabels, label.TraefikProtocol, label.DefaultProtocol)
port := getPort(container)
serverURL := fmt.Sprintf("%s://%s", protocol, net.JoinHostPort(ip, port)) serverURL := fmt.Sprintf("%s://%s", protocol, net.JoinHostPort(ip, port))

View file

@ -1287,12 +1287,173 @@ func TestDockerGetIPAddress(t *testing.T) {
Network: "webnet", Network: "webnet",
} }
actual := provider.getIPAddress(dData) actual := provider.getDeprecatedIPAddress(dData)
assert.Equal(t, test.expected, actual) assert.Equal(t, test.expected, actual)
}) })
} }
} }
func TestDockerGetIPPort(t *testing.T) {
testCases := []struct {
desc string
container docker.ContainerJSON
ip, port string
expectsError bool
}{
{
desc: "label traefik.port not set, binding with ip:port should create a route to the bound ip:port",
container: containerJSON(
ports(nat.PortMap{
"80/tcp": []nat.PortBinding{
{
HostIP: "1.2.3.4",
HostPort: "8081",
},
},
}),
withNetwork("testnet", ipv4("10.11.12.13"))),
ip: "1.2.3.4",
port: "8081",
},
{
desc: "label traefik.port set, multiple bindings on different ports, uses the label to select the correct (first) binding",
container: containerJSON(
labels(map[string]string{
label.TraefikPort: "80",
}),
ports(nat.PortMap{
"80/tcp": []nat.PortBinding{
{
HostIP: "1.2.3.4",
HostPort: "8081",
},
},
"443/tcp": []nat.PortBinding{
{
HostIP: "5.6.7.8",
HostPort: "8082",
},
},
}),
withNetwork("testnet", ipv4("10.11.12.13"))),
ip: "1.2.3.4",
port: "8081",
},
{
desc: "label traefik.port set, multiple bindings on different ports, uses the label to select the correct (second) binding",
container: containerJSON(
labels(map[string]string{
label.TraefikPort: "443",
}),
ports(nat.PortMap{
"80/tcp": []nat.PortBinding{
{
HostIP: "1.2.3.4",
HostPort: "8081",
},
},
"443/tcp": []nat.PortBinding{
{
HostIP: "5.6.7.8",
HostPort: "8082",
},
},
}),
withNetwork("testnet", ipv4("10.11.12.13"))),
ip: "5.6.7.8",
port: "8082",
},
{
desc: "label traefik.port set, single binding with ip:port for the label, creates the route",
container: containerJSON(
labels(map[string]string{
label.TraefikPort: "443",
}),
ports(nat.PortMap{
"443/tcp": []nat.PortBinding{
{
HostIP: "5.6.7.8",
HostPort: "8082",
},
},
}),
withNetwork("testnet", ipv4("10.11.12.13"))),
ip: "5.6.7.8",
port: "8082",
},
{
desc: "label traefik.port not set, single binding with port only, server ignored",
container: containerJSON(
ports(nat.PortMap{
"80/tcp": []nat.PortBinding{
{
HostPort: "8082",
},
},
}),
withNetwork("testnet", ipv4("10.11.12.13"))),
expectsError: true,
},
{
desc: "label traefik.port not set, no binding, server ignored",
container: containerJSON(
withNetwork("testnet", ipv4("10.11.12.13"))),
expectsError: true,
},
{
desc: "label traefik.port set, no binding on the corresponding port, server ignored",
container: containerJSON(
labels(map[string]string{
label.TraefikPort: "80",
}),
ports(nat.PortMap{
"443/tcp": []nat.PortBinding{
{
HostIP: "5.6.7.8",
HostPort: "8082",
},
},
}),
withNetwork("testnet", ipv4("10.11.12.13"))),
expectsError: true,
},
{
desc: "label traefik.port set, no binding, server ignored",
container: containerJSON(
labels(map[string]string{
label.TraefikPort: "80",
}),
withNetwork("testnet", ipv4("10.11.12.13"))),
expectsError: true,
},
}
for _, test := range testCases {
test := test
t.Run(test.desc, func(t *testing.T) {
t.Parallel()
dData := parseContainer(test.container)
segmentProperties := label.ExtractTraefikLabels(dData.Labels)
dData.SegmentLabels = segmentProperties[""]
provider := &Provider{
Network: "webnet",
UseBindPortIP: true,
}
actualIP, actualPort, actualError := provider.getIPPort(dData)
if test.expectsError {
require.Error(t, actualError)
} else {
require.NoError(t, actualError)
}
assert.Equal(t, test.ip, actualIP)
assert.Equal(t, test.port, actualPort)
})
}
}
func TestDockerGetPort(t *testing.T) { func TestDockerGetPort(t *testing.T) {
testCases := []struct { testCases := []struct {
container docker.ContainerJSON container docker.ContainerJSON

View file

@ -933,7 +933,7 @@ func TestSwarmGetIPAddress(t *testing.T) {
segmentProperties := label.ExtractTraefikLabels(dData.Labels) segmentProperties := label.ExtractTraefikLabels(dData.Labels)
dData.SegmentLabels = segmentProperties[""] dData.SegmentLabels = segmentProperties[""]
actual := provider.getIPAddress(dData) actual := provider.getDeprecatedIPAddress(dData)
assert.Equal(t, test.expected, actual) assert.Equal(t, test.expected, actual)
}) })
} }

View file

@ -1,8 +1,10 @@
package docker package docker
import ( import (
"context"
"math" "math"
"strconv" "strconv"
"strings"
"text/template" "text/template"
"github.com/BurntSushi/ty/fun" "github.com/BurntSushi/ty/fun"
@ -19,7 +21,7 @@ func (p *Provider) buildConfigurationV1(containersInspected []dockerData) *types
"isBackendLBSwarm": isBackendLBSwarm, "isBackendLBSwarm": isBackendLBSwarm,
// Backend functions // Backend functions
"getIPAddress": p.getIPAddress, "getIPAddress": p.getIPAddressV1,
"getPort": getPortV1, "getPort": getPortV1,
"getWeight": getFuncIntLabelV1(label.TraefikWeight, label.DefaultWeight), "getWeight": getFuncIntLabelV1(label.TraefikWeight, label.DefaultWeight),
"getProtocol": getFuncStringLabelV1(label.TraefikProtocol, label.DefaultProtocol), "getProtocol": getFuncStringLabelV1(label.TraefikProtocol, label.DefaultProtocol),
@ -202,3 +204,60 @@ func (p Provider) containerFilterV1(container dockerData) bool {
return true return true
} }
func (p Provider) getIPAddressV1(container dockerData) string {
if value := label.GetStringValue(container.Labels, labelDockerNetwork, p.Network); value != "" {
networkSettings := container.NetworkSettings
if networkSettings.Networks != nil {
network := networkSettings.Networks[value]
if network != nil {
return network.Addr
}
log.Warnf("Could not find network named '%s' for container '%s'! Maybe you're missing the project's prefix in the label? Defaulting to first available network.", value, container.Name)
}
}
if container.NetworkSettings.NetworkMode.IsHost() {
if container.Node != nil {
if container.Node.IPAddress != "" {
return container.Node.IPAddress
}
}
return "127.0.0.1"
}
if container.NetworkSettings.NetworkMode.IsContainer() {
dockerClient, err := p.createClient()
if err != nil {
log.Warnf("Unable to get IP address for container %s, error: %s", container.Name, err)
return ""
}
connectedContainer := container.NetworkSettings.NetworkMode.ConnectedContainer()
containerInspected, err := dockerClient.ContainerInspect(context.Background(), connectedContainer)
if err != nil {
log.Warnf("Unable to get IP address for container %s : Failed to inspect container ID %s, error: %s", container.Name, connectedContainer, err)
return ""
}
return p.getIPAddress(parseContainer(containerInspected))
}
if p.UseBindPortIP {
port := getPortV1(container)
for netPort, portBindings := range container.NetworkSettings.Ports {
if strings.EqualFold(string(netPort), port+"/TCP") || strings.EqualFold(string(netPort), port+"/UDP") {
for _, p := range portBindings {
return p.HostIP
}
}
}
}
for _, network := range container.NetworkSettings.Networks {
return network.Addr
}
log.Warnf("Unable to find the IP address for the container %q.", container.Name)
return ""
}

View file

@ -898,7 +898,7 @@ func TestDockerGetIPAddressV1(t *testing.T) {
t.Parallel() t.Parallel()
dData := parseContainer(test.container) dData := parseContainer(test.container)
provider := &Provider{} provider := &Provider{}
actual := provider.getIPAddress(dData) actual := provider.getDeprecatedIPAddress(dData)
if actual != test.expected { if actual != test.expected {
t.Errorf("expected %q, got %q", test.expected, actual) t.Errorf("expected %q, got %q", test.expected, actual)
} }

View file

@ -667,7 +667,7 @@ func TestSwarmGetIPAddressV1(t *testing.T) {
SwarmMode: true, SwarmMode: true,
} }
actual := provider.getIPAddress(dData) actual := provider.getDeprecatedIPAddress(dData)
if actual != test.expected { if actual != test.expected {
t.Errorf("expected %q, got %q", test.expected, actual) t.Errorf("expected %q, got %q", test.expected, actual)
} }

View file

@ -40,6 +40,59 @@ import (
var httpServerLogger = stdlog.New(log.WriterLevel(logrus.DebugLevel), "", 0) var httpServerLogger = stdlog.New(log.WriterLevel(logrus.DebugLevel), "", 0)
func newHijackConnectionTracker() *hijackConnectionTracker {
return &hijackConnectionTracker{
conns: make(map[net.Conn]struct{}),
}
}
type hijackConnectionTracker struct {
conns map[net.Conn]struct{}
lock sync.RWMutex
}
// AddHijackedConnection add a connection in the tracked connections list
func (h *hijackConnectionTracker) AddHijackedConnection(conn net.Conn) {
h.lock.Lock()
defer h.lock.Unlock()
h.conns[conn] = struct{}{}
}
// RemoveHijackedConnection remove a connection from the tracked connections list
func (h *hijackConnectionTracker) RemoveHijackedConnection(conn net.Conn) {
h.lock.Lock()
defer h.lock.Unlock()
delete(h.conns, conn)
}
// Shutdown wait for the connection closing
func (h *hijackConnectionTracker) Shutdown(ctx context.Context) error {
ticker := time.NewTicker(500 * time.Millisecond)
defer ticker.Stop()
for {
h.lock.RLock()
if len(h.conns) == 0 {
return nil
}
h.lock.RUnlock()
select {
case <-ctx.Done():
return ctx.Err()
case <-ticker.C:
}
}
}
// Close close all the connections in the tracked connections list
func (h *hijackConnectionTracker) Close() {
for conn := range h.conns {
if err := conn.Close(); err != nil {
log.Errorf("Error while closing Hijacked conn: %v", err)
}
delete(h.conns, conn)
}
}
// Server is the reverse-proxy/load-balancer engine // Server is the reverse-proxy/load-balancer engine
type Server struct { type Server struct {
serverEntryPoints serverEntryPoints serverEntryPoints serverEntryPoints
@ -74,12 +127,41 @@ type EntryPoint struct {
type serverEntryPoints map[string]*serverEntryPoint type serverEntryPoints map[string]*serverEntryPoint
type serverEntryPoint struct { type serverEntryPoint struct {
httpServer *h2c.Server httpServer *h2c.Server
listener net.Listener listener net.Listener
httpRouter *middlewares.HandlerSwitcher httpRouter *middlewares.HandlerSwitcher
certs *traefiktls.CertificateStore certs *traefiktls.CertificateStore
onDemandListener func(string) (*tls.Certificate, error) onDemandListener func(string) (*tls.Certificate, error)
tlsALPNGetter func(string) (*tls.Certificate, error) tlsALPNGetter func(string) (*tls.Certificate, error)
hijackConnectionTracker *hijackConnectionTracker
}
func (s serverEntryPoint) Shutdown(ctx context.Context) {
var wg sync.WaitGroup
wg.Add(1)
go func() {
defer wg.Done()
if err := s.httpServer.Shutdown(ctx); err != nil {
if ctx.Err() == context.DeadlineExceeded {
log.Debugf("Wait server shutdown is over due to: %s", err)
err = s.httpServer.Close()
if err != nil {
log.Error(err)
}
}
}
}()
wg.Add(1)
go func() {
defer wg.Done()
if err := s.hijackConnectionTracker.Shutdown(ctx); err != nil {
if ctx.Err() == context.DeadlineExceeded {
log.Debugf("Wait hijack connection is over due to: %s", err)
s.hijackConnectionTracker.Close()
}
}
}()
wg.Wait()
} }
// NewServer returns an initialized Server. // NewServer returns an initialized Server.
@ -187,13 +269,7 @@ func (s *Server) Stop() {
graceTimeOut := time.Duration(s.globalConfiguration.LifeCycle.GraceTimeOut) graceTimeOut := time.Duration(s.globalConfiguration.LifeCycle.GraceTimeOut)
ctx, cancel := context.WithTimeout(context.Background(), graceTimeOut) ctx, cancel := context.WithTimeout(context.Background(), graceTimeOut)
log.Debugf("Waiting %s seconds before killing connections on entrypoint %s...", graceTimeOut, serverEntryPointName) log.Debugf("Waiting %s seconds before killing connections on entrypoint %s...", graceTimeOut, serverEntryPointName)
if err := serverEntryPoint.httpServer.Shutdown(ctx); err != nil { serverEntryPoint.Shutdown(ctx)
log.Debugf("Wait is over due to: %s", err)
err = serverEntryPoint.httpServer.Close()
if err != nil {
log.Error(err)
}
}
cancel() cancel()
log.Debugf("Entrypoint %s closed", serverEntryPointName) log.Debugf("Entrypoint %s closed", serverEntryPointName)
}(sepn, sep) }(sepn, sep)
@ -447,6 +523,16 @@ func (s *Server) setupServerEntryPoint(newServerEntryPointName string, newServer
serverEntryPoint.httpServer = newSrv serverEntryPoint.httpServer = newSrv
serverEntryPoint.listener = listener serverEntryPoint.listener = listener
serverEntryPoint.hijackConnectionTracker = newHijackConnectionTracker()
serverEntryPoint.httpServer.ConnState = func(conn net.Conn, state http.ConnState) {
switch state {
case http.StateHijacked:
serverEntryPoint.hijackConnectionTracker.AddHijackedConnection(conn)
case http.StateClosed:
serverEntryPoint.hijackConnectionTracker.RemoveHijackedConnection(conn)
}
}
return serverEntryPoint return serverEntryPoint
} }

View file

@ -4,6 +4,7 @@ import (
"crypto/tls" "crypto/tls"
"encoding/json" "encoding/json"
"fmt" "fmt"
"net"
"net/http" "net/http"
"reflect" "reflect"
"sort" "sort"
@ -245,6 +246,15 @@ func (s *Server) buildForwarder(entryPointName string, entryPoint *configuration
forward.Rewriter(rewriter), forward.Rewriter(rewriter),
forward.ResponseModifier(responseModifier), forward.ResponseModifier(responseModifier),
forward.BufferPool(s.bufferPool), forward.BufferPool(s.bufferPool),
forward.WebsocketConnectionClosedHook(func(req *http.Request, conn net.Conn) {
server := req.Context().Value(http.ServerContextKey).(*http.Server)
if server != nil {
connState := server.ConnState
if connState != nil {
connState(conn, http.StateClosed)
}
}
}),
) )
if err != nil { if err != nil {
return nil, fmt.Errorf("error creating forwarder for frontend %s: %v", frontendName, err) return nil, fmt.Errorf("error creating forwarder for frontend %s: %v", frontendName, err)

View file

@ -7,6 +7,7 @@ import (
"crypto/tls" "crypto/tls"
"errors" "errors"
"fmt" "fmt"
"net"
"net/http" "net/http"
"net/http/httptest" "net/http/httptest"
"net/http/httputil" "net/http/httputil"
@ -126,6 +127,14 @@ func StateListener(stateListener UrlForwardingStateListener) optSetter {
} }
} }
// WebsocketConnectionClosedHook defines a hook called when websocket connection is closed
func WebsocketConnectionClosedHook(hook func(req *http.Request, conn net.Conn)) optSetter {
return func(f *Forwarder) error {
f.httpForwarder.websocketConnectionClosedHook = hook
return nil
}
}
// ResponseModifier defines a response modifier for the HTTP forwarder // ResponseModifier defines a response modifier for the HTTP forwarder
func ResponseModifier(responseModifier func(*http.Response) error) optSetter { func ResponseModifier(responseModifier func(*http.Response) error) optSetter {
return func(f *Forwarder) error { return func(f *Forwarder) error {
@ -188,7 +197,8 @@ type httpForwarder struct {
log OxyLogger log OxyLogger
bufferPool httputil.BufferPool bufferPool httputil.BufferPool
websocketConnectionClosedHook func(req *http.Request, conn net.Conn)
} }
const defaultFlushInterval = time.Duration(100) * time.Millisecond const defaultFlushInterval = time.Duration(100) * time.Millisecond
@ -374,8 +384,13 @@ func (f *httpForwarder) serveWebSocket(w http.ResponseWriter, req *http.Request,
log.Errorf("vulcand/oxy/forward/websocket: Error while upgrading connection : %v", err) log.Errorf("vulcand/oxy/forward/websocket: Error while upgrading connection : %v", err)
return return
} }
defer underlyingConn.Close() defer func() {
defer targetConn.Close() underlyingConn.Close()
targetConn.Close()
if f.websocketConnectionClosedHook != nil {
f.websocketConnectionClosedHook(req, underlyingConn.UnderlyingConn())
}
}()
errClient := make(chan error, 1) errClient := make(chan error, 1)
errBackend := make(chan error, 1) errBackend := make(chan error, 1)