Skip to content

Commit

Permalink
[autodiff] Handle multiple, mixed Independent Blocks (IBs) within mul…
Browse files Browse the repository at this point in the history
…ti-levels serial for-loops (#4523)

* collect all IBs, only preserve the outer most IB

* handle multiple IBs and mixed IB/Non-IBs

* fix typo and add test
  • Loading branch information
erizmr authored Mar 16, 2022
1 parent 7d1596a commit be2357e
Show file tree
Hide file tree
Showing 2 changed files with 260 additions and 14 deletions.
89 changes: 75 additions & 14 deletions taichi/transforms/auto_diff.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,9 @@
#include "taichi/ir/transforms.h"
#include "taichi/ir/visitors.h"

#include <typeinfo>
#include <algorithm>

TLANG_NAMESPACE_BEGIN
class IndependentBlocksJudger : public BasicStmtVisitor {
public:
Expand Down Expand Up @@ -78,6 +81,51 @@ class IndependentBlocksJudger : public BasicStmtVisitor {
bool is_inside_loop_ = false;
};

// Remove the duplicated IBs, remove blocks who are others' children because
// each block should only be processed once
class DuplicateIndependentBlocksCleaner : public BasicStmtVisitor {
public:
using BasicStmtVisitor::visit;

void check_children_ib(Block *target_block) {
// Remove the block if it is the child of the block being visiting
if (independent_blocks_cleaned_.find(target_block) !=
independent_blocks_cleaned_.end()) {
independent_blocks_cleaned_.erase(target_block);
}
}

void visit(StructForStmt *stmt) override {
check_children_ib(stmt->body.get());
stmt->body->accept(this);
}
void visit(RangeForStmt *stmt) override {
check_children_ib(stmt->body.get());
stmt->body->accept(this);
}

static std::set<Block *> run(
const std::vector<std::pair<int, Block *>> &raw_IBs) {
DuplicateIndependentBlocksCleaner cleaner;
// Remove duplicate IBs
for (auto const &item : raw_IBs) {
cleaner.independent_blocks_cleaned_.insert(item.second);
}
// No clean is needed if only one IB exists
if (cleaner.independent_blocks_cleaned_.size() > 1) {
// Check from the block with smallest depth, ensure no duplicate visit
// happens
for (const auto &block : cleaner.independent_blocks_cleaned_) {
block->accept(&cleaner);
}
}
return cleaner.independent_blocks_cleaned_;
}

private:
std::set<Block *> independent_blocks_cleaned_;
};

// Do automatic differentiation pass in the reverse order (reverse-mode AD)

// Independent Block (IB): blocks (i.e. loop bodies) whose iterations are
Expand Down Expand Up @@ -120,9 +168,21 @@ class IdentifyIndependentBlocks : public BasicStmtVisitor {
void visit_loop_body(Block *block) {
if (is_independent_block(block)) {
current_ib_ = block;
auto old_current_ib_ = current_ib_;
block->accept(this);
// Lower level block is not an IB, therefore store the current block as an
// IB
if (old_current_ib_ == current_ib_) {
independent_blocks_.push_back({depth_, current_ib_});
}
} else {
// No need to dive further
if (depth_ <= 1) {
TI_ASSERT(depth_ == 1);
// The top level block is already not an IB, store it
independent_blocks_.push_back({depth_ - 1, block});
} else {
independent_blocks_.push_back({depth_ - 1, block->parent_block()});
}
}
}

Expand All @@ -132,9 +192,6 @@ class IdentifyIndependentBlocks : public BasicStmtVisitor {
current_ib_ = stmt->body.get();
visit_loop_body(stmt->body.get());
depth_--;
if (depth_ == 0) {
independent_blocks_.push_back(current_ib_);
}
}

void visit(RangeForStmt *stmt) override {
Expand All @@ -144,12 +201,9 @@ class IdentifyIndependentBlocks : public BasicStmtVisitor {
depth_++;
visit_loop_body(stmt->body.get());
depth_--;
if (depth_ == 0) {
independent_blocks_.push_back(current_ib_);
}
}

static std::vector<Block *> run(IRNode *root) {
static std::set<Block *> run(IRNode *root) {
IdentifyIndependentBlocks pass;
Block *block = root->as<Block>();
bool has_for = false;
Expand All @@ -160,16 +214,23 @@ class IdentifyIndependentBlocks : public BasicStmtVisitor {
}
if (!has_for) {
// The whole block is an IB
pass.independent_blocks_.push_back(block);
pass.independent_blocks_.push_back({0, block});
} else {
root->accept(&pass);
}
// Sort the IBs by their depth from shallow to deep
std::sort(pass.independent_blocks_.begin(), pass.independent_blocks_.end(),
[](const std::pair<int, Block *> &a,
const std::pair<int, Block *> &b) -> bool {
return a.first < b.first;
});

TI_ASSERT(!pass.independent_blocks_.empty());
return pass.independent_blocks_;
return DuplicateIndependentBlocksCleaner::run(pass.independent_blocks_);
}

private:
std::vector<Block *> independent_blocks_;
std::vector<std::pair<int, Block *>> independent_blocks_;
int depth_{0};
Block *current_ib_{nullptr};
};
Expand Down Expand Up @@ -393,7 +454,7 @@ class ReverseOuterLoops : public BasicStmtVisitor {
using BasicStmtVisitor::visit;

private:
ReverseOuterLoops(const std::vector<Block *> &IB) : loop_depth_(0), ib_(IB) {
ReverseOuterLoops(const std::set<Block *> &IB) : loop_depth_(0), ib_(IB) {
}

bool is_ib(Block *block) const {
Expand All @@ -418,10 +479,10 @@ class ReverseOuterLoops : public BasicStmtVisitor {
}

int loop_depth_;
std::vector<Block *> ib_;
std::set<Block *> ib_;

public:
static void run(IRNode *root, const std::vector<Block *> &IB) {
static void run(IRNode *root, const std::set<Block *> &IB) {
ReverseOuterLoops pass(IB);
root->accept(&pass);
}
Expand Down
185 changes: 185 additions & 0 deletions tests/python/test_ad_for.py
Original file line number Diff line number Diff line change
Expand Up @@ -775,3 +775,188 @@ def test_large_loop():

assert loss[None] == 1e7
assert x.grad[None] == 1e7


@test_utils.test(require=ti.extension.adstack)
def test_multiple_ib():
x = ti.field(float, (), needs_grad=True)
y = ti.field(float, (), needs_grad=True)

@ti.kernel
def compute_y():
for j in range(2):
for i in range(3):
y[None] += x[None]
for i in range(3):
y[None] += x[None]

x[None] = 1.0
with ti.Tape(y):
compute_y()

assert y[None] == 12.0
assert x.grad[None] == 12.0


@test_utils.test(require=ti.extension.adstack)
def test_multiple_ib_multiple_outermost():
x = ti.field(float, (), needs_grad=True)
y = ti.field(float, (), needs_grad=True)

@ti.kernel
def compute_y():
for j in range(2):
for i in range(3):
y[None] += x[None]
for i in range(3):
y[None] += x[None]
for j in range(2):
for i in range(3):
y[None] += x[None]
for i in range(3):
y[None] += x[None]

x[None] = 1.0
with ti.Tape(y):
compute_y()

assert y[None] == 24.0
assert x.grad[None] == 24.0


@test_utils.test(require=ti.extension.adstack)
def test_multiple_ib_multiple_outermost_mixed():
x = ti.field(float, (), needs_grad=True)
y = ti.field(float, (), needs_grad=True)

@ti.kernel
def compute_y():
for j in range(2):
for i in range(3):
y[None] += x[None]
for i in range(3):
y[None] += x[None]
for j in range(2):
for i in range(3):
y[None] += x[None]
for i in range(3):
y[None] += x[None]
for ii in range(3):
y[None] += x[None]

x[None] = 1.0
with ti.Tape(y):
compute_y()

assert y[None] == 42.0
assert x.grad[None] == 42.0


@test_utils.test(require=ti.extension.adstack)
def test_multiple_ib_mixed():
x = ti.field(float, (), needs_grad=True)
y = ti.field(float, (), needs_grad=True)

@ti.kernel
def compute_y():
for j in range(2):
for i in range(3):
y[None] += x[None]
for i in range(3):
y[None] += x[None]
for k in range(2):
y[None] += x[None]
for i in range(3):
y[None] += x[None]

x[None] = 1.0
with ti.Tape(y):
compute_y()

assert y[None] == 30.0
assert x.grad[None] == 30.0


@test_utils.test(require=ti.extension.adstack)
def test_multiple_ib_deeper():
x = ti.field(float, (), needs_grad=True)
y = ti.field(float, (), needs_grad=True)

@ti.kernel
def compute_y():
for j in range(2):
for i in range(3):
y[None] += x[None]
for i in range(3):
for ii in range(2):
y[None] += x[None]
for i in range(3):
for ii in range(2):
for iii in range(2):
y[None] += x[None]

x[None] = 1.0
with ti.Tape(y):
compute_y()

assert y[None] == 42.0
assert x.grad[None] == 42.0


@test_utils.test(require=ti.extension.adstack)
def test_multiple_ib_deeper_non_scalar():
N = 10
x = ti.field(float, shape=N, needs_grad=True)
y = ti.field(float, shape=N, needs_grad=True)

@ti.kernel
def compute_y():
for j in range(N):
for i in range(j):
y[j] += x[j]
for i in range(3):
for ii in range(j):
y[j] += x[j]
for i in range(3):
for ii in range(2):
for iii in range(j):
y[j] += x[j]

x.fill(1.0)
for i in range(N):
y.grad[i] = 1.0
compute_y()
compute_y.grad()
for i in range(N):
assert y[i] == i * 10.0
assert x.grad[i] == i * 10.0


@test_utils.test(require=ti.extension.adstack)
def test_multiple_ib_inner_mixed():
x = ti.field(float, (), needs_grad=True)
y = ti.field(float, (), needs_grad=True)

@ti.kernel
def compute_y():
for j in range(2):
for i in range(3):
y[None] += x[None]
for i in range(3):
for ii in range(2):
y[None] += x[None]
for iii in range(2):
y[None] += x[None]
for iiii in range(2):
y[None] += x[None]
for i in range(3):
for ii in range(2):
for iii in range(2):
y[None] += x[None]

x[None] = 1.0
with ti.Tape(y):
compute_y()

assert y[None] == 78.0
assert x.grad[None] == 78.0

0 comments on commit be2357e

Please sign in to comment.