Select Git revision
mongo.go 15.18 KiB
package expr
import (
"context"
"fmt"
"regexp"
"strings"
"github.com/expr-lang/expr"
"github.com/expr-lang/expr/ast"
exprcompiler "github.com/expr-lang/expr/compiler"
"github.com/expr-lang/expr/conf"
"github.com/expr-lang/expr/parser"
"go.mongodb.org/mongo-driver/bson"
)
var geoTypes = map[string]string{
"box": "$box",
"polygon": "$polygon",
}
func ConvertToMongo(ctx context.Context, exp string, env map[string]interface{}, identifierRenameFn func(string) string, ops ...expr.Option) (b bson.M, err error) {
if exp == "" {
return bson.M{}, nil
}
tree, err := parser.Parse(exp)
if err != nil {
return nil, err
}
return convertToMongo(ctx, tree, env, identifierRenameFn, ops...)
}
func convertToMongo(ctx context.Context, tree *parser.Tree, env map[string]interface{}, identifierRenameFn func(string) string, ops ...expr.Option) (b bson.M, err error) {
defer func() {
if r := recover(); r != nil {
err = fmt.Errorf("%v", r)
}
}()
if env == nil {
env = make(map[string]interface{})
}
env[EnvContextKey] = ctx
config := GetDefaultConfig(env)
for _, op := range ops {
op(config)
}
env = config.Env.(map[string]interface{})
if len(config.Visitors) >= 0 {
for _, v := range config.Visitors {
ast.Walk(&tree.Node, v)
}
}
c := &compiler{tree: tree, env: env, config: config, identifierRenameFn: identifierRenameFn}
v, ok := c.compile(tree.Node).(bson.M)
if !ok || v == nil {
return nil, fmt.Errorf("invalid expression")
}
return v, nil
}
type compiler struct {
env map[string]interface{}
tree *parser.Tree
config *conf.Config
identifierRenameFn func(string) string
}
func (c *compiler) eval(node ast.Node) interface{} {
t := &parser.Tree{
Node: node,
Source: c.tree.Source,
}
prg, err := exprcompiler.Compile(t, c.config)
if err != nil {
panic(fmt.Sprintf("compile error %s", err.Error()))
}
ret, err := expr.Run(prg, c.env)
if err != nil {
panic(fmt.Sprintf("execution error %s", err.Error()))
}
return ret
}
func (c *compiler) compile(node ast.Node) interface{} {
switch n := node.(type) {
case *ast.NilNode:
return c.NilNode(n)
case *ast.IdentifierNode:
return c.IdentifierNode(n)
case *ast.IntegerNode:
return c.IntegerNode(n)
case *ast.FloatNode:
return c.FloatNode(n)
case *ast.BoolNode:
return c.BoolNode(n)
case *ast.StringNode:
return c.StringNode(n)
case *ast.ConstantNode:
return c.ConstantNode(n)
case *ast.UnaryNode:
return c.UnaryNode(n)
case *ast.BinaryNode:
return c.BinaryNode(n)
case *ast.MemberNode:
return c.MemberNode(n)
case *ast.ChainNode:
return c.ChainNode(n)
case *ast.SliceNode:
return c.SliceNode(n)
case *ast.CallNode:
return c.CallNode(n)
case *ast.BuiltinNode:
return c.BuiltinNode(n)
case *ast.ClosureNode:
return c.ClosureNode(n)
case *ast.PointerNode:
return c.PointerNode(n)
case *ast.ConditionalNode:
return c.ConditionalNode(n)
case *ast.VariableDeclaratorNode:
return c.VariableDeclaratorNode(n)
case *ast.ArrayNode:
return c.ArrayNode(n)
case *ast.MapNode:
return c.MapNode(n)
case *ast.PairNode:
return c.PairNode(n)
default:
panic(fmt.Sprintf("undefined node type (%T)", node))
}
}
func (c *compiler) NilNode(node *ast.NilNode) interface{} {
return nil
}
func (c *compiler) IdentifierNode(node *ast.IdentifierNode) string {
identifier := node.Value
if c.identifierRenameFn != nil {
identifier = c.identifierRenameFn(identifier)
}
return identifier
}
func (c *compiler) IntegerNode(node *ast.IntegerNode) int {
return node.Value
//t := node.Type()
//if t == nil {
// c.emitPush(node.Value)
// return
//}
//
//switch t.Kind() {
//case reflect.Float32:
// c.emitPush(float32(node.Value))
//case reflect.Float64:
// c.emitPush(float64(node.Value))
//
//case reflect.Int:
// c.emitPush(int(node.Value))
//case reflect.Int8:
// c.emitPush(int8(node.Value))
//case reflect.Int16:
// c.emitPush(int16(node.Value))
//case reflect.Int32:
// c.emitPush(int32(node.Value))
//case reflect.Int64:
// c.emitPush(int64(node.Value))
//
//case reflect.Uint:
// c.emitPush(uint(node.Value))
//case reflect.Uint8:
// c.emitPush(uint8(node.Value))
//case reflect.Uint16:
// c.emitPush(uint16(node.Value))
//case reflect.Uint32:
// c.emitPush(uint32(node.Value))
//case reflect.Uint64:
// c.emitPush(uint64(node.Value))
//
//default:
// c.emitPush(node.Value)
//}
}
func (c *compiler) FloatNode(node *ast.FloatNode) float64 {
return node.Value
}
func (c *compiler) BoolNode(node *ast.BoolNode) bool {
return node.Value
}
func (c *compiler) StringNode(node *ast.StringNode) string {
return node.Value
}
func (c *compiler) ConstantNode(node *ast.ConstantNode) interface{} {
return node.Value
}
func (c *compiler) UnaryNode(node *ast.UnaryNode) interface{} {
switch node.Operator {
case "!", "not":
nodeIn, ok := node.Node.(*ast.BinaryNode)
if ok && nodeIn.Operator == "in" {
return bson.M{c.identifier(nodeIn.Left): bson.M{"$nin": c.eval(nodeIn.Right)}}
}
return bson.M{"$not": c.compile(node.Node)}
default:
panic(fmt.Sprintf("unknown operator (%v)", node.Operator))
}
}
func (c *compiler) identifier(node ast.Node) string {
switch l := node.(type) {
case *ast.MemberNode:
return c.MemberNode(l)
case *ast.IdentifierNode:
return c.IdentifierNode(l)
}
panic(fmt.Sprintf("incorrect identifier node (%v) ", ast.Dump(node)))
}
func (c *compiler) BinaryNode(node *ast.BinaryNode) interface{} {
switch node.Operator {
case "==":
return bson.M{c.identifier(node.Left): c.eval(node.Right)}
case "!=":
return bson.M{c.identifier(node.Left): bson.M{"$ne": c.eval(node.Right)}}
case "or", "||":
return bson.M{"$or": bson.A{c.compile(node.Left), c.compile(node.Right)}}
case "and", "&&":
return bson.M{"$and": bson.A{c.compile(node.Left), c.compile(node.Right)}}
case "in":
return bson.M{c.identifier(node.Left): bson.M{"$in": c.eval(node.Right)}}
case "not in":
return bson.M{c.identifier(node.Left): bson.M{"$nin": c.eval(node.Right)}}
case "<":
return bson.M{c.identifier(node.Left): bson.M{"$lt": c.eval(node.Right)}}
case ">":
return bson.M{c.identifier(node.Left): bson.M{"$gt": c.eval(node.Right)}}
case "<=":
return bson.M{c.identifier(node.Left): bson.M{"$lte": c.eval(node.Right)}}
case ">=":
return bson.M{c.identifier(node.Left): bson.M{"$gte": c.eval(node.Right)}}
//case "+":
// c.compile(node.Left)
// c.compile(node.Right)
// c.emit(OpAdd)
//
//case "-":
// c.compile(node.Left)
// c.compile(node.Right)
// c.emit(OpSubtract)
//
//case "*":
// c.compile(node.Left)
// c.compile(node.Right)
// c.emit(OpMultiply)
//
//case "/":
// c.compile(node.Left)
// c.compile(node.Right)
// c.emit(OpDivide)
//
//case "%":
// c.compile(node.Left)
// c.compile(node.Right)
// c.emit(OpModulo)
//
//case "**":
// c.compile(node.Left)
// c.compile(node.Right)
// c.emit(OpExponent)
case "contains":
value, ok := c.eval(node.Right).(string)
if !ok {
panic("contains requires string as an argument")
}
return bson.M{c.identifier(node.Left): bson.M{"$regex": regexp.QuoteMeta(value)}}
case "startsWith":
value, ok := c.eval(node.Right).(string)
if !ok {
panic("startsWith requires string as an argument")
}
return bson.M{c.identifier(node.Left): bson.M{"$regex": fmt.Sprintf("^%s.*", regexp.QuoteMeta(value))}}
case "endsWith":
value, ok := c.eval(node.Right).(string)
if !ok {
panic("endsWith requires string as an argument")
}
return bson.M{c.identifier(node.Left): bson.M{"$regex": fmt.Sprintf(".*%s$", regexp.QuoteMeta(value))}}
case "..":
panic("unsupported range")
default:
panic(fmt.Sprintf("unknown operator (%v)", node.Operator))
}
}
func (c *compiler) ChainNode(node *ast.ChainNode) string {
panic("unsupported chain node")
}
func (c *compiler) MemberNode(node *ast.MemberNode) string {
v := c.compile(node.Node)
if val, ok := v.(string); ok {
return fmt.Sprintf("%s.%s", val, c.compile(node.Property))
}
panic(fmt.Sprintf("unsupported property for %v", ast.Dump(node.Node)))
}
func (c *compiler) SliceNode(node *ast.SliceNode) interface{} {
panic("unsupported slice node")
}
func (c *compiler) CallNode(node *ast.CallNode) interface{} {
switch node.Callee.String() {
case "search", "q":
val := c.compile(node.Arguments[0])
return bson.M{"$text": bson.M{"$search": val}}
case "near":
v := c.identifier(node.Arguments[0])
point := c.eval(node.Arguments[1])
distance := c.eval(node.Arguments[2])
if v == "" {
panic("incorrect argument, empty field name")
}
if !strings.HasSuffix(v, ".geometry") {
v += ".geometry"
}
if _, ok := point.([]interface{}); !ok {
panic("incorrect argument, point must coordinates array")
}
return bson.M{
v: bson.M{"$near": bson.D{{Key: "$geometry", Value: map[string]interface{}{"type": "Point", "coordinates": point}}, {Key: "$maxDistance", Value: distance}}},
}
case "within":
v := c.identifier(node.Arguments[0])
t := c.eval(node.Arguments[1])
points := c.eval(node.Arguments[2])
if v == "" {
panic("incorrect argument, empty field name")
}
if !strings.HasSuffix(v, ".geometry") {
v += ".geometry"
}
typ, ok := t.(string)
if !ok {
panic("incorrect argument, geotype must be string")
}
typ, ok = geoTypes[typ]
if !ok {
panic("incorrect geotype value")
}
if _, ok := points.([]interface{}); !ok {
panic("incorrect argument, points must be array of coordinates")
}
return bson.M{
v: bson.M{"$geoWithin": bson.M{typ: points}},
}
case "In":
fields := c.identifier(node.Arguments[0])
if fields == "" {
panic("incorrect argument, empty field name")
}
array, ok := c.eval(node.Arguments[1]).([]interface{})
if !ok {
array = []interface{}{c.eval(node.Arguments[1])}
}
return bson.M{fields: bson.M{"$in": array}}
case "icontains":
v := c.identifier(node.Arguments[0])
t, ok := c.eval(node.Arguments[1]).(string)
if !ok {
panic("icontains requires string as an argument")
}
return bson.M{v: bson.M{"$regex": regexp.QuoteMeta(t), "$options": "i"}}
case "istartsWith":
v := c.identifier(node.Arguments[0])
t, ok := c.eval(node.Arguments[1]).(string)
if !ok {
panic("istartsWith requires string as an argument")
}
return bson.M{v: bson.M{"$regex": fmt.Sprintf("^%s.*", regexp.QuoteMeta(t)), "$options": "i"}}
case "iendsWith":
v := c.identifier(node.Arguments[0])
t, ok := c.eval(node.Arguments[1]).(string)
if !ok {
panic("iendsWith requires string as an argument")
}
return bson.M{v: bson.M{"$regex": fmt.Sprintf(".*%s$", regexp.QuoteMeta(t)), "$options": "i"}}
}
panic("unsupported function")
//for _, arg := range node.Arguments {
// c.compile(arg)
//}
//op := OpCall
//if node.Fast {
// op = OpCallFast
//}
//c.emit(op, c.makeConstant(Call{Name: node.Name, Size: len(node.Arguments)})...)
}
func (c *compiler) BuiltinNode(node *ast.BuiltinNode) interface{} {
panic("unsupported builin node")
//switch node.Name {
//case "len":
// c.compile(node.Arguments[0])
// c.emit(OpLen)
// c.emit(OpRot)
// c.emit(OpPop)
//
//case "all":
// c.compile(node.Arguments[0])
// c.emit(OpBegin)
// var loopBreak int
// c.emitLoop(func() {
// c.compile(node.Arguments[1])
// loopBreak = c.emit(OpJumpIfFalse, c.placeholder()...)
// c.emit(OpPop)
// })
// c.emit(OpTrue)
// c.patchJump(loopBreak)
// c.emit(OpEnd)
//
//case "none":
// c.compile(node.Arguments[0])
// c.emit(OpBegin)
// var loopBreak int
// c.emitLoop(func() {
// c.compile(node.Arguments[1])
// c.emit(OpNot)
// loopBreak = c.emit(OpJumpIfFalse, c.placeholder()...)
// c.emit(OpPop)
// })
// c.emit(OpTrue)
// c.patchJump(loopBreak)
// c.emit(OpEnd)
//
//case "any":
// c.compile(node.Arguments[0])
// c.emit(OpBegin)
// var loopBreak int
// c.emitLoop(func() {
// c.compile(node.Arguments[1])
// loopBreak = c.emit(OpJumpIfTrue, c.placeholder()...)
// c.emit(OpPop)
// })
// c.emit(OpFalse)
// c.patchJump(loopBreak)
// c.emit(OpEnd)
//
//case "one":
// count := c.makeConstant("count")
// c.compile(node.Arguments[0])
// c.emit(OpBegin)
// c.emitPush(0)
// c.emit(OpStore, count...)
// c.emitLoop(func() {
// c.compile(node.Arguments[1])
// c.emitCond(func() {
// c.emit(OpInc, count...)
// })
// })
// c.emit(OpLoad, count...)
// c.emitPush(1)
// c.emit(OpEqual)
// c.emit(OpEnd)
//
//case "filter":
// count := c.makeConstant("count")
// c.compile(node.Arguments[0])
// c.emit(OpBegin)
// c.emitPush(0)
// c.emit(OpStore, count...)
// c.emitLoop(func() {
// c.compile(node.Arguments[1])
// c.emitCond(func() {
// c.emit(OpInc, count...)
//
// c.emit(OpLoad, c.makeConstant("array")...)
// c.emit(OpLoad, c.makeConstant("i")...)
// c.emit(OpIndex)
// })
// })
// c.emit(OpLoad, count...)
// c.emit(OpEnd)
// c.emit(OpArray)
//
//case "map":
// c.compile(node.Arguments[0])
// c.emit(OpBegin)
// size := c.emitLoop(func() {
// c.compile(node.Arguments[1])
// })
// c.emit(OpLoad, size...)
// c.emit(OpEnd)
// c.emit(OpArray)
//
//case "count":
// count := c.makeConstant("count")
// c.compile(node.Arguments[0])
// c.emit(OpBegin)
// c.emitPush(0)
// c.emit(OpStore, count...)
// c.emitLoop(func() {
// c.compile(node.Arguments[1])
// c.emitCond(func() {
// c.emit(OpInc, count...)
// })
// })
// c.emit(OpLoad, count...)
// c.emit(OpEnd)
//
//default:
// panic(fmt.Sprintf("unknown builtin %v", node.Name))
//}
}
//func (c *compiler) emitLoop(body func()) []byte {
// i := c.makeConstant("i")
// size := c.makeConstant("size")
// array := c.makeConstant("array")
//
// c.emit(OpLen)
// c.emit(OpStore, size...)
// c.emit(OpStore, array...)
// c.emitPush(0)
// c.emit(OpStore, i...)
//
// cond := len(c.bytecode)
// c.emit(OpLoad, i...)
// c.emit(OpLoad, size...)
// c.emit(OpLess)
// end := c.emit(OpJumpIfFalse, c.placeholder()...)
// c.emit(OpPop)
//
// body()
//
// c.emit(OpInc, i...)
// c.emit(OpJumpBackward, c.calcBackwardJump(cond)...)
//
// c.patchJump(end)
// c.emit(OpPop)
//
// return size
//}
func (c *compiler) ClosureNode(node *ast.ClosureNode) interface{} {
return c.compile(node.Node)
}
func (c *compiler) PointerNode(node *ast.PointerNode) interface{} {
panic("unsupported pointer node")
//c.emit(OpLoad, c.makeConstant("array")...)
//c.emit(OpLoad, c.makeConstant("i")...)
//c.emit(OpIndex)
}
func (c *compiler) ConditionalNode(node *ast.ConditionalNode) interface{} {
panic("unsupported conditional node")
//c.compile(node.Cond)
//otherwise := c.emit(OpJumpIfFalse, c.placeholder()...)
//
//c.emit(OpPop)
//c.compile(node.Exp1)
//end := c.emit(OpJump, c.placeholder()...)
//
//c.patchJump(otherwise)
//c.emit(OpPop)
//c.compile(node.Exp2)
//
//c.patchJump(end)
}
func (c *compiler) VariableDeclaratorNode(node *ast.VariableDeclaratorNode) int {
panic("unsupported variable declarator node ")
}
func (c *compiler) ArrayNode(node *ast.ArrayNode) interface{} {
panic("unsupported array node")
//for _, node := range node.Nodes {
// c.compile(node)
//}
//
//c.emitPush(len(node.Nodes))
//c.emit(OpArray)
}
func (c *compiler) MapNode(node *ast.MapNode) interface{} {
panic("unsupported map node")
//for _, pair := range node.Pairs {
// c.compile(pair)
//}
//
//c.emitPush(len(node.Pairs))
//c.emit(OpMap)
}
func (c *compiler) PairNode(node *ast.PairNode) interface{} {
panic("unsupported pair node")
//c.compile(node.Key)
//c.compile(node.Value)
}