traefik/vendor/github.com/vulcand/predicate/parse.go
2018-01-22 12:16:03 +01:00

256 lines
6 KiB
Go

package predicate
import (
"fmt"
"go/ast"
"go/parser"
"go/token"
"reflect"
"strconv"
"strings"
"github.com/gravitational/trace"
)
func NewParser(d Def) (Parser, error) {
return &predicateParser{d: d}, nil
}
type predicateParser struct {
d Def
}
func (p *predicateParser) Parse(in string) (interface{}, error) {
expr, err := parser.ParseExpr(in)
if err != nil {
return nil, err
}
return p.parseNode(expr)
}
func (p *predicateParser) parseNode(node ast.Node) (interface{}, error) {
switch n := node.(type) {
case *ast.BasicLit:
return literalToValue(n)
case *ast.BinaryExpr:
x, err := p.parseNode(n.X)
if err != nil {
return nil, err
}
y, err := p.parseNode(n.Y)
if err != nil {
return nil, err
}
return p.joinPredicates(n.Op, x, y)
case *ast.CallExpr:
// We expect function that will return predicate
name, err := getIdentifier(n.Fun)
if err != nil {
return nil, err
}
fn, err := p.getFunction(name)
if err != nil {
return nil, err
}
arguments, err := p.evaluateArguments(n.Args)
if err != nil {
return nil, err
}
return callFunction(fn, arguments)
case *ast.ParenExpr:
return p.parseNode(n.X)
}
return nil, trace.BadParameter("unsupported %T", node)
}
func (p *predicateParser) evaluateArguments(nodes []ast.Expr) ([]interface{}, error) {
out := make([]interface{}, len(nodes))
for i, n := range nodes {
val, err := p.evaluateExpr(n)
if err != nil {
return nil, trace.Wrap(err)
}
out[i] = val
}
return out, nil
}
func (p *predicateParser) evaluateExpr(n ast.Expr) (interface{}, error) {
switch l := n.(type) {
case *ast.BasicLit:
val, err := literalToValue(l)
if err != nil {
return nil, err
}
return val, nil
case *ast.IndexExpr:
if p.d.GetProperty == nil {
return nil, trace.NotFound("properties are not supported")
}
mapVal, err := p.evaluateExpr(l.X)
if err != nil {
return nil, trace.Wrap(err)
}
keyVal, err := p.evaluateExpr(l.Index)
if err != nil {
return nil, trace.Wrap(err)
}
val, err := p.d.GetProperty(mapVal, keyVal)
if err != nil {
return nil, trace.Wrap(err)
}
return val, nil
case *ast.SelectorExpr:
fields, err := evaluateSelector(l, []string{})
if err != nil {
return nil, trace.Wrap(err)
}
if p.d.GetIdentifier == nil {
return nil, trace.NotFound("%v is not defined", strings.Join(fields, "."))
}
val, err := p.d.GetIdentifier(fields)
if err != nil {
return nil, trace.Wrap(err)
}
return val, nil
case *ast.Ident:
if p.d.GetIdentifier == nil {
return nil, trace.NotFound("%v is not defined", l.Name)
}
val, err := p.d.GetIdentifier([]string{l.Name})
if err != nil {
return nil, trace.Wrap(err)
}
return val, nil
default:
return nil, trace.BadParameter("%T is not supported", n)
}
}
// evaluateSelector recursively evaluates the selector field and returns a list
// of properties at the end
func evaluateSelector(sel *ast.SelectorExpr, fields []string) ([]string, error) {
fields = append([]string{sel.Sel.Name}, fields...)
switch l := sel.X.(type) {
case *ast.SelectorExpr:
return evaluateSelector(l, fields)
case *ast.Ident:
fields = append([]string{l.Name}, fields...)
return fields, nil
default:
return nil, trace.BadParameter("unsupported selector type: %T", l)
}
}
func (p *predicateParser) getFunction(name string) (interface{}, error) {
v, ok := p.d.Functions[name]
if !ok {
return nil, trace.BadParameter("unsupported function: %s", name)
}
return v, nil
}
func (p *predicateParser) joinPredicates(op token.Token, a, b interface{}) (interface{}, error) {
joinFn, err := p.getJoinFunction(op)
if err != nil {
return nil, err
}
return callFunction(joinFn, []interface{}{a, b})
}
func (p *predicateParser) getJoinFunction(op token.Token) (interface{}, error) {
var fn interface{}
switch op {
case token.LAND:
fn = p.d.Operators.AND
case token.LOR:
fn = p.d.Operators.OR
case token.GTR:
fn = p.d.Operators.GT
case token.GEQ:
fn = p.d.Operators.GE
case token.LSS:
fn = p.d.Operators.LT
case token.LEQ:
fn = p.d.Operators.LE
case token.EQL:
fn = p.d.Operators.EQ
case token.NEQ:
fn = p.d.Operators.NEQ
}
if fn == nil {
return nil, trace.BadParameter("%v is not supported", op)
}
return fn, nil
}
func getIdentifier(node ast.Node) (string, error) {
sexpr, ok := node.(*ast.SelectorExpr)
if ok {
id, ok := sexpr.X.(*ast.Ident)
if !ok {
return "", trace.BadParameter("expected selector identifier, got: %T", sexpr.X)
}
return fmt.Sprintf("%s.%s", id.Name, sexpr.Sel.Name), nil
}
id, ok := node.(*ast.Ident)
if !ok {
return "", trace.BadParameter("expected identifier, got: %T", node)
}
return id.Name, nil
}
func literalToValue(a *ast.BasicLit) (interface{}, error) {
switch a.Kind {
case token.FLOAT:
value, err := strconv.ParseFloat(a.Value, 64)
if err != nil {
return nil, trace.BadParameter("failed to parse argument: %s, error: %s", a.Value, err)
}
return value, nil
case token.INT:
value, err := strconv.Atoi(a.Value)
if err != nil {
return nil, trace.BadParameter("failed to parse argument: %s, error: %s", a.Value, err)
}
return value, nil
case token.STRING:
value, err := strconv.Unquote(a.Value)
if err != nil {
return nil, trace.BadParameter("failed to parse argument: %s, error: %s", a.Value, err)
}
return value, nil
}
return nil, trace.BadParameter("unsupported function argument type: '%v'", a.Kind)
}
func callFunction(f interface{}, args []interface{}) (v interface{}, err error) {
defer func() {
if r := recover(); r != nil {
err = trace.BadParameter("%s", r)
}
}()
arguments := make([]reflect.Value, len(args))
for i, a := range args {
arguments[i] = reflect.ValueOf(a)
}
fn := reflect.ValueOf(f)
ret := fn.Call(arguments)
switch len(ret) {
case 1:
return ret[0].Interface(), nil
case 2:
v, e := ret[0].Interface(), ret[1].Interface()
if e == nil {
return v, nil
}
err, ok := e.(error)
if !ok {
return nil, trace.BadParameter("expected error as a second return value, got %T", e)
}
return v, err
}
return nil, trace.BadParameter("expected at least one return argument for '%v'", fn)
}