diff --git a/taichi/transforms/auto_diff.cpp b/taichi/transforms/auto_diff.cpp index 869df578185ed..81c50adf19a57 100644 --- a/taichi/transforms/auto_diff.cpp +++ b/taichi/transforms/auto_diff.cpp @@ -4,6 +4,9 @@ #include "taichi/ir/transforms.h" #include "taichi/ir/visitors.h" +#include +#include + TLANG_NAMESPACE_BEGIN class IndependentBlocksJudger : public BasicStmtVisitor { public: @@ -78,6 +81,51 @@ class IndependentBlocksJudger : public BasicStmtVisitor { bool is_inside_loop_ = false; }; +// Remove the duplicated IBs, remove blocks who are others' children because +// each block should only be processed once +class DuplicateIndependentBlocksCleaner : public BasicStmtVisitor { + public: + using BasicStmtVisitor::visit; + + void check_children_ib(Block *target_block) { + // Remove the block if it is the child of the block being visiting + if (independent_blocks_cleaned_.find(target_block) != + independent_blocks_cleaned_.end()) { + independent_blocks_cleaned_.erase(target_block); + } + } + + void visit(StructForStmt *stmt) override { + check_children_ib(stmt->body.get()); + stmt->body->accept(this); + } + void visit(RangeForStmt *stmt) override { + check_children_ib(stmt->body.get()); + stmt->body->accept(this); + } + + static std::set run( + const std::vector> &raw_IBs) { + DuplicateIndependentBlocksCleaner cleaner; + // Remove duplicate IBs + for (auto const &item : raw_IBs) { + cleaner.independent_blocks_cleaned_.insert(item.second); + } + // No clean is needed if only one IB exists + if (cleaner.independent_blocks_cleaned_.size() > 1) { + // Check from the block with smallest depth, ensure no duplicate visit + // happens + for (const auto &block : cleaner.independent_blocks_cleaned_) { + block->accept(&cleaner); + } + } + return cleaner.independent_blocks_cleaned_; + } + + private: + std::set independent_blocks_cleaned_; +}; + // Do automatic differentiation pass in the reverse order (reverse-mode AD) // Independent Block (IB): blocks (i.e. loop bodies) whose iterations are @@ -120,9 +168,21 @@ class IdentifyIndependentBlocks : public BasicStmtVisitor { void visit_loop_body(Block *block) { if (is_independent_block(block)) { current_ib_ = block; + auto old_current_ib_ = current_ib_; block->accept(this); + // Lower level block is not an IB, therefore store the current block as an + // IB + if (old_current_ib_ == current_ib_) { + independent_blocks_.push_back({depth_, current_ib_}); + } } else { - // No need to dive further + if (depth_ <= 1) { + TI_ASSERT(depth_ == 1); + // The top level block is already not an IB, store it + independent_blocks_.push_back({depth_ - 1, block}); + } else { + independent_blocks_.push_back({depth_ - 1, block->parent_block()}); + } } } @@ -132,9 +192,6 @@ class IdentifyIndependentBlocks : public BasicStmtVisitor { current_ib_ = stmt->body.get(); visit_loop_body(stmt->body.get()); depth_--; - if (depth_ == 0) { - independent_blocks_.push_back(current_ib_); - } } void visit(RangeForStmt *stmt) override { @@ -144,12 +201,9 @@ class IdentifyIndependentBlocks : public BasicStmtVisitor { depth_++; visit_loop_body(stmt->body.get()); depth_--; - if (depth_ == 0) { - independent_blocks_.push_back(current_ib_); - } } - static std::vector run(IRNode *root) { + static std::set run(IRNode *root) { IdentifyIndependentBlocks pass; Block *block = root->as(); bool has_for = false; @@ -160,16 +214,23 @@ class IdentifyIndependentBlocks : public BasicStmtVisitor { } if (!has_for) { // The whole block is an IB - pass.independent_blocks_.push_back(block); + pass.independent_blocks_.push_back({0, block}); } else { root->accept(&pass); } + // Sort the IBs by their depth from shallow to deep + std::sort(pass.independent_blocks_.begin(), pass.independent_blocks_.end(), + [](const std::pair &a, + const std::pair &b) -> bool { + return a.first < b.first; + }); + TI_ASSERT(!pass.independent_blocks_.empty()); - return pass.independent_blocks_; + return DuplicateIndependentBlocksCleaner::run(pass.independent_blocks_); } private: - std::vector independent_blocks_; + std::vector> independent_blocks_; int depth_{0}; Block *current_ib_{nullptr}; }; @@ -393,7 +454,7 @@ class ReverseOuterLoops : public BasicStmtVisitor { using BasicStmtVisitor::visit; private: - ReverseOuterLoops(const std::vector &IB) : loop_depth_(0), ib_(IB) { + ReverseOuterLoops(const std::set &IB) : loop_depth_(0), ib_(IB) { } bool is_ib(Block *block) const { @@ -418,10 +479,10 @@ class ReverseOuterLoops : public BasicStmtVisitor { } int loop_depth_; - std::vector ib_; + std::set ib_; public: - static void run(IRNode *root, const std::vector &IB) { + static void run(IRNode *root, const std::set &IB) { ReverseOuterLoops pass(IB); root->accept(&pass); } diff --git a/tests/python/test_ad_for.py b/tests/python/test_ad_for.py index b0b3cb6d3bb10..8d1cc36b0debf 100644 --- a/tests/python/test_ad_for.py +++ b/tests/python/test_ad_for.py @@ -775,3 +775,188 @@ def test_large_loop(): assert loss[None] == 1e7 assert x.grad[None] == 1e7 + + +@test_utils.test(require=ti.extension.adstack) +def test_multiple_ib(): + x = ti.field(float, (), needs_grad=True) + y = ti.field(float, (), needs_grad=True) + + @ti.kernel + def compute_y(): + for j in range(2): + for i in range(3): + y[None] += x[None] + for i in range(3): + y[None] += x[None] + + x[None] = 1.0 + with ti.Tape(y): + compute_y() + + assert y[None] == 12.0 + assert x.grad[None] == 12.0 + + +@test_utils.test(require=ti.extension.adstack) +def test_multiple_ib_multiple_outermost(): + x = ti.field(float, (), needs_grad=True) + y = ti.field(float, (), needs_grad=True) + + @ti.kernel + def compute_y(): + for j in range(2): + for i in range(3): + y[None] += x[None] + for i in range(3): + y[None] += x[None] + for j in range(2): + for i in range(3): + y[None] += x[None] + for i in range(3): + y[None] += x[None] + + x[None] = 1.0 + with ti.Tape(y): + compute_y() + + assert y[None] == 24.0 + assert x.grad[None] == 24.0 + + +@test_utils.test(require=ti.extension.adstack) +def test_multiple_ib_multiple_outermost_mixed(): + x = ti.field(float, (), needs_grad=True) + y = ti.field(float, (), needs_grad=True) + + @ti.kernel + def compute_y(): + for j in range(2): + for i in range(3): + y[None] += x[None] + for i in range(3): + y[None] += x[None] + for j in range(2): + for i in range(3): + y[None] += x[None] + for i in range(3): + y[None] += x[None] + for ii in range(3): + y[None] += x[None] + + x[None] = 1.0 + with ti.Tape(y): + compute_y() + + assert y[None] == 42.0 + assert x.grad[None] == 42.0 + + +@test_utils.test(require=ti.extension.adstack) +def test_multiple_ib_mixed(): + x = ti.field(float, (), needs_grad=True) + y = ti.field(float, (), needs_grad=True) + + @ti.kernel + def compute_y(): + for j in range(2): + for i in range(3): + y[None] += x[None] + for i in range(3): + y[None] += x[None] + for k in range(2): + y[None] += x[None] + for i in range(3): + y[None] += x[None] + + x[None] = 1.0 + with ti.Tape(y): + compute_y() + + assert y[None] == 30.0 + assert x.grad[None] == 30.0 + + +@test_utils.test(require=ti.extension.adstack) +def test_multiple_ib_deeper(): + x = ti.field(float, (), needs_grad=True) + y = ti.field(float, (), needs_grad=True) + + @ti.kernel + def compute_y(): + for j in range(2): + for i in range(3): + y[None] += x[None] + for i in range(3): + for ii in range(2): + y[None] += x[None] + for i in range(3): + for ii in range(2): + for iii in range(2): + y[None] += x[None] + + x[None] = 1.0 + with ti.Tape(y): + compute_y() + + assert y[None] == 42.0 + assert x.grad[None] == 42.0 + + +@test_utils.test(require=ti.extension.adstack) +def test_multiple_ib_deeper_non_scalar(): + N = 10 + x = ti.field(float, shape=N, needs_grad=True) + y = ti.field(float, shape=N, needs_grad=True) + + @ti.kernel + def compute_y(): + for j in range(N): + for i in range(j): + y[j] += x[j] + for i in range(3): + for ii in range(j): + y[j] += x[j] + for i in range(3): + for ii in range(2): + for iii in range(j): + y[j] += x[j] + + x.fill(1.0) + for i in range(N): + y.grad[i] = 1.0 + compute_y() + compute_y.grad() + for i in range(N): + assert y[i] == i * 10.0 + assert x.grad[i] == i * 10.0 + + +@test_utils.test(require=ti.extension.adstack) +def test_multiple_ib_inner_mixed(): + x = ti.field(float, (), needs_grad=True) + y = ti.field(float, (), needs_grad=True) + + @ti.kernel + def compute_y(): + for j in range(2): + for i in range(3): + y[None] += x[None] + for i in range(3): + for ii in range(2): + y[None] += x[None] + for iii in range(2): + y[None] += x[None] + for iiii in range(2): + y[None] += x[None] + for i in range(3): + for ii in range(2): + for iii in range(2): + y[None] += x[None] + + x[None] = 1.0 + with ti.Tape(y): + compute_y() + + assert y[None] == 78.0 + assert x.grad[None] == 78.0