Skip to content

Commit

Permalink
[autodiff] Extract shared components for reverse and forward mode (#5088
Browse files Browse the repository at this point in the history
)

extract shared components for reverse and forward mode
  • Loading branch information
erizmr authored Jun 2, 2022
1 parent 8dc598d commit 52a7cd8
Showing 1 changed file with 107 additions and 80 deletions.
187 changes: 107 additions & 80 deletions taichi/transforms/auto_diff.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -488,10 +488,9 @@ class ReverseOuterLoops : public BasicStmtVisitor {
}
};

// Generate the adjoint version of an independent block

class MakeAdjoint : public IRVisitor {
private:
// Base class for both reverse (make adjoint) and forward (make dual) mode
class ADTransform : public IRVisitor {
protected:
Stmt *constant(float32 x) {
return insert<ConstStmt>(TypedConstant(x));
}
Expand Down Expand Up @@ -556,6 +555,107 @@ class MakeAdjoint : public IRVisitor {
}

public:
virtual Stmt *insert_grad_stmt(std::unique_ptr<Stmt> &&stmt) = 0;

template <typename T, typename... Args>
Stmt *insert(Args &&...args) {
return insert_grad_stmt(Stmt::make<T>(args...));
}

void visit(AllocaStmt *alloca) override {
// do nothing.
}

void visit(AdStackAllocaStmt *alloca) override {
// do nothing.
}

void visit(ArgLoadStmt *stmt) override {
// do nothing.
}

void visit(LoopIndexStmt *stmt) override {
// do nothing.
}

void visit(PrintStmt *print_stmt) override {
// do nothing
}

void visit(ConstStmt *const_stmt) override {
// do nothing
}

void visit(WhileControlStmt *stmt) override {
TI_NOT_IMPLEMENTED
}

void visit(ContinueStmt *stmt) override {
TI_NOT_IMPLEMENTED;
}

void visit(WhileStmt *stmt) override {
TI_NOT_IMPLEMENTED
}

void visit(GlobalPtrStmt *stmt) override {
// do nothing
}

Stmt *load(Stmt *alloc) {
TI_ASSERT(alloc != nullptr);
if (alloc->is<AllocaStmt>()) {
return insert<LocalLoadStmt>(LocalAddress(alloc, 0));
} else {
// non alloca
return alloc;
}
}

bool gradients_stopped(GlobalLoadStmt *stmt, SNode *snode) {
for (auto block = stmt->parent; block; block = block->parent_block()) {
for (auto s : block->stop_gradients) {
if (s == snode) {
return true;
}
}
}
return false;
}

void visit(ElementShuffleStmt *stmt) override {
TI_NOT_IMPLEMENTED
}

void visit(AssertStmt *stmt) override {
// do nothing
}

void visit(RangeAssumptionStmt *stmt) override {
// do nothing
}

void visit(LinearizeStmt *stmt) override {
// do nothing
}

void visit(BitExtractStmt *stmt) override {
// do nothing
}

void visit(IntegerOffsetStmt *stmt) override {
// do nothing
}

void visit(RandStmt *stmt) override {
TI_ERROR("RandStmt not supported in AutoDiff for now.");
}
};

// Generate the adjoint version of an independent block
class MakeAdjoint : public ADTransform {
public:
using ADTransform::visit;
Block *current_block;
Block *alloca_block;
// Backup the forward pass (the forward pass might be modified during the
Expand Down Expand Up @@ -593,17 +693,12 @@ class MakeAdjoint : public IRVisitor {
}
}

Stmt *insert_back(std::unique_ptr<Stmt> &&stmt) {
Stmt *insert_grad_stmt(std::unique_ptr<Stmt> &&stmt) override {
auto ptr = stmt.get();
current_block->insert(std::move(stmt), -1);
return ptr;
}

template <typename T, typename... Args>
Stmt *insert(Args &&...args) {
return insert_back(Stmt::make<T>(args...));
}

// Accumulate [value] to the adjoint of [primal]
void accumulate(Stmt *primal, Stmt *value) {
auto alloca_ = adjoint(primal);
Expand Down Expand Up @@ -675,22 +770,6 @@ class MakeAdjoint : public IRVisitor {
return adjoint_stmt[stmt];
}

void visit(AllocaStmt *alloca) override {
// do nothing.
}

void visit(AdStackAllocaStmt *alloca) override {
// do nothing.
}

void visit(ArgLoadStmt *stmt) override {
// do nothing.
}

void visit(LoopIndexStmt *stmt) override {
// do nothing.
}

void visit(UnaryOpStmt *stmt) override {
if (stmt->op_type == UnaryOpType::floor ||
stmt->op_type == UnaryOpType::ceil) {
Expand Down Expand Up @@ -827,34 +906,14 @@ class MakeAdjoint : public IRVisitor {
}
current_block = old_current_block;
}
insert_back(std::move(new_if));
}

void visit(PrintStmt *print_stmt) override {
// do nothing
}

void visit(ConstStmt *const_stmt) override {
// do nothing
}

void visit(WhileControlStmt *stmt) override {
TI_NOT_IMPLEMENTED
}

void visit(ContinueStmt *stmt) override {
TI_NOT_IMPLEMENTED;
}

void visit(WhileStmt *stmt) override {
TI_NOT_IMPLEMENTED
insert_grad_stmt(std::move(new_if));
}

void visit(RangeForStmt *for_stmt) override {
auto new_for = for_stmt->clone();
auto new_for_ptr = new_for->as<RangeForStmt>();
new_for_ptr->reversed = !new_for_ptr->reversed;
insert_back(std::move(new_for));
insert_grad_stmt(std::move(new_for));
const int len = new_for_ptr->body->size();

for (int i = 0; i < len; i++) {
Expand Down Expand Up @@ -889,10 +948,6 @@ class MakeAdjoint : public IRVisitor {
for_stmt->body->accept(this);
}

void visit(GlobalPtrStmt *stmt) override {
// do nothing
}

// Equivalent to AdStackLoadTopStmt when no stack is needed
void visit(LocalLoadStmt *stmt) override {
// TI_ASSERT(!needs_grad(stmt->ret_type));
Expand Down Expand Up @@ -999,34 +1054,6 @@ class MakeAdjoint : public IRVisitor {
}
stmt->parent->erase(stmt);
}

void visit(ElementShuffleStmt *stmt) override {
TI_NOT_IMPLEMENTED
}

void visit(AssertStmt *stmt) override {
// do nothing
}

void visit(RangeAssumptionStmt *stmt) override {
// do nothing
}

void visit(LinearizeStmt *stmt) override {
// do nothing
}

void visit(BitExtractStmt *stmt) override {
// do nothing
}

void visit(IntegerOffsetStmt *stmt) override {
// do nothing
}

void visit(RandStmt *stmt) override {
TI_ERROR("RandStmt not supported in AutoDiff for now.");
}
};

class BackupSSA : public BasicStmtVisitor {
Expand Down

0 comments on commit 52a7cd8

Please sign in to comment.