2018-07-11 10:08:03 +02:00
// Package roundrobin implements dynamic weighted round robin load balancer http handler
2017-02-07 22:33:23 +01:00
package roundrobin
import (
"fmt"
"net/http"
"net/url"
"sync"
2018-01-22 12:16:03 +01:00
log "github.com/sirupsen/logrus"
2017-02-07 22:33:23 +01:00
"github.com/vulcand/oxy/utils"
)
// Weight is an optional functional argument that sets weight of the server
func Weight ( w int ) ServerOption {
return func ( s * server ) error {
if w < 0 {
return fmt . Errorf ( "Weight should be >= 0" )
}
s . weight = w
return nil
}
}
// ErrorHandler is a functional argument that sets error handler of the server
func ErrorHandler ( h utils . ErrorHandler ) LBOption {
return func ( s * RoundRobin ) error {
s . errHandler = h
return nil
}
}
2018-07-11 10:08:03 +02:00
// EnableStickySession enable sticky session
2017-11-22 18:20:03 +01:00
func EnableStickySession ( stickySession * StickySession ) LBOption {
2017-02-07 22:33:23 +01:00
return func ( s * RoundRobin ) error {
2017-11-22 18:20:03 +01:00
s . stickySession = stickySession
return nil
}
}
2018-07-11 10:08:03 +02:00
// RoundRobinRequestRewriteListener is a functional argument that sets error handler of the server
2017-11-22 18:20:03 +01:00
func RoundRobinRequestRewriteListener ( rrl RequestRewriteListener ) LBOption {
return func ( s * RoundRobin ) error {
s . requestRewriteListener = rrl
2017-02-07 22:33:23 +01:00
return nil
}
}
2018-07-11 10:08:03 +02:00
// RoundRobin implements dynamic weighted round robin load balancer http handler
2017-02-07 22:33:23 +01:00
type RoundRobin struct {
mutex * sync . Mutex
next http . Handler
errHandler utils . ErrorHandler
// Current index (starts from -1)
2017-11-22 18:20:03 +01:00
index int
servers [ ] * server
currentWeight int
stickySession * StickySession
requestRewriteListener RequestRewriteListener
2018-07-11 10:08:03 +02:00
log * log . Logger
2017-02-07 22:33:23 +01:00
}
2018-07-11 10:08:03 +02:00
// New created a new RoundRobin
2017-02-07 22:33:23 +01:00
func New ( next http . Handler , opts ... LBOption ) ( * RoundRobin , error ) {
rr := & RoundRobin {
2017-11-22 18:20:03 +01:00
next : next ,
index : - 1 ,
mutex : & sync . Mutex { } ,
servers : [ ] * server { } ,
stickySession : nil ,
2018-07-11 10:08:03 +02:00
log : log . StandardLogger ( ) ,
2017-02-07 22:33:23 +01:00
}
for _ , o := range opts {
if err := o ( rr ) ; err != nil {
return nil , err
}
}
if rr . errHandler == nil {
rr . errHandler = utils . DefaultHandler
}
return rr , nil
}
2018-07-11 10:08:03 +02:00
// RoundRobinLogger defines the logger the round robin load balancer will use.
//
// It defaults to logrus.StandardLogger(), the global logger used by logrus.
func RoundRobinLogger ( l * log . Logger ) LBOption {
return func ( r * RoundRobin ) error {
r . log = l
return nil
}
}
// Next returns the next handler
2017-02-07 22:33:23 +01:00
func ( r * RoundRobin ) Next ( ) http . Handler {
return r . next
}
func ( r * RoundRobin ) ServeHTTP ( w http . ResponseWriter , req * http . Request ) {
2018-07-11 10:08:03 +02:00
if r . log . Level >= log . DebugLevel {
logEntry := r . log . WithField ( "Request" , utils . DumpHttpRequest ( req ) )
2017-11-22 18:20:03 +01:00
logEntry . Debug ( "vulcand/oxy/roundrobin/rr: begin ServeHttp on request" )
2018-06-04 14:14:03 +02:00
defer logEntry . Debug ( "vulcand/oxy/roundrobin/rr: completed ServeHttp on request" )
2017-11-22 18:20:03 +01:00
}
2017-02-07 22:33:23 +01:00
// make shallow copy of request before chaning anything to avoid side effects
newReq := * req
stuck := false
2017-11-22 18:20:03 +01:00
if r . stickySession != nil {
cookieURL , present , err := r . stickySession . GetBackend ( & newReq , r . Servers ( ) )
2017-02-07 22:33:23 +01:00
if err != nil {
2018-02-12 17:24:03 +01:00
log . Warnf ( "vulcand/oxy/roundrobin/rr: error using server from cookie: %v" , err )
2017-02-07 22:33:23 +01:00
}
if present {
2017-11-22 18:20:03 +01:00
newReq . URL = cookieURL
2017-02-07 22:33:23 +01:00
stuck = true
}
}
if ! stuck {
url , err := r . NextServer ( )
if err != nil {
r . errHandler . ServeHTTP ( w , req , err )
return
}
2017-11-22 18:20:03 +01:00
if r . stickySession != nil {
r . stickySession . StickBackend ( url , & w )
2017-02-07 22:33:23 +01:00
}
newReq . URL = url
}
2017-11-22 18:20:03 +01:00
2018-07-11 10:08:03 +02:00
if r . log . Level >= log . DebugLevel {
// log which backend URL we're sending this request to
r . log . WithFields ( log . Fields { "Request" : utils . DumpHttpRequest ( req ) , "ForwardURL" : newReq . URL } ) . Debugf ( "vulcand/oxy/roundrobin/rr: Forwarding this request to URL" )
2017-11-22 18:20:03 +01:00
}
2018-07-11 10:08:03 +02:00
// Emit event to a listener if one exists
2017-11-22 18:20:03 +01:00
if r . requestRewriteListener != nil {
r . requestRewriteListener ( req , & newReq )
}
2017-02-07 22:33:23 +01:00
r . next . ServeHTTP ( w , & newReq )
}
2018-07-11 10:08:03 +02:00
// NextServer gets the next server
2017-02-07 22:33:23 +01:00
func ( r * RoundRobin ) NextServer ( ) ( * url . URL , error ) {
srv , err := r . nextServer ( )
if err != nil {
return nil , err
}
return utils . CopyURL ( srv . url ) , nil
}
func ( r * RoundRobin ) nextServer ( ) ( * server , error ) {
r . mutex . Lock ( )
defer r . mutex . Unlock ( )
if len ( r . servers ) == 0 {
return nil , fmt . Errorf ( "no servers in the pool" )
}
// The algo below may look messy, but is actually very simple
// it calculates the GCD and subtracts it on every iteration, what interleaves servers
// and allows us not to build an iterator every time we readjust weights
// GCD across all enabled servers
gcd := r . weightGcd ( )
// Maximum weight across all enabled servers
max := r . maxWeight ( )
for {
r . index = ( r . index + 1 ) % len ( r . servers )
if r . index == 0 {
r . currentWeight = r . currentWeight - gcd
if r . currentWeight <= 0 {
r . currentWeight = max
if r . currentWeight == 0 {
return nil , fmt . Errorf ( "all servers have 0 weight" )
}
}
}
srv := r . servers [ r . index ]
if srv . weight >= r . currentWeight {
return srv , nil
}
}
}
2018-07-11 10:08:03 +02:00
// RemoveServer remove a server
2017-02-07 22:33:23 +01:00
func ( r * RoundRobin ) RemoveServer ( u * url . URL ) error {
r . mutex . Lock ( )
defer r . mutex . Unlock ( )
e , index := r . findServerByURL ( u )
if e == nil {
return fmt . Errorf ( "server not found" )
}
r . servers = append ( r . servers [ : index ] , r . servers [ index + 1 : ] ... )
r . resetState ( )
return nil
}
2018-07-11 10:08:03 +02:00
// Servers gets servers URL
2018-04-10 17:24:04 +02:00
func ( r * RoundRobin ) Servers ( ) [ ] * url . URL {
r . mutex . Lock ( )
defer r . mutex . Unlock ( )
2017-02-07 22:33:23 +01:00
2018-04-10 17:24:04 +02:00
out := make ( [ ] * url . URL , len ( r . servers ) )
for i , srv := range r . servers {
2017-02-07 22:33:23 +01:00
out [ i ] = srv . url
}
return out
}
2018-07-11 10:08:03 +02:00
// ServerWeight gets the server weight
2018-04-10 17:24:04 +02:00
func ( r * RoundRobin ) ServerWeight ( u * url . URL ) ( int , bool ) {
r . mutex . Lock ( )
defer r . mutex . Unlock ( )
2017-02-07 22:33:23 +01:00
2018-04-10 17:24:04 +02:00
if s , _ := r . findServerByURL ( u ) ; s != nil {
2017-02-07 22:33:23 +01:00
return s . weight , true
}
return - 1 , false
}
2018-07-11 10:08:03 +02:00
// UpsertServer In case if server is already present in the load balancer, returns error
2018-04-10 17:24:04 +02:00
func ( r * RoundRobin ) UpsertServer ( u * url . URL , options ... ServerOption ) error {
r . mutex . Lock ( )
defer r . mutex . Unlock ( )
2017-02-07 22:33:23 +01:00
if u == nil {
return fmt . Errorf ( "server URL can't be nil" )
}
2018-04-10 17:24:04 +02:00
if s , _ := r . findServerByURL ( u ) ; s != nil {
2017-02-07 22:33:23 +01:00
for _ , o := range options {
if err := o ( s ) ; err != nil {
return err
}
}
2018-04-10 17:24:04 +02:00
r . resetState ( )
2017-02-07 22:33:23 +01:00
return nil
}
srv := & server { url : utils . CopyURL ( u ) }
for _ , o := range options {
if err := o ( srv ) ; err != nil {
return err
}
}
if srv . weight == 0 {
srv . weight = defaultWeight
}
2018-04-10 17:24:04 +02:00
r . servers = append ( r . servers , srv )
r . resetState ( )
2017-02-07 22:33:23 +01:00
return nil
}
func ( r * RoundRobin ) resetIterator ( ) {
r . index = - 1
r . currentWeight = 0
}
func ( r * RoundRobin ) resetState ( ) {
r . resetIterator ( )
}
func ( r * RoundRobin ) findServerByURL ( u * url . URL ) ( * server , int ) {
if len ( r . servers ) == 0 {
return nil , - 1
}
for i , s := range r . servers {
if sameURL ( u , s . url ) {
return s , i
}
}
return nil , - 1
}
2018-04-10 17:24:04 +02:00
func ( r * RoundRobin ) maxWeight ( ) int {
2017-02-07 22:33:23 +01:00
max := - 1
2018-04-10 17:24:04 +02:00
for _ , s := range r . servers {
2017-02-07 22:33:23 +01:00
if s . weight > max {
max = s . weight
}
}
return max
}
2018-04-10 17:24:04 +02:00
func ( r * RoundRobin ) weightGcd ( ) int {
2017-02-07 22:33:23 +01:00
divisor := - 1
2018-04-10 17:24:04 +02:00
for _ , s := range r . servers {
2017-02-07 22:33:23 +01:00
if divisor == - 1 {
divisor = s . weight
} else {
divisor = gcd ( divisor , s . weight )
}
}
return divisor
}
func gcd ( a , b int ) int {
for b != 0 {
a , b = b , a % b
}
return a
}
// ServerOption provides various options for server, e.g. weight
type ServerOption func ( * server ) error
// LBOption provides options for load balancer
type LBOption func ( * RoundRobin ) error
// Set additional parameters for the server can be supplied when adding server
type server struct {
url * url . URL
// Relative weight for the enpoint to other enpoints in the load balancer
weight int
}
2018-04-10 17:24:04 +02:00
var defaultWeight = 1
2018-07-11 10:08:03 +02:00
// SetDefaultWeight sets the default server weight
2018-04-10 17:24:04 +02:00
func SetDefaultWeight ( weight int ) error {
if weight < 0 {
return fmt . Errorf ( "default weight should be >= 0" )
}
defaultWeight = weight
return nil
}
2017-02-07 22:33:23 +01:00
func sameURL ( a , b * url . URL ) bool {
return a . Path == b . Path && a . Host == b . Host && a . Scheme == b . Scheme
}
type balancerHandler interface {
Servers ( ) [ ] * url . URL
ServeHTTP ( w http . ResponseWriter , req * http . Request )
ServerWeight ( u * url . URL ) ( int , bool )
RemoveServer ( u * url . URL ) error
UpsertServer ( u * url . URL , options ... ServerOption ) error
NextServer ( ) ( * url . URL , error )
Next ( ) http . Handler
}