227 lines
5.2 KiB
Go
227 lines
5.2 KiB
Go
package stream
|
|
|
|
import (
|
|
"fmt"
|
|
"net/http"
|
|
|
|
"github.com/vulcand/oxy/utils"
|
|
"github.com/vulcand/predicate"
|
|
)
|
|
|
|
func IsValidExpression(expr string) bool {
|
|
_, err := parseExpression(expr)
|
|
return err == nil
|
|
}
|
|
|
|
type context struct {
|
|
r *http.Request
|
|
attempt int
|
|
responseCode int
|
|
log utils.Logger
|
|
}
|
|
|
|
type hpredicate func(*context) bool
|
|
|
|
// Parses expression in the go language into Failover predicates
|
|
func parseExpression(in string) (hpredicate, error) {
|
|
p, err := predicate.NewParser(predicate.Def{
|
|
Operators: predicate.Operators{
|
|
AND: and,
|
|
OR: or,
|
|
EQ: eq,
|
|
NEQ: neq,
|
|
LT: lt,
|
|
GT: gt,
|
|
LE: le,
|
|
GE: ge,
|
|
},
|
|
Functions: map[string]interface{}{
|
|
"RequestMethod": requestMethod,
|
|
"IsNetworkError": isNetworkError,
|
|
"Attempts": attempts,
|
|
"ResponseCode": responseCode,
|
|
},
|
|
})
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
out, err := p.Parse(in)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
pr, ok := out.(hpredicate)
|
|
if !ok {
|
|
return nil, fmt.Errorf("expected predicate, got %T", out)
|
|
}
|
|
return pr, nil
|
|
}
|
|
|
|
type toString func(c *context) string
|
|
type toInt func(c *context) int
|
|
|
|
// RequestMethod returns mapper of the request to its method e.g. POST
|
|
func requestMethod() toString {
|
|
return func(c *context) string {
|
|
return c.r.Method
|
|
}
|
|
}
|
|
|
|
// Attempts returns mapper of the request to the number of proxy attempts
|
|
func attempts() toInt {
|
|
return func(c *context) int {
|
|
return c.attempt
|
|
}
|
|
}
|
|
|
|
// ResponseCode returns mapper of the request to the last response code, returns 0 if there was no response code.
|
|
func responseCode() toInt {
|
|
return func(c *context) int {
|
|
return c.responseCode
|
|
}
|
|
}
|
|
|
|
// IsNetworkError returns a predicate that returns true if last attempt ended with network error.
|
|
func isNetworkError() hpredicate {
|
|
return func(c *context) bool {
|
|
return c.responseCode == http.StatusBadGateway || c.responseCode == http.StatusGatewayTimeout
|
|
}
|
|
}
|
|
|
|
// and returns predicate by joining the passed predicates with logical 'and'
|
|
func and(fns ...hpredicate) hpredicate {
|
|
return func(c *context) bool {
|
|
for _, fn := range fns {
|
|
if !fn(c) {
|
|
return false
|
|
}
|
|
}
|
|
return true
|
|
}
|
|
}
|
|
|
|
// or returns predicate by joining the passed predicates with logical 'or'
|
|
func or(fns ...hpredicate) hpredicate {
|
|
return func(c *context) bool {
|
|
for _, fn := range fns {
|
|
if fn(c) {
|
|
return true
|
|
}
|
|
}
|
|
return false
|
|
}
|
|
}
|
|
|
|
// not creates negation of the passed predicate
|
|
func not(p hpredicate) hpredicate {
|
|
return func(c *context) bool {
|
|
return !p(c)
|
|
}
|
|
}
|
|
|
|
// eq returns predicate that tests for equality of the value of the mapper and the constant
|
|
func eq(m interface{}, value interface{}) (hpredicate, error) {
|
|
switch mapper := m.(type) {
|
|
case toString:
|
|
return stringEQ(mapper, value)
|
|
case toInt:
|
|
return intEQ(mapper, value)
|
|
}
|
|
return nil, fmt.Errorf("unsupported argument: %T", m)
|
|
}
|
|
|
|
// neq returns predicate that tests for inequality of the value of the mapper and the constant
|
|
func neq(m interface{}, value interface{}) (hpredicate, error) {
|
|
p, err := eq(m, value)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
return not(p), nil
|
|
}
|
|
|
|
// lt returns predicate that tests that value of the mapper function is less than the constant
|
|
func lt(m interface{}, value interface{}) (hpredicate, error) {
|
|
switch mapper := m.(type) {
|
|
case toInt:
|
|
return intLT(mapper, value)
|
|
}
|
|
return nil, fmt.Errorf("unsupported argument: %T", m)
|
|
}
|
|
|
|
// le returns predicate that tests that value of the mapper function is less or equal than the constant
|
|
func le(m interface{}, value interface{}) (hpredicate, error) {
|
|
l, err := lt(m, value)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
e, err := eq(m, value)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
return func(c *context) bool {
|
|
return l(c) || e(c)
|
|
}, nil
|
|
}
|
|
|
|
// gt returns predicate that tests that value of the mapper function is greater than the constant
|
|
func gt(m interface{}, value interface{}) (hpredicate, error) {
|
|
switch mapper := m.(type) {
|
|
case toInt:
|
|
return intGT(mapper, value)
|
|
}
|
|
return nil, fmt.Errorf("unsupported argument: %T", m)
|
|
}
|
|
|
|
// ge returns predicate that tests that value of the mapper function is less or equal than the constant
|
|
func ge(m interface{}, value interface{}) (hpredicate, error) {
|
|
g, err := gt(m, value)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
e, err := eq(m, value)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
return func(c *context) bool {
|
|
return g(c) || e(c)
|
|
}, nil
|
|
}
|
|
|
|
func stringEQ(m toString, val interface{}) (hpredicate, error) {
|
|
value, ok := val.(string)
|
|
if !ok {
|
|
return nil, fmt.Errorf("expected string, got %T", val)
|
|
}
|
|
return func(c *context) bool {
|
|
return m(c) == value
|
|
}, nil
|
|
}
|
|
|
|
func intEQ(m toInt, val interface{}) (hpredicate, error) {
|
|
value, ok := val.(int)
|
|
if !ok {
|
|
return nil, fmt.Errorf("expected int, got %T", val)
|
|
}
|
|
return func(c *context) bool {
|
|
return m(c) == value
|
|
}, nil
|
|
}
|
|
|
|
func intLT(m toInt, val interface{}) (hpredicate, error) {
|
|
value, ok := val.(int)
|
|
if !ok {
|
|
return nil, fmt.Errorf("expected int, got %T", val)
|
|
}
|
|
return func(c *context) bool {
|
|
return m(c) < value
|
|
}, nil
|
|
}
|
|
|
|
func intGT(m toInt, val interface{}) (hpredicate, error) {
|
|
value, ok := val.(int)
|
|
if !ok {
|
|
return nil, fmt.Errorf("expected int, got %T", val)
|
|
}
|
|
return func(c *context) bool {
|
|
return m(c) > value
|
|
}, nil
|
|
}
|