diff --git a/compiler/compiler.go b/compiler/compiler.go index ac11805e..5c7f3603 100644 --- a/compiler/compiler.go +++ b/compiler/compiler.go @@ -45,6 +45,9 @@ func Compile(tree *parser.Tree, config *conf.Config) (program *Program, err erro case reflect.Float64: c.emit(OpCast, 2) } + if c.config.Optimize { + c.optimize() + } } program = NewProgram( @@ -1050,6 +1053,19 @@ func (c *compiler) derefInNeeded(node ast.Node) { } } +func (c *compiler) optimize() { + for i, op := range c.bytecode { + switch op { + case OpJumpIfTrue, OpJumpIfFalse, OpJumpIfNil, OpJumpIfNotNil: + target := i + c.arguments[i] + 1 + for target < len(c.bytecode) && c.bytecode[target] == op { + target += c.arguments[target] + 1 + } + c.arguments[i] = target - i - 1 + } + } +} + func kind(node ast.Node) reflect.Kind { t := node.Type() if t == nil { diff --git a/compiler/compiler_test.go b/compiler/compiler_test.go index ed11a9dd..b7bbfcfb 100644 --- a/compiler/compiler_test.go +++ b/compiler/compiler_test.go @@ -384,3 +384,175 @@ func TestCompile_OpCallFast(t *testing.T) { require.Equal(t, vm.OpCallFast, program.Bytecode[4]) require.Equal(t, 3, program.Arguments[4]) } + +func TestCompile_optimizes_jumps(t *testing.T) { + env := map[string]any{ + "a": true, + "b": true, + "c": true, + "d": true, + } + type op struct { + Bytecode vm.Opcode + Arg int + } + tests := []struct { + code string + want []op + }{ + { + `let foo = true; let bar = false; let baz = true; foo || bar || baz`, + []op{ + {vm.OpTrue, 0}, + {vm.OpStore, 0}, + {vm.OpFalse, 0}, + {vm.OpStore, 1}, + {vm.OpTrue, 0}, + {vm.OpStore, 2}, + {vm.OpLoadVar, 0}, + {vm.OpJumpIfTrue, 5}, + {vm.OpPop, 0}, + {vm.OpLoadVar, 1}, + {vm.OpJumpIfTrue, 2}, + {vm.OpPop, 0}, + {vm.OpLoadVar, 2}, + }, + }, + { + `a && b && c`, + []op{ + {vm.OpLoadFast, 0}, + {vm.OpJumpIfFalse, 5}, + {vm.OpPop, 0}, + {vm.OpLoadFast, 1}, + {vm.OpJumpIfFalse, 2}, + {vm.OpPop, 0}, + {vm.OpLoadFast, 2}, + }, + }, + { + `a && b || c && d`, + []op{ + {vm.OpLoadFast, 0}, + {vm.OpJumpIfFalse, 2}, + {vm.OpPop, 0}, + {vm.OpLoadFast, 1}, + {vm.OpJumpIfTrue, 5}, + {vm.OpPop, 0}, + {vm.OpLoadFast, 2}, + {vm.OpJumpIfFalse, 2}, + {vm.OpPop, 0}, + {vm.OpLoadFast, 3}, + }, + }, + { + `filter([1, 2, 3, 4, 5], # > 3 && # != 4 && # != 5)`, + []op{ + {vm.OpPush, 0}, + {vm.OpBegin, 0}, + {vm.OpJumpIfEnd, 26}, + {vm.OpPointer, 0}, + {vm.OpDeref, 0}, + {vm.OpPush, 1}, + {vm.OpMore, 0}, + {vm.OpJumpIfFalse, 18}, + {vm.OpPop, 0}, + {vm.OpPointer, 0}, + {vm.OpDeref, 0}, + {vm.OpPush, 2}, + {vm.OpEqual, 0}, + {vm.OpNot, 0}, + {vm.OpJumpIfFalse, 11}, + {vm.OpPop, 0}, + {vm.OpPointer, 0}, + {vm.OpDeref, 0}, + {vm.OpPush, 3}, + {vm.OpEqual, 0}, + {vm.OpNot, 0}, + {vm.OpJumpIfFalse, 4}, + {vm.OpPop, 0}, + {vm.OpIncrementCount, 0}, + {vm.OpPointer, 0}, + {vm.OpJump, 1}, + {vm.OpPop, 0}, + {vm.OpIncrementIndex, 0}, + {vm.OpJumpBackward, 27}, + {vm.OpGetCount, 0}, + {vm.OpEnd, 0}, + {vm.OpArray, 0}, + }, + }, + { + `let foo = true; let bar = false; let baz = true; foo && bar || baz`, + []op{ + {vm.OpTrue, 0}, + {vm.OpStore, 0}, + {vm.OpFalse, 0}, + {vm.OpStore, 1}, + {vm.OpTrue, 0}, + {vm.OpStore, 2}, + {vm.OpLoadVar, 0}, + {vm.OpJumpIfFalse, 2}, + {vm.OpPop, 0}, + {vm.OpLoadVar, 1}, + {vm.OpJumpIfTrue, 2}, + {vm.OpPop, 0}, + {vm.OpLoadVar, 2}, + }, + }, + { + `true ?? nil ?? nil ?? nil`, + []op{ + {vm.OpTrue, 0}, + {vm.OpJumpIfNotNil, 8}, + {vm.OpPop, 0}, + {vm.OpNil, 0}, + {vm.OpJumpIfNotNil, 5}, + {vm.OpPop, 0}, + {vm.OpNil, 0}, + {vm.OpJumpIfNotNil, 2}, + {vm.OpPop, 0}, + {vm.OpNil, 0}, + }, + }, + { + `let m = {"a": {"b": {"c": 1}}}; m?.a?.b?.c`, + []op{ + {vm.OpPush, 0}, + {vm.OpPush, 1}, + {vm.OpPush, 2}, + {vm.OpPush, 3}, + {vm.OpPush, 3}, + {vm.OpMap, 0}, + {vm.OpPush, 3}, + {vm.OpMap, 0}, + {vm.OpPush, 3}, + {vm.OpMap, 0}, + {vm.OpStore, 0}, + {vm.OpLoadVar, 0}, + {vm.OpJumpIfNil, 8}, + {vm.OpPush, 0}, + {vm.OpFetch, 0}, + {vm.OpJumpIfNil, 5}, + {vm.OpPush, 1}, + {vm.OpFetch, 0}, + {vm.OpJumpIfNil, 2}, + {vm.OpPush, 2}, + {vm.OpFetch, 0}, + }, + }, + } + + for _, test := range tests { + t.Run(test.code, func(t *testing.T) { + program, err := expr.Compile(test.code, expr.Env(env)) + require.NoError(t, err) + + require.Equal(t, len(test.want), len(program.Bytecode)) + for i, op := range test.want { + require.Equal(t, op.Bytecode, program.Bytecode[i]) + require.Equalf(t, op.Arg, program.Arguments[i], "at %d", i) + } + }) + } +}