diff --git a/cpp/tests/ast/transform_tests.cpp b/cpp/tests/ast/transform_tests.cpp index 624a781c5b9..0476cb17693 100644 --- a/cpp/tests/ast/transform_tests.cpp +++ b/cpp/tests/ast/transform_tests.cpp @@ -343,6 +343,80 @@ TEST_F(TransformTest, ImbalancedTreeArithmeticDeep) CUDF_TEST_EXPECT_COLUMNS_EQUAL(expected, result->view(), verbosity); } +TEST_F(TransformTest, DeeplyNestedArithmeticLogicalExpression) +{ + // Test logic for deeply nested arithmetic and logical expressions. + constexpr int64_t left_depth_level = 100; + constexpr int64_t right_depth_level = 75; + + auto generate_ast_expr = [](int64_t depth_level, + cudf::ast::column_reference col_ref, + cudf::ast::ast_operator root_operator, + cudf::ast::ast_operator arithmetic_operator, + bool nested_left_tree) { + // Note that a std::list is required here because of its guarantees against reference + // invalidation when items are added or removed. References to items in a std::vector are not + // safe if the vector must re-allocate. + auto expressions = std::list(); + + auto op = arithmetic_operator; + expressions.push_back(cudf::ast::operation(op, col_ref, col_ref)); + + for (int64_t i = 0; i < depth_level - 1; i++) { + if (i == depth_level - 2) { + op = root_operator; + } else { + op = arithmetic_operator; + } + if (nested_left_tree) { + expressions.push_back(cudf::ast::operation(op, expressions.back(), col_ref)); + } else { + expressions.push_back(cudf::ast::operation(op, col_ref, expressions.back())); + } + } + return expressions; + }; + + auto c_0 = column_wrapper{0, 0, 0}; + auto c_1 = column_wrapper{0, 0, 0}; + auto table = cudf::table_view{{c_0, c_1}}; + + auto col_ref_0 = cudf::ast::column_reference(0); + auto col_ref_1 = cudf::ast::column_reference(1); + + auto left_expression = generate_ast_expr(left_depth_level, + col_ref_0, + cudf::ast::ast_operator::LESS, + cudf::ast::ast_operator::ADD, + false); + auto right_expression = generate_ast_expr(right_depth_level, + col_ref_1, + cudf::ast::ast_operator::EQUAL, + cudf::ast::ast_operator::SUB, + true); + + auto expression_tree = cudf::ast::operation( + cudf::ast::ast_operator::LOGICAL_OR, left_expression.back(), right_expression.back()); + + // Expression: + // OR(<(+(+(+(+($0, $0), $0), $0), $0), $0), ==($1, -($1, -($1, -($1, -($1, $1)))))) + // ... + // OR(<($L, $0), ==($1, $R)) + // true + // + // Breakdown: + // - Left Operand ($L): (+(+(+(+($0, $0), $0), $0), $0), $0) + // - Right Operand ($R): -($1, -($1, -($1, -($1, $1)))) + // Explanation: + // If all $1 values and $R values are zeros, the result is true because of the equality check + // combined with the OR operator in OR(<($L, $0), ==($1, $R)). + + auto result = cudf::compute_column(table, expression_tree); + auto expected = column_wrapper{true, true, true}; + + CUDF_TEST_EXPECT_COLUMNS_EQUAL(expected, result->view(), verbosity); +} + TEST_F(TransformTest, MultiLevelTreeComparator) { auto c_0 = column_wrapper{3, 20, 1, 50};