Skip to content
Snippets Groups Projects
Commit e8011e6c authored by Semyon Krestyaninov's avatar Semyon Krestyaninov :dog2:
Browse files

wip

parent 2fcf77f2
No related branches found
No related tags found
No related merge requests found
package expr
import (
"fmt"
"slices"
"github.com/expr-lang/expr/ast"
"github.com/expr-lang/expr/file"
"github.com/expr-lang/expr/parser/operator"
)
type PruneidentWalker struct {
idents map[string]struct{}
pruned map[file.Location]struct{}
scopes []scope
}
func NewPruneidentWalker(idents []string) *PruneidentWalker {
w := &PruneidentWalker{
idents: make(map[string]struct{}),
}
for _, ident := range idents {
w.idents[ident] = struct{}{}
}
return w
}
type scope struct {
variable string
}
func (w *PruneidentWalker) Walk(node *ast.Node) {
w.scopes = []scope{}
w.pruned = make(map[file.Location]struct{})
w.walk(node)
}
func (w *PruneidentWalker) walk(node *ast.Node) {
if node == nil || *node == nil {
return
}
switch n := (*node).(type) {
case *ast.NilNode:
case *ast.IdentifierNode:
case *ast.IntegerNode:
case *ast.FloatNode:
case *ast.BoolNode:
case *ast.StringNode:
case *ast.ConstantNode:
case *ast.UnaryNode:
w.walk(&n.Node)
case *ast.BinaryNode:
w.walk(&n.Left)
w.walk(&n.Right)
case *ast.ChainNode:
w.walk(&n.Node)
case *ast.MemberNode:
w.walk(&n.Node)
w.walk(&n.Property)
case *ast.SliceNode:
w.walk(&n.Node)
if n.From != nil {
w.walk(&n.From)
}
if n.To != nil {
w.walk(&n.To)
}
case *ast.CallNode:
w.walk(&n.Callee)
for i := range n.Arguments {
w.walk(&n.Arguments[i])
}
case *ast.BuiltinNode:
for i := range n.Arguments {
w.walk(&n.Arguments[i])
}
case *ast.PredicateNode:
w.walk(&n.Node)
case *ast.PointerNode:
case *ast.VariableDeclaratorNode:
w.walk(&n.Value)
w.beginScope(n.Name)
w.walk(&n.Expr)
w.endScope()
case *ast.SequenceNode:
for i := range n.Nodes {
w.walk(&n.Nodes[i])
}
case *ast.ConditionalNode:
w.walk(&n.Cond)
w.walk(&n.Exp1)
w.walk(&n.Exp2)
case *ast.ArrayNode:
for i := range n.Nodes {
w.walk(&n.Nodes[i])
}
case *ast.MapNode:
for i := range n.Pairs {
w.walk(&n.Pairs[i])
}
case *ast.PairNode:
w.walk(&n.Key)
w.walk(&n.Value)
default:
panic(fmt.Sprintf("undefined node type (%T)", node))
}
w.visit(node)
}
func (w *PruneidentWalker) visit(node *ast.Node) {
switch n := (*node).(type) {
case *ast.IdentifierNode:
if w.mustPruned(n.Value) {
w.pruneNode(node)
}
case *ast.UnaryNode:
if w.prunedNode(n.Node) {
w.pruneNode(node)
}
case *ast.BinaryNode:
if n.Operator == "in" {
if in, ok := n.Right.(*ast.IdentifierNode); ok && in.Value == "$env" {
if sn, ok := n.Left.(*ast.StringNode); ok && w.mustPruned(sn.Value) {
w.pruneNode(node)
return
}
}
}
leftPruned := w.prunedNode(n.Left)
rightPruned := w.prunedNode(n.Right)
if operator.IsBoolean(n.Operator) {
switch {
case leftPruned && rightPruned:
w.pruneNode(node)
case leftPruned:
ast.Patch(node, n.Right)
case rightPruned:
ast.Patch(node, n.Left)
}
} else {
if leftPruned || rightPruned {
w.pruneNode(node)
}
}
case *ast.ChainNode:
if w.prunedNode(n.Node) {
w.pruneNode(node)
}
case *ast.MemberNode:
if in, ok := n.Node.(*ast.IdentifierNode); ok && in.Value == "$env" {
if sn, ok := n.Property.(*ast.StringNode); ok && w.mustPruned(sn.Value) {
w.pruneNode(node)
return
}
}
if w.prunedNode(n.Node) || w.prunedNode(n.Property) {
w.pruneNode(node)
}
case *ast.SliceNode:
if w.prunedNode(n.Node) || w.prunedNode(n.From) || w.prunedNode(n.To) {
w.pruneNode(node)
}
case *ast.CallNode:
if w.prunedNode(n.Callee) {
w.pruneNode(node)
}
for _, arg := range n.Arguments {
if w.prunedNode(arg) {
w.pruneNode(node)
}
}
case *ast.BuiltinNode:
for _, arg := range n.Arguments {
if w.prunedNode(arg) {
w.pruneNode(node)
}
}
case *ast.PredicateNode:
if w.prunedNode(n.Node) {
w.pruneNode(node)
}
case *ast.ConditionalNode:
if w.prunedNode(n.Cond) || w.prunedNode(n.Exp1) || w.prunedNode(n.Exp2) {
w.pruneNode(node)
}
case *ast.VariableDeclaratorNode:
if w.prunedNode(n.Value) || w.prunedNode(n.Expr) {
w.pruneNode(node)
}
case *ast.SequenceNode:
n.Nodes = slices.DeleteFunc(n.Nodes, w.prunedNode)
case *ast.ArrayNode:
n.Nodes = slices.DeleteFunc(n.Nodes, w.prunedNode)
case *ast.MapNode:
n.Pairs = slices.DeleteFunc(n.Pairs, w.prunedNode)
case *ast.PairNode:
if w.prunedNode(n.Key) || w.prunedNode(n.Value) {
w.pruneNode(node)
}
}
}
func (w *PruneidentWalker) mustPruned(ident string) bool {
_, exists := w.idents[ident]
return exists && !w.scoped(ident)
}
func (w *PruneidentWalker) pruneNode(node *ast.Node) {
if node == nil || *node == nil {
return
}
prune := &ast.NilNode{}
ast.Patch(node, prune)
(*node).SetType(prune.Type())
w.pruned[prune.Location()] = struct{}{}
}
func (w *PruneidentWalker) prunedNode(node ast.Node) bool {
if node == nil {
return false
}
if n, ok := node.(*ast.NilNode); ok {
_, exists := w.pruned[n.Location()]
return exists
}
return false
}
func (w *PruneidentWalker) beginScope(variable string) {
w.scopes = append(w.scopes, scope{variable: variable})
}
func (w *PruneidentWalker) endScope() {
w.scopes = w.scopes[:len(w.scopes)-1]
}
func (w *PruneidentWalker) scoped(variable string) bool {
for i := len(w.scopes) - 1; i >= 0; i-- {
if w.scopes[i].variable == variable {
return true
}
}
return false
}
package expr
import (
"fmt"
"testing"
"github.com/expr-lang/expr/parser"
"github.com/expr-lang/expr/parser/operator"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)
func TestPruneidentWalker_Operator(t *testing.T) {
t.Run("Binary", func(t *testing.T) {
for op := range operator.Binary {
input := fmt.Sprintf("foo %s bar", op)
if op == "|" {
input = fmt.Sprintf("foo | pipe()")
}
t.Run(input, func(t *testing.T) {
tree, err := parser.Parse(input)
require.NoError(t, err)
walker := NewPruneidentWalker([]string{"foo"})
walker.Walk(&tree.Node)
if operator.IsBoolean(op) {
assert.Equal(t, "bar", tree.Node.String())
} else {
assert.Equal(t, "nil", tree.Node.String())
}
})
}
})
t.Run("Unary", func(t *testing.T) {
for op := range operator.Unary {
input := fmt.Sprintf("%s foo", op)
t.Run(input, func(t *testing.T) {
tree, err := parser.Parse(input)
require.NoError(t, err)
walker := NewPruneidentWalker([]string{"foo"})
walker.Walk(&tree.Node)
assert.Equal(t, "nil", tree.Node.String())
})
}
})
}
func TestPruneidentWalker_Walk(t *testing.T) {
tests := []struct {
name string
input string
want string
}{
{
name: "member",
input: "foo.bar",
want: "nil",
},
{
name: "member",
input: "bar.foo",
want: "bar.foo",
},
{
name: "chain",
input: "foo?.bar",
want: "nil",
},
{
name: "chain",
input: "bar?.foo",
want: "bar?.foo",
},
{
name: "slice",
input: "foo[:]",
want: "nil",
},
{
name: "slice",
input: "bar[foo:]",
want: "nil",
},
{
name: "slice",
input: "bar[:foo]",
want: "nil",
},
{
name: "call",
input: "foo.bar()",
want: "nil",
},
{
name: "call",
input: "foo(bar)",
want: "nil",
},
{
name: "call",
input: "bar(foo)",
want: "nil",
},
{
name: "builtin",
input: "duration(foo)",
want: "nil",
},
{
name: "builtin",
input: "max(foo, bar)",
want: "nil",
},
{
name: "predicate",
input: "filter(foo, .bar == 1)",
want: "nil",
},
{
name: "predicate",
input: "filter(bar, .foo == 1)",
want: "filter(bar, .foo == 1)",
},
{
name: "predicate",
input: "filter(qux, .bar == foo)",
want: "nil",
},
{
name: "predicate",
input: `filter(qux, let foo = 1; .bar == foo)`,
want: "filter(qux, let foo = 1; .bar == foo)",
},
{
name: "conditional",
input: "foo ? bar : baz",
want: "nil",
},
{
name: "conditional",
input: "bar ? foo : baz",
want: "nil",
},
{
name: "conditional",
input: "bar ? baz : foo",
want: "nil",
},
{
name: "conditional",
input: "foo == bar ? baz : qux",
want: "nil",
},
{
name: "variable declarator",
input: "let bar = 1; let baz = foo; bar + baz",
want: "nil",
},
{
name: "variable declarator",
input: "let foo = bar; let baz = qux; foo + baz",
want: "let foo = bar; let baz = qux; foo + baz",
},
{
name: "variable declarator",
input: "let bar = foo; bar",
want: "nil",
},
{
name: "sequence",
input: "bar; foo; baz",
want: "bar; baz",
},
{
name: "array",
input: "[foo, bar]",
want: "[bar]",
},
{
name: "array",
input: "[foo]",
want: "[]",
},
{
name: "map",
input: "{foo: bar}",
want: "{foo: bar}",
},
{
name: "map",
input: "{bar: foo}",
want: "{}",
},
{
name: "combined",
input: "foo > 2 || bar == 3",
want: "bar == 3",
},
{
name: "combined",
input: "(foo > 2 || bar == 3) && foo % 2 == 0 || bar - 1 == 2",
want: "bar == 3 || bar - 1 == 2",
},
{
name: "combined",
input: "all(bar, .baz == qux || .baz == foo)",
want: "all(bar, .baz == qux)",
},
{
name: "$env",
input: "$env.foo",
want: "nil",
},
{
name: "$env",
input: `$env["foo"]`,
want: "nil",
},
{
name: "$env",
input: `"foo" in $env`,
want: "nil",
},
}
for _, tc := range tests {
t.Run(tc.name, func(t *testing.T) {
tree, err := parser.Parse(tc.input)
require.NoError(t, err)
walker := NewPruneidentWalker([]string{"foo"})
walker.Walk(&tree.Node)
assert.Equal(t, tc.want, tree.Node.String())
})
}
}
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment