Skip to content

Commit

Permalink
Support backward final hook (#44686)
Browse files Browse the repository at this point in the history
  • Loading branch information
JiabinYang authored Jul 29, 2022
1 parent b7496bc commit 8c43c0f
Show file tree
Hide file tree
Showing 17 changed files with 259 additions and 130 deletions.
2 changes: 1 addition & 1 deletion paddle/fluid/distributed/collective/reducer.cc
Original file line number Diff line number Diff line change
Expand Up @@ -321,7 +321,7 @@ EagerReducer::EagerReducer(
const auto &accumulation_grad_node =
std::dynamic_pointer_cast<egr::GradNodeAccumulation>(grad_node);
accumulation_grad_node->RegisterReduceHook(
std::make_shared<egr::CppTensorVoidHook>(reduce_hook));
std::make_shared<egr::CppVoidHook>(reduce_hook));

gradnode_index_map_[grad_node.get()] = global_var_index;
}
Expand Down
2 changes: 1 addition & 1 deletion paddle/fluid/eager/accumulation/accumulation_node.cc
Original file line number Diff line number Diff line change
Expand Up @@ -127,7 +127,7 @@ GradNodeAccumulation::operator()(
}

void GradNodeAccumulation::RegisterReduceHook(
std::shared_ptr<TensorVoidHook>&& hook) {
std::shared_ptr<VoidHook>&& hook) {
reduce_hooks_.emplace_back(std::move(hook));
}

Expand Down
4 changes: 2 additions & 2 deletions paddle/fluid/eager/accumulation/accumulation_node.h
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,7 @@ class GradNodeAccumulation : public GradNodeBase {
/**
* Register ReduceHook
* **/
void RegisterReduceHook(std::shared_ptr<TensorVoidHook>&& hook);
void RegisterReduceHook(std::shared_ptr<VoidHook>&& hook);

/**
* Apply ReduceHook here
Expand All @@ -70,7 +70,7 @@ class GradNodeAccumulation : public GradNodeBase {
// TODO(Jiabin): remove this when we make our clear gradient really cleared;
bool is_fake_empty_ = {false};
std::weak_ptr<paddle::experimental::Tensor> weak_grad_;
std::vector<std::shared_ptr<TensorVoidHook>> reduce_hooks_;
std::vector<std::shared_ptr<VoidHook>> reduce_hooks_;
std::function<paddle::experimental::Tensor(
const paddle::experimental::Tensor&)>
retain_grad_hook_;
Expand Down
19 changes: 18 additions & 1 deletion paddle/fluid/eager/api/utils/global_utils.h
Original file line number Diff line number Diff line change
Expand Up @@ -18,11 +18,11 @@
#include <atomic>
#include <memory>

#include "paddle/fluid/eager/hooks.h"
#include "paddle/fluid/eager/type_defs.h"
#include "paddle/fluid/imperative/tracer.h"
#include "paddle/phi/api/ext/op_meta_info.h"
#include "paddle/utils/small_vector.h"

namespace egr {
class UniqueNameGenerator {
public:
Expand Down Expand Up @@ -85,6 +85,22 @@ class Controller {
GetCustomEdgesSlotMap() {
return custom_edges_slot_map_;
}
// For Cpp Hook
void RegisterBackwardFinalHook(const std::function<void()>& call_back) {
VLOG(6) << "RegisterBackwardFinalHook";
final_backward_hooks_.emplace_back(
std::make_shared<CppVoidHook>(std::move(call_back)));
VLOG(6) << "Size: " << final_backward_hooks_.size();
}
// For Python hook
void RegisterBackwardFinalHook(const std::shared_ptr<VoidHook>& call_back) {
final_backward_hooks_.emplace_back(call_back);
}
const std::vector<std::shared_ptr<VoidHook>>& FinalBackwardHooks() const {
return final_backward_hooks_;
}

void ClearFinalBackwardHooks() { final_backward_hooks_.clear(); }

private:
Controller() = default;
Expand All @@ -98,6 +114,7 @@ class Controller {
std::unordered_map<std::string,
std::vector<std::vector<std::unordered_map<int, int>>>>
custom_edges_slot_map_;
std::vector<std::shared_ptr<VoidHook>> final_backward_hooks_;
DISABLE_COPY_AND_ASSIGN(Controller);
};

Expand Down
18 changes: 12 additions & 6 deletions paddle/fluid/eager/api/utils/hook_utils.cc
Original file line number Diff line number Diff line change
Expand Up @@ -25,17 +25,20 @@ namespace egr_utils_api {

int64_t RegisterGradientHookForTensor(
const paddle::experimental::Tensor& tensor,
std::shared_ptr<egr::TensorHook>&& hook) {
const std::function<paddle::experimental::Tensor(
const paddle::experimental::Tensor&)>& hook) {
// Find grad_node and out_rank from AutogradMeta
std::shared_ptr<GradNodeBase> grad_node = EagerUtils::grad_node(tensor);
auto rank_info = EagerUtils::unsafe_autograd_meta(tensor)->OutRankInfo();

return grad_node->RegisterGradientHook(
rank_info.first, rank_info.second, std::move(hook));
rank_info.first,
rank_info.second,
std::move(std::make_shared<CppTensorHook>(hook)));
}

void RegisterReduceHookForTensor(const paddle::experimental::Tensor& tensor,
std::shared_ptr<egr::TensorVoidHook>&& hook) {
const std::function<void()>& hook) {
if (IsLeafTensor(tensor)) {
VLOG(6) << "Register ReduceHook for leaf tensor";
std::shared_ptr<GradNodeBase> grad_node = EagerUtils::grad_node(tensor);
Expand All @@ -46,7 +49,8 @@ void RegisterReduceHookForTensor(const paddle::experimental::Tensor& tensor,
"with type: GradNodeAccumulation"));
auto accumulation_grad_node =
std::dynamic_pointer_cast<GradNodeAccumulation>(grad_node);
accumulation_grad_node->RegisterReduceHook(std::move(hook));
accumulation_grad_node->RegisterReduceHook(
std::move(std::make_shared<CppVoidHook>(hook)));
} else {
PADDLE_THROW(paddle::platform::errors::Fatal(
"Only can register reduce hook for leaf Tensor."));
Expand Down Expand Up @@ -90,10 +94,12 @@ void RetainGradForTensor(const paddle::experimental::Tensor& tensor) {
};

// Append to GradientHooks
RegisterGradientHookForTensor(tensor,
std::make_shared<egr::CppTensorHook>(hook));
RegisterGradientHookForTensor(tensor, hook);
}
}

void RegisterBackwardFinalHook(const std::function<void()>& hook) {
Controller::Instance().RegisterBackwardFinalHook(hook);
}
} // namespace egr_utils_api
} // namespace egr
7 changes: 5 additions & 2 deletions paddle/fluid/eager/api/utils/hook_utils.h
Original file line number Diff line number Diff line change
Expand Up @@ -23,11 +23,14 @@ namespace egr_utils_api {

int64_t RegisterGradientHookForTensor(
const paddle::experimental::Tensor& tensor,
std::shared_ptr<egr::TensorHook>&& hook);
const std::function<paddle::experimental::Tensor(
const paddle::experimental::Tensor&)>& hook);

void RegisterReduceHookForTensor(const paddle::experimental::Tensor& tensor,
std::shared_ptr<egr::TensorVoidHook>&& hook);
const std::function<void()>& hook);
void RetainGradForTensor(const paddle::experimental::Tensor& tensor);

void RegisterBackwardFinalHook(const std::function<void()>& hook);

} // namespace egr_utils_api
} // namespace egr
6 changes: 6 additions & 0 deletions paddle/fluid/eager/backward.cc
Original file line number Diff line number Diff line change
Expand Up @@ -371,6 +371,12 @@ std::vector<paddle::experimental::Tensor> RunBackward(
}
}

VLOG(6) << "Run Backward Final hook size: "
<< egr::Controller::Instance().FinalBackwardHooks().size();
for (auto& hook : egr::Controller::Instance().FinalBackwardHooks()) {
(*hook)();
}
egr::Controller::Instance().ClearFinalBackwardHooks();
if (!is_general_grad) return {};
return GeneralGrad::Instance().GetResults(inputs, allow_unused, create_graph);
}
Expand Down
13 changes: 7 additions & 6 deletions paddle/fluid/eager/hooks.h
Original file line number Diff line number Diff line change
Expand Up @@ -29,16 +29,16 @@ class TensorHook {
const paddle::experimental::Tensor& var) = 0;
};

class TensorVoidHook {
class VoidHook {
public:
virtual ~TensorVoidHook() = default;
virtual ~VoidHook() = default;
virtual void operator()() = 0;
};

class CppTensorHook : public TensorHook {
public:
explicit CppTensorHook(std::function<paddle::experimental::Tensor(
const paddle::experimental::Tensor&)>&& fn)
explicit CppTensorHook(const std::function<paddle::experimental::Tensor(
const paddle::experimental::Tensor&)>& fn)
: fn_(std::move(fn)) {}

paddle::experimental::Tensor operator()(
Expand All @@ -52,13 +52,14 @@ class CppTensorHook : public TensorHook {
fn_;
};

class CppTensorVoidHook : public TensorVoidHook {
class CppVoidHook : public VoidHook {
public:
explicit CppTensorVoidHook(std::function<void()>&& fn) : fn_(std::move(fn)) {}
explicit CppVoidHook(const std::function<void()>& fn) : fn_(std::move(fn)) {}

void operator()() override { return fn_(); }

private:
std::function<void()> fn_;
};

} // namespace egr
Original file line number Diff line number Diff line change
Expand Up @@ -328,8 +328,7 @@ TEST(AccumulationNode, Tensor) {
VLOG(6) << "Running Reduce Hook";
};

node->RegisterReduceHook(
std::make_shared<egr::CppTensorVoidHook>(reduce_hook_1));
node->RegisterReduceHook(std::make_shared<egr::CppVoidHook>(reduce_hook_1));

// operator()
paddle::experimental::Tensor _ret = node->operator()(et0_vec)[0][0];
Expand All @@ -354,8 +353,7 @@ TEST(AccumulationNode, Tensor) {
ret_et0_ptr[0] = 100.0; // set to 100.0
VLOG(6) << "Running Reduce Hook";
};
node->RegisterReduceHook(
std::make_shared<egr::CppTensorVoidHook>(reduce_hook_2));
node->RegisterReduceHook(std::make_shared<egr::CppVoidHook>(reduce_hook_2));
node->ApplyReduceHooks();

// Check ApplyReduceHooks result
Expand Down
12 changes: 6 additions & 6 deletions paddle/fluid/eager/tests/task_tests/fwd_bwd_joint_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -256,26 +256,26 @@ TEST(FwdBwdJoint, GradientHook) {
true /*bias_after_scale*/,
true /*trace_backward*/);
egr_utils_api::RetainGradForTensor(out0); // hook: +5
egr_utils_api::RegisterGradientHookForTensor(
out0, std::make_shared<egr::CppTensorHook>(hook_function)); // hook: +5
egr_utils_api::RegisterGradientHookForTensor(out0,
hook_function); // hook: +5

// Run Forward Node 1
float scale1 = 5.0;
float bias1 = 10.0;
paddle::experimental::Tensor out1 = egr::scale(
out0, scale1, bias1, true /*bias_after_scale*/, true /*trace_backward*/);
egr_utils_api::RetainGradForTensor(out1); // hook: +5
egr_utils_api::RegisterGradientHookForTensor(
out1, std::make_shared<egr::CppTensorHook>(hook_function)); // hook: +5
egr_utils_api::RegisterGradientHookForTensor(out1,
hook_function); // hook: +5

// Run Forward Node 2
float scale2 = 10.0;
float bias2 = 20.0;
paddle::experimental::Tensor out2 = egr::scale(
out0, scale2, bias2, true /*bias_after_scale*/, true /*trace_backward*/);
egr_utils_api::RetainGradForTensor(out2); // hook: +5
egr_utils_api::RegisterGradientHookForTensor(
out2, std::make_shared<egr::CppTensorHook>(hook_function)); // hook: +5
egr_utils_api::RegisterGradientHookForTensor(out2,
hook_function); // hook: +5

// 4. Run Backward
std::vector<paddle::experimental::Tensor> outs = {out1, out2};
Expand Down
12 changes: 4 additions & 8 deletions paddle/fluid/eager/tests/task_tests/hook_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -95,8 +95,7 @@ TEST(RetainGrad, HookBeforeRetainGrad) {
std::dynamic_pointer_cast<paddle::experimental::AbstractAutogradMeta>(
auto_grad_meta));

egr_utils_api::RegisterGradientHookForTensor(
target_tensor, std::make_shared<egr::CppTensorHook>(hook_function));
egr_utils_api::RegisterGradientHookForTensor(target_tensor, hook_function);
egr_utils_api::RetainGradForTensor(
target_tensor); // result: 1.0 + 3.0 = 4.0
egr_utils_api::RetainGradForTensor(
Expand All @@ -122,8 +121,7 @@ TEST(RetainGrad, HookBeforeRetainGrad) {
std::dynamic_pointer_cast<paddle::experimental::AbstractAutogradMeta>(
tmp_tensor0.mutable_autograd_meta()));

egr_utils_api::RegisterGradientHookForTensor(
leaf_tensor, std::make_shared<egr::CppTensorHook>(hook_function));
egr_utils_api::RegisterGradientHookForTensor(leaf_tensor, hook_function);
egr_utils_api::RetainGradForTensor(
leaf_tensor); // result: 4.0*5.0 + 3.0 = 23.0
}
Expand Down Expand Up @@ -173,8 +171,7 @@ TEST(RetainGrad, HookAfterRetainGrad) {
auto_grad_meta));

egr_utils_api::RetainGradForTensor(target_tensor); // result: 1.0
egr_utils_api::RegisterGradientHookForTensor(
target_tensor, std::make_shared<egr::CppTensorHook>(hook_function));
egr_utils_api::RegisterGradientHookForTensor(target_tensor, hook_function);
}

// Retain Grad for leaf tensor1
Expand All @@ -193,8 +190,7 @@ TEST(RetainGrad, HookAfterRetainGrad) {
std::dynamic_pointer_cast<paddle::experimental::AbstractAutogradMeta>(
tmp_tensor0.mutable_autograd_meta()));

egr_utils_api::RegisterGradientHookForTensor(
leaf_tensor, std::make_shared<egr::CppTensorHook>(hook_function));
egr_utils_api::RegisterGradientHookForTensor(leaf_tensor, hook_function);
}

Backward(target_tensors, {});
Expand Down
Loading

0 comments on commit 8c43c0f

Please sign in to comment.