package expr

import (
	"context"
	"fmt"
	"regexp"
	"strings"

	"github.com/antonmedv/expr"
	"github.com/antonmedv/expr/ast"
	compiler2 "github.com/antonmedv/expr/compiler"
	"github.com/antonmedv/expr/conf"
	"github.com/antonmedv/expr/parser"
	"go.mongodb.org/mongo-driver/bson"
)

var geoTypes = map[string]string{
	"box":     "$box",
	"polygon": "$polygon",
}

func ConvertToMongo(ctx context.Context, exp string, env map[string]interface{}, identifierRenameFn func(string) string, ops ...expr.Option) (b bson.M, err error) {
	if exp == "" {
		return bson.M{}, nil
	}
	tree, err := parser.Parse(exp)
	if err != nil {
		return nil, err
	}
	return convertToMongo(ctx, tree, env, identifierRenameFn, ops...)
}

func convertToMongo(ctx context.Context, tree *parser.Tree, env map[string]interface{}, identifierRenameFn func(string) string, ops ...expr.Option) (b bson.M, err error) {
	defer func() {
		if r := recover(); r != nil {
			err = fmt.Errorf("%v", r)
		}
	}()

	if env == nil {
		env = make(map[string]interface{})
	}

	env[EnvContextKey] = ctx
	config := GetDefaultConfig(env)

	for _, op := range ops {
		op(config)
	}

	env = config.Env.(map[string]interface{})

	if len(config.Visitors) >= 0 {
		for _, v := range config.Visitors {
			ast.Walk(&tree.Node, v)
		}
	}

	c := &compiler{tree: tree, env: env, config: config, identifierRenameFn: identifierRenameFn}
	v := c.compile(tree.Node)
	switch e := v.(type) {
	case bson.M:
		b = e
	case string:
		b = bson.M{"$text": bson.M{"$search": e}}
	default:
		err = fmt.Errorf("invalid expression")
	}
	return
}

type compiler struct {
	env                map[string]interface{}
	tree               *parser.Tree
	config             *conf.Config
	identifierRenameFn func(string) string
}

func (c *compiler) eval(node ast.Node) interface{} {
	t := &parser.Tree{
		Node:   node,
		Source: c.tree.Source,
	}
	prg, err := compiler2.Compile(t, c.config)
	if err != nil {
		panic(fmt.Sprintf("compile error %s", err.Error()))
	}
	ret, err := expr.Run(prg, c.env)
	if err != nil {
		panic(fmt.Sprintf("execution error %s", err.Error()))
	}
	return ret
}

func (c *compiler) compile(node ast.Node) interface{} {
	switch n := node.(type) {
	case *ast.NilNode:
		return c.NilNode(n)
	case *ast.IdentifierNode:
		return c.IdentifierNode(n)
	case *ast.IntegerNode:
		return c.IntegerNode(n)
	case *ast.FloatNode:
		return c.FloatNode(n)
	case *ast.BoolNode:
		return c.BoolNode(n)
	case *ast.StringNode:
		return c.StringNode(n)
	case *ast.ConstantNode:
		return c.ConstantNode(n)
	case *ast.UnaryNode:
		return c.UnaryNode(n)
	case *ast.BinaryNode:
		return c.BinaryNode(n)
	case *ast.MatchesNode:
		return c.MatchesNode(n)
	case *ast.PropertyNode:
		return c.PropertyNode(n)
	case *ast.IndexNode:
		return c.IndexNode(n)
	case *ast.SliceNode:
		return c.SliceNode(n)
	case *ast.MethodNode:
		return c.MethodNode(n)
	case *ast.FunctionNode:
		return c.FunctionNode(n)
	case *ast.BuiltinNode:
		return c.BuiltinNode(n)
	case *ast.ClosureNode:
		return c.ClosureNode(n)
	case *ast.PointerNode:
		return c.PointerNode(n)
	case *ast.ConditionalNode:
		return c.ConditionalNode(n)
	case *ast.ArrayNode:
		return c.ArrayNode(n)
	case *ast.MapNode:
		return c.MapNode(n)
	case *ast.PairNode:
		return c.PairNode(n)
	default:
		panic(fmt.Sprintf("undefined node type (%T)", node))
	}
}

func (c *compiler) NilNode(node *ast.NilNode) interface{} {
	return nil
}

func (c *compiler) IdentifierNode(node *ast.IdentifierNode) string {
	identifier := node.Value
	if c.identifierRenameFn != nil {
		identifier = c.identifierRenameFn(identifier)
	}
	return identifier
}

func (c *compiler) IntegerNode(node *ast.IntegerNode) int {
	return node.Value
	//t := node.Type()
	//if t == nil {
	//	c.emitPush(node.Value)
	//	return
	//}
	//
	//switch t.Kind() {
	//case reflect.Float32:
	//	c.emitPush(float32(node.Value))
	//case reflect.Float64:
	//	c.emitPush(float64(node.Value))
	//
	//case reflect.Int:
	//	c.emitPush(int(node.Value))
	//case reflect.Int8:
	//	c.emitPush(int8(node.Value))
	//case reflect.Int16:
	//	c.emitPush(int16(node.Value))
	//case reflect.Int32:
	//	c.emitPush(int32(node.Value))
	//case reflect.Int64:
	//	c.emitPush(int64(node.Value))
	//
	//case reflect.Uint:
	//	c.emitPush(uint(node.Value))
	//case reflect.Uint8:
	//	c.emitPush(uint8(node.Value))
	//case reflect.Uint16:
	//	c.emitPush(uint16(node.Value))
	//case reflect.Uint32:
	//	c.emitPush(uint32(node.Value))
	//case reflect.Uint64:
	//	c.emitPush(uint64(node.Value))
	//
	//default:
	//	c.emitPush(node.Value)
	//}
}

func (c *compiler) FloatNode(node *ast.FloatNode) float64 {
	return node.Value
}

func (c *compiler) BoolNode(node *ast.BoolNode) bool {
	return node.Value
}

func (c *compiler) StringNode(node *ast.StringNode) string {
	return node.Value
}

func (c *compiler) ConstantNode(node *ast.ConstantNode) interface{} {
	return node.Value
}

func (c *compiler) UnaryNode(node *ast.UnaryNode) interface{} {
	op := c.compile(node.Node)

	switch node.Operator {

	case "!", "not":
		return bson.M{"$not": op}
	default:
		panic(fmt.Sprintf("unknown operator (%v)", node.Operator))
	}
}

func (c *compiler) identifier(node ast.Node) string {
	switch l := node.(type) {
	case *ast.PropertyNode:
		return c.PropertyNode(l)
	case *ast.IdentifierNode:
		return c.IdentifierNode(l)
	}
	panic(fmt.Sprintf("incorrect identifier node (%v) ", ast.Dump(node)))
}

func (c *compiler) BinaryNode(node *ast.BinaryNode) interface{} {
	switch node.Operator {
	case "==":
		return bson.M{c.identifier(node.Left): c.eval(node.Right)}

	case "!=":
		return bson.M{c.identifier(node.Left): bson.M{"$ne": c.eval(node.Right)}}

	case "or", "||":
		return bson.M{"$or": bson.A{c.compile(node.Left), c.compile(node.Right)}}

	case "and", "&&":
		return bson.M{"$and": bson.A{c.compile(node.Left), c.compile(node.Right)}}

	case "in":
		return bson.M{c.identifier(node.Left): bson.M{"$in": c.eval(node.Right)}}

	case "not in":
		return bson.M{c.identifier(node.Left): bson.M{"$nin": c.eval(node.Right)}}

	case "<":
		return bson.M{c.identifier(node.Left): bson.M{"$lt": c.eval(node.Right)}}

	case ">":
		return bson.M{c.identifier(node.Left): bson.M{"$gt": c.eval(node.Right)}}

	case "<=":
		return bson.M{c.identifier(node.Left): bson.M{"$lte": c.eval(node.Right)}}

	case ">=":
		return bson.M{c.identifier(node.Left): bson.M{"$gte": c.eval(node.Right)}}

	//case "+":
	//	c.compile(node.Left)
	//	c.compile(node.Right)
	//	c.emit(OpAdd)
	//
	//case "-":
	//	c.compile(node.Left)
	//	c.compile(node.Right)
	//	c.emit(OpSubtract)
	//
	//case "*":
	//	c.compile(node.Left)
	//	c.compile(node.Right)
	//	c.emit(OpMultiply)
	//
	//case "/":
	//	c.compile(node.Left)
	//	c.compile(node.Right)
	//	c.emit(OpDivide)
	//
	//case "%":
	//	c.compile(node.Left)
	//	c.compile(node.Right)
	//	c.emit(OpModulo)
	//
	//case "**":
	//	c.compile(node.Left)
	//	c.compile(node.Right)
	//	c.emit(OpExponent)

	case "contains":
		value, ok := c.eval(node.Right).(string)
		if !ok {
			panic("contains requires string as an argument")
		}

		return bson.M{c.identifier(node.Left): bson.M{"$regex": regexp.QuoteMeta(value)}}

	case "startsWith":
		value, ok := c.eval(node.Right).(string)
		if !ok {
			panic("startsWith requires string as an argument")
		}

		return bson.M{c.identifier(node.Left): bson.M{"$regex": fmt.Sprintf("^%s.*", regexp.QuoteMeta(value))}}

	case "endsWith":
		value, ok := c.eval(node.Right).(string)
		if !ok {
			panic("endsWith requires string as an argument")
		}

		return bson.M{c.identifier(node.Left): bson.M{"$regex": fmt.Sprintf(".*%s$", regexp.QuoteMeta(value))}}

	case "..":
		panic("unsupported range")

	default:
		panic(fmt.Sprintf("unknown operator (%v)", node.Operator))

	}
}

func (c *compiler) MatchesNode(node *ast.MatchesNode) interface{} {
	panic("unsupported match node")
	//if node.Regexp != nil {
	//	c.compile(node.Left)
	//	c.emit(OpMatchesConst, c.makeConstant(node.Regexp)...)
	//	return
	//}
	//c.compile(node.Left)
	//c.compile(node.Right)
	//c.emit(OpMatches)
}

func (c *compiler) PropertyNode(node *ast.PropertyNode) string {
	v := c.compile(node.Node)
	if val, ok := v.(string); ok {
		return fmt.Sprintf("%s.%s", val, node.Property)
	}
	panic(fmt.Sprintf("unsupported property for %v", ast.Dump(node.Node)))
}

func (c *compiler) IndexNode(node *ast.IndexNode) string {
	return fmt.Sprintf("{index-%v}", c.compile(node.Index))
}

func (c *compiler) SliceNode(node *ast.SliceNode) interface{} {
	panic("unsupported slice node")
}

func (c *compiler) MethodNode(node *ast.MethodNode) interface{} {
	panic("unsupported method node")
	//c.compile(node.Node)
	//for _, arg := range node.Arguments {
	//	c.compile(arg)
	//}
	//c.emit(OpMethod, c.makeConstant(Call{Name: node.Method, Size: len(node.Arguments)})...)
}

func (c *compiler) FunctionNode(node *ast.FunctionNode) interface{} {
	switch node.Name {
	case "search", "q":
		val := c.compile(node.Arguments[0])
		return bson.M{"$text": bson.M{"$search": val}}
	case "near":
		v := c.identifier(node.Arguments[0])
		point := c.eval(node.Arguments[1])
		distance := c.eval(node.Arguments[2])

		if v == "" {
			panic("incorrect argument, empty field name")
		}
		if !strings.HasSuffix(v, ".geometry") {
			v += ".geometry"
		}

		if _, ok := point.([]interface{}); !ok {
			panic("incorrect argument, point must coordinates array")
		}

		return bson.M{
			v: bson.M{"$near": bson.D{{Key: "$geometry", Value: map[string]interface{}{"type": "Point", "coordinates": point}}, {Key: "$maxDistance", Value: distance}}},
		}
	case "within":
		v := c.identifier(node.Arguments[0])
		t := c.eval(node.Arguments[1])
		points := c.eval(node.Arguments[2])

		if v == "" {
			panic("incorrect argument, empty field name")
		}

		if !strings.HasSuffix(v, ".geometry") {
			v += ".geometry"
		}

		typ, ok := t.(string)
		if !ok {
			panic("incorrect argument, geotype must be string")
		}
		typ, ok = geoTypes[typ]
		if !ok {
			panic("incorrect geotype value")
		}

		if _, ok := points.([]interface{}); !ok {
			panic("incorrect argument, points must be array of coordinates")
		}

		return bson.M{
			v: bson.M{"$geoWithin": bson.M{typ: points}},
		}
	case "In":
		fields := c.identifier(node.Arguments[0])
		if fields == "" {
			panic("incorrect argument, empty field name")
		}
		array, ok := c.eval(node.Arguments[1]).([]interface{})
		if !ok {
			array = []interface{}{c.eval(node.Arguments[1])}
		}

		return bson.M{fields: bson.M{"$in": array}}

	case "icontains":
		v := c.identifier(node.Arguments[0])
		t, ok := c.eval(node.Arguments[1]).(string)
		if !ok {
			panic("icontains requires string as an argument")
		}
		return bson.M{v: bson.M{"$regex": regexp.QuoteMeta(t), "$options": "i"}}

	case "istartsWith":
		v := c.identifier(node.Arguments[0])
		t, ok := c.eval(node.Arguments[1]).(string)
		if !ok {
			panic("istartsWith requires string as an argument")
		}
		return bson.M{v: bson.M{"$regex": fmt.Sprintf("^%s.*", regexp.QuoteMeta(t)), "$options": "i"}}

	case "iendsWith":
		v := c.identifier(node.Arguments[0])
		t, ok := c.eval(node.Arguments[1]).(string)
		if !ok {
			panic("iendsWith requires string as an argument")
		}
		return bson.M{v: bson.M{"$regex": fmt.Sprintf(".*%s$", regexp.QuoteMeta(t)), "$options": "i"}}
	}
	panic("unsupported function")
	//for _, arg := range node.Arguments {
	//	c.compile(arg)
	//}
	//op := OpCall
	//if node.Fast {
	//	op = OpCallFast
	//}
	//c.emit(op, c.makeConstant(Call{Name: node.Name, Size: len(node.Arguments)})...)
}

func (c *compiler) BuiltinNode(node *ast.BuiltinNode) interface{} {
	panic("unsupported builin node")
	//switch node.Name {
	//case "len":
	//	c.compile(node.Arguments[0])
	//	c.emit(OpLen)
	//	c.emit(OpRot)
	//	c.emit(OpPop)
	//
	//case "all":
	//	c.compile(node.Arguments[0])
	//	c.emit(OpBegin)
	//	var loopBreak int
	//	c.emitLoop(func() {
	//		c.compile(node.Arguments[1])
	//		loopBreak = c.emit(OpJumpIfFalse, c.placeholder()...)
	//		c.emit(OpPop)
	//	})
	//	c.emit(OpTrue)
	//	c.patchJump(loopBreak)
	//	c.emit(OpEnd)
	//
	//case "none":
	//	c.compile(node.Arguments[0])
	//	c.emit(OpBegin)
	//	var loopBreak int
	//	c.emitLoop(func() {
	//		c.compile(node.Arguments[1])
	//		c.emit(OpNot)
	//		loopBreak = c.emit(OpJumpIfFalse, c.placeholder()...)
	//		c.emit(OpPop)
	//	})
	//	c.emit(OpTrue)
	//	c.patchJump(loopBreak)
	//	c.emit(OpEnd)
	//
	//case "any":
	//	c.compile(node.Arguments[0])
	//	c.emit(OpBegin)
	//	var loopBreak int
	//	c.emitLoop(func() {
	//		c.compile(node.Arguments[1])
	//		loopBreak = c.emit(OpJumpIfTrue, c.placeholder()...)
	//		c.emit(OpPop)
	//	})
	//	c.emit(OpFalse)
	//	c.patchJump(loopBreak)
	//	c.emit(OpEnd)
	//
	//case "one":
	//	count := c.makeConstant("count")
	//	c.compile(node.Arguments[0])
	//	c.emit(OpBegin)
	//	c.emitPush(0)
	//	c.emit(OpStore, count...)
	//	c.emitLoop(func() {
	//		c.compile(node.Arguments[1])
	//		c.emitCond(func() {
	//			c.emit(OpInc, count...)
	//		})
	//	})
	//	c.emit(OpLoad, count...)
	//	c.emitPush(1)
	//	c.emit(OpEqual)
	//	c.emit(OpEnd)
	//
	//case "filter":
	//	count := c.makeConstant("count")
	//	c.compile(node.Arguments[0])
	//	c.emit(OpBegin)
	//	c.emitPush(0)
	//	c.emit(OpStore, count...)
	//	c.emitLoop(func() {
	//		c.compile(node.Arguments[1])
	//		c.emitCond(func() {
	//			c.emit(OpInc, count...)
	//
	//			c.emit(OpLoad, c.makeConstant("array")...)
	//			c.emit(OpLoad, c.makeConstant("i")...)
	//			c.emit(OpIndex)
	//		})
	//	})
	//	c.emit(OpLoad, count...)
	//	c.emit(OpEnd)
	//	c.emit(OpArray)
	//
	//case "map":
	//	c.compile(node.Arguments[0])
	//	c.emit(OpBegin)
	//	size := c.emitLoop(func() {
	//		c.compile(node.Arguments[1])
	//	})
	//	c.emit(OpLoad, size...)
	//	c.emit(OpEnd)
	//	c.emit(OpArray)
	//
	//case "count":
	//	count := c.makeConstant("count")
	//	c.compile(node.Arguments[0])
	//	c.emit(OpBegin)
	//	c.emitPush(0)
	//	c.emit(OpStore, count...)
	//	c.emitLoop(func() {
	//		c.compile(node.Arguments[1])
	//		c.emitCond(func() {
	//			c.emit(OpInc, count...)
	//		})
	//	})
	//	c.emit(OpLoad, count...)
	//	c.emit(OpEnd)
	//
	//default:
	//	panic(fmt.Sprintf("unknown builtin %v", node.Name))
	//}
}

//func (c *compiler) emitLoop(body func()) []byte {
//	i := c.makeConstant("i")
//	size := c.makeConstant("size")
//	array := c.makeConstant("array")
//
//	c.emit(OpLen)
//	c.emit(OpStore, size...)
//	c.emit(OpStore, array...)
//	c.emitPush(0)
//	c.emit(OpStore, i...)
//
//	cond := len(c.bytecode)
//	c.emit(OpLoad, i...)
//	c.emit(OpLoad, size...)
//	c.emit(OpLess)
//	end := c.emit(OpJumpIfFalse, c.placeholder()...)
//	c.emit(OpPop)
//
//	body()
//
//	c.emit(OpInc, i...)
//	c.emit(OpJumpBackward, c.calcBackwardJump(cond)...)
//
//	c.patchJump(end)
//	c.emit(OpPop)
//
//	return size
//}

func (c *compiler) ClosureNode(node *ast.ClosureNode) interface{} {
	return c.compile(node.Node)
}

func (c *compiler) PointerNode(node *ast.PointerNode) interface{} {
	panic("unsupported pointer node")
	//c.emit(OpLoad, c.makeConstant("array")...)
	//c.emit(OpLoad, c.makeConstant("i")...)
	//c.emit(OpIndex)
}

func (c *compiler) ConditionalNode(node *ast.ConditionalNode) interface{} {
	panic("unsupported conditional node")
	//c.compile(node.Cond)
	//otherwise := c.emit(OpJumpIfFalse, c.placeholder()...)
	//
	//c.emit(OpPop)
	//c.compile(node.Exp1)
	//end := c.emit(OpJump, c.placeholder()...)
	//
	//c.patchJump(otherwise)
	//c.emit(OpPop)
	//c.compile(node.Exp2)
	//
	//c.patchJump(end)
}

func (c *compiler) ArrayNode(node *ast.ArrayNode) interface{} {
	panic("unsupported array node")
	//for _, node := range node.Nodes {
	//	c.compile(node)
	//}
	//
	//c.emitPush(len(node.Nodes))
	//c.emit(OpArray)
}

func (c *compiler) MapNode(node *ast.MapNode) interface{} {
	panic("unsupported map node")
	//for _, pair := range node.Pairs {
	//	c.compile(pair)
	//}
	//
	//c.emitPush(len(node.Pairs))
	//c.emit(OpMap)
}

func (c *compiler) PairNode(node *ast.PairNode) interface{} {
	panic("unsupported pair node")
	//c.compile(node.Key)
	//c.compile(node.Value)
}
