fix: TCP/UDP wrr when all servers have a weight set to 0

Co-authored-by: Kevin Pollet <pollet.kevin@gmail.com>
This commit is contained in:
Tom Moulard 2021-11-08 17:58:12 +01:00 committed by GitHub
parent ffdfc13461
commit d91eefa74f
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
3 changed files with 56 additions and 39 deletions

View file

@ -15,7 +15,7 @@ type server struct {
// WRRLoadBalancer is a naive RoundRobin load balancer for TCP services. // WRRLoadBalancer is a naive RoundRobin load balancer for TCP services.
type WRRLoadBalancer struct { type WRRLoadBalancer struct {
servers []server servers []server
lock sync.RWMutex lock sync.Mutex
currentWeight int currentWeight int
index int index int
} }
@ -29,16 +29,16 @@ func NewWRRLoadBalancer() *WRRLoadBalancer {
// ServeTCP forwards the connection to the right service. // ServeTCP forwards the connection to the right service.
func (b *WRRLoadBalancer) ServeTCP(conn WriteCloser) { func (b *WRRLoadBalancer) ServeTCP(conn WriteCloser) {
if len(b.servers) == 0 { b.lock.Lock()
log.WithoutContext().Error("no available server")
return
}
next, err := b.next() next, err := b.next()
b.lock.Unlock()
if err != nil { if err != nil {
log.WithoutContext().Errorf("Error during load balancing: %v", err) log.WithoutContext().Errorf("Error during load balancing: %v", err)
conn.Close() conn.Close()
return
} }
next.ServeTCP(conn) next.ServeTCP(conn)
} }
@ -50,6 +50,9 @@ func (b *WRRLoadBalancer) AddServer(serverHandler Handler) {
// AddWeightServer appends a server to the existing list with a weight. // AddWeightServer appends a server to the existing list with a weight.
func (b *WRRLoadBalancer) AddWeightServer(serverHandler Handler, weight *int) { func (b *WRRLoadBalancer) AddWeightServer(serverHandler Handler, weight *int) {
b.lock.Lock()
defer b.lock.Unlock()
w := 1 w := 1
if weight != nil { if weight != nil {
w = *weight w = *weight
@ -87,9 +90,6 @@ func gcd(a, b int) int {
} }
func (b *WRRLoadBalancer) next() (Handler, error) { func (b *WRRLoadBalancer) next() (Handler, error) {
b.lock.Lock()
defer b.lock.Unlock()
if len(b.servers) == 0 { if len(b.servers) == 0 {
return nil, fmt.Errorf("no servers in the pool") return nil, fmt.Errorf("no servers in the pool")
} }
@ -98,10 +98,14 @@ func (b *WRRLoadBalancer) next() (Handler, error) {
// it calculates the GCD and subtracts it on every iteration, what interleaves servers // it calculates the GCD and subtracts it on every iteration, what interleaves servers
// and allows us not to build an iterator every time we readjust weights // and allows us not to build an iterator every time we readjust weights
// GCD across all enabled servers
gcd := b.weightGcd()
// Maximum weight across all enabled servers // Maximum weight across all enabled servers
max := b.maxWeight() max := b.maxWeight()
if max == 0 {
return nil, fmt.Errorf("all servers have 0 weight")
}
// GCD across all enabled servers
gcd := b.weightGcd()
for { for {
b.index = (b.index + 1) % len(b.servers) b.index = (b.index + 1) % len(b.servers)
@ -109,9 +113,6 @@ func (b *WRRLoadBalancer) next() (Handler, error) {
b.currentWeight -= gcd b.currentWeight -= gcd
if b.currentWeight <= 0 { if b.currentWeight <= 0 {
b.currentWeight = max b.currentWeight = max
if b.currentWeight == 0 {
return nil, fmt.Errorf("all servers have 0 weight")
}
} }
} }
srv := b.servers[b.index] srv := b.servers[b.index]

View file

@ -10,7 +10,8 @@ import (
) )
type fakeConn struct { type fakeConn struct {
call map[string]int writeCall map[string]int
closeCall int
} }
func (f *fakeConn) Read(b []byte) (n int, err error) { func (f *fakeConn) Read(b []byte) (n int, err error) {
@ -18,12 +19,13 @@ func (f *fakeConn) Read(b []byte) (n int, err error) {
} }
func (f *fakeConn) Write(b []byte) (n int, err error) { func (f *fakeConn) Write(b []byte) (n int, err error) {
f.call[string(b)]++ f.writeCall[string(b)]++
return len(b), nil return len(b), nil
} }
func (f *fakeConn) Close() error { func (f *fakeConn) Close() error {
panic("implement me") f.closeCall++
return nil
} }
func (f *fakeConn) LocalAddr() net.Addr { func (f *fakeConn) LocalAddr() net.Addr {
@ -55,7 +57,8 @@ func TestLoadBalancing(t *testing.T) {
desc string desc string
serversWeight map[string]int serversWeight map[string]int
totalCall int totalCall int
expected map[string]int expectedWrite map[string]int
expectedClose int
}{ }{
{ {
desc: "RoundRobin", desc: "RoundRobin",
@ -64,7 +67,7 @@ func TestLoadBalancing(t *testing.T) {
"h2": 1, "h2": 1,
}, },
totalCall: 4, totalCall: 4,
expected: map[string]int{ expectedWrite: map[string]int{
"h1": 2, "h1": 2,
"h2": 2, "h2": 2,
}, },
@ -76,7 +79,7 @@ func TestLoadBalancing(t *testing.T) {
"h2": 1, "h2": 1,
}, },
totalCall: 4, totalCall: 4,
expected: map[string]int{ expectedWrite: map[string]int{
"h1": 3, "h1": 3,
"h2": 1, "h2": 1,
}, },
@ -88,22 +91,33 @@ func TestLoadBalancing(t *testing.T) {
"h2": 1, "h2": 1,
}, },
totalCall: 16, totalCall: 16,
expected: map[string]int{ expectedWrite: map[string]int{
"h1": 12, "h1": 12,
"h2": 4, "h2": 4,
}, },
}, },
{ {
desc: "WeighedRoundRobin with 0 weight server", desc: "WeighedRoundRobin with one 0 weight server",
serversWeight: map[string]int{ serversWeight: map[string]int{
"h1": 3, "h1": 3,
"h2": 0, "h2": 0,
}, },
totalCall: 16, totalCall: 16,
expected: map[string]int{ expectedWrite: map[string]int{
"h1": 16, "h1": 16,
}, },
}, },
{
desc: "WeighedRoundRobin with all servers with 0 weight",
serversWeight: map[string]int{
"h1": 0,
"h2": 0,
"h3": 0,
},
totalCall: 10,
expectedWrite: map[string]int{},
expectedClose: 10,
},
} }
for _, test := range testCases { for _, test := range testCases {
@ -120,12 +134,13 @@ func TestLoadBalancing(t *testing.T) {
}), &weight) }), &weight)
} }
conn := &fakeConn{call: make(map[string]int)} conn := &fakeConn{writeCall: make(map[string]int)}
for i := 0; i < test.totalCall; i++ { for i := 0; i < test.totalCall; i++ {
balancer.ServeTCP(conn) balancer.ServeTCP(conn)
} }
assert.Equal(t, test.expected, conn.call) assert.Equal(t, test.expectedWrite, conn.writeCall)
assert.Equal(t, test.expectedClose, conn.closeCall)
}) })
} }
} }

View file

@ -15,7 +15,7 @@ type server struct {
// WRRLoadBalancer is a naive RoundRobin load balancer for UDP services. // WRRLoadBalancer is a naive RoundRobin load balancer for UDP services.
type WRRLoadBalancer struct { type WRRLoadBalancer struct {
servers []server servers []server
lock sync.RWMutex lock sync.Mutex
currentWeight int currentWeight int
index int index int
} }
@ -29,16 +29,16 @@ func NewWRRLoadBalancer() *WRRLoadBalancer {
// ServeUDP forwards the connection to the right service. // ServeUDP forwards the connection to the right service.
func (b *WRRLoadBalancer) ServeUDP(conn *Conn) { func (b *WRRLoadBalancer) ServeUDP(conn *Conn) {
if len(b.servers) == 0 { b.lock.Lock()
log.WithoutContext().Error("no available server")
return
}
next, err := b.next() next, err := b.next()
b.lock.Unlock()
if err != nil { if err != nil {
log.WithoutContext().Errorf("Error during load balancing: %v", err) log.WithoutContext().Errorf("Error during load balancing: %v", err)
conn.Close() conn.Close()
return
} }
next.ServeUDP(conn) next.ServeUDP(conn)
} }
@ -50,6 +50,9 @@ func (b *WRRLoadBalancer) AddServer(serverHandler Handler) {
// AddWeightedServer appends a handler to the existing list with a weight. // AddWeightedServer appends a handler to the existing list with a weight.
func (b *WRRLoadBalancer) AddWeightedServer(serverHandler Handler, weight *int) { func (b *WRRLoadBalancer) AddWeightedServer(serverHandler Handler, weight *int) {
b.lock.Lock()
defer b.lock.Unlock()
w := 1 w := 1
if weight != nil { if weight != nil {
w = *weight w = *weight
@ -87,9 +90,6 @@ func gcd(a, b int) int {
} }
func (b *WRRLoadBalancer) next() (Handler, error) { func (b *WRRLoadBalancer) next() (Handler, error) {
b.lock.Lock()
defer b.lock.Unlock()
if len(b.servers) == 0 { if len(b.servers) == 0 {
return nil, fmt.Errorf("no servers in the pool") return nil, fmt.Errorf("no servers in the pool")
} }
@ -98,10 +98,14 @@ func (b *WRRLoadBalancer) next() (Handler, error) {
// but is actually very simple it calculates the GCD and subtracts it on every iteration, // but is actually very simple it calculates the GCD and subtracts it on every iteration,
// what interleaves servers and allows us not to build an iterator every time we readjust weights. // what interleaves servers and allows us not to build an iterator every time we readjust weights.
// GCD across all enabled servers
gcd := b.weightGcd()
// Maximum weight across all enabled servers // Maximum weight across all enabled servers
max := b.maxWeight() max := b.maxWeight()
if max == 0 {
return nil, fmt.Errorf("all servers have 0 weight")
}
// GCD across all enabled servers
gcd := b.weightGcd()
for { for {
b.index = (b.index + 1) % len(b.servers) b.index = (b.index + 1) % len(b.servers)
@ -109,9 +113,6 @@ func (b *WRRLoadBalancer) next() (Handler, error) {
b.currentWeight -= gcd b.currentWeight -= gcd
if b.currentWeight <= 0 { if b.currentWeight <= 0 {
b.currentWeight = max b.currentWeight = max
if b.currentWeight == 0 {
return nil, fmt.Errorf("all servers have 0 weight")
}
} }
} }
srv := b.servers[b.index] srv := b.servers[b.index]