diff --git a/pkg/tcp/wrr_load_balancer.go b/pkg/tcp/wrr_load_balancer.go index 8f26251ef..236fd8007 100644 --- a/pkg/tcp/wrr_load_balancer.go +++ b/pkg/tcp/wrr_load_balancer.go @@ -15,7 +15,7 @@ type server struct { // WRRLoadBalancer is a naive RoundRobin load balancer for TCP services. type WRRLoadBalancer struct { servers []server - lock sync.RWMutex + lock sync.Mutex currentWeight int index int } @@ -29,16 +29,16 @@ func NewWRRLoadBalancer() *WRRLoadBalancer { // ServeTCP forwards the connection to the right service. func (b *WRRLoadBalancer) ServeTCP(conn WriteCloser) { - if len(b.servers) == 0 { - log.WithoutContext().Error("no available server") - return - } - + b.lock.Lock() next, err := b.next() + b.lock.Unlock() + if err != nil { log.WithoutContext().Errorf("Error during load balancing: %v", err) conn.Close() + return } + next.ServeTCP(conn) } @@ -50,6 +50,9 @@ func (b *WRRLoadBalancer) AddServer(serverHandler Handler) { // AddWeightServer appends a server to the existing list with a weight. func (b *WRRLoadBalancer) AddWeightServer(serverHandler Handler, weight *int) { + b.lock.Lock() + defer b.lock.Unlock() + w := 1 if weight != nil { w = *weight @@ -87,9 +90,6 @@ func gcd(a, b int) int { } func (b *WRRLoadBalancer) next() (Handler, error) { - b.lock.Lock() - defer b.lock.Unlock() - if len(b.servers) == 0 { return nil, fmt.Errorf("no servers in the pool") } @@ -98,10 +98,14 @@ func (b *WRRLoadBalancer) next() (Handler, error) { // it calculates the GCD and subtracts it on every iteration, what interleaves servers // and allows us not to build an iterator every time we readjust weights - // GCD across all enabled servers - gcd := b.weightGcd() // Maximum weight across all enabled servers max := b.maxWeight() + if max == 0 { + return nil, fmt.Errorf("all servers have 0 weight") + } + + // GCD across all enabled servers + gcd := b.weightGcd() for { b.index = (b.index + 1) % len(b.servers) @@ -109,9 +113,6 @@ func (b *WRRLoadBalancer) next() (Handler, error) { b.currentWeight -= gcd if b.currentWeight <= 0 { b.currentWeight = max - if b.currentWeight == 0 { - return nil, fmt.Errorf("all servers have 0 weight") - } } } srv := b.servers[b.index] diff --git a/pkg/tcp/wrr_load_balancer_test.go b/pkg/tcp/wrr_load_balancer_test.go index 933f48976..a07f51762 100644 --- a/pkg/tcp/wrr_load_balancer_test.go +++ b/pkg/tcp/wrr_load_balancer_test.go @@ -10,7 +10,8 @@ import ( ) type fakeConn struct { - call map[string]int + writeCall map[string]int + closeCall int } func (f *fakeConn) Read(b []byte) (n int, err error) { @@ -18,12 +19,13 @@ func (f *fakeConn) Read(b []byte) (n int, err error) { } func (f *fakeConn) Write(b []byte) (n int, err error) { - f.call[string(b)]++ + f.writeCall[string(b)]++ return len(b), nil } func (f *fakeConn) Close() error { - panic("implement me") + f.closeCall++ + return nil } func (f *fakeConn) LocalAddr() net.Addr { @@ -55,7 +57,8 @@ func TestLoadBalancing(t *testing.T) { desc string serversWeight map[string]int totalCall int - expected map[string]int + expectedWrite map[string]int + expectedClose int }{ { desc: "RoundRobin", @@ -64,7 +67,7 @@ func TestLoadBalancing(t *testing.T) { "h2": 1, }, totalCall: 4, - expected: map[string]int{ + expectedWrite: map[string]int{ "h1": 2, "h2": 2, }, @@ -76,7 +79,7 @@ func TestLoadBalancing(t *testing.T) { "h2": 1, }, totalCall: 4, - expected: map[string]int{ + expectedWrite: map[string]int{ "h1": 3, "h2": 1, }, @@ -88,22 +91,33 @@ func TestLoadBalancing(t *testing.T) { "h2": 1, }, totalCall: 16, - expected: map[string]int{ + expectedWrite: map[string]int{ "h1": 12, "h2": 4, }, }, { - desc: "WeighedRoundRobin with 0 weight server", + desc: "WeighedRoundRobin with one 0 weight server", serversWeight: map[string]int{ "h1": 3, "h2": 0, }, totalCall: 16, - expected: map[string]int{ + expectedWrite: map[string]int{ "h1": 16, }, }, + { + desc: "WeighedRoundRobin with all servers with 0 weight", + serversWeight: map[string]int{ + "h1": 0, + "h2": 0, + "h3": 0, + }, + totalCall: 10, + expectedWrite: map[string]int{}, + expectedClose: 10, + }, } for _, test := range testCases { @@ -120,12 +134,13 @@ func TestLoadBalancing(t *testing.T) { }), &weight) } - conn := &fakeConn{call: make(map[string]int)} + conn := &fakeConn{writeCall: make(map[string]int)} for i := 0; i < test.totalCall; i++ { balancer.ServeTCP(conn) } - assert.Equal(t, test.expected, conn.call) + assert.Equal(t, test.expectedWrite, conn.writeCall) + assert.Equal(t, test.expectedClose, conn.closeCall) }) } } diff --git a/pkg/udp/wrr_load_balancer.go b/pkg/udp/wrr_load_balancer.go index 384b699fc..7dd2a9c07 100644 --- a/pkg/udp/wrr_load_balancer.go +++ b/pkg/udp/wrr_load_balancer.go @@ -15,7 +15,7 @@ type server struct { // WRRLoadBalancer is a naive RoundRobin load balancer for UDP services. type WRRLoadBalancer struct { servers []server - lock sync.RWMutex + lock sync.Mutex currentWeight int index int } @@ -29,16 +29,16 @@ func NewWRRLoadBalancer() *WRRLoadBalancer { // ServeUDP forwards the connection to the right service. func (b *WRRLoadBalancer) ServeUDP(conn *Conn) { - if len(b.servers) == 0 { - log.WithoutContext().Error("no available server") - return - } - + b.lock.Lock() next, err := b.next() + b.lock.Unlock() + if err != nil { log.WithoutContext().Errorf("Error during load balancing: %v", err) conn.Close() + return } + next.ServeUDP(conn) } @@ -50,6 +50,9 @@ func (b *WRRLoadBalancer) AddServer(serverHandler Handler) { // AddWeightedServer appends a handler to the existing list with a weight. func (b *WRRLoadBalancer) AddWeightedServer(serverHandler Handler, weight *int) { + b.lock.Lock() + defer b.lock.Unlock() + w := 1 if weight != nil { w = *weight @@ -87,9 +90,6 @@ func gcd(a, b int) int { } func (b *WRRLoadBalancer) next() (Handler, error) { - b.lock.Lock() - defer b.lock.Unlock() - if len(b.servers) == 0 { return nil, fmt.Errorf("no servers in the pool") } @@ -98,10 +98,14 @@ func (b *WRRLoadBalancer) next() (Handler, error) { // but is actually very simple it calculates the GCD and subtracts it on every iteration, // what interleaves servers and allows us not to build an iterator every time we readjust weights. - // GCD across all enabled servers - gcd := b.weightGcd() // Maximum weight across all enabled servers max := b.maxWeight() + if max == 0 { + return nil, fmt.Errorf("all servers have 0 weight") + } + + // GCD across all enabled servers + gcd := b.weightGcd() for { b.index = (b.index + 1) % len(b.servers) @@ -109,9 +113,6 @@ func (b *WRRLoadBalancer) next() (Handler, error) { b.currentWeight -= gcd if b.currentWeight <= 0 { b.currentWeight = max - if b.currentWeight == 0 { - return nil, fmt.Errorf("all servers have 0 weight") - } } } srv := b.servers[b.index]