diff --git a/integration/healthcheck_test.go b/integration/healthcheck_test.go index 3dff3583d..128a447e8 100644 --- a/integration/healthcheck_test.go +++ b/integration/healthcheck_test.go @@ -6,6 +6,7 @@ import ( "io" "net/http" "os" + "strings" "time" "github.com/go-check/check" @@ -319,30 +320,51 @@ func (s *HealthCheckSuite) TestPropagate(c *check.C) { try.Sleep(time.Second) - // Verify load-balancing on root still works, and that we're getting wsp2, wsp4, wsp2, wsp4, etc. - var want string - for i := 0; i < 4; i++ { - if i%2 == 0 { - want = `IP: ` + s.whoami4IP - } else { - want = `IP: ` + s.whoami2IP - } + want2 := `IP: ` + s.whoami2IP + want4 := `IP: ` + s.whoami4IP + // Verify load-balancing on root still works, and that we're getting an alternation between wsp2, and wsp4. + reachedServers := make(map[string]int) + for i := 0; i < 4; i++ { resp, err := client.Do(rootReq) c.Assert(err, checker.IsNil) body, err := io.ReadAll(resp.Body) c.Assert(err, checker.IsNil) - c.Assert(string(body), checker.Contains, want) + if reachedServers[s.whoami4IP] > reachedServers[s.whoami2IP] { + c.Assert(string(body), checker.Contains, want2) + reachedServers[s.whoami2IP]++ + continue + } + + if reachedServers[s.whoami2IP] > reachedServers[s.whoami4IP] { + c.Assert(string(body), checker.Contains, want4) + reachedServers[s.whoami4IP]++ + continue + } + + // First iteration, so we can't tell whether it's going to be wsp2, or wsp4. + if strings.Contains(string(body), `IP: `+s.whoami4IP) { + reachedServers[s.whoami4IP]++ + continue + } + + if strings.Contains(string(body), `IP: `+s.whoami2IP) { + reachedServers[s.whoami2IP]++ + continue + } } + c.Assert(reachedServers[s.whoami2IP], checker.Equals, 2) + c.Assert(reachedServers[s.whoami4IP], checker.Equals, 2) + fooReq, err := http.NewRequest(http.MethodGet, "http://127.0.0.1:8000", nil) c.Assert(err, checker.IsNil) fooReq.Host = "foo.localhost" // Verify load-balancing on foo still works, and that we're getting wsp2, wsp2, wsp2, wsp2, etc. - want = `IP: ` + s.whoami2IP + want := `IP: ` + s.whoami2IP for i := 0; i < 4; i++ { resp, err := client.Do(fooReq) c.Assert(err, checker.IsNil) @@ -407,43 +429,109 @@ func (s *HealthCheckSuite) TestPropagate(c *check.C) { try.Sleep(time.Second) // Verify everything is up on root router. - wantIPs := []string{s.whoami3IP, s.whoami1IP, s.whoami4IP, s.whoami2IP} + reachedServers = make(map[string]int) for i := 0; i < 4; i++ { - want := `IP: ` + wantIPs[i] resp, err := client.Do(rootReq) c.Assert(err, checker.IsNil) body, err := io.ReadAll(resp.Body) c.Assert(err, checker.IsNil) - c.Assert(string(body), checker.Contains, want) + if strings.Contains(string(body), `IP: `+s.whoami1IP) { + reachedServers[s.whoami1IP]++ + continue + } + + if strings.Contains(string(body), `IP: `+s.whoami2IP) { + reachedServers[s.whoami2IP]++ + continue + } + + if strings.Contains(string(body), `IP: `+s.whoami3IP) { + reachedServers[s.whoami3IP]++ + continue + } + + if strings.Contains(string(body), `IP: `+s.whoami4IP) { + reachedServers[s.whoami4IP]++ + continue + } } + c.Assert(reachedServers[s.whoami1IP], checker.Equals, 1) + c.Assert(reachedServers[s.whoami2IP], checker.Equals, 1) + c.Assert(reachedServers[s.whoami3IP], checker.Equals, 1) + c.Assert(reachedServers[s.whoami4IP], checker.Equals, 1) + // Verify everything is up on foo router. - wantIPs = []string{s.whoami1IP, s.whoami1IP, s.whoami3IP, s.whoami2IP} + reachedServers = make(map[string]int) for i := 0; i < 4; i++ { - want := `IP: ` + wantIPs[i] resp, err := client.Do(fooReq) c.Assert(err, checker.IsNil) body, err := io.ReadAll(resp.Body) c.Assert(err, checker.IsNil) - c.Assert(string(body), checker.Contains, want) + if strings.Contains(string(body), `IP: `+s.whoami1IP) { + reachedServers[s.whoami1IP]++ + continue + } + + if strings.Contains(string(body), `IP: `+s.whoami2IP) { + reachedServers[s.whoami2IP]++ + continue + } + + if strings.Contains(string(body), `IP: `+s.whoami3IP) { + reachedServers[s.whoami3IP]++ + continue + } + + if strings.Contains(string(body), `IP: `+s.whoami4IP) { + reachedServers[s.whoami4IP]++ + continue + } } + c.Assert(reachedServers[s.whoami1IP], checker.Equals, 2) + c.Assert(reachedServers[s.whoami2IP], checker.Equals, 1) + c.Assert(reachedServers[s.whoami3IP], checker.Equals, 1) + c.Assert(reachedServers[s.whoami4IP], checker.Equals, 0) + // Verify everything is up on bar router. - wantIPs = []string{s.whoami1IP, s.whoami1IP, s.whoami3IP, s.whoami2IP} + reachedServers = make(map[string]int) for i := 0; i < 4; i++ { - want := `IP: ` + wantIPs[i] resp, err := client.Do(barReq) c.Assert(err, checker.IsNil) body, err := io.ReadAll(resp.Body) c.Assert(err, checker.IsNil) - c.Assert(string(body), checker.Contains, want) + if strings.Contains(string(body), `IP: `+s.whoami1IP) { + reachedServers[s.whoami1IP]++ + continue + } + + if strings.Contains(string(body), `IP: `+s.whoami2IP) { + reachedServers[s.whoami2IP]++ + continue + } + + if strings.Contains(string(body), `IP: `+s.whoami3IP) { + reachedServers[s.whoami3IP]++ + continue + } + + if strings.Contains(string(body), `IP: `+s.whoami4IP) { + reachedServers[s.whoami4IP]++ + continue + } } + + c.Assert(reachedServers[s.whoami1IP], checker.Equals, 2) + c.Assert(reachedServers[s.whoami2IP], checker.Equals, 1) + c.Assert(reachedServers[s.whoami3IP], checker.Equals, 1) + c.Assert(reachedServers[s.whoami4IP], checker.Equals, 0) } func (s *HealthCheckSuite) TestPropagateNoHealthCheck(c *check.C) { diff --git a/integration/retry_test.go b/integration/retry_test.go index 710595777..6a146e5ce 100644 --- a/integration/retry_test.go +++ b/integration/retry_test.go @@ -36,13 +36,11 @@ func (s *RetrySuite) TestRetry(c *check.C) { err = try.GetRequest("http://127.0.0.1:8080/api/rawdata", 60*time.Second, try.BodyContains("PathPrefix(`/`)")) c.Assert(err, checker.IsNil) - start := time.Now() - // This simulates a DialTimeout when connecting to the backend server. response, err := http.Get("http://127.0.0.1:8000/") - duration, allowed := time.Since(start), time.Millisecond*250 c.Assert(err, checker.IsNil) + + // The test only verifies that the retry middleware makes sure that the working service is eventually reached. c.Assert(response.StatusCode, checker.Equals, http.StatusOK) - c.Assert(int64(duration), checker.LessThan, int64(allowed)) } func (s *RetrySuite) TestRetryBackoff(c *check.C) { @@ -58,16 +56,11 @@ func (s *RetrySuite) TestRetryBackoff(c *check.C) { err = try.GetRequest("http://127.0.0.1:8080/api/rawdata", 60*time.Second, try.BodyContains("PathPrefix(`/`)")) c.Assert(err, checker.IsNil) - start := time.Now() - // This simulates a DialTimeout when connecting to the backend server. response, err := http.Get("http://127.0.0.1:8000/") - duration := time.Since(start) - // test case delays: 500 + 700 + 1000ms with randomization. It should be safely > 1500ms - minAllowed := time.Millisecond * 1400 - c.Assert(err, checker.IsNil) + + // The test only verifies that the retry middleware allows finally to reach the working service. c.Assert(response.StatusCode, checker.Equals, http.StatusOK) - c.Assert(int64(duration), checker.GreaterThan, int64(minAllowed)) } func (s *RetrySuite) TestRetryWebsocket(c *check.C) { @@ -83,11 +76,12 @@ func (s *RetrySuite) TestRetryWebsocket(c *check.C) { err = try.GetRequest("http://127.0.0.1:8080/api/rawdata", 60*time.Second, try.BodyContains("PathPrefix(`/`)")) c.Assert(err, checker.IsNil) - // This simulates a DialTimeout when connecting to the backend server. + // The test only verifies that the retry middleware makes sure that the working service is eventually reached. _, response, err := websocket.DefaultDialer.Dial("ws://127.0.0.1:8000/echo", nil) c.Assert(err, checker.IsNil) c.Assert(response.StatusCode, checker.Equals, http.StatusSwitchingProtocols) + // The test verifies a second time that the working service is eventually reached. _, response, err = websocket.DefaultDialer.Dial("ws://127.0.0.1:8000/echo", nil) c.Assert(err, checker.IsNil) c.Assert(response.StatusCode, checker.Equals, http.StatusSwitchingProtocols) diff --git a/integration/udp_test.go b/integration/udp_test.go index 46eb25015..c5ee868a1 100644 --- a/integration/udp_test.go +++ b/integration/udp_test.go @@ -76,7 +76,7 @@ func (s *UDPSuite) TestWRR(c *check.C) { stop := make(chan struct{}) go func() { call := map[string]int{} - for i := 0; i < 4; i++ { + for i := 0; i < 8; i++ { out, err := guessWhoUDP("127.0.0.1:8093") c.Assert(err, checker.IsNil) switch { @@ -90,7 +90,7 @@ func (s *UDPSuite) TestWRR(c *check.C) { call["unknown"]++ } } - c.Assert(call, checker.DeepEquals, map[string]int{"whoami-a": 2, "whoami-b": 1, "whoami-c": 1}) + c.Assert(call, checker.DeepEquals, map[string]int{"whoami-a": 3, "whoami-b": 2, "whoami-c": 3}) close(stop) }() diff --git a/pkg/server/service/service.go b/pkg/server/service/service.go index b331f753f..67c2d97d1 100644 --- a/pkg/server/service/service.go +++ b/pkg/server/service/service.go @@ -4,6 +4,7 @@ import ( "context" "errors" "fmt" + "math/rand" "net/http" "net/http/httputil" "net/url" @@ -51,6 +52,7 @@ func NewManager(configs map[string]*runtime.ServiceInfo, metricsRegistry metrics roundTripperManager: roundTripperManager, balancers: make(map[string]healthcheck.Balancers), configs: configs, + rand: rand.New(rand.NewSource(time.Now().UnixNano())), } } @@ -66,6 +68,7 @@ type Manager struct { // which is why there is not just one Balancer per service name. balancers map[string]healthcheck.Balancers configs map[string]*runtime.ServiceInfo + rand *rand.Rand // For the initial shuffling of load-balancers. } // BuildHTTP Creates a http.Handler for a service configuration. @@ -212,7 +215,7 @@ func (m *Manager) getWRRServiceHandler(ctx context.Context, serviceName string, } balancer := wrr.New(config.Sticky, config.HealthCheck) - for _, service := range config.Services { + for _, service := range shuffle(config.Services, m.rand) { serviceHandler, err := m.BuildHTTP(ctx, service.Name) if err != nil { return nil, err @@ -414,7 +417,7 @@ func (m *Manager) getLoadBalancer(ctx context.Context, serviceName string, servi func (m *Manager) upsertServers(ctx context.Context, lb healthcheck.BalancerHandler, servers []dynamic.Server) error { logger := log.FromContext(ctx) - for name, srv := range servers { + for name, srv := range shuffle(servers, m.rand) { u, err := url.Parse(srv.URL) if err != nil { return fmt.Errorf("error parsing server URL %s: %w", srv.URL, err) @@ -443,3 +446,11 @@ func convertSameSite(sameSite string) http.SameSite { return 0 } } + +func shuffle[T any](values []T, r *rand.Rand) []T { + shuffled := make([]T, len(values)) + copy(shuffled, values) + r.Shuffle(len(shuffled), func(i, j int) { shuffled[i], shuffled[j] = shuffled[j], shuffled[i] }) + + return shuffled +} diff --git a/pkg/server/service/service_test.go b/pkg/server/service/service_test.go index f479263ad..4d7ef7268 100644 --- a/pkg/server/service/service_test.go +++ b/pkg/server/service/service_test.go @@ -111,6 +111,7 @@ func TestGetLoadBalancerServiceHandler(t *testing.T) { type ExpectedResult struct { StatusCode int XFrom string + LoadBalanced bool SecureCookie bool HTTPOnlyCookie bool } @@ -139,12 +140,12 @@ func TestGetLoadBalancerServiceHandler(t *testing.T) { }, expected: []ExpectedResult{ { - StatusCode: http.StatusOK, - XFrom: "first", + StatusCode: http.StatusOK, + LoadBalanced: true, }, { - StatusCode: http.StatusOK, - XFrom: "second", + StatusCode: http.StatusOK, + LoadBalanced: true, }, }, }, @@ -193,11 +194,9 @@ func TestGetLoadBalancerServiceHandler(t *testing.T) { expected: []ExpectedResult{ { StatusCode: http.StatusOK, - XFrom: "first", }, { StatusCode: http.StatusOK, - XFrom: "first", }, }, }, @@ -302,13 +301,27 @@ func TestGetLoadBalancerServiceHandler(t *testing.T) { req.Header.Set("Cookie", test.cookieRawValue) } + var prevXFrom string for _, expected := range test.expected { recorder := httptest.NewRecorder() handler.ServeHTTP(recorder, req) assert.Equal(t, expected.StatusCode, recorder.Code) - assert.Equal(t, expected.XFrom, recorder.Header().Get("X-From")) + + if expected.XFrom != "" { + assert.Equal(t, expected.XFrom, recorder.Header().Get("X-From")) + } + + xFrom := recorder.Header().Get("X-From") + if prevXFrom != "" { + if expected.LoadBalanced { + assert.NotEqual(t, prevXFrom, xFrom) + } else { + assert.Equal(t, prevXFrom, xFrom) + } + } + prevXFrom = xFrom cookieHeader := recorder.Header().Get("Set-Cookie") if len(cookieHeader) > 0 { diff --git a/pkg/server/service/tcp/service.go b/pkg/server/service/tcp/service.go index 69c9ccb31..e7dc4dbc6 100644 --- a/pkg/server/service/tcp/service.go +++ b/pkg/server/service/tcp/service.go @@ -4,6 +4,7 @@ import ( "context" "errors" "fmt" + "math/rand" "net" "time" @@ -16,12 +17,14 @@ import ( // Manager is the TCPHandlers factory. type Manager struct { configs map[string]*runtime.TCPServiceInfo + rand *rand.Rand // For the initial shuffling of load-balancers. } // NewManager creates a new manager. func NewManager(conf *runtime.Configuration) *Manager { return &Manager{ configs: conf.TCPServices, + rand: rand.New(rand.NewSource(time.Now().UnixNano())), } } @@ -53,7 +56,7 @@ func (m *Manager) BuildTCP(rootCtx context.Context, serviceName string) (tcp.Han } duration := time.Duration(*conf.LoadBalancer.TerminationDelay) * time.Millisecond - for name, server := range conf.LoadBalancer.Servers { + for name, server := range shuffle(conf.LoadBalancer.Servers, m.rand) { if _, _, err := net.SplitHostPort(server.Address); err != nil { logger.Errorf("In service %q: %v", serviceQualifiedName, err) continue @@ -71,7 +74,8 @@ func (m *Manager) BuildTCP(rootCtx context.Context, serviceName string) (tcp.Han return loadBalancer, nil case conf.Weighted != nil: loadBalancer := tcp.NewWRRLoadBalancer() - for _, service := range conf.Weighted.Services { + + for _, service := range shuffle(conf.Weighted.Services, m.rand) { handler, err := m.BuildTCP(rootCtx, service.Name) if err != nil { logger.Errorf("In service %q: %v", serviceQualifiedName, err) @@ -86,3 +90,11 @@ func (m *Manager) BuildTCP(rootCtx context.Context, serviceName string) (tcp.Han return nil, err } } + +func shuffle[T any](values []T, r *rand.Rand) []T { + shuffled := make([]T, len(values)) + copy(shuffled, values) + r.Shuffle(len(shuffled), func(i, j int) { shuffled[i], shuffled[j] = shuffled[j], shuffled[i] }) + + return shuffled +} diff --git a/pkg/server/service/udp/service.go b/pkg/server/service/udp/service.go index 0fa3fd386..f388a7824 100644 --- a/pkg/server/service/udp/service.go +++ b/pkg/server/service/udp/service.go @@ -4,7 +4,9 @@ import ( "context" "errors" "fmt" + "math/rand" "net" + "time" "github.com/traefik/traefik/v2/pkg/config/runtime" "github.com/traefik/traefik/v2/pkg/log" @@ -15,12 +17,14 @@ import ( // Manager handles UDP services creation. type Manager struct { configs map[string]*runtime.UDPServiceInfo + rand *rand.Rand // For the initial shuffling of load-balancers. } // NewManager creates a new manager. func NewManager(conf *runtime.Configuration) *Manager { return &Manager{ configs: conf.UDPServices, + rand: rand.New(rand.NewSource(time.Now().UnixNano())), } } @@ -46,7 +50,7 @@ func (m *Manager) BuildUDP(rootCtx context.Context, serviceName string) (udp.Han case conf.LoadBalancer != nil: loadBalancer := udp.NewWRRLoadBalancer() - for name, server := range conf.LoadBalancer.Servers { + for name, server := range shuffle(conf.LoadBalancer.Servers, m.rand) { if _, _, err := net.SplitHostPort(server.Address); err != nil { logger.Errorf("In udp service %q: %v", serviceQualifiedName, err) continue @@ -64,7 +68,8 @@ func (m *Manager) BuildUDP(rootCtx context.Context, serviceName string) (udp.Han return loadBalancer, nil case conf.Weighted != nil: loadBalancer := udp.NewWRRLoadBalancer() - for _, service := range conf.Weighted.Services { + + for _, service := range shuffle(conf.Weighted.Services, m.rand) { handler, err := m.BuildUDP(rootCtx, service.Name) if err != nil { logger.Errorf("In udp service %q: %v", serviceQualifiedName, err) @@ -79,3 +84,11 @@ func (m *Manager) BuildUDP(rootCtx context.Context, serviceName string) (udp.Han return nil, err } } + +func shuffle[T any](values []T, r *rand.Rand) []T { + shuffled := make([]T, len(values)) + copy(shuffled, values) + r.Shuffle(len(shuffled), func(i, j int) { shuffled[i], shuffled[j] = shuffled[j], shuffled[i] }) + + return shuffled +}